зеркало из https://github.com/microsoft/nni.git
Merge pull request #5036 from microsoft/promote-retiarii-to-nas
[DO NOT SQUASH] Promote retiarii to NAS
This commit is contained in:
Коммит
a0fd003671
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator
|
||||
from .trainer import CdartsTrainer
|
|
@ -1,143 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
|
||||
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
|
||||
from nni.algorithms.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order
|
||||
from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order
|
||||
from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order
|
||||
|
||||
|
||||
class RegularizedDartsMutator(DartsMutator):
|
||||
"""
|
||||
This is :class:`~nni.algorithms.nas.pytorch.darts.DartsMutator` basically, with two differences.
|
||||
|
||||
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
|
||||
forward pass and thus consumes no memory.
|
||||
|
||||
2. Regularization on choices, to prevent the mutator from overfitting on some choices.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Warnings
|
||||
--------
|
||||
Renamed :func:`~reset_with_loss` to return regularization loss on reset.
|
||||
"""
|
||||
raise ValueError("You should probably call `reset_with_loss`.")
|
||||
|
||||
def cut_choices(self, cut_num=2):
|
||||
"""
|
||||
Cut the choices with the smallest weights.
|
||||
``cut_num`` should be the accumulative number of cutting, e.g., if first time cutting
|
||||
is 2, the second time should be 4 to cut another two.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cut_num : int
|
||||
Number of choices to cut, so far.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
Though the parameters are set to :math:`-\infty` to be bypassed, they will still receive gradient of 0,
|
||||
which introduced ``nan`` problem when calling ``optimizer.step()``. To solve this issue, a simple way is to
|
||||
reset nan to :math:`-\infty` each time after the parameters are updated.
|
||||
"""
|
||||
# `cut_choices` is implemented but not used in current implementation of CdartsTrainer
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
_, idx = torch.topk(-self.choices[mutable.key], cut_num)
|
||||
with torch.no_grad():
|
||||
for i in idx:
|
||||
self.choices[mutable.key][i] = -float("inf")
|
||||
|
||||
def reset_with_loss(self):
|
||||
"""
|
||||
Resample and return loss. If loss is 0, to avoid device issue, it will return ``None``.
|
||||
|
||||
Currently loss penalty are proportional to the L1-norm of parameters corresponding
|
||||
to modules if their type name contains certain substrings. These substrings include: ``poolwithoutbn``,
|
||||
``identity``, ``dilconv``.
|
||||
"""
|
||||
self._cache, reg_loss = self.sample_search()
|
||||
return reg_loss
|
||||
|
||||
def sample_search(self):
|
||||
result = super().sample_search()
|
||||
loss = []
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
def need_reg(choice):
|
||||
return any(t in str(type(choice)).lower() for t in ["poolwithoutbn", "identity", "dilconv"])
|
||||
|
||||
for i, choice in enumerate(mutable.choices):
|
||||
if need_reg(choice):
|
||||
norm = torch.abs(self.choices[mutable.key][i])
|
||||
if norm < 1E10:
|
||||
loss.append(norm)
|
||||
if not loss:
|
||||
return result, None
|
||||
return result, sum(loss)
|
||||
|
||||
def export(self, logger=None):
|
||||
"""
|
||||
Export an architecture with logger. Genotype will be printed with logger.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A mapping from mutable keys to decisions.
|
||||
"""
|
||||
result = self.sample_final()
|
||||
if hasattr(self.model, "plot_genotype") and logger is not None:
|
||||
genotypes = self.model.plot_genotype(result, logger)
|
||||
return result, genotypes
|
||||
|
||||
|
||||
class RegularizedMutatorParallel(DistributedDataParallel):
|
||||
"""
|
||||
Parallelize :class:`~RegularizedDartsMutator`.
|
||||
|
||||
This makes :func:`~RegularizedDartsMutator.reset_with_loss` method parallelized,
|
||||
also allowing :func:`~RegularizedDartsMutator.cut_choices` and :func:`~RegularizedDartsMutator.export`
|
||||
to be easily accessible.
|
||||
"""
|
||||
def reset_with_loss(self):
|
||||
"""
|
||||
Parallelized :func:`~RegularizedDartsMutator.reset_with_loss`.
|
||||
"""
|
||||
result = self.module.reset_with_loss()
|
||||
self.callback_queued = False
|
||||
return result
|
||||
|
||||
def cut_choices(self, *args, **kwargs):
|
||||
"""
|
||||
Parallelized :func:`~RegularizedDartsMutator.cut_choices`.
|
||||
"""
|
||||
self.module.cut_choices(*args, **kwargs)
|
||||
|
||||
def export(self, logger):
|
||||
"""
|
||||
Parallelized :func:`~RegularizedDartsMutator.export`.
|
||||
"""
|
||||
return self.module.export(logger)
|
||||
|
||||
|
||||
class DartsDiscreteMutator(Mutator):
|
||||
"""
|
||||
A mutator that applies the final sampling result of a parent mutator on another model to train.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
The model to apply the mutator.
|
||||
parent_mutator : nni.nas.pytorch.mutator.Mutator
|
||||
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
|
||||
"""
|
||||
def __init__(self, model, parent_mutator):
|
||||
super().__init__(model)
|
||||
self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included
|
||||
|
||||
def sample_search(self):
|
||||
return self.parent_mutator.sample_final()
|
|
@ -1,275 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import apex # pylint: disable=import-error
|
||||
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
|
||||
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator # pylint: disable=wrong-import-order
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup # pylint: disable=wrong-import-order
|
||||
|
||||
from .utils import CyclicIterator, TorchTensorEncoder, accuracy, reduce_metrics
|
||||
|
||||
PHASE_SMALL = "small"
|
||||
PHASE_LARGE = "large"
|
||||
|
||||
|
||||
class InteractiveKLLoss(nn.Module):
|
||||
def __init__(self, temperature):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
# self.kl_loss = nn.KLDivLoss(reduction = 'batchmean')
|
||||
self.kl_loss = nn.KLDivLoss()
|
||||
|
||||
def forward(self, student, teacher):
|
||||
return self.kl_loss(F.log_softmax(student / self.temperature, dim=1),
|
||||
F.softmax(teacher / self.temperature, dim=1))
|
||||
|
||||
|
||||
class CdartsTrainer(object):
|
||||
"""
|
||||
CDARTS trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_small : nn.Module
|
||||
PyTorch model to be trained. This is the search network of CDARTS.
|
||||
model_large : nn.Module
|
||||
PyTorch model to be trained. This is the evaluation network of CDARTS.
|
||||
criterion : callable
|
||||
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
|
||||
loaders : list of torch.utils.data.DataLoader
|
||||
List of train data and valid data loaders, for training weights and architecture weights respectively.
|
||||
samplers : list of torch.utils.data.Sampler
|
||||
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
|
||||
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
|
||||
logger : logging.Logger
|
||||
The logger for logging. Will use nni logger by default (if logger is ``None``).
|
||||
regular_coeff : float
|
||||
The coefficient of regular loss.
|
||||
regular_ratio : float
|
||||
The ratio of regular loss.
|
||||
warmup_epochs : int
|
||||
The epochs to warmup the search network
|
||||
fix_head : bool
|
||||
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
|
||||
epochs : int
|
||||
Number of epochs planned for training.
|
||||
steps_per_epoch : int
|
||||
Steps of one epoch.
|
||||
loss_alpha : float
|
||||
The loss coefficient.
|
||||
loss_T : float
|
||||
The loss coefficient.
|
||||
distributed : bool
|
||||
``True`` if using distributed training, else non-distributed training.
|
||||
log_frequency : int
|
||||
Step count per logging.
|
||||
grad_clip : float
|
||||
Gradient clipping for weights.
|
||||
interactive_type : string
|
||||
``kl`` or ``smoothl1``.
|
||||
output_path : string
|
||||
Log storage path.
|
||||
w_lr : float
|
||||
Learning rate of the search network parameters.
|
||||
w_momentum : float
|
||||
Momentum of the search and the evaluation network.
|
||||
w_weight_decay : float
|
||||
The weight decay the search and the evaluation network parameters.
|
||||
alpha_lr : float
|
||||
Learning rate of the architecture parameters.
|
||||
alpha_weight_decay : float
|
||||
The weight decay the architecture parameters.
|
||||
nasnet_lr : float
|
||||
Learning rate of the evaluation network parameters.
|
||||
local_rank : int
|
||||
The number of thread.
|
||||
share_module : bool
|
||||
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
|
||||
"""
|
||||
def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None,
|
||||
regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True,
|
||||
epochs=32, steps_per_epoch=None, loss_alpha=2, loss_T=2, distributed=True,
|
||||
log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs',
|
||||
w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4,
|
||||
nasnet_lr=0.2, local_rank=0, share_module=True):
|
||||
if logger is None:
|
||||
logger = logging.getLogger(__name__)
|
||||
train_loader, valid_loader = loaders
|
||||
train_sampler, valid_sampler = samplers
|
||||
self.train_loader = CyclicIterator(train_loader, train_sampler, distributed)
|
||||
self.valid_loader = CyclicIterator(valid_loader, valid_sampler, distributed)
|
||||
|
||||
self.regular_coeff = regular_coeff
|
||||
self.regular_ratio = regular_ratio
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.fix_head = fix_head
|
||||
self.epochs = epochs
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
if self.steps_per_epoch is None:
|
||||
self.steps_per_epoch = min(len(self.train_loader), len(self.valid_loader))
|
||||
self.loss_alpha = loss_alpha
|
||||
self.grad_clip = grad_clip
|
||||
if interactive_type == "kl":
|
||||
self.interactive_loss = InteractiveKLLoss(loss_T)
|
||||
elif interactive_type == "smoothl1":
|
||||
self.interactive_loss = nn.SmoothL1Loss()
|
||||
self.loss_T = loss_T
|
||||
self.distributed = distributed
|
||||
self.log_frequency = log_frequency
|
||||
self.main_proc = not distributed or local_rank == 0
|
||||
|
||||
self.logger = logger
|
||||
self.checkpoint_dir = output_path
|
||||
if self.main_proc:
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
if distributed:
|
||||
torch.distributed.barrier()
|
||||
|
||||
self.model_small = model_small
|
||||
self.model_large = model_large
|
||||
if self.fix_head:
|
||||
for param in self.model_small.aux_head.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.model_large.aux_head.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.mutator_small = RegularizedDartsMutator(self.model_small).cuda()
|
||||
self.mutator_large = DartsDiscreteMutator(self.model_large, self.mutator_small).cuda()
|
||||
self.criterion = criterion
|
||||
|
||||
self.optimizer_small = torch.optim.SGD(self.model_small.parameters(), w_lr,
|
||||
momentum=w_momentum, weight_decay=w_weight_decay)
|
||||
self.optimizer_large = torch.optim.SGD(self.model_large.parameters(), nasnet_lr,
|
||||
momentum=w_momentum, weight_decay=w_weight_decay)
|
||||
self.optimizer_alpha = torch.optim.Adam(self.mutator_small.parameters(), alpha_lr,
|
||||
betas=(0.5, 0.999), weight_decay=alpha_weight_decay)
|
||||
|
||||
if distributed:
|
||||
apex.parallel.convert_syncbn_model(self.model_small)
|
||||
apex.parallel.convert_syncbn_model(self.model_large)
|
||||
self.model_small = DistributedDataParallel(self.model_small, delay_allreduce=True)
|
||||
self.model_large = DistributedDataParallel(self.model_large, delay_allreduce=True)
|
||||
self.mutator_small = RegularizedMutatorParallel(self.mutator_small, delay_allreduce=True)
|
||||
if share_module:
|
||||
self.model_small.callback_queued = True
|
||||
self.model_large.callback_queued = True
|
||||
# mutator large never gets optimized, so do not need parallelized
|
||||
|
||||
def _warmup(self, phase, epoch):
|
||||
assert phase in [PHASE_SMALL, PHASE_LARGE]
|
||||
if phase == PHASE_SMALL:
|
||||
model, optimizer = self.model_small, self.optimizer_small
|
||||
elif phase == PHASE_LARGE:
|
||||
model, optimizer = self.model_large, self.optimizer_large
|
||||
model.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step in range(self.steps_per_epoch):
|
||||
x, y = next(self.train_loader)
|
||||
x, y = x.cuda(), y.cuda()
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits_main, _ = model(x)
|
||||
loss = self.criterion(logits_main, y)
|
||||
loss.backward()
|
||||
|
||||
self._clip_grad_norm(model)
|
||||
optimizer.step()
|
||||
prec1, prec5 = accuracy(logits_main, y, topk=(1, 5))
|
||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
|
||||
metrics = reduce_metrics(metrics, self.distributed)
|
||||
meters.update(metrics)
|
||||
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
|
||||
self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s) %s", epoch + 1, self.epochs,
|
||||
step + 1, self.steps_per_epoch, phase, meters)
|
||||
|
||||
def _clip_grad_norm(self, model):
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
nn.utils.clip_grad_norm_(model.module.parameters(), self.grad_clip)
|
||||
else:
|
||||
nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
|
||||
|
||||
def _reset_nan(self, parameters):
|
||||
with torch.no_grad():
|
||||
for param in parameters:
|
||||
for i, p in enumerate(param):
|
||||
if p != p: # equivalent to `isnan(p)`
|
||||
param[i] = float("-inf")
|
||||
|
||||
def _joint_train(self, epoch):
|
||||
self.model_large.train()
|
||||
self.model_small.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step in range(self.steps_per_epoch):
|
||||
trn_x, trn_y = next(self.train_loader)
|
||||
val_x, val_y = next(self.valid_loader)
|
||||
trn_x, trn_y = trn_x.cuda(), trn_y.cuda()
|
||||
val_x, val_y = val_x.cuda(), val_y.cuda()
|
||||
|
||||
# step 1. optimize architecture
|
||||
self.optimizer_alpha.zero_grad()
|
||||
self.optimizer_large.zero_grad()
|
||||
reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
|
||||
(self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
|
||||
loss_regular = self.mutator_small.reset_with_loss()
|
||||
if loss_regular:
|
||||
loss_regular *= reg_decay
|
||||
logits_search, emsemble_logits_search = self.model_small(val_x)
|
||||
logits_main, emsemble_logits_main = self.model_large(val_x)
|
||||
loss_cls = (self.criterion(logits_search, val_y) + self.criterion(logits_main, val_y)) / self.loss_alpha
|
||||
loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * (self.loss_T ** 2) * self.loss_alpha
|
||||
loss = loss_cls + loss_interactive + loss_regular
|
||||
loss.backward()
|
||||
self._clip_grad_norm(self.model_large)
|
||||
self.optimizer_large.step()
|
||||
self.optimizer_alpha.step()
|
||||
# NOTE: need to call here `self._reset_nan(self.mutator_small.parameters())` if `cut_choices`
|
||||
|
||||
# step 2. optimize op weights
|
||||
self.optimizer_small.zero_grad()
|
||||
with torch.no_grad():
|
||||
# resample architecture since parameters have been changed
|
||||
self.mutator_small.reset_with_loss()
|
||||
logits_search_train, _ = self.model_small(trn_x)
|
||||
loss_weight = self.criterion(logits_search_train, trn_y)
|
||||
loss_weight.backward()
|
||||
self._clip_grad_norm(self.model_small)
|
||||
self.optimizer_small.step()
|
||||
|
||||
metrics = {"loss_cls": loss_cls, "loss_interactive": loss_interactive,
|
||||
"loss_regular": loss_regular, "loss_weight": loss_weight}
|
||||
metrics = reduce_metrics(metrics, self.distributed)
|
||||
meters.update(metrics)
|
||||
|
||||
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
|
||||
self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs,
|
||||
step + 1, self.steps_per_epoch, meters)
|
||||
|
||||
def train(self):
|
||||
for epoch in range(self.epochs):
|
||||
if epoch < self.warmup_epochs:
|
||||
with torch.no_grad(): # otherwise grads will be retained on the architecture params
|
||||
self.mutator_small.reset_with_loss()
|
||||
self._warmup(PHASE_SMALL, epoch)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.mutator_large.reset()
|
||||
self._warmup(PHASE_LARGE, epoch)
|
||||
self._joint_train(epoch)
|
||||
|
||||
self.export(os.path.join(self.checkpoint_dir, "epoch_{:02d}.json".format(epoch)),
|
||||
os.path.join(self.checkpoint_dir, "epoch_{:02d}.genotypes".format(epoch)))
|
||||
|
||||
def export(self, file, genotype_file):
|
||||
if self.main_proc:
|
||||
mutator_export, genotypes = self.mutator_small.export(self.logger)
|
||||
with open(file, "w") as f:
|
||||
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
|
||||
with open(genotype_file, "w") as f:
|
||||
f.write(str(genotypes))
|
|
@ -1,76 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class CyclicIterator:
|
||||
def __init__(self, loader, sampler, distributed):
|
||||
self.loader = loader
|
||||
self.sampler = sampler
|
||||
self.epoch = 0
|
||||
self.distributed = distributed
|
||||
self._next_epoch()
|
||||
|
||||
def _next_epoch(self):
|
||||
if self.distributed:
|
||||
self.sampler.set_epoch(self.epoch)
|
||||
self.iterator = iter(self.loader)
|
||||
self.epoch += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self._next_epoch()
|
||||
return next(self.iterator)
|
||||
|
||||
|
||||
class TorchTensorEncoder(json.JSONEncoder):
|
||||
def default(self, o): # pylint: disable=method-hidden
|
||||
if isinstance(o, torch.Tensor):
|
||||
return o.tolist()
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
""" Computes the precision@k for the specified values of k """
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
# one-hot case
|
||||
if target.ndimension() > 1:
|
||||
target = target.max(1)[1]
|
||||
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(1.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= float(os.environ["WORLD_SIZE"])
|
||||
return rt
|
||||
|
||||
|
||||
def reduce_metrics(metrics, distributed=False):
|
||||
if distributed:
|
||||
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
|
||||
return {k: v.item() for k, v in metrics.items()}
|
|
@ -1,221 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
import nni
|
||||
from nni.runtime.env_vars import trial_env_vars
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
|
||||
LAYER_CHOICE = "layer_choice"
|
||||
INPUT_CHOICE = "input_choice"
|
||||
|
||||
|
||||
def get_and_apply_next_architecture(model):
|
||||
"""
|
||||
Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful,
|
||||
similar to ``get_next_parameter`` for HPO.
|
||||
|
||||
It will generate search space based on ``model``.
|
||||
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
|
||||
generating search space for the experiment.
|
||||
If not, there are still two mode, one is nni experiment mode where users
|
||||
use ``nnictl`` to start an experiment. The other is standalone mode
|
||||
where users directly run the trial command, this mode chooses the first
|
||||
one(s) for each LayerChoice and InputChoice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
ClassicMutator(model)
|
||||
|
||||
|
||||
class ClassicMutator(Mutator):
|
||||
"""
|
||||
This mutator is to apply the architecture chosen from tuner.
|
||||
It implements the forward function of LayerChoice and InputChoice,
|
||||
to only activate the chosen ones.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super(ClassicMutator, self).__init__(model)
|
||||
self._chosen_arch = {}
|
||||
self._search_space = self._generate_search_space()
|
||||
if NNI_GEN_SEARCH_SPACE in os.environ:
|
||||
# dry run for only generating search space
|
||||
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
|
||||
sys.exit(0)
|
||||
|
||||
if trial_env_vars.NNI_PLATFORM is None:
|
||||
logger.warning("This is in standalone mode, the chosen are the first one(s).")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
# get chosen arch from tuner
|
||||
self._chosen_arch = nni.get_next_parameter()
|
||||
if self._chosen_arch is None:
|
||||
if trial_env_vars.NNI_PLATFORM == "unittest":
|
||||
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
|
||||
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
|
||||
self.reset()
|
||||
|
||||
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert layer choice to tensor representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
# doesn't support multihot for layer choice yet
|
||||
onehot_list = [False] * len(mutable)
|
||||
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
|
||||
onehot_list[idx] = True
|
||||
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
|
||||
|
||||
def _sample_input_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert input choice to tensor representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
candidate_repr = search_space_item["candidates"]
|
||||
multihot_list = [False] * mutable.n_candidates
|
||||
for i, v in zip(idx, value):
|
||||
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
|
||||
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
|
||||
multihot_list[i] = True
|
||||
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
See :meth:`sample_final`.
|
||||
"""
|
||||
return self.sample_final()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Convert the chosen arch and apply it on model.
|
||||
"""
|
||||
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
|
||||
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
|
||||
self._chosen_arch.keys())
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, (LayerChoice, InputChoice)):
|
||||
assert mutable.key in self._chosen_arch, \
|
||||
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
|
||||
data = self._chosen_arch[mutable.key]
|
||||
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
|
||||
"'{}' is not a valid choice.".format(data)
|
||||
if isinstance(mutable, LayerChoice):
|
||||
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, InputChoice):
|
||||
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return result
|
||||
|
||||
def _standalone_generate_chosen(self):
|
||||
"""
|
||||
Generate the chosen architecture for standalone mode,
|
||||
i.e., choose the first one(s) for LayerChoice and InputChoice.
|
||||
::
|
||||
{ key_name: {"_value": "conv1",
|
||||
"_idx": 0} }
|
||||
{ key_name: {"_value": ["in1"],
|
||||
"_idx": [0]} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the chosen architecture
|
||||
"""
|
||||
chosen_arch = {}
|
||||
for key, val in self._search_space.items():
|
||||
if val["_type"] == LAYER_CHOICE:
|
||||
choices = val["_value"]
|
||||
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
|
||||
elif val["_type"] == INPUT_CHOICE:
|
||||
choices = val["_value"]["candidates"]
|
||||
n_chosen = val["_value"]["n_chosen"]
|
||||
if n_chosen is None:
|
||||
n_chosen = len(choices)
|
||||
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
|
||||
else:
|
||||
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
|
||||
return chosen_arch
|
||||
|
||||
def _generate_search_space(self):
|
||||
"""
|
||||
Generate search space from mutables.
|
||||
Here is the search space format:
|
||||
::
|
||||
{ key_name: {"_type": "layer_choice",
|
||||
"_value": ["conv1", "conv2"]} }
|
||||
{ key_name: {"_type": "input_choice",
|
||||
"_value": {"candidates": ["in1", "in2"],
|
||||
"n_chosen": 1}} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the generated search space
|
||||
"""
|
||||
search_space = {}
|
||||
for mutable in self.mutables:
|
||||
# for now we only generate flattened search space
|
||||
if isinstance(mutable, LayerChoice):
|
||||
key = mutable.key
|
||||
val = mutable.names
|
||||
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
|
||||
elif isinstance(mutable, InputChoice):
|
||||
key = mutable.key
|
||||
search_space[key] = {"_type": INPUT_CHOICE,
|
||||
"_value": {"candidates": mutable.choose_from,
|
||||
"n_chosen": mutable.n_chosen}}
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return search_space
|
||||
|
||||
def _dump_search_space(self, file_path):
|
||||
with open(file_path, "w") as ss_file:
|
||||
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)
|
|
@ -1,403 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup
|
||||
|
||||
from .utils import accuracy, reduce_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreamSupernetTrainer(Trainer):
|
||||
"""
|
||||
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with mutables.
|
||||
loss : callable
|
||||
Called with logits and targets. Returns a loss tensor.
|
||||
val_loss : callable
|
||||
Called with logits and targets for validation only. Returns a loss tensor.
|
||||
optimizer : Optimizer
|
||||
Optimizer that optimizes the model.
|
||||
num_epochs : int
|
||||
Number of epochs of training.
|
||||
train_loader : iterablez
|
||||
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
valid_loader : iterablez
|
||||
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
mutator : Mutator
|
||||
A mutator object that has been initialized with the model.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
log_frequency : int
|
||||
Number of mini-batches to log metrics.
|
||||
meta_sta_epoch : int
|
||||
start epoch of using meta matching network to pick teacher architecture
|
||||
update_iter : int
|
||||
interval of updating meta matching networks
|
||||
slices : int
|
||||
batch size of mini training data in the process of training meta matching network
|
||||
pool_size : int
|
||||
board size
|
||||
pick_method : basestring
|
||||
how to pick teacher network
|
||||
choice_num : int
|
||||
number of operations in supernet
|
||||
sta_num : int
|
||||
layer number of each stage in supernet (5 stage in supernet)
|
||||
acc_gap : int
|
||||
maximum accuracy improvement to omit the limitation of flops
|
||||
flops_dict : Dict
|
||||
dictionary of each layer's operations in supernet
|
||||
flops_fixed : int
|
||||
flops of fixed part in supernet
|
||||
local_rank : int
|
||||
index of current rank
|
||||
callbacks : list of Callback
|
||||
Callbacks to plug into the trainer. See Callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, model, loss, val_loss,
|
||||
optimizer, num_epochs, train_loader, valid_loader,
|
||||
mutator=None, batch_size=64, log_frequency=None,
|
||||
meta_sta_epoch=20, update_iter=200, slices=2,
|
||||
pool_size=10, pick_method='meta', choice_num=6,
|
||||
sta_num=(4, 4, 4, 4, 4), acc_gap=5,
|
||||
flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None):
|
||||
assert torch.cuda.is_available()
|
||||
super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None,
|
||||
optimizer, num_epochs, None, None,
|
||||
batch_size, None, None, log_frequency, callbacks)
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.val_loss = val_loss
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
self.log_frequency = log_frequency
|
||||
self.batch_size = batch_size
|
||||
self.optimizer = optimizer
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.num_epochs = num_epochs
|
||||
self.meta_sta_epoch = meta_sta_epoch
|
||||
self.update_iter = update_iter
|
||||
self.slices = slices
|
||||
self.pick_method = pick_method
|
||||
self.pool_size = pool_size
|
||||
self.local_rank = local_rank
|
||||
self.choice_num = choice_num
|
||||
self.sta_num = sta_num
|
||||
self.acc_gap = acc_gap
|
||||
self.flops_dict = flops_dict
|
||||
self.flops_fixed = flops_fixed
|
||||
|
||||
self.current_student_arch = None
|
||||
self.current_teacher_arch = None
|
||||
self.main_proc = (local_rank == 0)
|
||||
self.current_epoch = 0
|
||||
|
||||
self.prioritized_board = []
|
||||
|
||||
# size of prioritized board
|
||||
def _board_size(self):
|
||||
return len(self.prioritized_board)
|
||||
|
||||
# select teacher architecture according to the logit difference
|
||||
def _select_teacher(self):
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
|
||||
if self.pick_method == 'top1':
|
||||
meta_value, teacher_cand = 0.5, sorted(
|
||||
self.prioritized_board, reverse=True)[0][3]
|
||||
elif self.pick_method == 'meta':
|
||||
meta_value, cand_idx, teacher_cand = -1000000000, -1, None
|
||||
for now_idx, item in enumerate(self.prioritized_board):
|
||||
inputx = item[4]
|
||||
output = torch.nn.functional.softmax(self.model(inputx), dim=1)
|
||||
weight = self.model.module.forward_meta(output - item[5])
|
||||
if weight > meta_value:
|
||||
meta_value = weight
|
||||
cand_idx = now_idx
|
||||
teacher_cand = self.prioritized_board[cand_idx][3]
|
||||
assert teacher_cand is not None
|
||||
meta_value = torch.nn.functional.sigmoid(-weight)
|
||||
else:
|
||||
raise ValueError('Method Not supported')
|
||||
|
||||
return meta_value, teacher_cand
|
||||
|
||||
# check whether to update prioritized board
|
||||
def _isUpdateBoard(self, prec1, flops):
|
||||
if self.current_epoch <= self.meta_sta_epoch:
|
||||
return False
|
||||
|
||||
if len(self.prioritized_board) < self.pool_size:
|
||||
return True
|
||||
|
||||
if prec1 > self.prioritized_board[-1][1] + self.acc_gap:
|
||||
return True
|
||||
|
||||
if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# update prioritized board
|
||||
def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops):
|
||||
if self._isUpdateBoard(prec1, flops):
|
||||
val_prec1 = prec1
|
||||
training_data = deepcopy(inputs[:self.slices].detach())
|
||||
if len(self.prioritized_board) == 0:
|
||||
features = deepcopy(outputs[:self.slices].detach())
|
||||
else:
|
||||
features = deepcopy(
|
||||
teacher_output[:self.slices].detach())
|
||||
self.prioritized_board.append(
|
||||
(val_prec1,
|
||||
prec1,
|
||||
flops,
|
||||
self.current_student_arch,
|
||||
training_data,
|
||||
torch.nn.functional.softmax(
|
||||
features,
|
||||
dim=1)))
|
||||
self.prioritized_board = sorted(
|
||||
self.prioritized_board, reverse=True)
|
||||
|
||||
if len(self.prioritized_board) > self.pool_size:
|
||||
del self.prioritized_board[-1]
|
||||
|
||||
# only update student network weights
|
||||
def _update_student_weights_only(self, grad_1):
|
||||
for weight, grad_item in zip(
|
||||
self.model.module.rand_parameters(self.current_student_arch), grad_1):
|
||||
weight.grad = grad_item
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.module.rand_parameters(self.current_student_arch), 1)
|
||||
self.optimizer.step()
|
||||
for weight, grad_item in zip(
|
||||
self.model.module.rand_parameters(self.current_student_arch), grad_1):
|
||||
del weight.grad
|
||||
|
||||
# only update meta networks weights
|
||||
def _update_meta_weights_only(self, teacher_cand, grad_teacher):
|
||||
for weight, grad_item in zip(self.model.module.rand_parameters(
|
||||
teacher_cand, self.pick_method == 'meta'), grad_teacher):
|
||||
weight.grad = grad_item
|
||||
|
||||
# clip gradients
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.module.rand_parameters(
|
||||
self.current_student_arch, self.pick_method == 'meta'), 1)
|
||||
|
||||
self.optimizer.step()
|
||||
for weight, grad_item in zip(self.model.module.rand_parameters(
|
||||
teacher_cand, self.pick_method == 'meta'), grad_teacher):
|
||||
del weight.grad
|
||||
|
||||
# simulate sgd updating
|
||||
def _simulate_sgd_update(self, w, g, optimizer):
|
||||
return g * optimizer.param_groups[-1]['lr'] + w
|
||||
|
||||
# split training images into several slices
|
||||
def _get_minibatch_input(self, input): # pylint: disable=redefined-builtin
|
||||
slice = self.slices # pylint: disable=redefined-builtin
|
||||
x = deepcopy(input[:slice].clone().detach())
|
||||
return x
|
||||
|
||||
# calculate 1st gradient of student architectures
|
||||
def _calculate_1st_gradient(self, kd_loss):
|
||||
self.optimizer.zero_grad()
|
||||
grad = torch.autograd.grad(
|
||||
kd_loss,
|
||||
self.model.module.rand_parameters(self.current_student_arch),
|
||||
create_graph=True)
|
||||
return grad
|
||||
|
||||
# calculate 2nd gradient of meta networks
|
||||
def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight):
|
||||
self.optimizer.zero_grad()
|
||||
grad_student_val = torch.autograd.grad(
|
||||
validation_loss,
|
||||
self.model.module.rand_parameters(self.current_student_arch),
|
||||
retain_graph=True)
|
||||
|
||||
grad_teacher = torch.autograd.grad(
|
||||
students_weight[0],
|
||||
self.model.module.rand_parameters(
|
||||
teacher_cand,
|
||||
self.pick_method == 'meta'),
|
||||
grad_outputs=grad_student_val)
|
||||
return grad_teacher
|
||||
|
||||
# forward training data
|
||||
def _forward_training(self, x, meta_value):
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output = self.model(x)
|
||||
|
||||
with torch.no_grad():
|
||||
self._replace_mutator_cand(self.current_teacher_arch)
|
||||
teacher_output = self.model(x)
|
||||
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
|
||||
|
||||
kd_loss = meta_value * \
|
||||
self._cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
return kd_loss
|
||||
|
||||
# calculate soft target loss
|
||||
def _cross_entropy_loss_with_soft_target(self, pred, soft_target):
|
||||
logsoftmax = torch.nn.LogSoftmax()
|
||||
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
|
||||
|
||||
# forward validation data
|
||||
def _forward_validation(self, input, target): # pylint: disable=redefined-builtin
|
||||
slice = self.slices # pylint: disable=redefined-builtin
|
||||
x = input[slice:slice * 2].clone()
|
||||
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output_2 = self.model(x)
|
||||
|
||||
validation_loss = self.loss(output_2, target[slice:slice * 2])
|
||||
return validation_loss
|
||||
|
||||
def _isUpdateMeta(self, batch_idx):
|
||||
isUpdate = True
|
||||
isUpdate &= (self.current_epoch > self.meta_sta_epoch)
|
||||
isUpdate &= (batch_idx > 0)
|
||||
isUpdate &= (batch_idx % self.update_iter == 0)
|
||||
isUpdate &= (self._board_size() > 0)
|
||||
return isUpdate
|
||||
|
||||
def _replace_mutator_cand(self, cand):
|
||||
self.mutator._cache = cand
|
||||
|
||||
# update meta matching networks
|
||||
def _run_update(self, input, target, batch_idx): # pylint: disable=redefined-builtin
|
||||
if self._isUpdateMeta(batch_idx):
|
||||
x = self._get_minibatch_input(input)
|
||||
|
||||
meta_value, teacher_cand = self._select_teacher()
|
||||
|
||||
kd_loss = self._forward_training(x, meta_value)
|
||||
|
||||
# calculate 1st gradient
|
||||
grad_1st = self._calculate_1st_gradient(kd_loss)
|
||||
|
||||
# simulate updated student weights
|
||||
students_weight = [
|
||||
self._simulate_sgd_update(
|
||||
p, grad_item, self.optimizer) for p, grad_item in zip(
|
||||
self.model.module.rand_parameters(self.current_student_arch), grad_1st)]
|
||||
|
||||
# update student weights
|
||||
self._update_student_weights_only(grad_1st)
|
||||
|
||||
validation_loss = self._forward_validation(input, target)
|
||||
|
||||
# calculate 2nd gradient
|
||||
grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight)
|
||||
|
||||
# update meta matching networks
|
||||
self._update_meta_weights_only(teacher_cand, grad_teacher)
|
||||
|
||||
# delete internal variants
|
||||
del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight
|
||||
|
||||
def _get_cand_flops(self, cand):
|
||||
flops = 0
|
||||
for block_id, block in enumerate(cand):
|
||||
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
|
||||
continue
|
||||
for idx, choice in enumerate(cand[block]):
|
||||
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
|
||||
return flops + self.flops_fixed
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
meters = AverageMeterGroup()
|
||||
self.steps_per_epoch = len(self.train_loader)
|
||||
for step, (input_data, target) in enumerate(self.train_loader):
|
||||
self.mutator.reset()
|
||||
self.current_student_arch = self.mutator._cache
|
||||
|
||||
input_data, target = input_data.cuda(), target.cuda()
|
||||
|
||||
# calculate flops of current architecture
|
||||
cand_flops = self._get_cand_flops(self.mutator._cache)
|
||||
|
||||
# update meta matching network
|
||||
self._run_update(input_data, target, step)
|
||||
|
||||
if self._board_size() > 0:
|
||||
# select teacher architecture
|
||||
meta_value, teacher_cand = self._select_teacher()
|
||||
self.current_teacher_arch = teacher_cand
|
||||
|
||||
# forward supernet
|
||||
if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output = self.model(input_data)
|
||||
|
||||
loss = self.loss(output, target)
|
||||
kd_loss, teacher_output, teacher_cand = None, None, None
|
||||
else:
|
||||
self._replace_mutator_cand(self.current_student_arch)
|
||||
output = self.model(input_data)
|
||||
|
||||
gt_loss = self.loss(output, target)
|
||||
|
||||
with torch.no_grad():
|
||||
self._replace_mutator_cand(self.current_teacher_arch)
|
||||
teacher_output = self.model(input_data).detach()
|
||||
|
||||
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
|
||||
kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
|
||||
loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2
|
||||
|
||||
# update network
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# update metrics
|
||||
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
|
||||
metrics = reduce_metrics(metrics)
|
||||
meters.update(metrics)
|
||||
|
||||
# update prioritized board
|
||||
self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops)
|
||||
|
||||
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
|
||||
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs,
|
||||
step + 1, len(self.train_loader), meters)
|
||||
|
||||
if self.main_proc and self.num_epochs == epoch + 1:
|
||||
for idx, i in enumerate(self.prioritized_board):
|
||||
logger.info("No.%s %s", idx, i[:4])
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
meters = AverageMeterGroup()
|
||||
with torch.no_grad():
|
||||
for step, (x, y) in enumerate(self.valid_loader):
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
loss = self.val_loss(logits, y)
|
||||
prec1, prec5 = accuracy(logits, y, topk=(1, 5))
|
||||
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
|
||||
metrics = reduce_metrics(metrics)
|
||||
meters.update(metrics)
|
||||
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.valid_loader), meters)
|
|
@ -1,37 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
""" Computes the precision@k for the specified values of k """
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
# one-hot case
|
||||
if target.ndimension() > 1:
|
||||
target = target.max(1)[1]
|
||||
|
||||
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(1.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def reduce_metrics(metrics):
|
||||
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
|
||||
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= float(os.environ["WORLD_SIZE"])
|
||||
return rt
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import DartsMutator
|
||||
from .trainer import DartsTrainer
|
|
@ -1,85 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DartsMutator(Mutator):
|
||||
"""
|
||||
Connects the model in a DARTS (differentiable) way.
|
||||
|
||||
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
|
||||
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
|
||||
(not chosen).
|
||||
|
||||
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
|
||||
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
|
||||
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
|
||||
|
||||
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
|
||||
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
|
||||
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
choices: ParameterDict
|
||||
dict that maps keys of LayerChoices to weighted-connection float tensors.
|
||||
"""
|
||||
def __init__(self, model):
|
||||
super().__init__(model)
|
||||
self.choices = nn.ParameterDict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
|
||||
|
||||
def device(self):
|
||||
for v in self.choices.values():
|
||||
return v.device
|
||||
|
||||
def sample_search(self):
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
|
||||
elif isinstance(mutable, InputChoice):
|
||||
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
|
||||
return result
|
||||
|
||||
def sample_final(self):
|
||||
result = dict()
|
||||
edges_max = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
|
||||
edges_max[mutable.key] = max_val
|
||||
result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, InputChoice):
|
||||
if mutable.n_chosen is not None:
|
||||
weights = []
|
||||
for src_key in mutable.choose_from:
|
||||
if src_key not in edges_max:
|
||||
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
|
||||
weights.append(edges_max.get(src_key, 0.))
|
||||
weights = torch.tensor(weights) # pylint: disable=not-callable
|
||||
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
|
||||
selected_multihot = []
|
||||
for i, src_key in enumerate(mutable.choose_from):
|
||||
if i not in topk_edge_indices and src_key in result:
|
||||
# If an edge is never selected, there is no need to calculate any op on this edge.
|
||||
# This is to eliminate redundant calculation.
|
||||
result[src_key] = torch.zeros_like(result[src_key])
|
||||
selected_multihot.append(i in topk_edge_indices)
|
||||
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
|
||||
else:
|
||||
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
|
||||
return result
|
|
@ -1,214 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup
|
||||
|
||||
from .mutator import DartsMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DartsTrainer(Trainer):
|
||||
"""
|
||||
DARTS trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model to be trained.
|
||||
loss : callable
|
||||
Receives logits and ground truth label, return a loss tensor.
|
||||
metrics : callable
|
||||
Receives logits and ground truth label, return a dict of metrics.
|
||||
optimizer : Optimizer
|
||||
The optimizer used for optimizing the model.
|
||||
num_epochs : int
|
||||
Number of epochs planned for training.
|
||||
dataset_train : Dataset
|
||||
Dataset for training. Will be split for training weights and architecture weights.
|
||||
dataset_valid : Dataset
|
||||
Dataset for testing.
|
||||
mutator : DartsMutator
|
||||
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
workers : int
|
||||
Workers for data loading.
|
||||
device : torch.device
|
||||
``torch.device("cpu")`` or ``torch.device("cuda")``.
|
||||
log_frequency : int
|
||||
Step count per logging.
|
||||
callbacks : list of Callback
|
||||
list of callbacks to trigger at events.
|
||||
arc_learning_rate : float
|
||||
Learning rate of architecture parameters.
|
||||
unrolled : float
|
||||
``True`` if using second order optimization, else first order optimization.
|
||||
"""
|
||||
def __init__(self, model, loss, metrics,
|
||||
optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
|
||||
callbacks=None, arc_learning_rate=3.0E-4, unrolled=False):
|
||||
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
|
||||
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
batch_size, workers, device, log_frequency, callbacks)
|
||||
|
||||
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
|
||||
weight_decay=1.0E-3)
|
||||
self.unrolled = unrolled
|
||||
|
||||
n_train = len(self.dataset_train)
|
||||
split = n_train // 2
|
||||
indices = list(range(n_train))
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
|
||||
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=workers)
|
||||
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=batch_size,
|
||||
sampler=valid_sampler,
|
||||
num_workers=workers)
|
||||
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
|
||||
batch_size=batch_size,
|
||||
num_workers=workers)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.model.train()
|
||||
self.mutator.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
|
||||
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
|
||||
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
|
||||
|
||||
# phase 1. architecture step
|
||||
self.ctrl_optim.zero_grad()
|
||||
if self.unrolled:
|
||||
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
|
||||
else:
|
||||
self._backward(val_X, val_y)
|
||||
self.ctrl_optim.step()
|
||||
|
||||
# phase 2: child network step
|
||||
self.optimizer.zero_grad()
|
||||
logits, loss = self._logits_and_loss(trn_X, trn_y)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
|
||||
self.optimizer.step()
|
||||
|
||||
metrics = self.metrics(logits, trn_y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.train_loader), meters)
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
self.mutator.eval()
|
||||
meters = AverageMeterGroup()
|
||||
with torch.no_grad():
|
||||
self.mutator.reset()
|
||||
for step, (X, y) in enumerate(self.test_loader):
|
||||
X, y = X.to(self.device), y.to(self.device)
|
||||
logits = self.model(X)
|
||||
metrics = self.metrics(logits, y)
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.test_loader), meters)
|
||||
|
||||
def _logits_and_loss(self, X, y):
|
||||
self.mutator.reset()
|
||||
logits = self.model(X)
|
||||
loss = self.loss(logits, y)
|
||||
self._write_graph_status()
|
||||
return logits, loss
|
||||
|
||||
def _backward(self, val_X, val_y):
|
||||
"""
|
||||
Simple backward with gradient descent
|
||||
"""
|
||||
_, loss = self._logits_and_loss(val_X, val_y)
|
||||
loss.backward()
|
||||
|
||||
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
|
||||
"""
|
||||
Compute unrolled loss and backward its gradients
|
||||
"""
|
||||
backup_params = copy.deepcopy(tuple(self.model.parameters()))
|
||||
|
||||
# do virtual step on training data
|
||||
lr = self.optimizer.param_groups[0]["lr"]
|
||||
momentum = self.optimizer.param_groups[0]["momentum"]
|
||||
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
|
||||
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
|
||||
|
||||
# calculate unrolled loss on validation data
|
||||
# keep gradients for model here for compute hessian
|
||||
_, loss = self._logits_and_loss(val_X, val_y)
|
||||
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
|
||||
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
|
||||
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
|
||||
|
||||
# compute hessian and final gradients
|
||||
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
|
||||
with torch.no_grad():
|
||||
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
|
||||
# gradient = dalpha - lr * hessian
|
||||
param.grad = d - lr * h
|
||||
|
||||
# restore weights
|
||||
self._restore_weights(backup_params)
|
||||
|
||||
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
|
||||
"""
|
||||
Compute unrolled weights w`
|
||||
"""
|
||||
# don't need zero_grad, using autograd to calculate gradients
|
||||
_, loss = self._logits_and_loss(X, y)
|
||||
gradients = torch.autograd.grad(loss, self.model.parameters())
|
||||
with torch.no_grad():
|
||||
for w, g in zip(self.model.parameters(), gradients):
|
||||
m = self.optimizer.state[w].get("momentum_buffer", 0.)
|
||||
w = w - lr * (momentum * m + g + weight_decay * w)
|
||||
|
||||
def _restore_weights(self, backup_params):
|
||||
with torch.no_grad():
|
||||
for param, backup in zip(self.model.parameters(), backup_params):
|
||||
param.copy_(backup)
|
||||
|
||||
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
|
||||
"""
|
||||
dw = dw` { L_val(w`, alpha) }
|
||||
w+ = w + eps * dw
|
||||
w- = w - eps * dw
|
||||
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
|
||||
eps = 0.01 / ||dw||
|
||||
"""
|
||||
self._restore_weights(backup_params)
|
||||
norm = torch.cat([w.view(-1) for w in dw]).norm()
|
||||
eps = 0.01 / norm
|
||||
if norm < 1E-8:
|
||||
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
|
||||
|
||||
dalphas = []
|
||||
for e in [eps, -2. * eps]:
|
||||
# w+ = w + eps*dw`, w- = w - eps*dw`
|
||||
with torch.no_grad():
|
||||
for p, d in zip(self.model.parameters(), dw):
|
||||
p += e * d
|
||||
|
||||
_, loss = self._logits_and_loss(trn_X, trn_y)
|
||||
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
|
||||
|
||||
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
|
||||
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
|
||||
return hessian
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import EnasMutator
|
||||
from .trainer import EnasTrainer
|
|
@ -1,197 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
|
||||
|
||||
|
||||
class StackedLSTMCell(nn.Module):
|
||||
def __init__(self, layers, size, bias):
|
||||
super().__init__()
|
||||
self.lstm_num_layers = layers
|
||||
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
|
||||
for _ in range(self.lstm_num_layers)])
|
||||
|
||||
def forward(self, inputs, hidden):
|
||||
prev_h, prev_c = hidden
|
||||
next_h, next_c = [], []
|
||||
for i, m in enumerate(self.lstm_modules):
|
||||
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
|
||||
next_c.append(curr_c)
|
||||
next_h.append(curr_h)
|
||||
# current implementation only supports batch size equals 1,
|
||||
# but the algorithm does not necessarily have this limitation
|
||||
inputs = curr_h[-1].view(1, -1)
|
||||
return next_h, next_c
|
||||
|
||||
|
||||
class EnasMutator(Mutator):
|
||||
"""
|
||||
A mutator that mutates the graph with RL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model.
|
||||
lstm_size : int
|
||||
Controller LSTM hidden units.
|
||||
lstm_num_layers : int
|
||||
Number of layers for stacked LSTM.
|
||||
tanh_constant : float
|
||||
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
|
||||
cell_exit_extra_step : bool
|
||||
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
|
||||
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
|
||||
skip_target : float
|
||||
Target probability that skipconnect will appear.
|
||||
temperature : float
|
||||
Temperature constant that divides the logits.
|
||||
branch_bias : float
|
||||
Manual bias applied to make some operations more likely to be chosen.
|
||||
Currently this is implemented with a hardcoded match rule that aligns with original repo.
|
||||
If a mutable has a ``reduce`` in its key, all its op choices
|
||||
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
|
||||
receive a bias of ``-self.branch_bias``.
|
||||
entropy_reduction : str
|
||||
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
|
||||
"""
|
||||
|
||||
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
|
||||
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
|
||||
super().__init__(model)
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_num_layers = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
self.cell_exit_extra_step = cell_exit_extra_step
|
||||
self.skip_target = skip_target
|
||||
self.branch_bias = branch_bias
|
||||
|
||||
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
|
||||
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
|
||||
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
|
||||
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
|
||||
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
|
||||
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
|
||||
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
|
||||
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
|
||||
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
|
||||
self.bias_dict = nn.ParameterDict()
|
||||
|
||||
self.max_layer_choice = 0
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
if self.max_layer_choice == 0:
|
||||
self.max_layer_choice = len(mutable)
|
||||
assert self.max_layer_choice == len(mutable), \
|
||||
"ENAS mutator requires all layer choice have the same number of candidates."
|
||||
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
|
||||
if "reduce" in mutable.key:
|
||||
def is_conv(choice):
|
||||
return "conv" in str(type(choice)).lower()
|
||||
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
|
||||
for choice in mutable])
|
||||
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
|
||||
|
||||
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
|
||||
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
|
||||
|
||||
def sample_search(self):
|
||||
self._initialize()
|
||||
self._sample(self.mutables)
|
||||
return self._choices
|
||||
|
||||
def sample_final(self):
|
||||
return self.sample_search()
|
||||
|
||||
def _sample(self, tree):
|
||||
mutable = tree.mutable
|
||||
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_layer_choice(mutable)
|
||||
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_input_choice(mutable)
|
||||
for child in tree.children:
|
||||
self._sample(child)
|
||||
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
|
||||
if self.cell_exit_extra_step:
|
||||
self._lstm_next_step()
|
||||
self._mark_anchor(mutable.key)
|
||||
|
||||
def _initialize(self):
|
||||
self._choices = dict()
|
||||
self._anchors_hid = dict()
|
||||
self._inputs = self.g_emb.data
|
||||
self._c = [torch.zeros((1, self.lstm_size),
|
||||
dtype=self._inputs.dtype,
|
||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
|
||||
self._h = [torch.zeros((1, self.lstm_size),
|
||||
dtype=self._inputs.dtype,
|
||||
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
|
||||
self.sample_log_prob = 0
|
||||
self.sample_entropy = 0
|
||||
self.sample_skip_penalty = 0
|
||||
|
||||
def _lstm_next_step(self):
|
||||
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
|
||||
|
||||
def _mark_anchor(self, key):
|
||||
self._anchors_hid[key] = self._h[-1]
|
||||
|
||||
def _sample_layer_choice(self, mutable):
|
||||
self._lstm_next_step()
|
||||
logit = self.soft(self._h[-1])
|
||||
if self.temperature is not None:
|
||||
logit /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
logit = self.tanh_constant * torch.tanh(logit)
|
||||
if mutable.key in self.bias_dict:
|
||||
logit += self.bias_dict[mutable.key]
|
||||
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
log_prob = self.cross_entropy_loss(logit, branch_id)
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
self._inputs = self.embedding(branch_id)
|
||||
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
|
||||
|
||||
def _sample_input_choice(self, mutable):
|
||||
query, anchors = [], []
|
||||
for label in mutable.choose_from:
|
||||
if label not in self._anchors_hid:
|
||||
self._lstm_next_step()
|
||||
self._mark_anchor(label) # empty loop, fill not found
|
||||
query.append(self.attn_anchor(self._anchors_hid[label]))
|
||||
anchors.append(self._anchors_hid[label])
|
||||
query = torch.cat(query, 0)
|
||||
query = torch.tanh(query + self.attn_query(self._h[-1]))
|
||||
query = self.v_attn(query)
|
||||
if self.temperature is not None:
|
||||
query /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
query = self.tanh_constant * torch.tanh(query)
|
||||
|
||||
if mutable.n_chosen is None:
|
||||
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
|
||||
|
||||
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
skip_prob = torch.sigmoid(logit)
|
||||
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
|
||||
self.sample_skip_penalty += kl
|
||||
log_prob = self.cross_entropy_loss(logit, skip)
|
||||
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
|
||||
else:
|
||||
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
|
||||
logit = query.view(1, -1)
|
||||
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
|
||||
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
|
||||
log_prob = self.cross_entropy_loss(logit, index)
|
||||
self._inputs = anchors[index.item()]
|
||||
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
return skip.bool()
|
|
@ -1,209 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from itertools import cycle
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup, to_device
|
||||
from .mutator import EnasMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnasTrainer(Trainer):
|
||||
"""
|
||||
ENAS trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model to be trained.
|
||||
loss : callable
|
||||
Receives logits and ground truth label, return a loss tensor.
|
||||
metrics : callable
|
||||
Receives logits and ground truth label, return a dict of metrics.
|
||||
reward_function : callable
|
||||
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
|
||||
optimizer : Optimizer
|
||||
The optimizer used for optimizing the model.
|
||||
num_epochs : int
|
||||
Number of epochs planned for training.
|
||||
dataset_train : Dataset
|
||||
Dataset for training. Will be split for training weights and architecture weights.
|
||||
dataset_valid : Dataset
|
||||
Dataset for testing.
|
||||
mutator : EnasMutator
|
||||
Use when customizing your own mutator or a mutator with customized parameters.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
workers : int
|
||||
Workers for data loading.
|
||||
device : torch.device
|
||||
``torch.device("cpu")`` or ``torch.device("cuda")``.
|
||||
log_frequency : int
|
||||
Step count per logging.
|
||||
callbacks : list of Callback
|
||||
list of callbacks to trigger at events.
|
||||
entropy_weight : float
|
||||
Weight of sample entropy loss.
|
||||
skip_weight : float
|
||||
Weight of skip penalty loss.
|
||||
baseline_decay : float
|
||||
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
|
||||
child_steps : int
|
||||
How many mini-batches for model training per epoch.
|
||||
mutator_lr : float
|
||||
Learning rate for RL controller.
|
||||
mutator_steps_aggregate : int
|
||||
Number of steps that will be aggregated into one mini-batch for RL controller.
|
||||
mutator_steps : int
|
||||
Number of mini-batches for each epoch of RL controller learning.
|
||||
aux_weight : float
|
||||
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
|
||||
test_arc_per_epoch : int
|
||||
How many architectures are chosen for direct test after each epoch.
|
||||
"""
|
||||
def __init__(self, model, loss, metrics, reward_function,
|
||||
optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
|
||||
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500,
|
||||
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4,
|
||||
test_arc_per_epoch=1):
|
||||
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
|
||||
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
|
||||
batch_size, workers, device, log_frequency, callbacks)
|
||||
self.reward_function = reward_function
|
||||
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
|
||||
self.batch_size = batch_size
|
||||
self.workers = workers
|
||||
|
||||
self.entropy_weight = entropy_weight
|
||||
self.skip_weight = skip_weight
|
||||
self.baseline_decay = baseline_decay
|
||||
self.baseline = 0.
|
||||
self.mutator_steps_aggregate = mutator_steps_aggregate
|
||||
self.mutator_steps = mutator_steps
|
||||
self.child_steps = child_steps
|
||||
self.aux_weight = aux_weight
|
||||
self.test_arc_per_epoch = test_arc_per_epoch
|
||||
|
||||
self.init_dataloader()
|
||||
|
||||
def init_dataloader(self):
|
||||
n_train = len(self.dataset_train)
|
||||
split = n_train // 10
|
||||
indices = list(range(n_train))
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
|
||||
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=self.workers)
|
||||
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=valid_sampler,
|
||||
num_workers=self.workers)
|
||||
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.workers)
|
||||
self.train_loader = cycle(self.train_loader)
|
||||
self.valid_loader = cycle(self.valid_loader)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
# Sample model and train
|
||||
self.model.train()
|
||||
self.mutator.eval()
|
||||
meters = AverageMeterGroup()
|
||||
for step in range(1, self.child_steps + 1):
|
||||
x, y = next(self.train_loader)
|
||||
x, y = to_device(x, self.device), to_device(y, self.device)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with torch.no_grad():
|
||||
self.mutator.reset()
|
||||
self._write_graph_status()
|
||||
logits = self.model(x)
|
||||
|
||||
if isinstance(logits, tuple):
|
||||
logits, aux_logits = logits
|
||||
aux_loss = self.loss(aux_logits, y)
|
||||
else:
|
||||
aux_loss = 0.
|
||||
metrics = self.metrics(logits, y)
|
||||
loss = self.loss(logits, y)
|
||||
loss = loss + self.aux_weight * aux_loss
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
|
||||
self.optimizer.step()
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
|
||||
self.num_epochs, step, self.child_steps, meters)
|
||||
|
||||
# Train sampler (mutator)
|
||||
self.model.eval()
|
||||
self.mutator.train()
|
||||
meters = AverageMeterGroup()
|
||||
for mutator_step in range(1, self.mutator_steps + 1):
|
||||
self.mutator_optim.zero_grad()
|
||||
for step in range(1, self.mutator_steps_aggregate + 1):
|
||||
x, y = next(self.valid_loader)
|
||||
x, y = to_device(x, self.device), to_device(y, self.device)
|
||||
|
||||
self.mutator.reset()
|
||||
with torch.no_grad():
|
||||
logits = self.model(x)
|
||||
self._write_graph_status()
|
||||
metrics = self.metrics(logits, y)
|
||||
reward = self.reward_function(logits, y)
|
||||
if self.entropy_weight:
|
||||
reward += self.entropy_weight * self.mutator.sample_entropy.item()
|
||||
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
|
||||
loss = self.mutator.sample_log_prob * (reward - self.baseline)
|
||||
if self.skip_weight:
|
||||
loss += self.skip_weight * self.mutator.sample_skip_penalty
|
||||
metrics["reward"] = reward
|
||||
metrics["loss"] = loss.item()
|
||||
metrics["ent"] = self.mutator.sample_entropy.item()
|
||||
metrics["log_prob"] = self.mutator.sample_log_prob.item()
|
||||
metrics["baseline"] = self.baseline
|
||||
metrics["skip"] = self.mutator.sample_skip_penalty
|
||||
|
||||
loss /= self.mutator_steps_aggregate
|
||||
loss.backward()
|
||||
meters.update(metrics)
|
||||
|
||||
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
|
||||
if self.log_frequency is not None and cur_step % self.log_frequency == 0:
|
||||
logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs,
|
||||
mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
|
||||
meters)
|
||||
|
||||
nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
|
||||
self.mutator_optim.step()
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
with torch.no_grad():
|
||||
for arc_id in range(self.test_arc_per_epoch):
|
||||
meters = AverageMeterGroup()
|
||||
for x, y in self.test_loader:
|
||||
x, y = to_device(x, self.device), to_device(y, self.device)
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
if isinstance(logits, tuple):
|
||||
logits, _ = logits
|
||||
metrics = self.metrics(logits, y)
|
||||
loss = self.loss(logits, y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
|
||||
logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
|
||||
epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch,
|
||||
meters.summary())
|
|
@ -1,14 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .mutator import FBNetMutator # noqa: F401
|
||||
from .trainer import FBNetTrainer # noqa: F401
|
||||
from .utils import ( # noqa: F401
|
||||
LookUpTable,
|
||||
NASConfig,
|
||||
RegularizerLoss,
|
||||
model_init,
|
||||
supernet_sample,
|
||||
)
|
|
@ -1,268 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from nni.nas.pytorch.base_mutator import BaseMutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
"""
|
||||
This class is to instantiate and manage info of one LayerChoice.
|
||||
It includes architecture weights and member functions for the weights.
|
||||
"""
|
||||
|
||||
def __init__(self, mutable, latency):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
A LayerChoice in user model
|
||||
latency : List
|
||||
performance cost for each op in mutable
|
||||
"""
|
||||
super(MixedOp, self).__init__()
|
||||
self.latency = latency
|
||||
n_choices = len(mutable)
|
||||
self.path_alpha = nn.Parameter(
|
||||
torch.FloatTensor([1.0 / n_choices for i in range(n_choices)])
|
||||
)
|
||||
self.path_alpha.requires_grad = False
|
||||
self.temperature = 1.0
|
||||
|
||||
def get_path_alpha(self):
|
||||
"""Return the architecture parameter."""
|
||||
return self.path_alpha
|
||||
|
||||
def get_weighted_latency(self):
|
||||
"""Return the weighted perf_cost of current mutable."""
|
||||
soft_masks = self.probs_over_ops()
|
||||
weighted_latency = sum(m * l for m, l in zip(soft_masks, self.latency))
|
||||
return weighted_latency
|
||||
|
||||
def set_temperature(self, temperature):
|
||||
"""
|
||||
Set the annealed temperature for gumbel softmax.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
temperature : float
|
||||
The annealed temperature for gumbel softmax
|
||||
"""
|
||||
self.temperature = temperature
|
||||
|
||||
def to_requires_grad(self):
|
||||
"""Enable gradient calculation."""
|
||||
self.path_alpha.requires_grad = True
|
||||
|
||||
def to_disable_grad(self):
|
||||
"""Disable gradient calculation."""
|
||||
self.path_alpha.requires_grad = False
|
||||
|
||||
def probs_over_ops(self):
|
||||
"""Apply gumbel softmax to generate probability distribution."""
|
||||
return F.gumbel_softmax(self.path_alpha, self.temperature)
|
||||
|
||||
def forward(self, mutable, x):
|
||||
"""
|
||||
Define forward of LayerChoice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
this layer's mutable
|
||||
x : tensor
|
||||
inputs of this layer, only support one input
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: tensor
|
||||
output of this layer
|
||||
"""
|
||||
candidate_ops = list(mutable)
|
||||
soft_masks = self.probs_over_ops()
|
||||
output = sum(m * op(x) for m, op in zip(soft_masks, candidate_ops))
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def chosen_index(self):
|
||||
"""
|
||||
choose the op with max prob
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index of the chosen one
|
||||
"""
|
||||
alphas = self.path_alpha.data.detach().cpu().numpy()
|
||||
index = int(np.argmax(alphas))
|
||||
return index
|
||||
|
||||
|
||||
class FBNetMutator(BaseMutator):
|
||||
"""
|
||||
This mutator initializes and operates all the LayerChoices of the supernet.
|
||||
It is for the related trainer to control the training flow of LayerChoices,
|
||||
coordinating with whole training process.
|
||||
"""
|
||||
|
||||
def __init__(self, model, lookup_table):
|
||||
"""
|
||||
Init a MixedOp instance for each mutable i.e., LayerChoice.
|
||||
And register the instantiated MixedOp in corresponding LayerChoice.
|
||||
If does not register it in LayerChoice, DataParallel does'nt work then,
|
||||
for architecture weights are not included in the DataParallel model.
|
||||
When MixedOPs are registered, we use ```requires_grad``` to control
|
||||
whether calculate gradients of architecture weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
The model that users want to tune,
|
||||
it includes search space defined with nni nas apis
|
||||
lookup_table : class
|
||||
lookup table object to manage model space information,
|
||||
including candidate ops for each stage as the model space,
|
||||
input channels/output channels/stride/fm_size as the layer config,
|
||||
and the performance information for perf_cost accumulation.
|
||||
|
||||
"""
|
||||
super(FBNetMutator, self).__init__(model)
|
||||
self.mutable_list = []
|
||||
|
||||
# Collect the op names of the candidate ops within each mutable
|
||||
ops_names_mutable = dict()
|
||||
left = 0
|
||||
right = 1
|
||||
for stage_name in lookup_table.layer_num:
|
||||
right = lookup_table.layer_num[stage_name]
|
||||
stage_ops = lookup_table.lut_ops[stage_name]
|
||||
ops_names = [op_name for op_name in stage_ops]
|
||||
|
||||
for i in range(left, left + right):
|
||||
ops_names_mutable[i] = ops_names
|
||||
left += right
|
||||
|
||||
# Create the mixed op
|
||||
for i, mutable in enumerate(self.undedup_mutables):
|
||||
ops_names = ops_names_mutable[i]
|
||||
latency_mutable = lookup_table.lut_perf[i]
|
||||
latency = [latency_mutable[op_name] for op_name in ops_names]
|
||||
self.mutable_list.append(mutable)
|
||||
mutable.registered_module = MixedOp(mutable, latency)
|
||||
|
||||
def on_forward_layer_choice(self, mutable, *args, **kwargs):
|
||||
"""
|
||||
Callback of layer choice forward. This function defines the forward
|
||||
logic of the input mutable. So mutable is only interface, its real
|
||||
implementation is defined in mutator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable: LayerChoice
|
||||
forward logic of this input mutable
|
||||
args: list of torch.Tensor
|
||||
inputs of this mutable
|
||||
kwargs: dict
|
||||
inputs of this mutable
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
output of this mutable, i.e., LayerChoice
|
||||
int
|
||||
index of the chosen op
|
||||
"""
|
||||
# FIXME: return mask, to be consistent with other algorithms
|
||||
idx = mutable.registered_module.chosen_index
|
||||
return mutable.registered_module(mutable, *args, **kwargs), idx
|
||||
|
||||
def num_arch_params(self):
|
||||
"""
|
||||
The number of mutables, i.e., LayerChoice
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
the number of LayerChoice in user model
|
||||
"""
|
||||
return len(self.mutable_list)
|
||||
|
||||
def get_architecture_parameters(self):
|
||||
"""
|
||||
Get all the architecture parameters.
|
||||
|
||||
yield
|
||||
-----
|
||||
PyTorch Parameter
|
||||
Return path_alpha of the traversed mutable
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
yield mutable.registered_module.get_path_alpha()
|
||||
|
||||
def get_weighted_latency(self):
|
||||
"""
|
||||
Get the latency weighted by gumbel softmax coefficients.
|
||||
|
||||
yield
|
||||
-----
|
||||
Tuple
|
||||
Return the weighted_latency of the traversed mutable
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
yield mutable.registered_module.get_weighted_latency()
|
||||
|
||||
def set_temperature(self, temperature):
|
||||
"""
|
||||
Set the annealed temperature of the op for gumbel softmax.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
temperature : float
|
||||
The annealed temperature for gumbel softmax
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.set_temperature(temperature)
|
||||
|
||||
def arch_requires_grad(self):
|
||||
"""
|
||||
Make architecture weights require gradient
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_requires_grad()
|
||||
|
||||
def arch_disable_grad(self):
|
||||
"""
|
||||
Disable gradient of architecture weights, i.e., does not
|
||||
calculate gradient for them.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_disable_grad()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Generate the final chosen architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the choice of each mutable, i.e., LayerChoice
|
||||
"""
|
||||
result = dict()
|
||||
for mutable in self.undedup_mutables:
|
||||
assert isinstance(mutable, LayerChoice)
|
||||
index = mutable.registered_module.chosen_index
|
||||
# pylint: disable=not-callable
|
||||
result[mutable.key] = (
|
||||
F.one_hot(torch.tensor(index), num_classes=len(mutable))
|
||||
.view(-1)
|
||||
.bool(),
|
||||
)
|
||||
return result
|
|
@ -1,413 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.autograd import Variable
|
||||
from nni.nas.pytorch.base_trainer import BaseTrainer
|
||||
from nni.nas.pytorch.trainer import TorchTensorEncoder
|
||||
from nni.nas.pytorch.utils import AverageMeter
|
||||
from .mutator import FBNetMutator
|
||||
from .utils import RegularizerLoss, accuracy
|
||||
|
||||
|
||||
class FBNetTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model_optim,
|
||||
criterion,
|
||||
device,
|
||||
device_ids,
|
||||
lookup_table,
|
||||
train_loader,
|
||||
valid_loader,
|
||||
n_epochs=120,
|
||||
load_ckpt=False,
|
||||
arch_path=None,
|
||||
logger=None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
the user model, which has mutables
|
||||
model_optim : pytorch optimizer
|
||||
the user defined optimizer
|
||||
criterion : pytorch loss
|
||||
the main task loss, nn.CrossEntropyLoss() is for classification
|
||||
device : pytorch device
|
||||
the devices to train/search the model
|
||||
device_ids : list of int
|
||||
the indexes of devices used for training
|
||||
lookup_table : class
|
||||
lookup table object for fbnet training
|
||||
train_loader : pytorch data loader
|
||||
data loader for the training set
|
||||
valid_loader : pytorch data loader
|
||||
data loader for the validation set
|
||||
n_epochs : int
|
||||
number of epochs to train/search
|
||||
load_ckpt : bool
|
||||
whether load checkpoint
|
||||
arch_path : str
|
||||
the path to store chosen architecture
|
||||
logger : logger
|
||||
the logger
|
||||
"""
|
||||
self.model = model
|
||||
self.model_optim = model_optim
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
self.device = device
|
||||
self.dev_num = len(device_ids)
|
||||
self.n_epochs = n_epochs
|
||||
self.lookup_table = lookup_table
|
||||
self.config = lookup_table.config
|
||||
self.start_epoch = self.config.start_epoch
|
||||
self.temp = self.config.init_temperature
|
||||
self.exp_anneal_rate = self.config.exp_anneal_rate
|
||||
self.mode = self.config.mode
|
||||
|
||||
self.load_ckpt = load_ckpt
|
||||
self.arch_path = arch_path
|
||||
self.logger = logger
|
||||
|
||||
# scheduler of learning rate
|
||||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
model_optim, T_max=n_epochs, last_epoch=-1
|
||||
)
|
||||
|
||||
# init mutator
|
||||
self.mutator = FBNetMutator(model, lookup_table)
|
||||
self.mutator.set_temperature(self.temp)
|
||||
|
||||
# DataParallel should be put behind the init of mutator
|
||||
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
|
||||
self.model.to(device)
|
||||
|
||||
# build architecture optimizer
|
||||
self.arch_optimizer = torch.optim.AdamW(
|
||||
self.mutator.get_architecture_parameters(),
|
||||
self.config.nas_lr,
|
||||
weight_decay=self.config.nas_weight_decay,
|
||||
)
|
||||
self.reg_loss = RegularizerLoss(config=self.config)
|
||||
|
||||
self.criterion = criterion
|
||||
self.epoch = 0
|
||||
|
||||
def _layer_choice_sample(self):
|
||||
"""
|
||||
Sample the index of network within layer choice
|
||||
"""
|
||||
stages = [stage_name for stage_name in self.lookup_table.layer_num]
|
||||
stage_lnum = [self.lookup_table.layer_num[stage] for stage in stages]
|
||||
|
||||
# get the choice idx in each layer
|
||||
choice_ids = list()
|
||||
layer_id = 0
|
||||
for param in self.mutator.get_architecture_parameters():
|
||||
param_np = param.cpu().detach().numpy()
|
||||
op_idx = np.argmax(param_np)
|
||||
choice_ids.append(op_idx)
|
||||
self.logger.info(
|
||||
"layer {}: {}, index: {}".format(layer_id, param_np, op_idx)
|
||||
)
|
||||
layer_id += 1
|
||||
|
||||
# get the arch_sample
|
||||
choice_names = list()
|
||||
layer_id = 0
|
||||
for i, stage_name in enumerate(stages):
|
||||
ops_names = [op for op in self.lookup_table.lut_ops[stage_name]]
|
||||
for _ in range(stage_lnum[i]):
|
||||
searched_op = ops_names[choice_ids[layer_id]]
|
||||
choice_names.append(searched_op)
|
||||
layer_id += 1
|
||||
|
||||
self.logger.info(choice_names)
|
||||
return choice_names
|
||||
|
||||
def _get_perf_cost(self, requires_grad=True):
|
||||
"""
|
||||
Get the accumulated performance cost.
|
||||
"""
|
||||
perf_cost = Variable(
|
||||
torch.zeros(1), requires_grad=requires_grad
|
||||
).to(self.device, non_blocking=True)
|
||||
|
||||
for latency in self.mutator.get_weighted_latency():
|
||||
perf_cost = perf_cost + latency
|
||||
|
||||
return perf_cost
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Do validation. During validation, LayerChoices use the mixed-op.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float, float, float
|
||||
average loss, average top1 accuracy, average top5 accuracy
|
||||
"""
|
||||
self.valid_loader.batch_sampler.drop_last = False
|
||||
batch_time = AverageMeter("batch_time")
|
||||
losses = AverageMeter("losses")
|
||||
top1 = AverageMeter("top1")
|
||||
top5 = AverageMeter("top5")
|
||||
|
||||
# test on validation set under eval mode
|
||||
self.model.eval()
|
||||
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for i, (images, labels) in enumerate(self.valid_loader):
|
||||
images = images.to(self.device, non_blocking=True)
|
||||
labels = labels.to(self.device, non_blocking=True)
|
||||
|
||||
output = self.model(images)
|
||||
|
||||
loss = self.criterion(output, labels)
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0 or i + 1 == len(self.valid_loader):
|
||||
test_log = (
|
||||
"Valid" + ": [{0}/{1}]\t"
|
||||
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"Loss {loss.val:.4f} ({loss.avg:.4f})\t"
|
||||
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
|
||||
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})".format(
|
||||
i,
|
||||
len(self.valid_loader) - 1,
|
||||
batch_time=batch_time,
|
||||
loss=losses,
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
)
|
||||
)
|
||||
self.logger.info(test_log)
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
||||
def _train_epoch(self, epoch, optimizer, arch_train=False):
|
||||
"""
|
||||
Train one epoch.
|
||||
"""
|
||||
batch_time = AverageMeter("batch_time")
|
||||
data_time = AverageMeter("data_time")
|
||||
losses = AverageMeter("losses")
|
||||
top1 = AverageMeter("top1")
|
||||
top5 = AverageMeter("top5")
|
||||
|
||||
# switch to train mode
|
||||
self.model.train()
|
||||
|
||||
data_loader = self.valid_loader if arch_train else self.train_loader
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
data_time.update(time.time() - end)
|
||||
images = images.to(self.device, non_blocking=True)
|
||||
labels = labels.to(self.device, non_blocking=True)
|
||||
|
||||
output = self.model(images)
|
||||
loss = self.criterion(output, labels)
|
||||
|
||||
# hardware-aware loss
|
||||
perf_cost = self._get_perf_cost(requires_grad=True)
|
||||
regu_loss = self.reg_loss(perf_cost)
|
||||
if self.mode.startswith("mul"):
|
||||
loss = loss * regu_loss
|
||||
elif self.mode.startswith("add"):
|
||||
loss = loss + regu_loss
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0].item(), images.size(0))
|
||||
top5.update(acc5[0].item(), images.size(0))
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0:
|
||||
batch_log = (
|
||||
"Warmup Train [{0}][{1}]\t"
|
||||
"Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
|
||||
"Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
|
||||
"Loss {losses.val:.4f} ({losses.avg:.4f})\t"
|
||||
"Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
|
||||
"Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\t".format(
|
||||
epoch + 1,
|
||||
i,
|
||||
batch_time=batch_time,
|
||||
data_time=data_time,
|
||||
losses=losses,
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
)
|
||||
)
|
||||
self.logger.info(batch_log)
|
||||
|
||||
def _warm_up(self):
|
||||
"""
|
||||
Warm up the model, while the architecture weights are not trained.
|
||||
"""
|
||||
for epoch in range(self.epoch, self.start_epoch):
|
||||
self.logger.info("\n--------Warmup epoch: %d--------\n", epoch + 1)
|
||||
self._train_epoch(epoch, self.model_optim)
|
||||
# adjust learning rate
|
||||
self.scheduler.step()
|
||||
|
||||
# validation
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = (
|
||||
"Warmup Valid [{0}/{1}]\t"
|
||||
"loss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}".format(
|
||||
epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5
|
||||
)
|
||||
)
|
||||
self.logger.info(val_log)
|
||||
|
||||
if epoch % 10 == 0:
|
||||
filename = os.path.join(
|
||||
self.config.model_dir, "checkpoint_%s.pth" % epoch
|
||||
)
|
||||
self.save_checkpoint(epoch, filename)
|
||||
|
||||
def _train(self):
|
||||
"""
|
||||
Train the model, it trains model weights and architecute weights.
|
||||
Architecture weights are trained according to the schedule.
|
||||
Before updating architecture weights, ```requires_grad``` is enabled.
|
||||
Then, it is disabled after the updating, in order not to update
|
||||
architecture weights when training model weights.
|
||||
"""
|
||||
arch_param_num = self.mutator.num_arch_params()
|
||||
self.logger.info("#arch_params: {}".format(arch_param_num))
|
||||
self.epoch = max(self.start_epoch, self.epoch)
|
||||
|
||||
ckpt_path = self.config.model_dir
|
||||
choice_names = None
|
||||
top1_best = 0.0
|
||||
|
||||
for epoch in range(self.epoch, self.n_epochs):
|
||||
self.logger.info("\n--------Train epoch: %d--------\n", epoch + 1)
|
||||
# update the weight parameters
|
||||
self._train_epoch(epoch, self.model_optim)
|
||||
# adjust learning rate
|
||||
self.scheduler.step()
|
||||
|
||||
self.logger.info("Update architecture parameters")
|
||||
# update the architecture parameters
|
||||
self.mutator.arch_requires_grad()
|
||||
self._train_epoch(epoch, self.arch_optimizer, True)
|
||||
self.mutator.arch_disable_grad()
|
||||
# temperature annealing
|
||||
self.temp = self.temp * self.exp_anneal_rate
|
||||
self.mutator.set_temperature(self.temp)
|
||||
# sample the architecture of sub-network
|
||||
choice_names = self._layer_choice_sample()
|
||||
|
||||
# validate
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = (
|
||||
"Valid [{0}]\t"
|
||||
"loss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}".format(
|
||||
epoch + 1, val_loss, val_top1, val_top5
|
||||
)
|
||||
)
|
||||
self.logger.info(val_log)
|
||||
|
||||
if epoch % 10 == 0:
|
||||
filename = os.path.join(ckpt_path, "checkpoint_%s.pth" % epoch)
|
||||
self.save_checkpoint(epoch, filename, choice_names)
|
||||
|
||||
val_top1 = val_top1.cpu().as_numpy()
|
||||
if val_top1 > top1_best:
|
||||
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
|
||||
self.save_checkpoint(epoch, filename, choice_names)
|
||||
top1_best = val_top1
|
||||
|
||||
def save_checkpoint(self, epoch, filename, choice_names=None):
|
||||
"""
|
||||
Save checkpoint of the whole model.
|
||||
Saving model weights and architecture weights as ```filename```,
|
||||
and saving currently chosen architecture in ```arch_path```.
|
||||
"""
|
||||
state = {
|
||||
"model": self.model.state_dict(),
|
||||
"optim": self.model_optim.state_dict(),
|
||||
"epoch": epoch,
|
||||
"arch_sample": choice_names,
|
||||
}
|
||||
torch.save(state, filename)
|
||||
self.logger.info("Save checkpoint to {0:}".format(filename))
|
||||
|
||||
if self.arch_path:
|
||||
self.export(self.arch_path)
|
||||
|
||||
def load_checkpoint(self, filename):
|
||||
"""
|
||||
Load the checkpoint from ```ckpt_path```.
|
||||
"""
|
||||
ckpt = torch.load(filename)
|
||||
self.epoch = ckpt["epoch"]
|
||||
self.model.load_state_dict(ckpt["model"])
|
||||
self.model_optim.load_state_dict(ckpt["optim"])
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Train the whole model.
|
||||
"""
|
||||
if self.load_ckpt:
|
||||
ckpt_path = self.config.model_dir
|
||||
filename = os.path.join(ckpt_path, "checkpoint_best.pth")
|
||||
if os.path.exists(filename):
|
||||
self.load_checkpoint(filename)
|
||||
|
||||
if self.epoch < self.start_epoch:
|
||||
self._warm_up()
|
||||
self._train()
|
||||
|
||||
def export(self, file_name):
|
||||
"""
|
||||
Export the chosen architecture into a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name : str
|
||||
the file that stores exported chosen architecture
|
||||
"""
|
||||
exported_arch = self.mutator.sample_final()
|
||||
with open(file_name, "w") as f:
|
||||
json.dump(
|
||||
exported_arch,
|
||||
f,
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
cls=TorchTensorEncoder,
|
||||
)
|
||||
|
||||
def validate(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def checkpoint(self):
|
||||
raise NotImplementedError
|
|
@ -1,433 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import ast
|
||||
import os
|
||||
import timeit
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.compression.pytorch.utils import count_flops_params
|
||||
|
||||
LUT_FILE = "lut.npy"
|
||||
LUT_JSON_FILE = "lut.txt"
|
||||
LUT_PATH = "lut"
|
||||
|
||||
DATA_TYPE = "float"
|
||||
|
||||
class NASConfig:
|
||||
def __init__(
|
||||
self,
|
||||
perf_metric="flops",
|
||||
lut_load=False,
|
||||
lut_load_format="json",
|
||||
model_dir=None,
|
||||
nas_lr=0.01,
|
||||
nas_weight_decay=5e-4,
|
||||
mode="mul",
|
||||
alpha=0.25,
|
||||
beta=0.6,
|
||||
start_epoch=50,
|
||||
init_temperature=5.0,
|
||||
exp_anneal_rate=np.exp(-0.045),
|
||||
search_space=None,
|
||||
):
|
||||
# LUT of performance metric
|
||||
# flops means the multiplies, latency means the time cost on platform
|
||||
self.perf_metric = perf_metric
|
||||
assert perf_metric in [
|
||||
"flops",
|
||||
"latency",
|
||||
], "perf_metric should be ['flops', 'latency']"
|
||||
# wether load or create lut file
|
||||
self.lut_load = lut_load
|
||||
|
||||
assert lut_load_format in [
|
||||
"json",
|
||||
"numpy",
|
||||
], "lut_load_format should be ['json', 'numpy']"
|
||||
self.lut_load_format = lut_load_format
|
||||
|
||||
# necessary dirs
|
||||
self.lut_en = model_dir is not None
|
||||
if self.lut_en:
|
||||
self.model_dir = model_dir
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
self.lut_path = os.path.join(model_dir, LUT_PATH)
|
||||
os.makedirs(self.lut_path, exist_ok=True)
|
||||
# NAS learning setting
|
||||
self.nas_lr = nas_lr
|
||||
self.nas_weight_decay = nas_weight_decay
|
||||
# hardware-aware loss setting
|
||||
self.mode = mode
|
||||
assert mode in ["mul", "add"], "mode should be ['mul', 'add']"
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
# NAS training setting
|
||||
self.start_epoch = start_epoch
|
||||
self.init_temperature = init_temperature
|
||||
self.exp_anneal_rate = exp_anneal_rate
|
||||
# definition of search blocks and space
|
||||
self.search_space = search_space
|
||||
|
||||
|
||||
class RegularizerLoss(nn.Module):
|
||||
"""Auxilliary loss for hardware-aware NAS."""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : class
|
||||
to manage the configuration for NAS training, and search space etc.
|
||||
"""
|
||||
super(RegularizerLoss, self).__init__()
|
||||
self.mode = config.mode
|
||||
self.alpha = config.alpha
|
||||
self.beta = config.beta
|
||||
|
||||
def forward(self, perf_cost, batch_size=1):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
perf_cost : tensor
|
||||
the accumulated performance cost
|
||||
batch_size : int
|
||||
batch size for normalization
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: tensor
|
||||
the hardware-aware constraint loss
|
||||
"""
|
||||
if self.mode == "mul":
|
||||
log_loss = torch.log(perf_cost / batch_size) ** self.beta
|
||||
return self.alpha * log_loss
|
||||
elif self.mode == "add":
|
||||
linear_loss = (perf_cost / batch_size) ** self.beta
|
||||
return self.alpha * linear_loss
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""
|
||||
Computes the precision@k for the specified values of k
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : pytorch tensor
|
||||
output, e.g., predicted value
|
||||
target : pytorch tensor
|
||||
label
|
||||
topk : tuple
|
||||
specify top1 and top5
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
accuracy of top1 and top5
|
||||
"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def supernet_sample(model, state_dict, sampled_arch=[], lookup_table=None):
|
||||
"""
|
||||
Initialize the searched sub-model from supernet.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
the created subnet
|
||||
state_dict : checkpoint
|
||||
the checkpoint of supernet, including the pre-trained params
|
||||
sampled_arch : list of str
|
||||
the searched layer names of the subnet
|
||||
lookup_table : class
|
||||
to manage the candidate ops, layer information and layer performance
|
||||
"""
|
||||
replace = list()
|
||||
stages = [stage for stage in lookup_table.layer_num]
|
||||
stage_lnum = [lookup_table.layer_num[stage] for stage in stages]
|
||||
|
||||
if sampled_arch:
|
||||
layer_id = 0
|
||||
for i, stage in enumerate(stages):
|
||||
ops_names = [op_name for op_name in lookup_table.lut_ops[stage]]
|
||||
for _ in range(stage_lnum[i]):
|
||||
searched_op = sampled_arch[layer_id]
|
||||
op_i = ops_names.index(searched_op)
|
||||
replace.append(
|
||||
[
|
||||
"blocks.{}.".format(layer_id),
|
||||
"blocks.{}.op.".format(layer_id),
|
||||
"blocks.{}.{}.".format(layer_id, op_i),
|
||||
]
|
||||
)
|
||||
layer_id += 1
|
||||
model_init(model, state_dict, replace=replace)
|
||||
|
||||
|
||||
def model_init(model, state_dict, replace=[]):
|
||||
"""Initialize the model from state_dict."""
|
||||
prefix = "module."
|
||||
param_dict = dict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith(prefix):
|
||||
k = k[7:]
|
||||
param_dict[k] = v
|
||||
|
||||
for k, (name, m) in enumerate(model.named_modules()):
|
||||
if replace:
|
||||
for layer_replace in replace:
|
||||
assert len(layer_replace) == 3, "The elements should be three."
|
||||
pre_scope, key, replace_key = layer_replace
|
||||
if pre_scope in name:
|
||||
name = name.replace(key, replace_key)
|
||||
|
||||
# Copy the state_dict to current model
|
||||
if (name + ".weight" in param_dict) or (
|
||||
name + ".running_mean" in param_dict
|
||||
):
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
shape = m.running_mean.shape
|
||||
if shape == param_dict[name + ".running_mean"].shape:
|
||||
if m.weight is not None:
|
||||
m.weight.data = param_dict[name + ".weight"]
|
||||
m.bias.data = param_dict[name + ".bias"]
|
||||
m.running_mean = param_dict[name + ".running_mean"]
|
||||
m.running_var = param_dict[name + ".running_var"]
|
||||
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
shape = m.weight.data.shape
|
||||
if shape == param_dict[name + ".weight"].shape:
|
||||
m.weight.data = param_dict[name + ".weight"]
|
||||
if m.bias is not None:
|
||||
m.bias.data = param_dict[name + ".bias"]
|
||||
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
m.weight.data = param_dict[name + ".weight"]
|
||||
if m.bias is not None:
|
||||
m.bias.data = param_dict[name + ".bias"]
|
||||
|
||||
|
||||
class LookUpTable:
|
||||
"""Build look-up table for NAS."""
|
||||
|
||||
def __init__(self, config, primitives):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
config : class
|
||||
to manage the configuration for NAS training, and search space etc.
|
||||
"""
|
||||
self.config = config
|
||||
# definition of search blocks and space
|
||||
self.search_space = config.search_space
|
||||
# layers for NAS
|
||||
self.cnt_layers = len(self.search_space["input_shape"])
|
||||
# constructors for each operation
|
||||
self.lut_ops = {
|
||||
stage_name: {
|
||||
op_name: primitives[op_name]
|
||||
for op_name in self.search_space["stages"][stage_name]["ops"]
|
||||
}
|
||||
for stage_name in self.search_space["stages"]
|
||||
}
|
||||
self.layer_num = {
|
||||
stage_name: self.search_space["stages"][stage_name]["layer_num"]
|
||||
for stage_name in self.search_space["stages"]
|
||||
}
|
||||
|
||||
# arguments for the ops constructors, input_shapes just for convinience
|
||||
self.layer_configs, self.layer_in_shapes = self._layer_configs()
|
||||
|
||||
# lookup_table
|
||||
self.perf_metric = config.perf_metric
|
||||
|
||||
if config.lut_en:
|
||||
self.lut_perf = None
|
||||
self.lut_file = os.path.join(config.lut_path, LUT_FILE)
|
||||
self.lut_json_file = LUT_JSON_FILE
|
||||
if config.lut_load:
|
||||
if config.lut_load_format == "numpy":
|
||||
# Load data from numpy file
|
||||
self._load_from_file()
|
||||
else:
|
||||
# Load data from json file
|
||||
self._load_from_json_file()
|
||||
else:
|
||||
self._create_perfs()
|
||||
|
||||
def _layer_configs(self):
|
||||
"""Generate basic params for different layers."""
|
||||
# layer_configs are : c_in, c_out, stride, fm_size
|
||||
layer_configs = [
|
||||
[
|
||||
self.search_space["input_shape"][layer_id][0],
|
||||
self.search_space["channel_size"][layer_id],
|
||||
self.search_space["strides"][layer_id],
|
||||
self.search_space["fm_size"][layer_id],
|
||||
]
|
||||
for layer_id in range(self.cnt_layers)
|
||||
]
|
||||
|
||||
# layer_in_shapes are (C_in, input_w, input_h)
|
||||
layer_in_shapes = self.search_space["input_shape"]
|
||||
|
||||
return layer_configs, layer_in_shapes
|
||||
|
||||
def _create_perfs(self, cnt_of_runs=200):
|
||||
"""Create performance cost for each op."""
|
||||
if self.perf_metric == "latency":
|
||||
self.lut_perf = self._calculate_latency(cnt_of_runs)
|
||||
elif self.perf_metric == "flops":
|
||||
self.lut_perf = self._calculate_flops()
|
||||
|
||||
self._write_lut_to_file()
|
||||
|
||||
def _calculate_flops(self, eps=0.001):
|
||||
"""FLOPs cost."""
|
||||
flops_lut = [{} for i in range(self.cnt_layers)]
|
||||
layer_id = 0
|
||||
|
||||
for stage_name in self.lut_ops:
|
||||
stage_ops = self.lut_ops[stage_name]
|
||||
ops_num = self.layer_num[stage_name]
|
||||
|
||||
for _ in range(ops_num):
|
||||
for op_name in stage_ops:
|
||||
layer_config = self.layer_configs[layer_id]
|
||||
key_params = {"fm_size": layer_config[3]}
|
||||
op = stage_ops[op_name](*layer_config[0:3], **key_params)
|
||||
|
||||
# measured in Flops
|
||||
in_shape = self.layer_in_shapes[layer_id]
|
||||
x = (1, in_shape[0], in_shape[1], in_shape[2])
|
||||
flops, _, _ = count_flops_params(op, x, verbose=False)
|
||||
flops = eps if flops == 0.0 else flops
|
||||
flops_lut[layer_id][op_name] = float(flops)
|
||||
layer_id += 1
|
||||
|
||||
return flops_lut
|
||||
|
||||
def _calculate_latency(self, cnt_of_runs):
|
||||
"""Latency cost."""
|
||||
LATENCY_BATCH_SIZE = 1
|
||||
latency_lut = [{} for i in range(self.cnt_layers)]
|
||||
layer_id = 0
|
||||
|
||||
for stage_name in self.lut_ops:
|
||||
stage_ops = self.lut_ops[stage_name]
|
||||
ops_num = self.layer_num[stage_name]
|
||||
|
||||
for _ in range(ops_num):
|
||||
for op_name in stage_ops:
|
||||
layer_config = self.layer_configs[layer_id]
|
||||
key_params = {"fm_size": layer_config[3]}
|
||||
op = stage_ops[op_name](*layer_config[0:3], **key_params)
|
||||
input_data = torch.randn(
|
||||
(LATENCY_BATCH_SIZE, *self.layer_in_shapes[layer_id])
|
||||
)
|
||||
globals()["op"], globals()["input_data"] = op, input_data
|
||||
total_time = timeit.timeit(
|
||||
"output = op(input_data)",
|
||||
setup="gc.enable()",
|
||||
globals=globals(),
|
||||
number=cnt_of_runs,
|
||||
)
|
||||
# measured in micro-second
|
||||
latency_lut[layer_id][op_name] = (
|
||||
total_time / cnt_of_runs / LATENCY_BATCH_SIZE * 1e6
|
||||
)
|
||||
layer_id += 1
|
||||
|
||||
return latency_lut
|
||||
|
||||
def _write_lut_to_file(self):
|
||||
"""Save lut as numpy file."""
|
||||
np.save(self.lut_file, self.lut_perf)
|
||||
|
||||
def _load_from_file(self):
|
||||
"""Load numpy file."""
|
||||
self.lut_perf = np.load(self.lut_file, allow_pickle=True)
|
||||
|
||||
def _load_from_json_file(self):
|
||||
"""Load json file."""
|
||||
|
||||
"""
|
||||
lut_json_file ('lut.txt') format:
|
||||
{'op_name': operator_name,
|
||||
'op_data_shape': (input_w, input_h, C_in, C_out, stride),
|
||||
'op_dtype': data_type,
|
||||
'op_latency': latency}
|
||||
{...}
|
||||
{...}
|
||||
"""
|
||||
latency_file = open(self.lut_json_file, "r")
|
||||
ops_latency = latency_file.readlines()
|
||||
|
||||
"""ops_lut: {'op_name': {'op_data_shape': {'op_dtype': latency}}}"""
|
||||
ops_lut = {}
|
||||
|
||||
for op_latency in ops_latency:
|
||||
assert isinstance(op_latency, str) or isinstance(op_latency, dict)
|
||||
|
||||
if isinstance(op_latency, str):
|
||||
record = ast.literal_eval(op_latency)
|
||||
elif isinstance(op_latency, dict):
|
||||
record = op_latency
|
||||
|
||||
op_name = record["op_name"]
|
||||
"""op_data_shape: (input_w, input_h, C_in, C_out, stride)"""
|
||||
op_data_shape = record["op_data_shape"]
|
||||
op_dtype = record["op_dtype"]
|
||||
op_latency = record["op_latency"]
|
||||
|
||||
if op_name not in ops_lut:
|
||||
ops_lut[op_name] = {}
|
||||
|
||||
if op_data_shape not in ops_lut[op_name]:
|
||||
ops_lut[op_name][op_data_shape] = {}
|
||||
|
||||
ops_lut[op_name][op_data_shape][op_dtype] = op_latency
|
||||
|
||||
self.lut_perf = [{} for i in range(self.cnt_layers)]
|
||||
layer_id = 0
|
||||
|
||||
for stage_name in self.lut_ops:
|
||||
stage_ops = self.lut_ops[stage_name]
|
||||
ops_num = self.layer_num[stage_name]
|
||||
|
||||
for _ in range(ops_num):
|
||||
for op_name in stage_ops:
|
||||
layer_config = self.layer_configs[layer_id]
|
||||
layer_in_shape = self.layer_in_shapes[layer_id]
|
||||
|
||||
input_w = layer_in_shape[1]
|
||||
input_h = layer_in_shape[2]
|
||||
c_in = layer_config[0]
|
||||
c_out = layer_config[1]
|
||||
stride = layer_config[2]
|
||||
op_data_shape = (input_w, input_h, c_in, c_out, stride)
|
||||
|
||||
if op_name in ops_lut and op_data_shape in ops_lut[op_name]:
|
||||
self.lut_perf[layer_id][op_name] = \
|
||||
ops_lut[op_name][op_data_shape][DATA_TYPE]
|
||||
|
||||
layer_id += 1
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from nni.algorithms.nas.pytorch.darts import DartsMutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
|
||||
|
||||
class PdartsMutator(DartsMutator):
|
||||
"""
|
||||
It works with PdartsTrainer to calculate ops weights,
|
||||
and drop weights in different PDARTS epochs.
|
||||
"""
|
||||
|
||||
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
|
||||
self.pdarts_epoch_index = pdarts_epoch_index
|
||||
self.pdarts_num_to_drop = pdarts_num_to_drop
|
||||
if switches is None:
|
||||
self.switches = {}
|
||||
else:
|
||||
self.switches = switches
|
||||
|
||||
super(PdartsMutator, self).__init__(model)
|
||||
|
||||
# this loop go through mutables with different keys,
|
||||
# it's mainly to update length of choices.
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
|
||||
switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
|
||||
choices = self.choices[mutable.key]
|
||||
|
||||
operations_count = np.sum(switches)
|
||||
# +1 and -1 are caused by zero operation in darts network
|
||||
# the zero operation is not in choices list in network, but its weight are in,
|
||||
# so it needs one more weights and switch for zero.
|
||||
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
|
||||
self.switches[mutable.key] = switches
|
||||
|
||||
# update LayerChoice instances in model,
|
||||
# it's physically remove dropped choices operations.
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, LayerChoice):
|
||||
switches = self.switches.get(module.key)
|
||||
choices = self.choices[module.key]
|
||||
if len(module) > len(choices):
|
||||
# from last to first, so that it won't effect previous indexes after removed one.
|
||||
for index in range(len(switches)-1, -1, -1):
|
||||
if switches[index] == False:
|
||||
del module[index]
|
||||
assert len(module) <= len(choices), "Failed to remove dropped choices."
|
||||
|
||||
def export(self):
|
||||
# Cannot rely on super().export() because P-DARTS has deleted some of the choices and has misaligned length.
|
||||
results = super().sample_final()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
# As some operations are dropped physically,
|
||||
# so it needs to fill back false to track dropped operations.
|
||||
trained_result = results[mutable.key]
|
||||
trained_index = 0
|
||||
switches = self.switches[mutable.key]
|
||||
result = torch.Tensor(switches).bool()
|
||||
for index in range(len(result)):
|
||||
if result[index]:
|
||||
result[index] = trained_result[trained_index]
|
||||
trained_index += 1
|
||||
results[mutable.key] = result
|
||||
return results
|
||||
|
||||
def drop_paths(self):
|
||||
"""
|
||||
This method is called when a PDARTS epoch is finished.
|
||||
It prepares switches for next epoch.
|
||||
candidate operations with False switch will be doppped in next epoch.
|
||||
"""
|
||||
all_switches = copy.deepcopy(self.switches)
|
||||
for key in all_switches:
|
||||
switches = all_switches[key]
|
||||
idxs = []
|
||||
for j in range(len(switches)):
|
||||
if switches[j]:
|
||||
idxs.append(j)
|
||||
sorted_weights = self.choices[key].data.cpu().numpy()[:-1]
|
||||
drop = np.argsort(sorted_weights)[:self.pdarts_num_to_drop[self.pdarts_epoch_index]]
|
||||
for idx in drop:
|
||||
switches[idxs[idx]] = False
|
||||
return all_switches
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from nni.nas.pytorch.callbacks import LRSchedulerCallback
|
||||
from nni.algorithms.nas.pytorch.darts import DartsTrainer
|
||||
from nni.nas.pytorch.trainer import BaseTrainer, TorchTensorEncoder
|
||||
|
||||
from .mutator import PdartsMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdartsTrainer(BaseTrainer):
|
||||
"""
|
||||
This trainer implements the PDARTS algorithm.
|
||||
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
|
||||
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
|
||||
pdarts_num_layers means how many layers more than first epoch.
|
||||
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
|
||||
So that the grew network can in similar size.
|
||||
"""
|
||||
|
||||
def __init__(self, model_creator, init_layers, metrics,
|
||||
num_epochs, dataset_train, dataset_valid,
|
||||
pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 1],
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, unrolled=False):
|
||||
super(PdartsTrainer, self).__init__()
|
||||
self.model_creator = model_creator
|
||||
self.init_layers = init_layers
|
||||
self.pdarts_num_layers = pdarts_num_layers
|
||||
self.pdarts_num_to_drop = pdarts_num_to_drop
|
||||
self.pdarts_epoch = len(pdarts_num_to_drop)
|
||||
self.darts_parameters = {
|
||||
"metrics": metrics,
|
||||
"num_epochs": num_epochs,
|
||||
"dataset_train": dataset_train,
|
||||
"dataset_valid": dataset_valid,
|
||||
"batch_size": batch_size,
|
||||
"workers": workers,
|
||||
"device": device,
|
||||
"log_frequency": log_frequency,
|
||||
"unrolled": unrolled
|
||||
}
|
||||
self.callbacks = callbacks if callbacks is not None else []
|
||||
|
||||
def train(self):
|
||||
|
||||
switches = None
|
||||
for epoch in range(self.pdarts_epoch):
|
||||
|
||||
layers = self.init_layers+self.pdarts_num_layers[epoch]
|
||||
model, criterion, optim, lr_scheduler = self.model_creator(layers)
|
||||
self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.build(model, self.mutator, self)
|
||||
callback.on_epoch_begin(epoch)
|
||||
|
||||
darts_callbacks = []
|
||||
if lr_scheduler is not None:
|
||||
darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
|
||||
|
||||
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
|
||||
callbacks=darts_callbacks, **self.darts_parameters)
|
||||
logger.info("start pdarts training epoch %s...", epoch)
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
switches = self.mutator.drop_paths()
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.on_epoch_end(epoch)
|
||||
|
||||
def validate(self):
|
||||
self.trainer.validate()
|
||||
|
||||
def export(self, file):
|
||||
mutator_export = self.mutator.export()
|
||||
with open(file, "w") as f:
|
||||
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
|
||||
|
||||
def checkpoint(self):
|
||||
raise NotImplementedError("Not implemented yet")
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import ProxylessNasMutator
|
||||
from .trainer import ProxylessNasTrainer
|
|
@ -1,478 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from nni.nas.pytorch.base_mutator import BaseMutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice
|
||||
from .utils import detach_variable
|
||||
|
||||
class ArchGradientFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, binary_gates, run_func, backward_func):
|
||||
ctx.run_func = run_func
|
||||
ctx.backward_func = backward_func
|
||||
|
||||
detached_x = detach_variable(x)
|
||||
with torch.enable_grad():
|
||||
output = run_func(detached_x)
|
||||
ctx.save_for_backward(detached_x, output)
|
||||
return output.data
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
detached_x, output = ctx.saved_tensors
|
||||
|
||||
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
|
||||
# compute gradients w.r.t. binary_gates
|
||||
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
|
||||
|
||||
return grad_x[0], binary_grads, None, None
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
"""
|
||||
This class is to instantiate and manage info of one LayerChoice.
|
||||
It includes architecture weights, binary weights, and member functions
|
||||
operating the weights.
|
||||
|
||||
forward_mode:
|
||||
forward/backward mode for LayerChoice: None, two, full, and full_v2.
|
||||
For training architecture weights, we use full_v2 by default, and for training
|
||||
model weights, we use None.
|
||||
"""
|
||||
forward_mode = None
|
||||
def __init__(self, mutable):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
A LayerChoice in user model
|
||||
"""
|
||||
super(MixedOp, self).__init__()
|
||||
self.ap_path_alpha = nn.Parameter(torch.Tensor(len(mutable)))
|
||||
self.ap_path_wb = nn.Parameter(torch.Tensor(len(mutable)))
|
||||
self.ap_path_alpha.requires_grad = False
|
||||
self.ap_path_wb.requires_grad = False
|
||||
self.active_index = [0]
|
||||
self.inactive_index = None
|
||||
self.log_prob = None
|
||||
self.current_prob_over_ops = None
|
||||
self.n_choices = len(mutable)
|
||||
|
||||
def get_ap_path_alpha(self):
|
||||
return self.ap_path_alpha
|
||||
|
||||
def to_requires_grad(self):
|
||||
self.ap_path_alpha.requires_grad = True
|
||||
self.ap_path_wb.requires_grad = True
|
||||
|
||||
def to_disable_grad(self):
|
||||
self.ap_path_alpha.requires_grad = False
|
||||
self.ap_path_wb.requires_grad = False
|
||||
|
||||
def forward(self, mutable, x):
|
||||
"""
|
||||
Define forward of LayerChoice. For 'full_v2', backward is also defined.
|
||||
The 'two' mode is explained in section 3.2.1 in the paper.
|
||||
The 'full_v2' mode is explained in Appendix D in the paper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
this layer's mutable
|
||||
x : tensor
|
||||
inputs of this layer, only support one input
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: tensor
|
||||
output of this layer
|
||||
"""
|
||||
if MixedOp.forward_mode == 'full' or MixedOp.forward_mode == 'two':
|
||||
output = 0
|
||||
for _i in self.active_index:
|
||||
oi = self.candidate_ops[_i](x)
|
||||
output = output + self.ap_path_wb[_i] * oi
|
||||
for _i in self.inactive_index:
|
||||
oi = self.candidate_ops[_i](x)
|
||||
output = output + self.ap_path_wb[_i] * oi.detach()
|
||||
elif MixedOp.forward_mode == 'full_v2':
|
||||
def run_function(key, candidate_ops, active_id):
|
||||
def forward(_x):
|
||||
return candidate_ops[active_id](_x)
|
||||
return forward
|
||||
|
||||
def backward_function(key, candidate_ops, active_id, binary_gates):
|
||||
def backward(_x, _output, grad_output):
|
||||
binary_grads = torch.zeros_like(binary_gates.data)
|
||||
with torch.no_grad():
|
||||
for k in range(len(candidate_ops)):
|
||||
if k != active_id:
|
||||
out_k = candidate_ops[k](_x.data)
|
||||
else:
|
||||
out_k = _output.data
|
||||
grad_k = torch.sum(out_k * grad_output)
|
||||
binary_grads[k] = grad_k
|
||||
return binary_grads
|
||||
return backward
|
||||
output = ArchGradientFunction.apply(
|
||||
x, self.ap_path_wb, run_function(mutable.key, list(mutable), self.active_index[0]),
|
||||
backward_function(mutable.key, list(mutable), self.active_index[0], self.ap_path_wb))
|
||||
else:
|
||||
output = self.active_op(mutable)(x)
|
||||
return output
|
||||
|
||||
@property
|
||||
def probs_over_ops(self):
|
||||
"""
|
||||
Apply softmax on alpha to generate probability distribution
|
||||
|
||||
Returns
|
||||
-------
|
||||
pytorch tensor
|
||||
probability distribution
|
||||
"""
|
||||
probs = F.softmax(self.ap_path_alpha, dim=0) # softmax to probability
|
||||
return probs
|
||||
|
||||
@property
|
||||
def chosen_index(self):
|
||||
"""
|
||||
choose the op with max prob
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index of the chosen one
|
||||
numpy.float32
|
||||
prob of the chosen one
|
||||
"""
|
||||
probs = self.probs_over_ops.data.cpu().numpy()
|
||||
index = int(np.argmax(probs))
|
||||
return index, probs[index]
|
||||
|
||||
def active_op(self, mutable):
|
||||
"""
|
||||
assume only one path is active
|
||||
|
||||
Returns
|
||||
-------
|
||||
PyTorch module
|
||||
the chosen operation
|
||||
"""
|
||||
return mutable[self.active_index[0]]
|
||||
|
||||
@property
|
||||
def active_op_index(self):
|
||||
"""
|
||||
return active op's index, the active op is sampled
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index of the active op
|
||||
"""
|
||||
return self.active_index[0]
|
||||
|
||||
def set_chosen_op_active(self):
|
||||
"""
|
||||
set chosen index, active and inactive indexes
|
||||
"""
|
||||
chosen_idx, _ = self.chosen_index
|
||||
self.active_index = [chosen_idx]
|
||||
self.inactive_index = [_i for _i in range(0, chosen_idx)] + \
|
||||
[_i for _i in range(chosen_idx + 1, self.n_choices)]
|
||||
|
||||
def binarize(self, mutable):
|
||||
"""
|
||||
Sample based on alpha, and set binary weights accordingly.
|
||||
ap_path_wb is set in this function, which is called binarize.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable : LayerChoice
|
||||
this layer's mutable
|
||||
"""
|
||||
self.log_prob = None
|
||||
# reset binary gates
|
||||
self.ap_path_wb.data.zero_()
|
||||
probs = self.probs_over_ops
|
||||
if MixedOp.forward_mode == 'two':
|
||||
# sample two ops according to probs
|
||||
sample_op = torch.multinomial(probs.data, 2, replacement=False)
|
||||
probs_slice = F.softmax(torch.stack([
|
||||
self.ap_path_alpha[idx] for idx in sample_op
|
||||
]), dim=0)
|
||||
self.current_prob_over_ops = torch.zeros_like(probs)
|
||||
for i, idx in enumerate(sample_op):
|
||||
self.current_prob_over_ops[idx] = probs_slice[i]
|
||||
# choose one to be active and the other to be inactive according to probs_slice
|
||||
c = torch.multinomial(probs_slice.data, 1)[0] # 0 or 1
|
||||
active_op = sample_op[c].item()
|
||||
inactive_op = sample_op[1-c].item()
|
||||
self.active_index = [active_op]
|
||||
self.inactive_index = [inactive_op]
|
||||
# set binary gate
|
||||
self.ap_path_wb.data[active_op] = 1.0
|
||||
else:
|
||||
sample = torch.multinomial(probs, 1)[0].item()
|
||||
self.active_index = [sample]
|
||||
self.inactive_index = [_i for _i in range(0, sample)] + \
|
||||
[_i for _i in range(sample + 1, len(mutable))]
|
||||
self.log_prob = torch.log(probs[sample])
|
||||
self.current_prob_over_ops = probs
|
||||
self.ap_path_wb.data[sample] = 1.0
|
||||
# avoid over-regularization
|
||||
for choice in mutable:
|
||||
for _, param in choice.named_parameters():
|
||||
param.grad = None
|
||||
|
||||
@staticmethod
|
||||
def delta_ij(i, j):
|
||||
if i == j:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def set_arch_param_grad(self, mutable):
|
||||
"""
|
||||
Calculate alpha gradient for this LayerChoice.
|
||||
It is calculated using gradient of binary gate, probs of ops.
|
||||
"""
|
||||
binary_grads = self.ap_path_wb.grad.data
|
||||
if self.active_op(mutable).is_zero_layer():
|
||||
self.ap_path_alpha.grad = None
|
||||
return
|
||||
if self.ap_path_alpha.grad is None:
|
||||
self.ap_path_alpha.grad = torch.zeros_like(self.ap_path_alpha.data)
|
||||
if MixedOp.forward_mode == 'two':
|
||||
involved_idx = self.active_index + self.inactive_index
|
||||
probs_slice = F.softmax(torch.stack([
|
||||
self.ap_path_alpha[idx] for idx in involved_idx
|
||||
]), dim=0).data
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
origin_i = involved_idx[i]
|
||||
origin_j = involved_idx[j]
|
||||
self.ap_path_alpha.grad.data[origin_i] += \
|
||||
binary_grads[origin_j] * probs_slice[j] * (MixedOp.delta_ij(i, j) - probs_slice[i])
|
||||
for _i, idx in enumerate(self.active_index):
|
||||
self.active_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
|
||||
for _i, idx in enumerate(self.inactive_index):
|
||||
self.inactive_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
|
||||
else:
|
||||
probs = self.probs_over_ops.data
|
||||
for i in range(self.n_choices):
|
||||
for j in range(self.n_choices):
|
||||
self.ap_path_alpha.grad.data[i] += binary_grads[j] * probs[j] * (MixedOp.delta_ij(i, j) - probs[i])
|
||||
return
|
||||
|
||||
def rescale_updated_arch_param(self):
|
||||
"""
|
||||
rescale architecture weights for the 'two' mode.
|
||||
"""
|
||||
if not isinstance(self.active_index[0], tuple):
|
||||
assert self.active_op.is_zero_layer()
|
||||
return
|
||||
involved_idx = [idx for idx, _ in (self.active_index + self.inactive_index)]
|
||||
old_alphas = [alpha for _, alpha in (self.active_index + self.inactive_index)]
|
||||
new_alphas = [self.ap_path_alpha.data[idx] for idx in involved_idx]
|
||||
|
||||
offset = math.log(
|
||||
sum([math.exp(alpha) for alpha in new_alphas]) / sum([math.exp(alpha) for alpha in old_alphas])
|
||||
)
|
||||
|
||||
for idx in involved_idx:
|
||||
self.ap_path_alpha.data[idx] -= offset
|
||||
|
||||
|
||||
class ProxylessNasMutator(BaseMutator):
|
||||
"""
|
||||
This mutator initializes and operates all the LayerChoices of the input model.
|
||||
It is for the corresponding trainer to control the training process of LayerChoices,
|
||||
coordinating with whole training process.
|
||||
"""
|
||||
def __init__(self, model):
|
||||
"""
|
||||
Init a MixedOp instance for each mutable i.e., LayerChoice.
|
||||
And register the instantiated MixedOp in corresponding LayerChoice.
|
||||
If does not register it in LayerChoice, DataParallel does not work then,
|
||||
because architecture weights are not included in the DataParallel model.
|
||||
When MixedOPs are registered, we use ```requires_grad``` to control
|
||||
whether calculate gradients of architecture weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
The model that users want to tune, it includes search space defined with nni nas apis
|
||||
"""
|
||||
super(ProxylessNasMutator, self).__init__(model)
|
||||
self._unused_modules = None
|
||||
self.mutable_list = []
|
||||
for mutable in self.undedup_mutables:
|
||||
self.mutable_list.append(mutable)
|
||||
mutable.registered_module = MixedOp(mutable)
|
||||
|
||||
def on_forward_layer_choice(self, mutable, *args, **kwargs):
|
||||
"""
|
||||
Callback of layer choice forward. This function defines the forward
|
||||
logic of the input mutable. So mutable is only interface, its real
|
||||
implementation is defined in mutator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mutable: LayerChoice
|
||||
forward logic of this input mutable
|
||||
args: list of torch.Tensor
|
||||
inputs of this mutable
|
||||
kwargs: dict
|
||||
inputs of this mutable
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
output of this mutable, i.e., LayerChoice
|
||||
int
|
||||
index of the chosen op
|
||||
"""
|
||||
# FIXME: return mask, to be consistent with other algorithms
|
||||
idx = mutable.registered_module.active_op_index
|
||||
return mutable.registered_module(mutable, *args, **kwargs), idx
|
||||
|
||||
def reset_binary_gates(self):
|
||||
"""
|
||||
For each LayerChoice, binarize binary weights
|
||||
based on alpha to only activate one op.
|
||||
It traverses all the mutables in the model to do this.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.binarize(mutable)
|
||||
|
||||
def set_chosen_op_active(self):
|
||||
"""
|
||||
For each LayerChoice, set the op with highest alpha as the chosen op.
|
||||
Usually used for validation.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.set_chosen_op_active()
|
||||
|
||||
def num_arch_params(self):
|
||||
"""
|
||||
The number of mutables, i.e., LayerChoice
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
the number of LayerChoice in user model
|
||||
"""
|
||||
return len(self.mutable_list)
|
||||
|
||||
def set_arch_param_grad(self):
|
||||
"""
|
||||
For each LayerChoice, calculate gradients for architecture weights, i.e., alpha
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.set_arch_param_grad(mutable)
|
||||
|
||||
def get_architecture_parameters(self):
|
||||
"""
|
||||
Get all the architecture parameters.
|
||||
|
||||
yield
|
||||
-----
|
||||
PyTorch Parameter
|
||||
Return ap_path_alpha of the traversed mutable
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
yield mutable.registered_module.get_ap_path_alpha()
|
||||
|
||||
def change_forward_mode(self, mode):
|
||||
"""
|
||||
Update forward mode of MixedOps, as training architecture weights and
|
||||
model weights use different forward modes.
|
||||
"""
|
||||
MixedOp.forward_mode = mode
|
||||
|
||||
def get_forward_mode(self):
|
||||
"""
|
||||
Get forward mode of MixedOp
|
||||
|
||||
Returns
|
||||
-------
|
||||
string
|
||||
the current forward mode of MixedOp
|
||||
"""
|
||||
return MixedOp.forward_mode
|
||||
|
||||
def rescale_updated_arch_param(self):
|
||||
"""
|
||||
Rescale architecture weights in 'two' mode.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.rescale_updated_arch_param()
|
||||
|
||||
def unused_modules_off(self):
|
||||
"""
|
||||
Remove unused modules for each mutables.
|
||||
The removed modules are kept in ```self._unused_modules``` for resume later.
|
||||
"""
|
||||
self._unused_modules = []
|
||||
for mutable in self.undedup_mutables:
|
||||
mixed_op = mutable.registered_module
|
||||
unused = {}
|
||||
if self.get_forward_mode() in ['full', 'two', 'full_v2']:
|
||||
involved_index = mixed_op.active_index + mixed_op.inactive_index
|
||||
else:
|
||||
involved_index = mixed_op.active_index
|
||||
for i in range(mixed_op.n_choices):
|
||||
if i not in involved_index:
|
||||
unused[i] = mutable[i]
|
||||
mutable[i] = None
|
||||
self._unused_modules.append(unused)
|
||||
|
||||
def unused_modules_back(self):
|
||||
"""
|
||||
Resume the removed modules back.
|
||||
"""
|
||||
if self._unused_modules is None:
|
||||
return
|
||||
for m, unused in zip(self.mutable_list, self._unused_modules):
|
||||
for i in unused:
|
||||
m[i] = unused[i]
|
||||
self._unused_modules = None
|
||||
|
||||
def arch_requires_grad(self):
|
||||
"""
|
||||
Make architecture weights require gradient
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_requires_grad()
|
||||
|
||||
def arch_disable_grad(self):
|
||||
"""
|
||||
Disable gradient of architecture weights, i.e., does not
|
||||
calcuate gradient for them.
|
||||
"""
|
||||
for mutable in self.undedup_mutables:
|
||||
mutable.registered_module.to_disable_grad()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Generate the final chosen architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the choice of each mutable, i.e., LayerChoice
|
||||
"""
|
||||
result = dict()
|
||||
for mutable in self.undedup_mutables:
|
||||
assert isinstance(mutable, LayerChoice)
|
||||
index, _ = mutable.registered_module.chosen_index
|
||||
# pylint: disable=not-callable
|
||||
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool()
|
||||
return result
|
|
@ -1,500 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from nni.nas.pytorch.base_trainer import BaseTrainer
|
||||
from nni.nas.pytorch.trainer import TorchTensorEncoder
|
||||
from nni.nas.pytorch.utils import AverageMeter
|
||||
from .mutator import ProxylessNasMutator
|
||||
from .utils import cross_entropy_with_label_smoothing, accuracy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ProxylessNasTrainer(BaseTrainer):
|
||||
def __init__(self, model, model_optim, device,
|
||||
train_loader, valid_loader, label_smoothing=0.1,
|
||||
n_epochs=120, init_lr=0.025, binary_mode='full_v2',
|
||||
arch_init_type='normal', arch_init_ratio=1e-3,
|
||||
arch_optim_lr=1e-3, arch_weight_decay=0,
|
||||
grad_update_arch_param_every=5, grad_update_steps=1,
|
||||
warmup=True, warmup_epochs=25,
|
||||
arch_valid_frequency=1,
|
||||
load_ckpt=False, ckpt_path=None, arch_path=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
the user model, which has mutables
|
||||
model_optim : pytorch optimizer
|
||||
the user defined optimizer
|
||||
device : pytorch device
|
||||
the devices to train/search the model
|
||||
train_loader : pytorch data loader
|
||||
data loader for the training set
|
||||
valid_loader : pytorch data loader
|
||||
data loader for the validation set
|
||||
label_smoothing : float
|
||||
for label smoothing
|
||||
n_epochs : int
|
||||
number of epochs to train/search
|
||||
init_lr : float
|
||||
init learning rate for training the model
|
||||
binary_mode : str
|
||||
the forward/backward mode for the binary weights in mutator
|
||||
arch_init_type : str
|
||||
the way to init architecture parameters
|
||||
arch_init_ratio : float
|
||||
the ratio to init architecture parameters
|
||||
arch_optim_lr : float
|
||||
learning rate of the architecture parameters optimizer
|
||||
arch_weight_decay : float
|
||||
weight decay of the architecture parameters optimizer
|
||||
grad_update_arch_param_every : int
|
||||
update architecture weights every this number of minibatches
|
||||
grad_update_steps : int
|
||||
during each update of architecture weights, the number of steps to train
|
||||
warmup : bool
|
||||
whether to do warmup
|
||||
warmup_epochs : int
|
||||
the number of epochs to do during warmup
|
||||
arch_valid_frequency : int
|
||||
frequency of printing validation result
|
||||
load_ckpt : bool
|
||||
whether load checkpoint
|
||||
ckpt_path : str
|
||||
checkpoint path, if load_ckpt is True, ckpt_path cannot be None
|
||||
arch_path : str
|
||||
the path to store chosen architecture
|
||||
"""
|
||||
self.model = model
|
||||
self.model_optim = model_optim
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
self.device = device
|
||||
self.n_epochs = n_epochs
|
||||
self.init_lr = init_lr
|
||||
self.warmup = warmup
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.arch_valid_frequency = arch_valid_frequency
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
self.train_batch_size = train_loader.batch_sampler.batch_size
|
||||
self.valid_batch_size = valid_loader.batch_sampler.batch_size
|
||||
# update architecture parameters every this number of minibatches
|
||||
self.grad_update_arch_param_every = grad_update_arch_param_every
|
||||
# the number of steps per architecture parameter update
|
||||
self.grad_update_steps = grad_update_steps
|
||||
self.binary_mode = binary_mode
|
||||
|
||||
self.load_ckpt = load_ckpt
|
||||
self.ckpt_path = ckpt_path
|
||||
self.arch_path = arch_path
|
||||
|
||||
# init mutator
|
||||
self.mutator = ProxylessNasMutator(model)
|
||||
|
||||
# DataParallel should be put behind the init of mutator
|
||||
self.model = torch.nn.DataParallel(self.model)
|
||||
self.model.to(self.device)
|
||||
|
||||
# iter of valid dataset for training architecture weights
|
||||
self._valid_iter = None
|
||||
# init architecture weights
|
||||
self._init_arch_params(arch_init_type, arch_init_ratio)
|
||||
# build architecture optimizer
|
||||
self.arch_optimizer = torch.optim.Adam(self.mutator.get_architecture_parameters(),
|
||||
arch_optim_lr,
|
||||
weight_decay=arch_weight_decay,
|
||||
betas=(0, 0.999),
|
||||
eps=1e-8)
|
||||
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
self.warmup_curr_epoch = 0
|
||||
self.train_curr_epoch = 0
|
||||
|
||||
def _init_arch_params(self, init_type='normal', init_ratio=1e-3):
|
||||
"""
|
||||
Initialize architecture weights
|
||||
"""
|
||||
for param in self.mutator.get_architecture_parameters():
|
||||
if init_type == 'normal':
|
||||
param.data.normal_(0, init_ratio)
|
||||
elif init_type == 'uniform':
|
||||
param.data.uniform_(-init_ratio, init_ratio)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Do validation. During validation, LayerChoices use the chosen active op.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float, float, float
|
||||
average loss, average top1 accuracy, average top5 accuracy
|
||||
"""
|
||||
self.valid_loader.batch_sampler.batch_size = self.valid_batch_size
|
||||
self.valid_loader.batch_sampler.drop_last = False
|
||||
|
||||
self.mutator.set_chosen_op_active()
|
||||
# remove unused modules to save memory
|
||||
self.mutator.unused_modules_off()
|
||||
# test on validation set under train mode
|
||||
self.model.train()
|
||||
batch_time = AverageMeter('batch_time')
|
||||
losses = AverageMeter('losses')
|
||||
top1 = AverageMeter('top1')
|
||||
top5 = AverageMeter('top5')
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
for i, (images, labels) in enumerate(self.valid_loader):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
output = self.model(images)
|
||||
loss = self.criterion(output, labels)
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0 or i + 1 == len(self.valid_loader):
|
||||
test_log = 'Valid' + ': [{0}/{1}]\t'\
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
|
||||
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
|
||||
format(i, len(self.valid_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
|
||||
# return top5:
|
||||
test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
|
||||
logger.info(test_log)
|
||||
self.mutator.unused_modules_back()
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
||||
def _warm_up(self):
|
||||
"""
|
||||
Warm up the model, during warm up, architecture weights are not trained.
|
||||
"""
|
||||
lr_max = 0.05
|
||||
data_loader = self.train_loader
|
||||
nBatch = len(data_loader)
|
||||
T_total = self.warmup_epochs * nBatch # total num of batches
|
||||
|
||||
for epoch in range(self.warmup_curr_epoch, self.warmup_epochs):
|
||||
logger.info('\n--------Warmup epoch: %d--------\n', epoch + 1)
|
||||
batch_time = AverageMeter('batch_time')
|
||||
data_time = AverageMeter('data_time')
|
||||
losses = AverageMeter('losses')
|
||||
top1 = AverageMeter('top1')
|
||||
top5 = AverageMeter('top5')
|
||||
# switch to train mode
|
||||
self.model.train()
|
||||
|
||||
end = time.time()
|
||||
logger.info('warm_up epoch: %d', epoch)
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
data_time.update(time.time() - end)
|
||||
# lr
|
||||
T_cur = epoch * nBatch + i
|
||||
warmup_lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
|
||||
for param_group in self.model_optim.param_groups:
|
||||
param_group['lr'] = warmup_lr
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
# compute output
|
||||
self.mutator.reset_binary_gates() # random sample binary gates
|
||||
self.mutator.unused_modules_off() # remove unused module for speedup
|
||||
output = self.model(images)
|
||||
if self.label_smoothing > 0:
|
||||
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
|
||||
else:
|
||||
loss = self.criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
# compute gradient and do SGD step
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
self.model_optim.step()
|
||||
# unused modules back
|
||||
self.mutator.unused_modules_back()
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0 or i + 1 == nBatch:
|
||||
batch_log = 'Warmup Train [{0}][{1}/{2}]\t' \
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
|
||||
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
|
||||
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
|
||||
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
|
||||
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
|
||||
losses=losses, top1=top1, top5=top5, lr=warmup_lr)
|
||||
logger.info(batch_log)
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}\t' \
|
||||
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}M'. \
|
||||
format(epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5, top1=top1, top5=top5)
|
||||
logger.info(val_log)
|
||||
self.save_checkpoint()
|
||||
self.warmup_curr_epoch += 1
|
||||
|
||||
def _get_update_schedule(self, nBatch):
|
||||
"""
|
||||
Generate schedule for training architecture weights. Key means after which minibatch
|
||||
to update architecture weights, value means how many steps for the update.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nBatch : int
|
||||
the total number of minibatches in one epoch
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the schedule for updating architecture weights
|
||||
"""
|
||||
schedule = {}
|
||||
for i in range(nBatch):
|
||||
if (i + 1) % self.grad_update_arch_param_every == 0:
|
||||
schedule[i] = self.grad_update_steps
|
||||
return schedule
|
||||
|
||||
def _calc_learning_rate(self, epoch, batch=0, nBatch=None):
|
||||
"""
|
||||
Update learning rate.
|
||||
"""
|
||||
T_total = self.n_epochs * nBatch
|
||||
T_cur = epoch * nBatch + batch
|
||||
lr = 0.5 * self.init_lr * (1 + math.cos(math.pi * T_cur / T_total))
|
||||
return lr
|
||||
|
||||
def _adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
|
||||
"""
|
||||
Adjust learning of a given optimizer and return the new learning rate
|
||||
|
||||
Parameters
|
||||
----------
|
||||
optimizer : pytorch optimizer
|
||||
the used optimizer
|
||||
epoch : int
|
||||
the current epoch number
|
||||
batch : int
|
||||
the current minibatch
|
||||
nBatch : int
|
||||
the total number of minibatches in one epoch
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
the adjusted learning rate
|
||||
"""
|
||||
new_lr = self._calc_learning_rate(epoch, batch, nBatch)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
return new_lr
|
||||
|
||||
def _train(self):
|
||||
"""
|
||||
Train the model, it trains model weights and architecute weights.
|
||||
Architecture weights are trained according to the schedule.
|
||||
Before updating architecture weights, ```requires_grad``` is enabled.
|
||||
Then, it is disabled after the updating, in order not to update
|
||||
architecture weights when training model weights.
|
||||
"""
|
||||
nBatch = len(self.train_loader)
|
||||
arch_param_num = self.mutator.num_arch_params()
|
||||
binary_gates_num = self.mutator.num_arch_params()
|
||||
logger.info('#arch_params: %d\t#binary_gates: %d', arch_param_num, binary_gates_num)
|
||||
|
||||
update_schedule = self._get_update_schedule(nBatch)
|
||||
|
||||
for epoch in range(self.train_curr_epoch, self.n_epochs):
|
||||
logger.info('\n--------Train epoch: %d--------\n', epoch + 1)
|
||||
batch_time = AverageMeter('batch_time')
|
||||
data_time = AverageMeter('data_time')
|
||||
losses = AverageMeter('losses')
|
||||
top1 = AverageMeter('top1')
|
||||
top5 = AverageMeter('top5')
|
||||
# switch to train mode
|
||||
self.model.train()
|
||||
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(self.train_loader):
|
||||
data_time.update(time.time() - end)
|
||||
lr = self._adjust_learning_rate(self.model_optim, epoch, batch=i, nBatch=nBatch)
|
||||
# train weight parameters
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
self.mutator.reset_binary_gates()
|
||||
self.mutator.unused_modules_off()
|
||||
output = self.model(images)
|
||||
if self.label_smoothing > 0:
|
||||
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
|
||||
else:
|
||||
loss = self.criterion(output, labels)
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
losses.update(loss, images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
self.model_optim.step()
|
||||
self.mutator.unused_modules_back()
|
||||
if epoch > 0:
|
||||
for _ in range(update_schedule.get(i, 0)):
|
||||
start_time = time.time()
|
||||
# GradientArchSearchConfig
|
||||
self.mutator.arch_requires_grad()
|
||||
arch_loss, exp_value = self._gradient_step()
|
||||
self.mutator.arch_disable_grad()
|
||||
used_time = time.time() - start_time
|
||||
log_str = 'Architecture [%d-%d]\t Time %.4f\t Loss %.4f\t null %s' % \
|
||||
(epoch + 1, i, used_time, arch_loss, exp_value)
|
||||
logger.info(log_str)
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
# training log
|
||||
if i % 10 == 0 or i + 1 == nBatch:
|
||||
batch_log = 'Train [{0}][{1}/{2}]\t' \
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
||||
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' \
|
||||
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
|
||||
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
|
||||
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
|
||||
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
|
||||
losses=losses, top1=top1, top5=top5, lr=lr)
|
||||
logger.info(batch_log)
|
||||
# validate
|
||||
if (epoch + 1) % self.arch_valid_frequency == 0:
|
||||
val_loss, val_top1, val_top5 = self._validate()
|
||||
val_log = 'Valid [{0}]\tloss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}\t' \
|
||||
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
|
||||
format(epoch + 1, val_loss, val_top1, val_top5, top1=top1, top5=top5)
|
||||
logger.info(val_log)
|
||||
self.save_checkpoint()
|
||||
self.train_curr_epoch += 1
|
||||
|
||||
def _valid_next_batch(self):
|
||||
"""
|
||||
Get next one minibatch from validation set
|
||||
|
||||
Returns
|
||||
-------
|
||||
(tensor, tensor)
|
||||
the tuple of images and labels
|
||||
"""
|
||||
if self._valid_iter is None:
|
||||
self._valid_iter = iter(self.valid_loader)
|
||||
try:
|
||||
data = next(self._valid_iter)
|
||||
except StopIteration:
|
||||
self._valid_iter = iter(self.valid_loader)
|
||||
data = next(self._valid_iter)
|
||||
return data
|
||||
|
||||
def _gradient_step(self):
|
||||
"""
|
||||
This gradient step is for updating architecture weights.
|
||||
Mutator is intensively used in this function to operate on
|
||||
architecture weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float, None
|
||||
loss of the model, None
|
||||
"""
|
||||
# use the same batch size as train batch size for architecture weights
|
||||
self.valid_loader.batch_sampler.batch_size = self.train_batch_size
|
||||
self.valid_loader.batch_sampler.drop_last = True
|
||||
self.model.train()
|
||||
self.mutator.change_forward_mode(self.binary_mode)
|
||||
time1 = time.time() # time
|
||||
# sample a batch of data from validation set
|
||||
images, labels = self._valid_next_batch()
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
time2 = time.time() # time
|
||||
self.mutator.reset_binary_gates()
|
||||
self.mutator.unused_modules_off()
|
||||
output = self.model(images)
|
||||
time3 = time.time()
|
||||
ce_loss = self.criterion(output, labels)
|
||||
expected_value = None
|
||||
loss = ce_loss
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
self.mutator.set_arch_param_grad()
|
||||
self.arch_optimizer.step()
|
||||
if self.mutator.get_forward_mode() == 'two':
|
||||
self.mutator.rescale_updated_arch_param()
|
||||
self.mutator.unused_modules_back()
|
||||
self.mutator.change_forward_mode(None)
|
||||
time4 = time.time()
|
||||
logger.info('(%.4f, %.4f, %.4f)', time2 - time1, time3 - time2, time4 - time3)
|
||||
return loss.data.item(), expected_value.item() if expected_value is not None else None
|
||||
|
||||
def save_checkpoint(self):
|
||||
"""
|
||||
Save checkpoint of the whole model. Saving model weights and architecture weights in
|
||||
```ckpt_path```, and saving currently chosen architecture in ```arch_path```.
|
||||
"""
|
||||
if self.ckpt_path:
|
||||
state = {
|
||||
'warmup_curr_epoch': self.warmup_curr_epoch,
|
||||
'train_curr_epoch': self.train_curr_epoch,
|
||||
'model': self.model.state_dict(),
|
||||
'optim': self.model_optim.state_dict(),
|
||||
'arch_optim': self.arch_optimizer.state_dict()
|
||||
}
|
||||
torch.save(state, self.ckpt_path)
|
||||
if self.arch_path:
|
||||
self.export(self.arch_path)
|
||||
|
||||
def load_checkpoint(self):
|
||||
"""
|
||||
Load the checkpoint from ```ckpt_path```.
|
||||
"""
|
||||
assert self.ckpt_path is not None, "If load_ckpt is not None, ckpt_path should not be None"
|
||||
ckpt = torch.load(self.ckpt_path)
|
||||
self.warmup_curr_epoch = ckpt['warmup_curr_epoch']
|
||||
self.train_curr_epoch = ckpt['train_curr_epoch']
|
||||
self.model.load_state_dict(ckpt['model'])
|
||||
self.model_optim.load_state_dict(ckpt['optim'])
|
||||
self.arch_optimizer.load_state_dict(ckpt['arch_optim'])
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Train the whole model.
|
||||
"""
|
||||
if self.load_ckpt:
|
||||
self.load_checkpoint()
|
||||
if self.warmup:
|
||||
self._warm_up()
|
||||
self._train()
|
||||
|
||||
def export(self, file_name):
|
||||
"""
|
||||
Export the chosen architecture into a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name : str
|
||||
the file that stores exported chosen architecture
|
||||
"""
|
||||
exported_arch = self.mutator.sample_final()
|
||||
with open(file_name, 'w') as f:
|
||||
json.dump(exported_arch, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
|
||||
|
||||
def validate(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def checkpoint(self):
|
||||
raise NotImplementedError
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def detach_variable(inputs):
|
||||
"""
|
||||
Detach variables
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : pytorch tensors
|
||||
pytorch tensors
|
||||
"""
|
||||
if isinstance(inputs, tuple):
|
||||
return tuple([detach_variable(x) for x in inputs])
|
||||
else:
|
||||
x = inputs.detach()
|
||||
x.requires_grad = inputs.requires_grad
|
||||
return x
|
||||
|
||||
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
pred : pytorch tensor
|
||||
predicted value
|
||||
target : pytorch tensor
|
||||
label
|
||||
label_smoothing : float
|
||||
the degree of label smoothing
|
||||
|
||||
Returns
|
||||
-------
|
||||
pytorch tensor
|
||||
cross entropy
|
||||
"""
|
||||
logsoftmax = nn.LogSoftmax()
|
||||
n_classes = pred.size(1)
|
||||
# convert to one-hot
|
||||
target = torch.unsqueeze(target, 1)
|
||||
soft_target = torch.zeros_like(pred)
|
||||
soft_target.scatter_(1, target, 1)
|
||||
# label smoothing
|
||||
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
|
||||
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""
|
||||
Computes the precision@k for the specified values of k
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : pytorch tensor
|
||||
output, e.g., predicted value
|
||||
target : pytorch tensor
|
||||
label
|
||||
topk : tuple
|
||||
specify top1 and top5
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
accuracy of top1 and top5
|
||||
"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nni.nas.pytorch.mutator import Mutator
|
||||
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
|
||||
|
||||
|
||||
class RandomMutator(Mutator):
|
||||
"""
|
||||
Random mutator that samples a random candidate in the search space each time ``reset()``.
|
||||
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
|
||||
"""
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
Sample a random candidate.
|
||||
"""
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
gen_index = torch.randint(high=len(mutable), size=(1, ))
|
||||
result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
|
||||
elif isinstance(mutable, InputChoice):
|
||||
if mutable.n_chosen is None:
|
||||
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
|
||||
else:
|
||||
perm = torch.randperm(mutable.n_candidates)
|
||||
mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
|
||||
result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
|
||||
return result
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Same as :meth:`sample_search`.
|
||||
"""
|
||||
return self.sample_search()
|
|
@ -1,6 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .evolution import SPOSEvolution
|
||||
from .mutator import SPOSSupernetTrainingMutator
|
||||
from .trainer import SPOSSupernetTrainer
|
|
@ -1,223 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from nni.tuner import Tuner
|
||||
from nni.algorithms.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SPOSEvolution(Tuner):
|
||||
"""
|
||||
SPOS evolution tuner.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_epochs : int
|
||||
Maximum number of epochs to run.
|
||||
num_select : int
|
||||
Number of survival candidates of each epoch.
|
||||
num_population : int
|
||||
Number of candidates at the start of each epoch. If candidates generated by
|
||||
crossover and mutation are not enough, the rest will be filled with random
|
||||
candidates.
|
||||
m_prob : float
|
||||
The probability of mutation.
|
||||
num_crossover : int
|
||||
Number of candidates generated by crossover in each epoch.
|
||||
num_mutation : int
|
||||
Number of candidates generated by mutation in each epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
|
||||
num_crossover=25, num_mutation=25):
|
||||
assert num_population >= num_select
|
||||
self.max_epochs = max_epochs
|
||||
self.num_select = num_select
|
||||
self.num_population = num_population
|
||||
self.m_prob = m_prob
|
||||
self.num_crossover = num_crossover
|
||||
self.num_mutation = num_mutation
|
||||
self.epoch = 0
|
||||
self.candidates = []
|
||||
self.search_space = None
|
||||
self.random_state = np.random.RandomState(0)
|
||||
|
||||
# async status
|
||||
self._to_evaluate_queue = deque()
|
||||
self._sending_parameter_queue = deque()
|
||||
self._pending_result_ids = set()
|
||||
self._reward_dict = dict()
|
||||
self._id2candidate = dict()
|
||||
self._st_callback = None
|
||||
|
||||
def update_search_space(self, search_space):
|
||||
"""
|
||||
Handle the initialization/update event of search space.
|
||||
"""
|
||||
self._search_space = search_space
|
||||
self._next_round()
|
||||
|
||||
def _next_round(self):
|
||||
_logger.info("Epoch %d, generating...", self.epoch)
|
||||
if self.epoch == 0:
|
||||
self._get_random_population()
|
||||
self.export_results(self.candidates)
|
||||
else:
|
||||
best_candidates = self._select_top_candidates()
|
||||
self.export_results(best_candidates)
|
||||
if self.epoch >= self.max_epochs:
|
||||
return
|
||||
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
|
||||
self._get_random_population()
|
||||
self.epoch += 1
|
||||
|
||||
def _random_candidate(self):
|
||||
chosen_arch = dict()
|
||||
for key, val in self._search_space.items():
|
||||
if val["_type"] == LAYER_CHOICE:
|
||||
choices = val["_value"]
|
||||
index = self.random_state.randint(len(choices))
|
||||
chosen_arch[key] = {"_value": choices[index], "_idx": index}
|
||||
elif val["_type"] == INPUT_CHOICE:
|
||||
raise NotImplementedError("Input choice is not implemented yet.")
|
||||
return chosen_arch
|
||||
|
||||
def _add_to_evaluate_queue(self, cand):
|
||||
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
|
||||
self._reward_dict[self._hashcode(cand)] = 0.
|
||||
self._to_evaluate_queue.append(cand)
|
||||
|
||||
def _get_random_population(self):
|
||||
while len(self.candidates) < self.num_population:
|
||||
cand = self._random_candidate()
|
||||
if self._is_legal(cand):
|
||||
_logger.info("Random candidate generated.")
|
||||
self._add_to_evaluate_queue(cand)
|
||||
self.candidates.append(cand)
|
||||
|
||||
def _get_crossover(self, best):
|
||||
result = []
|
||||
for _ in range(10 * self.num_crossover):
|
||||
cand_p1 = best[self.random_state.randint(len(best))]
|
||||
cand_p2 = best[self.random_state.randint(len(best))]
|
||||
assert cand_p1.keys() == cand_p2.keys()
|
||||
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
|
||||
for k in cand_p1.keys()}
|
||||
if self._is_legal(cand):
|
||||
result.append(cand)
|
||||
self._add_to_evaluate_queue(cand)
|
||||
if len(result) >= self.num_crossover:
|
||||
break
|
||||
_logger.info("Found %d architectures with crossover.", len(result))
|
||||
return result
|
||||
|
||||
def _get_mutation(self, best):
|
||||
result = []
|
||||
for _ in range(10 * self.num_mutation):
|
||||
cand = best[self.random_state.randint(len(best))].copy()
|
||||
mutation_sample = np.random.random_sample(len(cand))
|
||||
for s, k in zip(mutation_sample, cand):
|
||||
if s < self.m_prob:
|
||||
choices = self._search_space[k]["_value"]
|
||||
index = self.random_state.randint(len(choices))
|
||||
cand[k] = {"_value": choices[index], "_idx": index}
|
||||
if self._is_legal(cand):
|
||||
result.append(cand)
|
||||
self._add_to_evaluate_queue(cand)
|
||||
if len(result) >= self.num_mutation:
|
||||
break
|
||||
_logger.info("Found %d architectures with mutation.", len(result))
|
||||
return result
|
||||
|
||||
def _get_architecture_repr(self, cand):
|
||||
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
|
||||
self._hashcode(cand))
|
||||
|
||||
def _is_legal(self, cand):
|
||||
if self._hashcode(cand) in self._reward_dict:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _select_top_candidates(self):
|
||||
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
|
||||
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
|
||||
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
|
||||
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _hashcode(d):
|
||||
return json.dumps(d, sort_keys=True)
|
||||
|
||||
def _bind_and_send_parameters(self):
|
||||
"""
|
||||
There are two types of resources: parameter ids and candidates. This function is called at
|
||||
necessary times to bind these resources to send new trials with st_callback.
|
||||
"""
|
||||
result = []
|
||||
while self._sending_parameter_queue and self._to_evaluate_queue:
|
||||
parameter_id = self._sending_parameter_queue.popleft()
|
||||
parameters = self._to_evaluate_queue.popleft()
|
||||
self._id2candidate[parameter_id] = parameters
|
||||
result.append(parameters)
|
||||
self._pending_result_ids.add(parameter_id)
|
||||
self._st_callback(parameter_id, parameters)
|
||||
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
|
||||
return result
|
||||
|
||||
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
|
||||
"""
|
||||
Callback function necessary to implement a tuner. This will put more parameter ids into the
|
||||
parameter id queue.
|
||||
"""
|
||||
if "st_callback" in kwargs and self._st_callback is None:
|
||||
self._st_callback = kwargs["st_callback"]
|
||||
for parameter_id in parameter_id_list:
|
||||
self._sending_parameter_queue.append(parameter_id)
|
||||
self._bind_and_send_parameters()
|
||||
return [] # always not use this. might induce problem of over-sending
|
||||
|
||||
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
|
||||
"""
|
||||
Callback function. Receive a trial result.
|
||||
"""
|
||||
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
|
||||
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
|
||||
|
||||
def trial_end(self, parameter_id, success, **kwargs):
|
||||
"""
|
||||
Callback function when a trial is ended and resource is released.
|
||||
"""
|
||||
self._pending_result_ids.remove(parameter_id)
|
||||
if not self._pending_result_ids and not self._to_evaluate_queue:
|
||||
# a new epoch now
|
||||
self._next_round()
|
||||
assert self._st_callback is not None
|
||||
self._bind_and_send_parameters()
|
||||
|
||||
def export_results(self, result):
|
||||
"""
|
||||
Export a number of candidates to `checkpoints` dir.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict
|
||||
Chosen architectures to be exported.
|
||||
"""
|
||||
os.makedirs("checkpoints", exist_ok=True)
|
||||
for i, cand in enumerate(result):
|
||||
converted = dict()
|
||||
for cand_key, cand_val in cand.items():
|
||||
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
|
||||
converted[cand_key] = onehot
|
||||
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
|
||||
json.dump(converted, fp)
|
|
@ -1,66 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from nni.algorithms.nas.pytorch.random import RandomMutator
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SPOSSupernetTrainingMutator(RandomMutator):
|
||||
"""
|
||||
A random mutator with flops limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
PyTorch model.
|
||||
flops_func : callable
|
||||
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
|
||||
is None, functions related to flops will be deactivated.
|
||||
flops_lb : number
|
||||
Lower bound of flops.
|
||||
flops_ub : number
|
||||
Upper bound of flops.
|
||||
flops_bin_num : number
|
||||
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
|
||||
uniform, but the sampling will be slower.
|
||||
flops_sample_timeout : int
|
||||
Maximum number of attempts to sample before giving up and use a random candidate.
|
||||
"""
|
||||
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
|
||||
flops_bin_num=7, flops_sample_timeout=500):
|
||||
|
||||
super().__init__(model)
|
||||
self._flops_func = flops_func
|
||||
if self._flops_func is not None:
|
||||
self._flops_bin_num = flops_bin_num
|
||||
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
|
||||
self._flops_sample_timeout = flops_sample_timeout
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
|
||||
relative to flops.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
"""
|
||||
if self._flops_func is not None:
|
||||
for times in range(self._flops_sample_timeout):
|
||||
idx = np.random.randint(self._flops_bin_num)
|
||||
cand = super().sample_search()
|
||||
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
|
||||
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
|
||||
return cand
|
||||
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
|
||||
return super().sample_search()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Implement only to suffice the interface of Mutator.
|
||||
"""
|
||||
return self.sample_search()
|
|
@ -1,95 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from nni.nas.pytorch.trainer import Trainer
|
||||
from nni.nas.pytorch.utils import AverageMeterGroup
|
||||
|
||||
from .mutator import SPOSSupernetTrainingMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SPOSSupernetTrainer(Trainer):
|
||||
"""
|
||||
This trainer trains a supernet that can be used for evolution search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
Model with mutables.
|
||||
mutator : nni.nas.pytorch.mutator.Mutator
|
||||
A mutator object that has been initialized with the model.
|
||||
loss : callable
|
||||
Called with logits and targets. Returns a loss tensor.
|
||||
metrics : callable
|
||||
Returns a dict that maps metrics keys to metrics data.
|
||||
optimizer : Optimizer
|
||||
Optimizer that optimizes the model.
|
||||
num_epochs : int
|
||||
Number of epochs of training.
|
||||
train_loader : iterable
|
||||
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
dataset_valid : iterable
|
||||
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
|
||||
batch_size : int
|
||||
Batch size.
|
||||
workers: int
|
||||
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
|
||||
device : torch.device
|
||||
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
|
||||
automatic detects GPU and selects GPU first.
|
||||
log_frequency : int
|
||||
Number of mini-batches to log metrics.
|
||||
callbacks : list of Callback
|
||||
Callbacks to plug into the trainer. See Callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self, model, loss, metrics,
|
||||
optimizer, num_epochs, train_loader, valid_loader,
|
||||
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
|
||||
callbacks=None):
|
||||
assert torch.cuda.is_available()
|
||||
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
|
||||
loss, metrics, optimizer, num_epochs, None, None,
|
||||
batch_size, workers, device, log_frequency, callbacks)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.model.train()
|
||||
meters = AverageMeterGroup()
|
||||
for step, (x, y) in enumerate(self.train_loader):
|
||||
x, y = x.to(self.device), y.to(self.device)
|
||||
self.optimizer.zero_grad()
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
loss = self.loss(logits, y)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
metrics = self.metrics(logits, y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.train_loader), meters)
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
self.model.eval()
|
||||
meters = AverageMeterGroup()
|
||||
with torch.no_grad():
|
||||
for step, (x, y) in enumerate(self.valid_loader):
|
||||
x, y = x.to(self.device), y.to(self.device)
|
||||
self.mutator.reset()
|
||||
logits = self.model(x)
|
||||
loss = self.loss(logits, y)
|
||||
metrics = self.metrics(logits, y)
|
||||
metrics["loss"] = loss.item()
|
||||
meters.update(metrics)
|
||||
if self.log_frequency is not None and step % self.log_frequency == 0:
|
||||
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
|
||||
self.num_epochs, step + 1, len(self.valid_loader), meters)
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import get_and_apply_next_architecture
|
|
@ -1,217 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import nni
|
||||
from nni.runtime.env_vars import trial_env_vars
|
||||
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
|
||||
from nni.nas.tensorflow.mutator import Mutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
|
||||
LAYER_CHOICE = "layer_choice"
|
||||
INPUT_CHOICE = "input_choice"
|
||||
|
||||
|
||||
def get_and_apply_next_architecture(model):
|
||||
"""
|
||||
Wrapper of :class:`~nni.nas.tensorflow.classic_nas.mutator.ClassicMutator` to make it more meaningful,
|
||||
similar to ``get_next_parameter`` for HPO.
|
||||
Tt will generate search space based on ``model``.
|
||||
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
|
||||
generating search space for the experiment.
|
||||
If not, there are still two mode, one is nni experiment mode where users
|
||||
use ``nnictl`` to start an experiment. The other is standalone mode
|
||||
where users directly run the trial command, this mode chooses the first
|
||||
one(s) for each LayerChoice and InputChoice.
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
ClassicMutator(model)
|
||||
|
||||
|
||||
class ClassicMutator(Mutator):
|
||||
"""
|
||||
This mutator is to apply the architecture chosen from tuner.
|
||||
It implements the forward function of LayerChoice and InputChoice,
|
||||
to only activate the chosen ones.
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super(ClassicMutator, self).__init__(model)
|
||||
self._chosen_arch = {}
|
||||
self._search_space = self._generate_search_space()
|
||||
if NNI_GEN_SEARCH_SPACE in os.environ:
|
||||
# dry run for only generating search space
|
||||
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
|
||||
sys.exit(0)
|
||||
|
||||
if trial_env_vars.NNI_PLATFORM is None:
|
||||
logger.warning("This is in standalone mode, the chosen are the first one(s).")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
# get chosen arch from tuner
|
||||
self._chosen_arch = nni.get_next_parameter()
|
||||
if self._chosen_arch is None:
|
||||
if trial_env_vars.NNI_PLATFORM == "unittest":
|
||||
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
|
||||
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
|
||||
self._chosen_arch = self._standalone_generate_chosen()
|
||||
else:
|
||||
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
|
||||
self.reset()
|
||||
|
||||
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert layer choice to tensor representation.
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
# doesn't support multihot for layer choice yet
|
||||
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
|
||||
mask = tf.one_hot(idx, len(mutable))
|
||||
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
|
||||
|
||||
def _sample_input_choice(self, mutable, idx, value, search_space_item):
|
||||
"""
|
||||
Convert input choice to tensor representation.
|
||||
Parameters
|
||||
----------
|
||||
mutable : Mutable
|
||||
idx : int
|
||||
Number `idx` of list will be selected.
|
||||
value : str
|
||||
The verbose representation of the selected value.
|
||||
search_space_item : list
|
||||
The list for corresponding search space.
|
||||
"""
|
||||
candidate_repr = search_space_item["candidates"]
|
||||
multihot_list = [False] * mutable.n_candidates
|
||||
for i, v in zip(idx, value):
|
||||
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
|
||||
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
|
||||
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
|
||||
multihot_list[i] = True
|
||||
return tf.cast(multihot_list, tf.bool) # pylint: disable=not-callable
|
||||
|
||||
def sample_search(self):
|
||||
"""
|
||||
See :meth:`sample_final`.
|
||||
"""
|
||||
return self.sample_final()
|
||||
|
||||
def sample_final(self):
|
||||
"""
|
||||
Convert the chosen arch and apply it on model.
|
||||
"""
|
||||
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
|
||||
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
|
||||
self._chosen_arch.keys())
|
||||
result = dict()
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, (LayerChoice, InputChoice)):
|
||||
assert mutable.key in self._chosen_arch, \
|
||||
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
|
||||
data = self._chosen_arch[mutable.key]
|
||||
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
|
||||
"'{}' is not a valid choice.".format(data)
|
||||
if isinstance(mutable, LayerChoice):
|
||||
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, InputChoice):
|
||||
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
|
||||
self._search_space[mutable.key]["_value"])
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return result
|
||||
|
||||
def _standalone_generate_chosen(self):
|
||||
"""
|
||||
Generate the chosen architecture for standalone mode,
|
||||
i.e., choose the first one(s) for LayerChoice and InputChoice.
|
||||
::
|
||||
{ key_name: {"_value": "conv1",
|
||||
"_idx": 0} }
|
||||
{ key_name: {"_value": ["in1"],
|
||||
"_idx": [0]} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the chosen architecture
|
||||
"""
|
||||
chosen_arch = {}
|
||||
for key, val in self._search_space.items():
|
||||
if val["_type"] == LAYER_CHOICE:
|
||||
choices = val["_value"]
|
||||
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
|
||||
elif val["_type"] == INPUT_CHOICE:
|
||||
choices = val["_value"]["candidates"]
|
||||
n_chosen = val["_value"]["n_chosen"]
|
||||
if n_chosen is None:
|
||||
n_chosen = len(choices)
|
||||
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
|
||||
else:
|
||||
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
|
||||
return chosen_arch
|
||||
|
||||
def _generate_search_space(self):
|
||||
"""
|
||||
Generate search space from mutables.
|
||||
Here is the search space format:
|
||||
::
|
||||
{ key_name: {"_type": "layer_choice",
|
||||
"_value": ["conv1", "conv2"]} }
|
||||
{ key_name: {"_type": "input_choice",
|
||||
"_value": {"candidates": ["in1", "in2"],
|
||||
"n_chosen": 1}} }
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the generated search space
|
||||
"""
|
||||
search_space = {}
|
||||
for mutable in self.mutables:
|
||||
# for now we only generate flattened search space
|
||||
if isinstance(mutable, LayerChoice):
|
||||
key = mutable.key
|
||||
val = mutable.names
|
||||
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
|
||||
elif isinstance(mutable, InputChoice):
|
||||
key = mutable.key
|
||||
search_space[key] = {"_type": INPUT_CHOICE,
|
||||
"_value": {"candidates": mutable.choose_from,
|
||||
"n_chosen": mutable.n_chosen}}
|
||||
elif isinstance(mutable, MutableScope):
|
||||
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
|
||||
else:
|
||||
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
|
||||
return search_space
|
||||
|
||||
def _dump_search_space(self, file_path):
|
||||
with open(file_path, "w") as ss_file:
|
||||
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import EnasMutator
|
||||
from .trainer import EnasTrainer
|
|
@ -1,162 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import Dense, Embedding, LSTMCell, RNN
|
||||
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction
|
||||
|
||||
from nni.nas.tensorflow.mutator import Mutator
|
||||
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
|
||||
|
||||
|
||||
class EnasMutator(Mutator):
|
||||
def __init__(self, model,
|
||||
lstm_size=64,
|
||||
lstm_num_layers=1,
|
||||
tanh_constant=1.5,
|
||||
cell_exit_extra_step=False,
|
||||
skip_target=0.4,
|
||||
temperature=None,
|
||||
branch_bias=0.25,
|
||||
entropy_reduction='sum'):
|
||||
super().__init__(model)
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
self.cell_exit_extra_step = cell_exit_extra_step
|
||||
|
||||
cells = [LSTMCell(units=lstm_size, use_bias=False) for _ in range(lstm_num_layers)]
|
||||
self.lstm = RNN(cells, stateful=True)
|
||||
self.g_emb = tf.random.normal((1, 1, lstm_size)) * 0.1
|
||||
self.skip_targets = tf.constant([1.0 - skip_target, skip_target])
|
||||
|
||||
self.max_layer_choice = 0
|
||||
self.bias_dict = {}
|
||||
for mutable in self.mutables:
|
||||
if isinstance(mutable, LayerChoice):
|
||||
if self.max_layer_choice == 0:
|
||||
self.max_layer_choice = len(mutable)
|
||||
assert self.max_layer_choice == len(mutable), \
|
||||
"ENAS mutator requires all layer choice have the same number of candidates."
|
||||
if 'reduce' in mutable.key:
|
||||
bias = []
|
||||
for choice in mutable.choices:
|
||||
if 'conv' in str(type(choice)).lower():
|
||||
bias.append(branch_bias)
|
||||
else:
|
||||
bias.append(-branch_bias)
|
||||
self.bias_dict[mutable.key] = tf.constant(bias)
|
||||
|
||||
# exposed for trainer
|
||||
self.sample_log_prob = 0
|
||||
self.sample_entropy = 0
|
||||
self.sample_skip_penalty = 0
|
||||
|
||||
# internal nn layers
|
||||
self.embedding = Embedding(self.max_layer_choice + 1, lstm_size)
|
||||
self.soft = Dense(self.max_layer_choice, use_bias=False)
|
||||
self.attn_anchor = Dense(lstm_size, use_bias=False)
|
||||
self.attn_query = Dense(lstm_size, use_bias=False)
|
||||
self.v_attn = Dense(1, use_bias=False)
|
||||
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
|
||||
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
|
||||
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
|
||||
|
||||
self._first_sample = True
|
||||
|
||||
def sample_search(self):
|
||||
self._initialize()
|
||||
self._sample(self.mutables)
|
||||
self._first_sample = False
|
||||
return self._choices
|
||||
|
||||
def sample_final(self):
|
||||
return self.sample_search()
|
||||
|
||||
def _sample(self, tree):
|
||||
mutable = tree.mutable
|
||||
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_layer_choice(mutable)
|
||||
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
|
||||
self._choices[mutable.key] = self._sample_input_choice(mutable)
|
||||
for child in tree.children:
|
||||
self._sample(child)
|
||||
if self.cell_exit_extra_step and isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
|
||||
self._anchors_hid[mutable.key] = self.lstm(self._inputs, 1)
|
||||
|
||||
def _initialize(self):
|
||||
self._choices = {}
|
||||
self._anchors_hid = {}
|
||||
self._inputs = self.g_emb
|
||||
# seems the `input_shape` parameter of RNN does not work
|
||||
# workaround it by omitting `reset_states` for first run
|
||||
if not self._first_sample:
|
||||
self.lstm.reset_states()
|
||||
self.sample_log_prob = 0
|
||||
self.sample_entropy = 0
|
||||
self.sample_skip_penalty = 0
|
||||
|
||||
def _sample_layer_choice(self, mutable):
|
||||
logit = self.soft(self.lstm(self._inputs))
|
||||
if self.temperature is not None:
|
||||
logit /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
logit = self.tanh_constant * tf.tanh(logit)
|
||||
if mutable.key in self.bias_dict:
|
||||
logit += self.bias_dict[mutable.key]
|
||||
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
|
||||
branch_id = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [1])
|
||||
log_prob = self.cross_entropy_loss(branch_id, logit)
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = log_prob * tf.math.exp(-log_prob)
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
self._inputs = tf.reshape(self.embedding(branch_id), [1, 1, -1])
|
||||
mask = tf.one_hot(branch_id, self.max_layer_choice)
|
||||
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
|
||||
|
||||
def _sample_input_choice(self, mutable):
|
||||
query, anchors = [], []
|
||||
for label in mutable.choose_from:
|
||||
if label not in self._anchors_hid:
|
||||
self._anchors_hid[label] = self.lstm(self._inputs)
|
||||
query.append(self.attn_anchor(self._anchors_hid[label]))
|
||||
anchors.append(self._anchors_hid[label])
|
||||
query = tf.concat(query, axis=0)
|
||||
query = tf.tanh(query + self.attn_query(anchors[-1]))
|
||||
query = self.v_attn(query)
|
||||
|
||||
if self.temperature is not None:
|
||||
query /= self.temperature
|
||||
if self.tanh_constant is not None:
|
||||
query = self.tanh_constant * tf.tanh(query)
|
||||
|
||||
if mutable.n_chosen is None:
|
||||
logit = tf.concat([-query, query], axis=1)
|
||||
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
|
||||
skip = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
|
||||
skip_prob = tf.math.sigmoid(logit)
|
||||
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
|
||||
self.sample_skip_penalty += kl
|
||||
log_prob = self.cross_entropy_loss(skip, logit)
|
||||
|
||||
skip = tf.cast(skip, tf.float32)
|
||||
inputs = tf.tensordot(skip, tf.concat(anchors, 0), 1) / (1. + tf.reduce_sum(skip))
|
||||
self._inputs = tf.reshape(inputs, [1, 1, -1])
|
||||
|
||||
else:
|
||||
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
|
||||
logit = tf.reshape(query, [1, -1])
|
||||
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
|
||||
index = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
|
||||
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
|
||||
# when the size is 1, tf does not accept tensor here, complaining the shape is wrong
|
||||
# but using a numpy array seems fine
|
||||
log_prob = self.cross_entropy_loss(logit, query.numpy())
|
||||
self._inputs = tf.reshape(anchors[index.numpy()[0]], [1, 1, -1])
|
||||
|
||||
self.sample_log_prob += self.entropy_reduction(log_prob)
|
||||
entropy = log_prob * tf.exp(-log_prob)
|
||||
self.sample_entropy += self.entropy_reduction(entropy)
|
||||
assert len(skip) == mutable.n_candidates, (skip, mutable.n_candidates, mutable.n_chosen)
|
||||
return tf.cast(skip, tf.bool)
|
|
@ -1,205 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import logging
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
|
||||
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
|
||||
|
||||
from .mutator import EnasMutator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnasTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
loss,
|
||||
metrics,
|
||||
reward_function,
|
||||
optimizer,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
dataset_train,
|
||||
dataset_valid,
|
||||
log_frequency=100,
|
||||
entropy_weight=0.0001,
|
||||
skip_weight=0.8,
|
||||
baseline_decay=0.999,
|
||||
child_steps=500,
|
||||
mutator_lr=0.00035,
|
||||
mutator_steps=50,
|
||||
mutator_steps_aggregate=20,
|
||||
aux_weight=0.4,
|
||||
test_arc_per_epoch=1,
|
||||
):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.metrics = metrics
|
||||
self.reward_function = reward_function
|
||||
self.optimizer = optimizer
|
||||
self.batch_size = batch_size
|
||||
self.num_epochs = num_epochs
|
||||
|
||||
x, y = dataset_train
|
||||
split = int(len(x) * 0.9)
|
||||
self.train_set = tf.data.Dataset.from_tensor_slices((x[:split], y[:split]))
|
||||
self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
|
||||
self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)
|
||||
|
||||
self.log_frequency = log_frequency
|
||||
self.entropy_weight = entropy_weight
|
||||
self.skip_weight = skip_weight
|
||||
self.baseline_decay = baseline_decay
|
||||
self.child_steps = child_steps
|
||||
self.mutator_lr = mutator_lr
|
||||
self.mutator_steps = mutator_steps
|
||||
self.mutator_steps_aggregate = mutator_steps_aggregate
|
||||
self.aux_weight = aux_weight
|
||||
self.test_arc_per_epoch = test_arc_per_epoch
|
||||
|
||||
self.mutator = EnasMutator(model)
|
||||
self.mutator_optim = Adam(learning_rate=self.mutator_lr)
|
||||
|
||||
self.baseline = 0.0
|
||||
|
||||
def train(self, validate=True):
|
||||
for epoch in range(self.num_epochs):
|
||||
logger.info("Epoch %d Training", epoch + 1)
|
||||
self.train_one_epoch(epoch)
|
||||
logger.info("Epoch %d Validating", epoch + 1)
|
||||
self.validate_one_epoch(epoch)
|
||||
|
||||
def validate(self):
|
||||
self.validate_one_epoch(-1)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
train_loader, valid_loader = self._create_train_loader()
|
||||
|
||||
# Sample model and train
|
||||
meters = AverageMeterGroup()
|
||||
|
||||
for step in range(1, self.child_steps + 1):
|
||||
x, y = next(train_loader)
|
||||
self.mutator.reset()
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits = self.model(x, training=True)
|
||||
if isinstance(logits, tuple):
|
||||
logits, aux_logits = logits
|
||||
aux_loss = self.loss(aux_logits, y)
|
||||
else:
|
||||
aux_loss = 0.0
|
||||
metrics = self.metrics(y, logits)
|
||||
loss = self.loss(y, logits) + self.aux_weight * aux_loss
|
||||
|
||||
grads = tape.gradient(loss, self.model.trainable_weights)
|
||||
grads = fill_zero_grads(grads, self.model.trainable_weights)
|
||||
grads, _ = tf.clip_by_global_norm(grads, 5.0)
|
||||
self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
|
||||
|
||||
metrics["loss"] = tf.reduce_mean(loss).numpy()
|
||||
meters.update(metrics)
|
||||
|
||||
if self.log_frequency and step % self.log_frequency == 0:
|
||||
logger.info(
|
||||
"Model Epoch [%d/%d] Step [%d/%d] %s",
|
||||
epoch + 1,
|
||||
self.num_epochs,
|
||||
step,
|
||||
self.child_steps,
|
||||
meters,
|
||||
)
|
||||
|
||||
# Train sampler (mutator)
|
||||
meters = AverageMeterGroup()
|
||||
for mutator_step in range(1, self.mutator_steps + 1):
|
||||
grads_list = []
|
||||
for step in range(1, self.mutator_steps_aggregate + 1):
|
||||
with tf.GradientTape() as tape:
|
||||
x, y = next(valid_loader)
|
||||
self.mutator.reset()
|
||||
|
||||
logits = self.model(x, training=False)
|
||||
metrics = self.metrics(y, logits)
|
||||
reward = (
|
||||
self.reward_function(y, logits)
|
||||
+ self.entropy_weight * self.mutator.sample_entropy
|
||||
)
|
||||
self.baseline = self.baseline * self.baseline_decay + reward * (
|
||||
1 - self.baseline_decay
|
||||
)
|
||||
loss = self.mutator.sample_log_prob * (reward - self.baseline)
|
||||
loss += self.skip_weight * self.mutator.sample_skip_penalty
|
||||
|
||||
meters.update(
|
||||
{
|
||||
"reward": reward,
|
||||
"loss": tf.reduce_mean(loss).numpy(),
|
||||
"ent": self.mutator.sample_entropy.numpy(),
|
||||
"log_prob": self.mutator.sample_log_prob.numpy(),
|
||||
"baseline": self.baseline,
|
||||
"skip": self.mutator.sample_skip_penalty,
|
||||
}
|
||||
)
|
||||
|
||||
cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
|
||||
if self.log_frequency and cur_step % self.log_frequency == 0:
|
||||
logger.info(
|
||||
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s",
|
||||
epoch + 1,
|
||||
self.num_epochs,
|
||||
mutator_step,
|
||||
self.mutator_steps,
|
||||
step,
|
||||
self.mutator_steps_aggregate,
|
||||
meters,
|
||||
)
|
||||
|
||||
grads = tape.gradient(loss, self.mutator.trainable_weights)
|
||||
grads = fill_zero_grads(grads, self.mutator.trainable_weights)
|
||||
grads_list.append(grads)
|
||||
total_grads = [
|
||||
tf.math.add_n(weight_grads) for weight_grads in zip(*grads_list)
|
||||
]
|
||||
total_grads, _ = tf.clip_by_global_norm(total_grads, 5.0)
|
||||
self.mutator_optim.apply_gradients(
|
||||
zip(total_grads, self.mutator.trainable_weights)
|
||||
)
|
||||
|
||||
def validate_one_epoch(self, epoch):
|
||||
test_loader = self._create_validate_loader()
|
||||
|
||||
for arc_id in range(self.test_arc_per_epoch):
|
||||
meters = AverageMeterGroup()
|
||||
for x, y in test_loader:
|
||||
self.mutator.reset()
|
||||
logits = self.model(x, training=False)
|
||||
if isinstance(logits, tuple):
|
||||
logits, _ = logits
|
||||
metrics = self.metrics(y, logits)
|
||||
loss = self.loss(y, logits)
|
||||
metrics["loss"] = tf.reduce_mean(loss).numpy()
|
||||
meters.update(metrics)
|
||||
|
||||
logger.info(
|
||||
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s",
|
||||
epoch + 1,
|
||||
self.num_epochs,
|
||||
arc_id + 1,
|
||||
self.test_arc_per_epoch,
|
||||
meters.summary(),
|
||||
)
|
||||
|
||||
def _create_train_loader(self):
|
||||
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
|
||||
test_set = self.valid_set.shuffle(1000000).repeat().batch(self.batch_size)
|
||||
return iter(train_set), iter(test_set)
|
||||
|
||||
def _create_validate_loader(self):
|
||||
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .execution import *
|
||||
from .fixed import fixed_arch
|
||||
from .mutable import *
|
||||
from .utils import *
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from nni.common.framework import shortcut_framework
|
||||
|
||||
from .evaluator import Evaluator
|
||||
from .functional import FunctionalEvaluator
|
||||
|
||||
shortcut_framework(__name__)
|
||||
|
||||
del shortcut_framework
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['Evaluator']
|
||||
|
||||
import abc
|
||||
from typing import Any, Callable, Type, Union, cast
|
||||
|
||||
|
||||
class Evaluator(abc.ABC):
|
||||
"""
|
||||
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
|
||||
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
|
||||
or tune-able parameters (such as learning rate), depending on the implementation of training code.
|
||||
|
||||
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
|
||||
For example, functional evaluator might directly import the function and call the function.
|
||||
"""
|
||||
|
||||
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
|
||||
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
|
||||
|
||||
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
|
||||
"""
|
||||
return self._execute(model_cls)
|
||||
|
||||
def __repr__(self):
|
||||
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
|
||||
return f'{self.__class__.__name__}({items})'
|
||||
|
||||
@staticmethod
|
||||
def _load(ir: Any) -> 'Evaluator':
|
||||
evaluator_type = ir.get('type')
|
||||
if isinstance(evaluator_type, str):
|
||||
# for debug purposes only
|
||||
for subclass in Evaluator.__subclasses__():
|
||||
if subclass.__name__ == evaluator_type:
|
||||
evaluator_type = subclass
|
||||
break
|
||||
assert issubclass(cast(type, evaluator_type), Evaluator)
|
||||
return cast(Type[Evaluator], evaluator_type)._load(ir)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dump(self) -> Any:
|
||||
"""
|
||||
Subclass implements ``_dump`` for their own serialization.
|
||||
They should return a dict, with a key ``type`` which equals ``self.__class__``,
|
||||
and optionally other keys.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other) -> bool:
|
||||
pass
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import nni
|
||||
from .evaluator import Evaluator
|
||||
|
||||
|
||||
@nni.trace
|
||||
class FunctionalEvaluator(Evaluator):
|
||||
"""
|
||||
Functional evaluator that directly takes a function and thus should be general.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
function
|
||||
The full name of the function.
|
||||
arguments
|
||||
Keyword arguments for the function other than model.
|
||||
"""
|
||||
|
||||
def __init__(self, function, **kwargs):
|
||||
self.function = function
|
||||
self.arguments = kwargs
|
||||
|
||||
@staticmethod
|
||||
def _load(ir):
|
||||
return FunctionalEvaluator(ir['function'], **ir['arguments'])
|
||||
|
||||
def _dump(self):
|
||||
return {
|
||||
'type': self.__class__,
|
||||
'function': self.function,
|
||||
'arguments': self.arguments
|
||||
}
|
||||
|
||||
def _execute(self, model_cls):
|
||||
return self.function(model_cls, **self.arguments)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.function == other.function and self.arguments == other.arguments
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import RandomMutator
|
||||
from .lightning import *
|
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union, Type
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchmetrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import nni
|
||||
|
||||
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
|
||||
from .trainer import Trainer
|
||||
|
||||
__all__ = [
|
||||
'_MultiModelSupervisedLearningModule', 'MultiModelSupervisedLearningModule',
|
||||
'_ClassificationModule', 'Classification',
|
||||
'_RegressionModule', 'Regression',
|
||||
]
|
||||
|
||||
|
||||
@nni.trace
|
||||
class _MultiModelSupervisedLearningModule(LightningModule):
|
||||
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric],
|
||||
n_models: int = 0,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam):
|
||||
super().__init__()
|
||||
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
|
||||
self.criterion = criterion()
|
||||
self.criterion_cls = criterion
|
||||
self.optimizer = optimizer
|
||||
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
|
||||
self.metrics_args = metrics
|
||||
self.n_models = n_models
|
||||
|
||||
def dump_kwargs(self):
|
||||
kwargs = {}
|
||||
kwargs['criterion'] = self.criterion_cls
|
||||
kwargs['metrics'] = self.metrics_args
|
||||
kwargs['n_models'] = self.n_models
|
||||
kwargs['learning_rate'] = self.hparams['learning_rate']
|
||||
kwargs['weight_decay'] = self.hparams['weight_decay']
|
||||
kwargs['optimizer'] = self.optimizer
|
||||
return kwargs
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y_hat = self.model(x)
|
||||
return y_hat
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
multi_y_hat = self(x)
|
||||
if isinstance(multi_y_hat, tuple):
|
||||
assert len(multi_y_hat) == self.n_models
|
||||
else:
|
||||
assert self.n_models == 1
|
||||
multi_y_hat = [multi_y_hat]
|
||||
multi_loss = []
|
||||
for idx, y_hat in enumerate(multi_y_hat):
|
||||
loss = self.criterion(y_hat.to("cpu"), y.to("cpu"))
|
||||
self.log(f'train_loss_{idx}', loss, prog_bar=True)
|
||||
for name, metric in self.metrics.items():
|
||||
self.log(f'train_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
multi_loss.append(loss)
|
||||
return sum(multi_loss)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
multi_y_hat = self(x)
|
||||
if isinstance(multi_y_hat, tuple):
|
||||
assert len(multi_y_hat) == self.n_models
|
||||
else:
|
||||
assert self.n_models == 1
|
||||
multi_y_hat = [multi_y_hat]
|
||||
for idx, y_hat in enumerate(multi_y_hat):
|
||||
self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
for name, metric in self.metrics.items():
|
||||
self.log(f'val_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
multi_y_hat = self(x)
|
||||
if isinstance(multi_y_hat, tuple):
|
||||
assert len(multi_y_hat) == self.n_models
|
||||
else:
|
||||
assert self.n_models == 1
|
||||
multi_y_hat = [multi_y_hat]
|
||||
for idx, y_hat in enumerate(multi_y_hat):
|
||||
self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
for name, metric in self.metrics.items():
|
||||
self.log(f'test_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
nni.report_intermediate_result(self._get_validation_metrics())
|
||||
|
||||
def teardown(self, stage):
|
||||
if stage == 'fit':
|
||||
nni.report_final_result(self._get_validation_metrics())
|
||||
|
||||
def _get_validation_metrics(self):
|
||||
# TODO: split metric of multiple models?
|
||||
if len(self.metrics) == 1:
|
||||
metric_name = next(iter(self.metrics))
|
||||
ret = []
|
||||
for idx in range(self.n_models):
|
||||
ret.append(self.trainer.callback_metrics[f'val_{idx}_' + metric_name].item())
|
||||
return ret
|
||||
else:
|
||||
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
|
||||
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
|
||||
|
||||
|
||||
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
|
||||
"""
|
||||
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
|
||||
Users who needs cross-graph optimization should use this module.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
criterion : nn.Module
|
||||
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
|
||||
learning_rate : float
|
||||
Learning rate. default: 0.001
|
||||
weight_decay : float
|
||||
L2 weight decay. default: 0
|
||||
optimizer : Optimizer
|
||||
Class for optimizer (not an instance). default: ``Adam``
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam):
|
||||
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
|
||||
|
||||
|
||||
class _ClassificationModule(_MultiModelSupervisedLearningModule):
|
||||
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam):
|
||||
super().__init__(criterion, {'acc': _AccuracyWithLogits},
|
||||
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
|
||||
|
||||
|
||||
class Classification(Lightning):
|
||||
"""
|
||||
Trainer that is used for classification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
criterion : nn.Module
|
||||
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
|
||||
learning_rate : float
|
||||
Learning rate. default: 0.001
|
||||
weight_decay : float
|
||||
L2 weight decay. default: 0
|
||||
optimizer : Optimizer
|
||||
Class for optimizer (not an instance). default: ``Adam``
|
||||
train_dataloders : DataLoader
|
||||
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
|
||||
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
|
||||
val_dataloaders : DataLoader or List of DataLoader
|
||||
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
|
||||
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
|
||||
trainer_kwargs : dict
|
||||
Optional keyword arguments passed to trainer. See
|
||||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
|
||||
**trainer_kwargs):
|
||||
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
|
||||
weight_decay=weight_decay, optimizer=optimizer)
|
||||
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
|
||||
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
|
||||
|
||||
class _RegressionModule(_MultiModelSupervisedLearningModule):
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam):
|
||||
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
|
||||
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
|
||||
|
||||
|
||||
class Regression(Lightning):
|
||||
"""
|
||||
Trainer that is used for regression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
criterion : nn.Module
|
||||
Class for criterion module (not an instance). default: ``nn.MSELoss``
|
||||
learning_rate : float
|
||||
Learning rate. default: 0.001
|
||||
weight_decay : float
|
||||
L2 weight decay. default: 0
|
||||
optimizer : Optimizer
|
||||
Class for optimizer (not an instance). default: ``Adam``
|
||||
train_dataloders : DataLoader
|
||||
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
|
||||
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
|
||||
val_dataloaders : DataLoader or List of DataLoader
|
||||
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
|
||||
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
|
||||
trainer_kwargs : dict
|
||||
Optional keyword arguments passed to trainer. See
|
||||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: nn.Module = nn.MSELoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: optim.Optimizer = optim.Adam,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
|
||||
**trainer_kwargs):
|
||||
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
|
||||
weight_decay=weight_decay, optimizer=optimizer)
|
||||
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
|
||||
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.strategies import SingleDeviceStrategy
|
||||
|
||||
class BypassStrategy(SingleDeviceStrategy):
|
||||
strategy_name = "single_device"
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
pass
|
||||
|
||||
class Trainer(pl.Trainer):
|
||||
"""
|
||||
Trainer for cross-graph optimization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_cgo : bool
|
||||
Whether cross-graph optimization (CGO) is used.
|
||||
If it is True, CGO will manage device placement.
|
||||
Any device placement from pytorch lightning will be bypassed.
|
||||
default: False
|
||||
trainer_kwargs : dict
|
||||
Optional keyword arguments passed to trainer. See
|
||||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
"""
|
||||
|
||||
def __init__(self, use_cgo=False, **trainer_kwargs):
|
||||
if use_cgo:
|
||||
if "accelerator" in trainer_kwargs:
|
||||
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
|
||||
|
||||
if 'strategy' in trainer_kwargs:
|
||||
raise ValueError("cgo.trainer does not support specifying strategy")
|
||||
trainer_kwargs['strategy'] = BypassStrategy()
|
||||
|
||||
super().__init__(**trainer_kwargs)
|
|
@ -0,0 +1,412 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union, Optional, List, Callable, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as nn_functional
|
||||
import torch.optim as optim
|
||||
import torchmetrics
|
||||
import torch.utils.data as torch_data
|
||||
|
||||
import nni
|
||||
from nni.common.serializer import is_traceable
|
||||
try:
|
||||
from .cgo import trainer as cgo_trainer
|
||||
cgo_import_failed = False
|
||||
except ImportError:
|
||||
cgo_import_failed = True
|
||||
|
||||
from nni.nas.evaluator import Evaluator
|
||||
from nni.typehint import Literal
|
||||
|
||||
|
||||
__all__ = [
|
||||
'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression',
|
||||
'_AccuracyWithLogits', '_SupervisedLearningModule', '_ClassificationModule', '_RegressionModule',
|
||||
# FIXME: hack to make it importable for tests
|
||||
]
|
||||
|
||||
|
||||
class LightningModule(pl.LightningModule):
|
||||
"""
|
||||
Basic wrapper of generated model.
|
||||
Lightning modules used in NNI should inherit this class.
|
||||
|
||||
It's a subclass of ``pytorch_lightning.LightningModule``.
|
||||
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
|
||||
"""
|
||||
|
||||
running_mode: Literal['multi', 'oneshot'] = 'multi'
|
||||
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
|
||||
This flag should be automatically set by experiments when they start to run.
|
||||
"""
|
||||
|
||||
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
|
||||
"""Set the inner model (architecture) to train / evaluate.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : callable or nn.Module
|
||||
Can be a callable returning nn.Module or nn.Module.
|
||||
"""
|
||||
if isinstance(model, nn.Module):
|
||||
self.model = model
|
||||
else:
|
||||
self.model = model()
|
||||
|
||||
|
||||
Trainer = nni.trace(pl.Trainer)
|
||||
Trainer.__doc__ = """
|
||||
Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
|
||||
"""
|
||||
DataLoader = nni.trace(torch_data.DataLoader)
|
||||
DataLoader.__doc__ = """
|
||||
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
|
||||
"""
|
||||
|
||||
|
||||
@nni.trace
|
||||
class Lightning(Evaluator):
|
||||
"""
|
||||
Delegate the whole training to PyTorch Lightning.
|
||||
|
||||
Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or
|
||||
``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in
|
||||
which case, dataloaders are not required for this class to work.
|
||||
|
||||
Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics``
|
||||
in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name
|
||||
and type depend on the specific task.
|
||||
|
||||
.. warning::
|
||||
|
||||
The Lightning evaluator are stateful. If you try to use a previous Lightning evaluator,
|
||||
please note that the inner ``lightning_module`` and ``trainer`` will be reused.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lightning_module
|
||||
Lightning module that defines the training logic.
|
||||
trainer
|
||||
Lightning trainer that handles the training.
|
||||
train_dataloders
|
||||
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
|
||||
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
|
||||
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
|
||||
val_dataloaders
|
||||
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
|
||||
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
|
||||
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
|
||||
"""
|
||||
|
||||
def __init__(self, lightning_module: LightningModule, trainer: Trainer,
|
||||
train_dataloaders: Optional[Any] = None,
|
||||
val_dataloaders: Optional[Any] = None,
|
||||
train_dataloader: Optional[Any] = None):
|
||||
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
|
||||
if train_dataloader is not None:
|
||||
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
|
||||
train_dataloaders = train_dataloader
|
||||
if cgo_import_failed:
|
||||
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
|
||||
else:
|
||||
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
|
||||
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
|
||||
f'Trainer must be imported from {__name__} or nni.nas.evaluator.pytorch.cgo.trainer'
|
||||
if not _check_dataloader(train_dataloaders):
|
||||
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
|
||||
f'import DataLoader from {__name__}: {train_dataloaders}',
|
||||
RuntimeWarning)
|
||||
if not _check_dataloader(val_dataloaders):
|
||||
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
|
||||
f'import DataLoader from {__name__}: {val_dataloaders}',
|
||||
RuntimeWarning)
|
||||
self.module = lightning_module
|
||||
self.trainer = trainer
|
||||
self.train_dataloaders = train_dataloaders
|
||||
self.val_dataloaders = val_dataloaders
|
||||
|
||||
@staticmethod
|
||||
def _load(ir):
|
||||
return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders'])
|
||||
|
||||
def _dump(self):
|
||||
return {
|
||||
'type': self.__class__,
|
||||
'module': self.module,
|
||||
'trainer': self.trainer,
|
||||
'train_dataloaders': self.train_dataloaders,
|
||||
'val_dataloaders': self.val_dataloaders
|
||||
}
|
||||
|
||||
def _execute(self, model_cls):
|
||||
return self.fit(model_cls)
|
||||
|
||||
@property
|
||||
def train_dataloader(self):
|
||||
warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning)
|
||||
|
||||
def __eq__(self, other):
|
||||
eq_func = False
|
||||
eq_args = False
|
||||
if other is None:
|
||||
return False
|
||||
if hasattr(self, "function") and hasattr(other, "function"):
|
||||
eq_func = getattr(self, "function") == getattr(other, "function")
|
||||
elif not (hasattr(self, "function") or hasattr(other, "function")):
|
||||
eq_func = True
|
||||
|
||||
if hasattr(self, "arguments") and hasattr(other, "arguments"):
|
||||
eq_args = getattr(self, "arguments") == getattr(other, "arguments")
|
||||
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
|
||||
eq_args = True
|
||||
|
||||
return eq_func and eq_args
|
||||
|
||||
def fit(self, model):
|
||||
"""
|
||||
Fit the model with provided dataloader, with Lightning trainer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : nn.Module
|
||||
The model to fit.
|
||||
"""
|
||||
self.module.set_model(model)
|
||||
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders)
|
||||
|
||||
|
||||
def _check_dataloader(dataloader):
|
||||
# Check the type of dataloader recursively.
|
||||
if isinstance(dataloader, list):
|
||||
return all([_check_dataloader(d) for d in dataloader])
|
||||
if isinstance(dataloader, dict):
|
||||
return all([_check_dataloader(v) for v in dataloader.values()])
|
||||
if isinstance(dataloader, torch_data.DataLoader):
|
||||
return is_traceable(dataloader)
|
||||
return True
|
||||
|
||||
|
||||
### The following are some commonly used Lightning modules ###
|
||||
|
||||
class _SupervisedLearningModule(LightningModule):
|
||||
|
||||
trainer: pl.Trainer
|
||||
|
||||
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
export_onnx: Union[Path, str, bool, None] = None):
|
||||
super().__init__()
|
||||
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
|
||||
self.criterion = criterion()
|
||||
self.optimizer = optimizer
|
||||
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
|
||||
|
||||
if export_onnx is None or export_onnx is True:
|
||||
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
|
||||
elif export_onnx:
|
||||
self.export_onnx = Path(export_onnx)
|
||||
else:
|
||||
self.export_onnx = None
|
||||
|
||||
def forward(self, x):
|
||||
y_hat = self.model(x)
|
||||
return y_hat
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = self.criterion(y_hat, y)
|
||||
self.log('train_loss', loss, prog_bar=True)
|
||||
for name, metric in self.metrics.items():
|
||||
self.log('train_' + name, metric(y_hat, y), prog_bar=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
|
||||
if self.running_mode == 'multi' and self.export_onnx is not None:
|
||||
self.export_onnx.parent.mkdir(exist_ok=True)
|
||||
try:
|
||||
self.to_onnx(self.export_onnx, x, export_params=True)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(f'ONNX conversion failed. As a result, you might not be able to use visualization. Error message: {e}')
|
||||
self.export_onnx = None
|
||||
|
||||
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
|
||||
for name, metric in self.metrics.items():
|
||||
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
self.log('test_loss', self.criterion(y_hat, y), prog_bar=True)
|
||||
for name, metric in self.metrics.items():
|
||||
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
if not self.trainer.sanity_checking and self.running_mode == 'multi':
|
||||
# Don't report metric when sanity checking
|
||||
nni.report_intermediate_result(self._get_validation_metrics())
|
||||
|
||||
def on_fit_end(self):
|
||||
if self.running_mode == 'multi':
|
||||
nni.report_final_result(self._get_validation_metrics())
|
||||
|
||||
def _get_validation_metrics(self):
|
||||
if len(self.metrics) == 1:
|
||||
metric_name = next(iter(self.metrics))
|
||||
return self.trainer.callback_metrics['val_' + metric_name].item()
|
||||
else:
|
||||
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
|
||||
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
|
||||
|
||||
|
||||
class _AccuracyWithLogits(torchmetrics.Accuracy):
|
||||
def update(self, pred, target):
|
||||
return super().update(nn_functional.softmax(pred, dim=-1), target)
|
||||
|
||||
|
||||
@nni.trace
|
||||
class _ClassificationModule(_SupervisedLearningModule):
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
export_onnx: bool = True):
|
||||
super().__init__(criterion, {'acc': _AccuracyWithLogits},
|
||||
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
|
||||
export_onnx=export_onnx)
|
||||
|
||||
|
||||
class Classification(Lightning):
|
||||
"""
|
||||
Evaluator that is used for classification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
criterion : nn.Module
|
||||
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
|
||||
learning_rate : float
|
||||
Learning rate. default: 0.001
|
||||
weight_decay : float
|
||||
L2 weight decay. default: 0
|
||||
optimizer : Optimizer
|
||||
Class for optimizer (not an instance). default: ``Adam``
|
||||
train_dataloaders : DataLoader
|
||||
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
|
||||
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
|
||||
val_dataloaders : DataLoader or List of DataLoader
|
||||
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
|
||||
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
|
||||
export_onnx : bool
|
||||
If true, model will be exported to ``model.onnx`` before training starts. default true
|
||||
trainer_kwargs : dict
|
||||
Optional keyword arguments passed to trainer. See
|
||||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> evaluator = Classification()
|
||||
|
||||
To use customized criterion and optimizer:
|
||||
|
||||
>>> evaluator = Classification(nn.LabelSmoothingCrossEntropy, optimizer=torch.optim.SGD)
|
||||
|
||||
Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration:
|
||||
|
||||
>>> evaluator = Classification(accelerator='gpu', devices=2, strategy='ddp')
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
train_dataloaders: Optional[DataLoader] = None,
|
||||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
|
||||
export_onnx: bool = True,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
**trainer_kwargs):
|
||||
if train_dataloader is not None:
|
||||
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
|
||||
train_dataloaders = train_dataloader
|
||||
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
|
||||
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
|
||||
super().__init__(module, Trainer(**trainer_kwargs),
|
||||
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
|
||||
|
||||
|
||||
@nni.trace
|
||||
class _RegressionModule(_SupervisedLearningModule):
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
export_onnx: bool = True):
|
||||
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
|
||||
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
|
||||
export_onnx=export_onnx)
|
||||
|
||||
|
||||
class Regression(Lightning):
|
||||
"""
|
||||
Evaluator that is used for regression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
criterion : nn.Module
|
||||
Class for criterion module (not an instance). default: ``nn.MSELoss``
|
||||
learning_rate : float
|
||||
Learning rate. default: 0.001
|
||||
weight_decay : float
|
||||
L2 weight decay. default: 0
|
||||
optimizer : Optimizer
|
||||
Class for optimizer (not an instance). default: ``Adam``
|
||||
train_dataloaders : DataLoader
|
||||
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
|
||||
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
|
||||
val_dataloaders : DataLoader or List of DataLoader
|
||||
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
|
||||
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
|
||||
export_onnx : bool
|
||||
If true, model will be exported to ``model.onnx`` before training starts. default: true
|
||||
trainer_kwargs : dict
|
||||
Optional keyword arguments passed to trainer. See
|
||||
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> evaluator = Regression()
|
||||
|
||||
Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration:
|
||||
|
||||
>>> evaluator = Regression(gpus=1)
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 0.,
|
||||
optimizer: Type[optim.Optimizer] = optim.Adam,
|
||||
train_dataloaders: Optional[DataLoader] = None,
|
||||
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
|
||||
export_onnx: bool = True,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
**trainer_kwargs):
|
||||
if train_dataloader is not None:
|
||||
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
|
||||
train_dataloaders = train_dataloader
|
||||
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
|
||||
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
|
||||
super().__init__(module, Trainer(**trainer_kwargs),
|
||||
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .nasbench201 import NASBench201Cell
|
||||
from .api import *
|
||||
from .common import *
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import time
|
||||
import warnings
|
||||
from typing import Iterable
|
||||
|
||||
from nni.nas.execution.common import (
|
||||
Model, ModelStatus,
|
||||
AbstractExecutionEngine,
|
||||
DefaultListener
|
||||
)
|
||||
|
||||
_execution_engine = None
|
||||
_default_listener = None
|
||||
|
||||
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
|
||||
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
|
||||
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
|
||||
|
||||
|
||||
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
|
||||
global _execution_engine
|
||||
if _execution_engine is not None:
|
||||
warnings.warn('Execution engine is already set. '
|
||||
'You should avoid instantiating RetiariiExperiment twice in one process. '
|
||||
'If you are running in a Jupyter notebook, please restart the kernel.',
|
||||
RuntimeWarning)
|
||||
_execution_engine = engine
|
||||
|
||||
|
||||
def get_execution_engine() -> AbstractExecutionEngine:
|
||||
global _execution_engine
|
||||
assert _execution_engine is not None, 'You need to set execution engine, before using it.'
|
||||
return _execution_engine
|
||||
|
||||
|
||||
def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener:
|
||||
global _default_listener
|
||||
if _default_listener is None:
|
||||
_default_listener = DefaultListener()
|
||||
engine.register_graph_listener(_default_listener)
|
||||
return _default_listener
|
||||
|
||||
|
||||
def submit_models(*models: Model) -> None:
|
||||
engine = get_execution_engine()
|
||||
get_and_register_default_listener(engine)
|
||||
engine.submit_models(*models)
|
||||
|
||||
|
||||
def list_models(*models: Model) -> Iterable[Model]:
|
||||
engine = get_execution_engine()
|
||||
get_and_register_default_listener(engine)
|
||||
return engine.list_models()
|
||||
|
||||
|
||||
def wait_models(*models: Model) -> None:
|
||||
get_and_register_default_listener(get_execution_engine())
|
||||
while True:
|
||||
time.sleep(1)
|
||||
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)]
|
||||
if not left_models:
|
||||
break
|
||||
|
||||
|
||||
def query_available_resources() -> int:
|
||||
engine = get_execution_engine()
|
||||
resources = engine.query_available_resource()
|
||||
return resources if isinstance(resources, int) else len(resources)
|
||||
|
||||
|
||||
def is_stopped_exec(model: Model) -> bool:
|
||||
return model.status in (ModelStatus.Trained, ModelStatus.Failed)
|
||||
|
||||
|
||||
def budget_exhausted() -> bool:
|
||||
engine = get_execution_engine()
|
||||
return engine.budget_exhausted()
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .engine import *
|
||||
from .graph_op import *
|
||||
from .graph import *
|
||||
from .integration_api import *
|
||||
from .integration import *
|
||||
from .listener import *
|
||||
from .utils import *
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from abc import ABC, abstractmethod, abstractclassmethod
|
||||
from typing import Any, Iterable, NewType, List, Union, Type
|
||||
|
||||
from .graph import Model, MetricData
|
||||
|
||||
__all__ = [
|
||||
'GraphData', 'WorkerInfo', 'MetricData',
|
||||
'AbstractGraphListener', 'AbstractExecutionEngine'
|
||||
]
|
||||
|
||||
|
||||
GraphData: Type[Any] = NewType('GraphData', Any)
|
||||
"""
|
||||
A _serializable_ internal data type defined by execution engine.
|
||||
|
||||
Execution engine will submit this kind of data through NNI to worker machine, and train it there.
|
||||
|
||||
A `GraphData` object describes a (merged) executable graph.
|
||||
|
||||
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format.
|
||||
|
||||
See `AbstractExecutionEngine` for details.
|
||||
"""
|
||||
|
||||
|
||||
WorkerInfo: Type[Any] = NewType('WorkerInfo', Any)
|
||||
"""
|
||||
To be designed. Discussion needed.
|
||||
|
||||
This describes the properties of a worker machine. (e.g. memory size)
|
||||
"""
|
||||
|
||||
|
||||
class AbstractGraphListener(ABC):
|
||||
"""
|
||||
Abstract listener interface to receive graph events.
|
||||
|
||||
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def on_metric(self, model: Model, metric: MetricData) -> None:
|
||||
"""
|
||||
Reports the final metric of a graph.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
|
||||
"""
|
||||
Reports the latest intermediate metric of a trainning graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_training_end(self, model: Model, success: bool) -> None:
|
||||
"""
|
||||
Reports either a graph is fully trained or the training process has failed.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AbstractExecutionEngine(ABC):
|
||||
"""
|
||||
The abstract interface of execution engine.
|
||||
|
||||
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial.
|
||||
Strategy will get the singleton execution engine object through a global API,
|
||||
and use it in either sync or async manner.
|
||||
|
||||
Execution engine is responsible for submitting (maybe-optimized) models to NNI,
|
||||
and assigning their metrics to the `Model` object after training.
|
||||
Execution engine is also responsible to launch the graph in trial process,
|
||||
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term.
|
||||
|
||||
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion.
|
||||
|
||||
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly,
|
||||
and will receive metrics from `Model` attributes.
|
||||
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now).
|
||||
|
||||
In asynchronized use case, the strategy will register a listener to receive events,
|
||||
while still using `submit_models` to train.
|
||||
|
||||
There will be a `BaseExecutionEngine` subclass.
|
||||
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`,
|
||||
while overrides `submit_models` and `trial_execute_graph`.
|
||||
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly,
|
||||
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic.
|
||||
|
||||
There might be some util functions benefit all optimizing methods,
|
||||
but non-mandatory utils should not be covered in abstract interface.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def submit_models(self, *models: Model) -> None:
|
||||
"""
|
||||
Submit models to NNI.
|
||||
|
||||
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self) -> Iterable[Model]:
|
||||
"""
|
||||
Get all models in submitted.
|
||||
|
||||
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query_available_resource(self) -> Union[List[WorkerInfo], int]: # type: ignore
|
||||
"""
|
||||
Returns information of all idle workers.
|
||||
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
|
||||
|
||||
Could be left unimplemented for first iteration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def budget_exhausted(self) -> bool:
|
||||
"""
|
||||
Check whether user configured max trial number or max execution duration has been reached
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
|
||||
"""
|
||||
Register a listener to receive graph events.
|
||||
|
||||
Could be left unimplemented for first iteration.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractclassmethod
|
||||
def trial_execute_graph(cls) -> MetricData:
|
||||
"""
|
||||
Train graph and returns its metrics, in a separate trial process.
|
||||
|
||||
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method.
|
||||
|
||||
Because this method will be invoked in trial process on training platform,
|
||||
it has different context from other methods and has no access to global variable or `self`.
|
||||
However util APIs like `.utils.experiment_config()` should still be available.
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,774 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Model representation for engines based on graph.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
|
||||
Optional, Set, Tuple, Type, Union, cast, overload)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mutator import Mutator
|
||||
|
||||
from nni.nas.evaluator import Evaluator
|
||||
from nni.nas.utils import uid
|
||||
from .graph_op import Cell, Operation, _IOPseudoOperation
|
||||
|
||||
__all__ = [
|
||||
'Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData',
|
||||
'DebugEvaluator',
|
||||
]
|
||||
|
||||
|
||||
MetricData = Any
|
||||
"""
|
||||
Type hint for graph metrics (loss, accuracy, etc).
|
||||
"""
|
||||
|
||||
EdgeEndpoint = Tuple['Node', Optional[int]]
|
||||
"""
|
||||
Type hint for edge's endpoint. The int indicates nodes' order.
|
||||
"""
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
Represents a neural network model.
|
||||
|
||||
During mutation, one :class:`Model` object is created for each trainable snapshot.
|
||||
For example, consider a mutator that insert a node at an edge for each iteration.
|
||||
In one iteration, the mutator invokes 4 primitives: add node, remove edge, add edge to head, add edge to tail.
|
||||
These 4 primitives operates in one :class:`Model` object.
|
||||
When they are all done the model will be set to "frozen" (trainable) status and be submitted to execution engine.
|
||||
And then a new iteration starts, and a new :class:`Model` object is created by forking last model.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
python_object
|
||||
Python object of base model. It will be none when the base model is not available.
|
||||
python_class
|
||||
Python class that base model is converted from.
|
||||
python_init_params
|
||||
Initialization parameters of python class.
|
||||
status
|
||||
See :class:`ModelStatus`.
|
||||
root_graph
|
||||
The outermost graph which usually takes dataset as input and feeds output to loss function.
|
||||
graphs
|
||||
All graphs (subgraphs) in this model.
|
||||
evaluator
|
||||
Model evaluator
|
||||
history
|
||||
Mutation history.
|
||||
``self`` is directly mutated from ``self.history[-1]``;
|
||||
``self.history[-1]`` is mutated from ``self.history[-2]``, and so on.
|
||||
``self.history[0]`` is the base graph.
|
||||
metric
|
||||
Training result of the model, or ``None`` if it's not yet trained or has failed to train.
|
||||
intermediate_metrics
|
||||
Intermediate training metrics. If the model is not trained, it's an empty list.
|
||||
"""
|
||||
|
||||
def __init__(self, _internal=False):
|
||||
assert _internal, '`Model()` is private, use `model.fork()` instead'
|
||||
self.model_id: int = uid('model')
|
||||
self.python_object: Optional[Any] = None # type is uncertain because it could differ between DL frameworks
|
||||
self.python_class: Optional[Type] = None
|
||||
self.python_init_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
self.status: ModelStatus = ModelStatus.Mutating
|
||||
|
||||
self._root_graph_name: str = '_model'
|
||||
self.graphs: Dict[str, Graph] = {}
|
||||
self.evaluator: Optional[Evaluator] = None
|
||||
|
||||
self.history: List['Mutation'] = []
|
||||
|
||||
self.metric: Optional[MetricData] = None
|
||||
self.intermediate_metrics: List[MetricData] = []
|
||||
|
||||
def __repr__(self):
|
||||
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
|
||||
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \
|
||||
f'python_class={self.python_class})'
|
||||
|
||||
@property
|
||||
def root_graph(self) -> 'Graph':
|
||||
return self.graphs[self._root_graph_name]
|
||||
|
||||
def fork(self) -> 'Model':
|
||||
"""
|
||||
Create a new model which has same topology, names, and IDs to current one.
|
||||
|
||||
Can only be invoked on a frozen model.
|
||||
The new model will be in `Mutating` state.
|
||||
|
||||
This API is used in mutator base class.
|
||||
"""
|
||||
new_model = Model(_internal=True)
|
||||
new_model._root_graph_name = self._root_graph_name
|
||||
new_model.python_class = self.python_class
|
||||
new_model.python_init_params = self.python_init_params
|
||||
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
|
||||
new_model.evaluator = self.evaluator # TODO this needs a clever copy (not deepcopy) if we need mutation
|
||||
new_model.history = [*self.history]
|
||||
# Note: the history is not updated. It will be updated when the model is changed, that is in mutator.
|
||||
return new_model
|
||||
|
||||
@staticmethod
|
||||
def _load(ir: Any) -> 'Model':
|
||||
model = Model(_internal=True)
|
||||
for graph_name, graph_data in ir.items():
|
||||
if graph_name != '_evaluator':
|
||||
Graph._load(model, graph_name, graph_data)._register()
|
||||
if '_evaluator' in ir:
|
||||
model.evaluator = Evaluator._load(ir['_evaluator'])
|
||||
return model
|
||||
|
||||
def _dump(self) -> Any:
|
||||
ret = {name: graph._dump() for name, graph in self.graphs.items()}
|
||||
if self.evaluator is not None:
|
||||
ret['_evaluator'] = self.evaluator._dump()
|
||||
return ret
|
||||
|
||||
def get_nodes(self) -> Iterable['Node']:
|
||||
"""
|
||||
Traverse through all the nodes.
|
||||
"""
|
||||
for graph in self.graphs.values():
|
||||
for node in graph.nodes:
|
||||
yield node
|
||||
|
||||
def get_nodes_by_label(self, label: str) -> List['Node']:
|
||||
"""
|
||||
Traverse all the nodes to find the matched node(s) with the given label.
|
||||
There could be multiple nodes with the same label. Name space name can uniquely
|
||||
identify a graph or node.
|
||||
|
||||
NOTE: the implementation does not support the class abstraction
|
||||
"""
|
||||
matched_nodes = []
|
||||
for graph in self.graphs.values():
|
||||
nodes = graph.get_nodes_by_label(label)
|
||||
matched_nodes.extend(nodes)
|
||||
return matched_nodes
|
||||
|
||||
def get_nodes_by_type(self, type_name: str) -> List['Node']:
|
||||
"""
|
||||
Traverse all the nodes to find the matched node(s) with the given type.
|
||||
"""
|
||||
matched_nodes = []
|
||||
for graph in self.graphs.values():
|
||||
nodes = graph.get_nodes_by_type(type_name)
|
||||
matched_nodes.extend(nodes)
|
||||
return matched_nodes
|
||||
|
||||
def get_node_by_name(self, node_name: str) -> 'Node' | None:
|
||||
"""
|
||||
Traverse all the nodes to find the matched node with the given name.
|
||||
"""
|
||||
matched_nodes = []
|
||||
for graph in self.graphs.values():
|
||||
nodes = graph.get_nodes_by_name(node_name)
|
||||
matched_nodes.extend(nodes)
|
||||
assert len(matched_nodes) <= 1
|
||||
if matched_nodes:
|
||||
return matched_nodes[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
|
||||
"""
|
||||
Traverse all the nodes to find the matched node with the given python_name.
|
||||
"""
|
||||
matched_nodes = []
|
||||
for graph in self.graphs.values():
|
||||
nodes = graph.get_nodes_by_python_name(python_name)
|
||||
matched_nodes.extend(nodes)
|
||||
# assert len(matched_nodes) <= 1
|
||||
if matched_nodes:
|
||||
return matched_nodes[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_cell_nodes(self) -> List['Node']:
|
||||
matched_nodes = []
|
||||
for graph in self.graphs.values():
|
||||
nodes = [node for node in graph.nodes if isinstance(node.operation, Cell)]
|
||||
matched_nodes.extend(nodes)
|
||||
return matched_nodes
|
||||
|
||||
|
||||
class ModelStatus(Enum):
|
||||
"""
|
||||
The status of model.
|
||||
|
||||
A model is created in `Mutating` status.
|
||||
When the mutation is done and the model get ready to train, its status becomes `Frozen`.
|
||||
When training started, the model's status becomes `Training`.
|
||||
If training is successfully ended, model's `metric` attribute get set and its status becomes `Trained`.
|
||||
If training failed, the status becomes `Failed`.
|
||||
"""
|
||||
Mutating = "mutating"
|
||||
Frozen = "frozen"
|
||||
Training = "training"
|
||||
Trained = "trained"
|
||||
Failed = "failed"
|
||||
|
||||
|
||||
_InputPseudoUid = -1
|
||||
_OutputPseudoUid = -2
|
||||
|
||||
|
||||
class Graph:
|
||||
"""
|
||||
Graph topology.
|
||||
|
||||
This class simply represents the topology, with no semantic meaning.
|
||||
All other information like metric, non-graph functions, mutation history, etc should go to :class:`Model`.
|
||||
|
||||
Each graph belongs to and only belongs to one :class:`Model`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
model
|
||||
The model containing (and owning) this graph.
|
||||
id
|
||||
Unique ID in the model.
|
||||
If two models have graphs of identical ID, they are semantically the same graph.
|
||||
Typically this means one graph is mutated from another, or they are both mutated from one ancestor.
|
||||
name
|
||||
Mnemonic name of this graph. It should have an one-to-one mapping with ID.
|
||||
input_names
|
||||
Optional mnemonic names of input parameters.
|
||||
output_names
|
||||
Optional mnemonic names of output values.
|
||||
input_node
|
||||
Incoming node.
|
||||
output_node
|
||||
Output node.
|
||||
hidden_nodes
|
||||
Hidden nodes
|
||||
nodes
|
||||
All input/output/hidden nodes.
|
||||
edges
|
||||
Edges.
|
||||
python_name
|
||||
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
|
||||
assert _internal, '`Graph()` is private'
|
||||
|
||||
self.model: Model = model
|
||||
self.id: int = graph_id
|
||||
self.name: str = name or f'_generated_{graph_id}'
|
||||
|
||||
# `python_name` is `None` by default. It should be set after initialization if it is needed.
|
||||
self.python_name: Optional[str] = None
|
||||
|
||||
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True)
|
||||
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True)
|
||||
self.hidden_nodes: List[Node] = []
|
||||
|
||||
self.edges: List[Edge] = []
|
||||
|
||||
def __repr__(self):
|
||||
return f'Graph(id={self.id}, name={self.name}, ' + \
|
||||
f'input_names={self.input_node.operation.io_names}, ' + \
|
||||
f'output_names={self.output_node.operation.io_names}, ' + \
|
||||
f'num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'
|
||||
|
||||
@property
|
||||
def nodes(self) -> List['Node']:
|
||||
return [self.input_node, self.output_node] + self.hidden_nodes
|
||||
|
||||
def _add_input(self, input_name) -> None:
|
||||
if self.input_node.operation.io_names is None:
|
||||
self.input_node.operation.io_names = [input_name]
|
||||
else:
|
||||
self.input_node.operation.io_names.append(input_name)
|
||||
|
||||
def _add_output(self, output_name) -> None:
|
||||
if self.output_node.operation.io_names is None:
|
||||
self.output_node.operation.io_names = [output_name]
|
||||
else:
|
||||
self.output_node.operation.io_names.append(output_name)
|
||||
|
||||
@overload
|
||||
def add_node(self, name: str, operation: Operation) -> 'Node': ...
|
||||
@overload
|
||||
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
|
||||
|
||||
def add_node(self, name, operation_or_type, parameters=None): # type: ignore
|
||||
if isinstance(operation_or_type, Operation):
|
||||
op = operation_or_type
|
||||
else:
|
||||
op = Operation.new(operation_or_type, cast(dict, parameters), name)
|
||||
return Node(self, uid(), name, op, _internal=True)._register()
|
||||
|
||||
@overload
|
||||
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
|
||||
|
||||
@overload
|
||||
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
|
||||
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
|
||||
|
||||
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
|
||||
if isinstance(operation_or_type, Operation):
|
||||
op = operation_or_type
|
||||
else:
|
||||
op = Operation.new(operation_or_type, cast(dict, parameters), name)
|
||||
new_node = Node(self, uid(), name, op, _internal=True)._register()
|
||||
# update edges
|
||||
self.add_edge((edge.head, edge.head_slot), (new_node, None))
|
||||
self.add_edge((new_node, None), (edge.tail, edge.tail_slot))
|
||||
self.del_edge(edge)
|
||||
return new_node
|
||||
|
||||
# mutation
|
||||
def add_edge(self, head: EdgeEndpoint, tail: EdgeEndpoint) -> 'Edge':
|
||||
assert head[0].graph is self and tail[0].graph is self
|
||||
return Edge(head, tail, _internal=True)._register()
|
||||
|
||||
def del_edge(self, edge: 'Edge') -> None:
|
||||
self.edges.remove(edge)
|
||||
|
||||
def get_node_by_name(self, name: str) -> Optional['Node']:
|
||||
"""
|
||||
Returns the node which has specified name; or returns `None` if no node has this name.
|
||||
"""
|
||||
found = [node for node in self.nodes if node.name == name]
|
||||
return found[0] if found else None
|
||||
|
||||
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
|
||||
"""
|
||||
Returns the node which has specified python_name; or returns `None` if no node has this python_name.
|
||||
"""
|
||||
found = [node for node in self.nodes if node.python_name == python_name]
|
||||
return found[0] if found else None
|
||||
|
||||
def get_nodes_by_type(self, operation_type: str) -> List['Node']:
|
||||
"""
|
||||
Returns nodes whose operation is specified typed.
|
||||
"""
|
||||
return [node for node in self.hidden_nodes if node.operation.type == operation_type]
|
||||
|
||||
def get_node_by_id(self, node_id: int) -> Optional['Node']:
|
||||
"""
|
||||
Returns the node which has specified name; or returns `None` if no node has this name.
|
||||
"""
|
||||
found = [node for node in self.nodes if node.id == node_id]
|
||||
return found[0] if found else None
|
||||
|
||||
def get_nodes_by_label(self, label: str) -> List['Node']:
|
||||
return [node for node in self.hidden_nodes if node.label == label]
|
||||
|
||||
def get_nodes_by_name(self, name: str) -> List['Node']:
|
||||
return [node for node in self.hidden_nodes if node.name == name]
|
||||
|
||||
def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
|
||||
return [node for node in self.nodes if node.python_name == python_name]
|
||||
|
||||
def topo_sort(self) -> List['Node']:
|
||||
node_to_fanin = {}
|
||||
curr_nodes = []
|
||||
for node in self.nodes:
|
||||
fanin = len(node.incoming_edges)
|
||||
node_to_fanin[node] = fanin
|
||||
if fanin == 0:
|
||||
curr_nodes.append(node)
|
||||
|
||||
sorted_nodes = []
|
||||
while curr_nodes:
|
||||
curr_node = curr_nodes.pop(0)
|
||||
sorted_nodes.append(curr_node)
|
||||
# use successor_slots because a node may connect to another node multiple times
|
||||
# to different slots
|
||||
for successor_slot in curr_node.successor_slots:
|
||||
successor = successor_slot[0]
|
||||
node_to_fanin[successor] -= 1
|
||||
if node_to_fanin[successor] == 0:
|
||||
curr_nodes.append(successor)
|
||||
|
||||
for key in node_to_fanin:
|
||||
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(
|
||||
key,
|
||||
node_to_fanin[key],
|
||||
key.predecessors[0],
|
||||
self.edges,
|
||||
node_to_fanin.values(),
|
||||
node_to_fanin.keys())
|
||||
|
||||
return sorted_nodes
|
||||
|
||||
def fork(self) -> 'Graph':
|
||||
"""
|
||||
Fork the model and returns corresponding graph in new model.
|
||||
This shortcut might be helpful because many algorithms only cares about "stem" subgraph instead of whole model.
|
||||
"""
|
||||
return self.model.fork().graphs[self.name]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self is other
|
||||
|
||||
def _fork_to(self, model: Model, name_prefix='') -> 'Graph':
|
||||
new_graph = Graph(model, self.id, name_prefix + self.name, _internal=True)._register()
|
||||
# TODO: use node copy instead
|
||||
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
|
||||
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
|
||||
new_graph.input_node.update_label(self.input_node.label)
|
||||
new_graph.output_node.update_label(self.output_node.label)
|
||||
new_graph.python_name = self.python_name
|
||||
|
||||
for node in self.hidden_nodes:
|
||||
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
|
||||
new_node.python_name = node.python_name
|
||||
new_node.update_label(node.label)
|
||||
new_node._register()
|
||||
|
||||
id_to_new_node = {node.id: node for node in new_graph.nodes}
|
||||
|
||||
for edge in self.edges:
|
||||
new_head = id_to_new_node[edge.head.id]
|
||||
new_tail = id_to_new_node[edge.tail.id]
|
||||
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
|
||||
|
||||
return new_graph
|
||||
|
||||
def _copy(self) -> 'Graph':
|
||||
# Copy this graph inside the model.
|
||||
# The new graph will have identical topology, but its nodes' name and ID will be different.
|
||||
new_graph = Graph(self.model, uid(), _internal=True)._register()
|
||||
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
|
||||
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
|
||||
new_graph.input_node.update_label(self.input_node.label)
|
||||
new_graph.output_node.update_label(self.output_node.label)
|
||||
new_graph.python_name = self.python_name
|
||||
|
||||
id_to_new_node = {} # old node ID -> new node object
|
||||
|
||||
for old_node in self.hidden_nodes:
|
||||
new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
|
||||
new_node.python_name = old_node.python_name
|
||||
new_node.update_label(old_node.label)
|
||||
id_to_new_node[old_node.id] = new_node
|
||||
|
||||
for edge in self.edges:
|
||||
new_head = id_to_new_node[edge.head.id]
|
||||
new_tail = id_to_new_node[edge.tail.id]
|
||||
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
|
||||
|
||||
return new_graph
|
||||
|
||||
def _register(self) -> 'Graph':
|
||||
self.model.graphs[self.name] = self
|
||||
return self
|
||||
|
||||
def _rename_graph(self, old_name, new_name):
|
||||
self.model.graphs[old_name].name = new_name
|
||||
self.model.graphs[new_name] = self.model.graphs[old_name]
|
||||
del self.model.graphs[old_name]
|
||||
|
||||
@staticmethod
|
||||
def _load(model: Model, name: str, ir: Any) -> 'Graph':
|
||||
graph = Graph(model, uid(), name, _internal=True)
|
||||
graph.input_node.operation.io_names = ir.get('inputs')
|
||||
graph.output_node.operation.io_names = ir.get('outputs')
|
||||
for node_name, node_data in ir['nodes'].items():
|
||||
Node._load(graph, node_name, node_data)._register()
|
||||
for edge_data in ir['edges']:
|
||||
Edge._load(graph, edge_data)._register()
|
||||
return graph
|
||||
|
||||
def _dump(self) -> Any:
|
||||
return {
|
||||
'inputs': self.input_node.operation.io_names,
|
||||
'outputs': self.output_node.operation.io_names,
|
||||
'nodes': {node.name: node._dump() for node in self.hidden_nodes},
|
||||
'edges': [edge._dump() for edge in self.edges]
|
||||
}
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
An operation or an opaque subgraph inside a graph.
|
||||
|
||||
Each node belongs to and only belongs to one :class:`Graph`.
|
||||
Nodes should never be created with constructor. Use :meth:`Graph.add_node` instead.
|
||||
|
||||
The node itself is for topology only.
|
||||
Information of tensor calculation should all go inside ``operation`` attribute.
|
||||
|
||||
TODO: parameter of subgraph (cell)
|
||||
It's easy to assign parameters on cell node, but it's hard to "use" them.
|
||||
We need to design a way to reference stored cell parameters in inner node operations.
|
||||
e.g. ``self.fc = Linear(self.units)`` <- how to express ``self.units`` in IR?
|
||||
|
||||
Attributes
|
||||
----------
|
||||
graph
|
||||
The graph containing this node.
|
||||
id
|
||||
Unique ID in the model.
|
||||
If two models have nodes with same ID, they are semantically the same node.
|
||||
name
|
||||
Mnemonic name. It should have an one-to-one mapping with ID.
|
||||
python_name
|
||||
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
|
||||
label
|
||||
Optional. If two nodes have the same label, they are considered same by the mutator.
|
||||
operation
|
||||
Operation.
|
||||
cell
|
||||
Read only shortcut to get the referenced subgraph.
|
||||
If this node is not a subgraph (is a primitive operation), accessing ``cell`` will raise an error.
|
||||
predecessors
|
||||
Predecessor nodes of this node in the graph. This is an optional mutation helper.
|
||||
successors
|
||||
Successor nodes of this node in the graph. This is an optional mutation helper.
|
||||
incoming_edges
|
||||
Incoming edges of this node in the graph. This is an optional mutation helper.
|
||||
outgoing_edges
|
||||
Outgoing edges of this node in the graph. This is an optional mutation helper.
|
||||
"""
|
||||
|
||||
def __init__(self, graph, node_id, name, operation, _internal=False):
|
||||
self.graph: Graph = graph
|
||||
self.id: int = node_id
|
||||
self.name: str = name or f'_generated_{node_id}'
|
||||
# `python_name` is `None` by default. It should be set after initialization if it is needed.
|
||||
self.python_name: Optional[str] = None
|
||||
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
|
||||
# maybe we should copy it here or make Operation class immutable, in next release
|
||||
self.operation: Operation = operation
|
||||
self.label: Optional[str] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f'Node(id={self.id}, name={self.name}, python_name={self.python_name}, label={self.label}, operation={self.operation})'
|
||||
|
||||
@property
|
||||
def predecessors(self) -> List['Node']:
|
||||
return sorted(set(edge.head for edge in self.incoming_edges), key=(lambda node: node.id))
|
||||
|
||||
@property
|
||||
def successors(self) -> List['Node']:
|
||||
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
|
||||
|
||||
@property
|
||||
def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
|
||||
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
|
||||
|
||||
@property
|
||||
def incoming_edges(self) -> List['Edge']:
|
||||
return [edge for edge in self.graph.edges if edge.tail is self]
|
||||
|
||||
@property
|
||||
def outgoing_edges(self) -> List['Edge']:
|
||||
return [edge for edge in self.graph.edges if edge.head is self]
|
||||
|
||||
@property
|
||||
def cell(self) -> Graph:
|
||||
assert isinstance(self.operation, Cell)
|
||||
return self.graph.model.graphs[self.operation.parameters['cell']]
|
||||
|
||||
def update_label(self, label: Optional[str]) -> None:
|
||||
self.label = label
|
||||
|
||||
@overload
|
||||
def update_operation(self, operation: Operation) -> None: ...
|
||||
@overload
|
||||
def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
|
||||
|
||||
def update_operation(self, operation_or_type, parameters=None): # type: ignore
|
||||
if isinstance(operation_or_type, Operation):
|
||||
self.operation = operation_or_type
|
||||
else:
|
||||
self.operation = Operation.new(operation_or_type, cast(dict, parameters))
|
||||
|
||||
# mutation
|
||||
def remove(self) -> None:
|
||||
assert not self.incoming_edges and not self.outgoing_edges
|
||||
self.graph.hidden_nodes.remove(self)
|
||||
|
||||
# mutation
|
||||
def specialize_cell(self) -> Graph:
|
||||
"""
|
||||
Only available if the operation is a cell.
|
||||
Duplicate the cell template and let this node reference to newly created copy.
|
||||
"""
|
||||
new_cell = self.cell._copy()._register()
|
||||
self.operation = Cell(new_cell.name)
|
||||
return new_cell
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self is other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
|
||||
def _register(self) -> 'Node':
|
||||
self.graph.hidden_nodes.append(self)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _load(graph: Graph, name: str, ir: Any) -> 'Node':
|
||||
if ir['operation']['type'] == '_cell':
|
||||
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}), attributes=ir['operation'].get('attributes', {}))
|
||||
else:
|
||||
op = Operation.new(ir['operation']['type'],
|
||||
ir['operation'].get('parameters', {}),
|
||||
attributes=ir['operation'].get('attributes', {}))
|
||||
node = Node(graph, uid(), name, op)
|
||||
if 'label' in ir:
|
||||
node.update_label(ir['label'])
|
||||
return node
|
||||
|
||||
def _dump(self) -> Any:
|
||||
ret: Dict[str, Any] = {
|
||||
'operation': {
|
||||
'type': self.operation.type,
|
||||
'parameters': self.operation.parameters,
|
||||
'attributes': self.operation.attributes
|
||||
}
|
||||
}
|
||||
if isinstance(self.operation, Cell):
|
||||
ret['operation']['cell_name'] = self.operation.cell_name
|
||||
if self.label is not None:
|
||||
ret['label'] = self.label
|
||||
if self.python_name is not None:
|
||||
ret['python_name'] = self.python_name
|
||||
return ret
|
||||
|
||||
|
||||
class Edge:
|
||||
"""
|
||||
A tensor, or "data flow", between two nodes.
|
||||
|
||||
Example forward code snippet: ::
|
||||
|
||||
a, b, c = split(x)
|
||||
p = concat(a, c)
|
||||
q = sum(b, p)
|
||||
z = relu(q)
|
||||
|
||||
Edges in above snippet: ::
|
||||
|
||||
+ head: (split, 0), tail: (concat, 0) # a in concat
|
||||
+ head: (split, 2), tail: (concat, 1) # c in concat
|
||||
+ head: (split, 1), tail: (sum, -1 or 0) # b in sum
|
||||
+ head: (concat, null), tail: (sum, -1 or 1) # p in sum
|
||||
+ head: (sum, null), tail: (relu, null) # q in relu
|
||||
|
||||
Attributes
|
||||
----------
|
||||
graph
|
||||
Graph.
|
||||
head
|
||||
Head node.
|
||||
tail
|
||||
Tail node.
|
||||
head_slot
|
||||
Index of outputs in head node.
|
||||
If the node has only one output, this should be ``null``.
|
||||
tail_slot
|
||||
Index of inputs in tail node.
|
||||
If the node has only one input, this should be ``null``.
|
||||
If the node does not care about order, this can be ``-1``.
|
||||
"""
|
||||
|
||||
def __init__(self, head: EdgeEndpoint, tail: EdgeEndpoint, _internal: bool = False):
|
||||
assert _internal, '`Edge()` is private'
|
||||
self.graph: Graph = head[0].graph
|
||||
self.head: Node = head[0]
|
||||
self.tail: Node = tail[0]
|
||||
self.head_slot: Optional[int] = head[1]
|
||||
self.tail_slot: Optional[int] = tail[1]
|
||||
|
||||
def __repr__(self):
|
||||
return f'Edge(head=({self.head}, {self.head_slot}), tail=({self.tail}, {self.tail_slot}))'
|
||||
|
||||
# mutation
|
||||
def remove(self) -> None:
|
||||
self.graph.edges.remove(self)
|
||||
|
||||
def _register(self) -> 'Edge':
|
||||
self.graph.edges.append(self)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _load(graph: Graph, ir: Any) -> 'Edge':
|
||||
head = graph.get_node_by_name(ir['head'][0])
|
||||
tail = graph.get_node_by_name(ir['tail'][0])
|
||||
assert head is not None and tail is not None
|
||||
return Edge((head, ir['head'][1]), (tail, ir['tail'][1]), _internal=True)
|
||||
|
||||
def _dump(self) -> Any:
|
||||
return {
|
||||
'head': [self.head.name, self.head_slot],
|
||||
'tail': [self.tail.name, self.tail_slot]
|
||||
}
|
||||
|
||||
|
||||
class Mutation:
|
||||
"""
|
||||
An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices),
|
||||
the model that it comes from, and the model that it becomes.
|
||||
|
||||
In general cases, the mutation logs are not reliable and should not be replayed as the mutators can
|
||||
be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here,
|
||||
this can be useful for metadata visualization and python execution mode.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
mutator
|
||||
Mutator.
|
||||
samples
|
||||
Decisions/choices.
|
||||
from_
|
||||
Model that is comes from.
|
||||
to
|
||||
Model that it becomes.
|
||||
"""
|
||||
|
||||
def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model): # noqa: F821
|
||||
self.mutator: 'Mutator' = mutator # noqa: F821
|
||||
self.samples: List[Any] = samples
|
||||
self.from_: Model = from_
|
||||
self.to: Model = to
|
||||
|
||||
def __repr__(self):
|
||||
return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})'
|
||||
|
||||
|
||||
class IllegalGraphError(ValueError):
|
||||
def __init__(self, graph, *args):
|
||||
self._debug_dump_graph(graph)
|
||||
super().__init__(*args)
|
||||
|
||||
@staticmethod
|
||||
def _debug_dump_graph(graph):
|
||||
if isinstance(graph, Graph):
|
||||
graph = graph._dump()
|
||||
with open('generated/debug.json', 'w') as dump_file:
|
||||
json.dump(graph, dump_file, indent=4)
|
||||
|
||||
|
||||
class DebugEvaluator(Evaluator):
|
||||
@staticmethod
|
||||
def _load(ir: Any) -> 'DebugEvaluator':
|
||||
return DebugEvaluator()
|
||||
|
||||
def _dump(self) -> Any:
|
||||
return {'type': DebugEvaluator}
|
||||
|
||||
def _execute(self, model_cls: type) -> Any:
|
||||
pass
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return True
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Operations used in graph-based engine.
|
||||
"""
|
||||
|
||||
from typing import (Any, Dict, List, Optional, cast)
|
||||
|
||||
from nni.common.framework import get_default_framework
|
||||
|
||||
|
||||
__all__ = ['Operation', 'Cell', 'PyTorchOperation', 'TensorFlowOperation']
|
||||
|
||||
|
||||
def _convert_name(name: str) -> str:
|
||||
"""
|
||||
Convert the names using separator '.' to valid variable name in code
|
||||
"""
|
||||
return name.replace('.', '__')
|
||||
|
||||
|
||||
class Operation:
|
||||
"""
|
||||
Calculation logic of a graph node.
|
||||
|
||||
The constructor is private. Use `Operation.new()` to create operation object.
|
||||
|
||||
`Operation` is a naive record.
|
||||
Do not "mutate" its attributes or store information relate to specific node.
|
||||
All complex logic should be implemented in `Node` class.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
type
|
||||
Operation type name (e.g. Conv2D).
|
||||
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output).
|
||||
parameters
|
||||
Arbitrary key-value parameters (e.g. kernel_size).
|
||||
"""
|
||||
|
||||
io_names: List[str] = []
|
||||
|
||||
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
|
||||
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
|
||||
self.type: str = type_name
|
||||
self.parameters: Dict[str, Any] = parameters
|
||||
self.attributes: Dict[str, Any] = attributes
|
||||
|
||||
def to_init_code(self, field: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _to_class_name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
|
||||
attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
|
||||
parameters = parameters or {}
|
||||
attributes = attributes or {}
|
||||
if type_name == '_cell':
|
||||
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
|
||||
return Cell(cell_name, parameters)
|
||||
else:
|
||||
if get_default_framework() in ('torch', 'pytorch'):
|
||||
from nni.nas.execution.pytorch import op_def # pylint: disable=unused-import
|
||||
cls = PyTorchOperation._find_subclass(type_name)
|
||||
elif get_default_framework() in ('tf', 'tensorflow'):
|
||||
from nni.nas.execution.tensorflow import op_def # pylint: disable=unused-import
|
||||
cls = TensorFlowOperation._find_subclass(type_name)
|
||||
else:
|
||||
raise ValueError(f'Unsupported framework: {get_default_framework()}')
|
||||
return cls(type_name, parameters, _internal=True, attributes=attributes)
|
||||
|
||||
@classmethod
|
||||
def _find_subclass(cls, subclass_name):
|
||||
for subclass in cls.__subclasses__():
|
||||
if subclass.__name__ == subclass_name:
|
||||
return subclass
|
||||
return cls
|
||||
|
||||
def __repr__(self):
|
||||
type_name = type(self).__name__
|
||||
args = [f'{key}={repr(value)}' for key, value in self.parameters.items()]
|
||||
if type_name != self.type:
|
||||
args = [f'type="{self.type}"'] + args
|
||||
return f'{type_name}({", ".join(args)})'
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(other) is type(self) and other.type == self.type and other.parameters == self.parameters
|
||||
|
||||
|
||||
class PyTorchOperation(Operation):
|
||||
@classmethod
|
||||
def _find_subclass(cls, subclass_name):
|
||||
if cls.to_class_name(subclass_name) is not None:
|
||||
subclass_name = 'ModuleOperator'
|
||||
if cls.is_functional(subclass_name):
|
||||
subclass_name = 'FunctionalOperator'
|
||||
for subclass in cls.__subclasses__():
|
||||
if hasattr(subclass, '_ori_type_name') and \
|
||||
subclass_name in cast(Any, subclass)._ori_type_name:
|
||||
return subclass
|
||||
for subclass in cls.__subclasses__():
|
||||
if hasattr(subclass, '_artificial_op_name') and \
|
||||
subclass_name in cast(Any, subclass)._artificial_op_name:
|
||||
return subclass
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def to_class_name(cls, type_name) -> Optional[str]:
|
||||
if type_name.startswith('__torch__.'):
|
||||
return type_name[len('__torch__.'):]
|
||||
elif type_name.startswith('__mutated__.'):
|
||||
return type_name[len('__mutated__.'):]
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_functional(cls, type_name) -> bool:
|
||||
return type_name.startswith('Function.')
|
||||
|
||||
def _to_class_name(self) -> Optional[str]:
|
||||
if self.type.startswith('__torch__.'):
|
||||
return self.type[len('__torch__.'):]
|
||||
elif self.type.startswith('__mutated__.'):
|
||||
return self.type[len('__mutated__.'):]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_import_pkg(self) -> Optional[str]:
|
||||
if self.type.startswith('__torch__.'):
|
||||
return self.type[len('__torch__.'):].split('.')[0]
|
||||
elif self.type.startswith('__mutated__.'):
|
||||
return self.type[len('__mutated__.'):].split('.')[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def to_init_code(self, field: str) -> Optional[str]:
|
||||
if self._to_class_name() is not None:
|
||||
assert 'positional_args' not in self.parameters
|
||||
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
|
||||
return f'self.{field} = {self._to_class_name()}({kw_params})'
|
||||
return None
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
field : str
|
||||
the name of member submodule
|
||||
output : str
|
||||
the output name (lvalue) of this line of code
|
||||
inputs : List[str]
|
||||
variables used in this line of code
|
||||
inputs_value : List[Any]
|
||||
some variables are actually constant, their real values are recorded in ```inputs_value```.
|
||||
if not constant, we simply put None at the corresponding index
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
generated code line
|
||||
"""
|
||||
if self.type == 'aten::slice':
|
||||
raise RuntimeError('not supposed to have aten::slice operation')
|
||||
else:
|
||||
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
|
||||
|
||||
|
||||
class TensorFlowOperation(Operation):
|
||||
def _to_class_name(self) -> str:
|
||||
return 'K.layers.' + self.type
|
||||
|
||||
|
||||
class Cell(PyTorchOperation):
|
||||
"""
|
||||
TODO: this is pytorch cell
|
||||
|
||||
An operation reference to a subgraph.
|
||||
|
||||
Example code:
|
||||
```
|
||||
def __init__(...):
|
||||
...
|
||||
self.cell = CustomCell(...)
|
||||
self.relu = K.layers.ReLU()
|
||||
...
|
||||
|
||||
def forward(...):
|
||||
...
|
||||
x = self.cell(x)
|
||||
...
|
||||
```
|
||||
|
||||
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`.
|
||||
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`.
|
||||
|
||||
TODO: parameters of subgraph (see `Node` class)
|
||||
|
||||
Attributes
|
||||
----------
|
||||
type
|
||||
Always "_cell".
|
||||
parameters
|
||||
A dict with only one item; the key is "cell" and the value is cell's name.
|
||||
framework
|
||||
No real usage. Exists for compatibility with base class.
|
||||
"""
|
||||
|
||||
def __init__(self, cell_name: str,
|
||||
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
|
||||
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
|
||||
self.type = '_cell'
|
||||
self.cell_name = cell_name
|
||||
self.parameters = parameters or {}
|
||||
self.attributes = attributes or {}
|
||||
|
||||
def _to_class_name(self):
|
||||
# TODO: ugly, think about how to refactor this part
|
||||
return _convert_name(self.cell_name)
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = self.{field}({", ".join(inputs)})'
|
||||
|
||||
class _IOPseudoOperation(Operation):
|
||||
"""
|
||||
This is the pseudo operation used by I/O nodes.
|
||||
The benefit is that users no longer need to verify `Node.operation is not None`,
|
||||
especially in static type checking.
|
||||
"""
|
||||
|
||||
def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
|
||||
assert type_name.startswith('_')
|
||||
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
|
||||
self.io_names = io_names
|
||||
|
||||
def to_init_code(self, field: str) -> str:
|
||||
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
|
@ -0,0 +1,235 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = ['RetiariiAdvisor']
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Dict, List, Tuple
|
||||
|
||||
import nni
|
||||
from nni.common.serializer import PayloadTooLarge
|
||||
from nni.common.version import version_dump
|
||||
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
|
||||
from nni.runtime.tuner_command_channel import CommandType
|
||||
from nni.utils import MetricType
|
||||
|
||||
from .graph import MetricData
|
||||
from .integration_api import register_advisor
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetiariiAdvisor(MsgDispatcherBase):
|
||||
"""
|
||||
The class is to connect Retiarii components to NNI backend.
|
||||
It can be considered as a Python wrapper of NNI manager.
|
||||
|
||||
It will function as the main thread when running a Retiarii experiment through NNI.
|
||||
Strategy will be launched as its thread, who will call APIs in execution engine. Execution
|
||||
engine will then find the advisor singleton and send payloads to advisor.
|
||||
|
||||
When metrics are sent back, advisor will first receive the payloads, who will call the callback
|
||||
function (that is a member function in graph listener).
|
||||
|
||||
The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
|
||||
needs to handle all the rest.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
send_trial_callback
|
||||
|
||||
request_trial_jobs_callback
|
||||
|
||||
trial_end_callback
|
||||
|
||||
intermediate_metric_callback
|
||||
|
||||
final_metric_callback
|
||||
"""
|
||||
|
||||
def __init__(self, url: str):
|
||||
super().__init__(url)
|
||||
register_advisor(self) # register the current advisor as the "global only" advisor
|
||||
self.search_space = None
|
||||
|
||||
self.send_trial_callback: Optional[Callable[[dict], None]] = None
|
||||
self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
|
||||
self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
|
||||
self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
|
||||
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
|
||||
|
||||
self.parameters_count = 0
|
||||
|
||||
# Sometimes messages arrive first before the callbacks get registered.
|
||||
# Or in case that we allow engine to be absent during the experiment.
|
||||
# Here we need to store the messages and invoke them later.
|
||||
self.call_queue: List[Tuple[str, list]] = []
|
||||
|
||||
def register_callbacks(self, callbacks: Dict[str, Callable[..., None]]):
|
||||
"""
|
||||
Register callbacks for NNI backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callbacks
|
||||
A dictionary of callbacks.
|
||||
The key is the name of the callback. The value is the callback function.
|
||||
"""
|
||||
self.send_trial_callback = callbacks.get('send_trial')
|
||||
self.request_trial_jobs_callback = callbacks.get('request_trial_jobs')
|
||||
self.trial_end_callback = callbacks.get('trial_end')
|
||||
self.intermediate_metric_callback = callbacks.get('intermediate_metric')
|
||||
self.final_metric_callback = callbacks.get('final_metric')
|
||||
|
||||
self.process_queued_callbacks()
|
||||
|
||||
def process_queued_callbacks(self) -> None:
|
||||
"""
|
||||
Process callbacks in queue.
|
||||
Consume the messages that haven't been handled previously.
|
||||
"""
|
||||
processed_idx = []
|
||||
for queue_idx, (call_name, call_args) in enumerate(self.call_queue):
|
||||
if call_name == 'send_trial' and self.send_trial_callback is not None:
|
||||
self.send_trial_callback(*call_args) # pylint: disable=not-callable
|
||||
processed_idx.append(queue_idx)
|
||||
if call_name == 'request_trial_jobs' and self.request_trial_jobs_callback is not None:
|
||||
self.request_trial_jobs_callback(*call_args) # pylint: disable=not-callable
|
||||
processed_idx.append(queue_idx)
|
||||
if call_name == 'trial_end' and self.trial_end_callback is not None:
|
||||
self.trial_end_callback(*call_args) # pylint: disable=not-callable
|
||||
processed_idx.append(queue_idx)
|
||||
if call_name == 'intermediate_metric' and self.intermediate_metric_callback is not None:
|
||||
self.intermediate_metric_callback(*call_args) # pylint: disable=not-callable
|
||||
processed_idx.append(queue_idx)
|
||||
if call_name == 'final_metric' and self.final_metric_callback is not None:
|
||||
self.final_metric_callback(*call_args) # pylint: disable=not-callable
|
||||
processed_idx.append(queue_idx)
|
||||
|
||||
# Remove processed messages
|
||||
for idx in reversed(processed_idx):
|
||||
self.call_queue.pop(idx)
|
||||
|
||||
def invoke_callback(self, name: str, *args: Any) -> None:
|
||||
"""
|
||||
Invoke callback.
|
||||
"""
|
||||
self.call_queue.append((name, list(args)))
|
||||
self.process_queued_callbacks()
|
||||
|
||||
def handle_initialize(self, data):
|
||||
"""callback for initializing the advisor
|
||||
Parameters
|
||||
----------
|
||||
data: dict
|
||||
search space
|
||||
"""
|
||||
self.handle_update_search_space(data)
|
||||
self.send(CommandType.Initialized, '')
|
||||
|
||||
def _validate_placement_constraint(self, placement_constraint):
|
||||
if placement_constraint is None:
|
||||
raise ValueError('placement_constraint is None')
|
||||
if not 'type' in placement_constraint:
|
||||
raise ValueError('placement_constraint must have `type`')
|
||||
if not 'gpus' in placement_constraint:
|
||||
raise ValueError('placement_constraint must have `gpus`')
|
||||
if placement_constraint['type'] not in ['None', 'GPUNumber', 'Device']:
|
||||
raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
|
||||
if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
|
||||
raise ValueError('placement_constraint.gpus must be an empty list when type == None')
|
||||
if placement_constraint['type'] == 'GPUNumber':
|
||||
if len(placement_constraint['gpus']) != 1:
|
||||
raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
|
||||
for e in placement_constraint['gpus']:
|
||||
if not isinstance(e, int):
|
||||
raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
|
||||
if placement_constraint['type'] == 'Device':
|
||||
for e in placement_constraint['gpus']:
|
||||
if not isinstance(e, tuple):
|
||||
raise ValueError('placement_constraint.gpus must be a list of tuple when type == Device')
|
||||
if not (len(e) == 2 and isinstance(e[0], str) and isinstance(e[1], int)):
|
||||
raise ValueError('placement_constraint.gpus`s tuple must be (str, int)')
|
||||
|
||||
def send_trial(self, parameters, placement_constraint=None):
|
||||
"""
|
||||
Send parameters to NNI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parameters : Any
|
||||
Any payload.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Parameter ID that is assigned to this parameter,
|
||||
which will be used for identification in future.
|
||||
"""
|
||||
self.parameters_count += 1
|
||||
if placement_constraint is None:
|
||||
placement_constraint = {
|
||||
'type': 'None',
|
||||
'gpus': []
|
||||
}
|
||||
self._validate_placement_constraint(placement_constraint)
|
||||
new_trial = {
|
||||
'parameter_id': self.parameters_count,
|
||||
'parameters': parameters,
|
||||
'parameter_source': 'algorithm',
|
||||
'placement_constraint': placement_constraint,
|
||||
'version_info': version_dump()
|
||||
}
|
||||
_logger.debug('New trial sent: %s', new_trial)
|
||||
|
||||
try:
|
||||
send_payload = nni.dump(new_trial, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
|
||||
except PayloadTooLarge:
|
||||
raise ValueError(
|
||||
'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
|
||||
'This is usually caused by pickling large objects (like datasets) by mistake. '
|
||||
'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
|
||||
'for how to resolve such issue. '
|
||||
)
|
||||
|
||||
# trial parameters can be super large, disable pickle size limit here
|
||||
# nevertheless, there could still be blocked by pipe / nni-manager
|
||||
self.send(CommandType.NewTrialJob, send_payload)
|
||||
|
||||
self.invoke_callback('send_trial', parameters)
|
||||
return self.parameters_count
|
||||
|
||||
def mark_experiment_as_ending(self):
|
||||
self.send(CommandType.NoMoreTrialJobs, '')
|
||||
|
||||
def handle_request_trial_jobs(self, num_trials):
|
||||
_logger.debug('Request trial jobs: %s', num_trials)
|
||||
self.invoke_callback('request_trial_jobs', num_trials)
|
||||
|
||||
def handle_update_search_space(self, data):
|
||||
_logger.debug('Received search space: %s', data)
|
||||
self.search_space = data
|
||||
|
||||
def handle_trial_end(self, data):
|
||||
_logger.debug('Trial end: %s', data)
|
||||
self.invoke_callback('trial_end', nni.load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
|
||||
|
||||
def handle_report_metric_data(self, data):
|
||||
_logger.debug('Metric reported: %s', data)
|
||||
if data['type'] == MetricType.REQUEST_PARAMETER:
|
||||
raise ValueError('Request parameter not supported')
|
||||
elif data['type'] == MetricType.PERIODICAL:
|
||||
self.invoke_callback('intermediate_metric', data['parameter_id'], self._process_value(data['value']))
|
||||
elif data['type'] == MetricType.FINAL:
|
||||
self.invoke_callback('final_metric', data['parameter_id'], self._process_value(data['value']))
|
||||
|
||||
@staticmethod
|
||||
def _process_value(value) -> Any: # hopefully a float
|
||||
value = nni.load(value)
|
||||
if isinstance(value, dict):
|
||||
if 'default' in value:
|
||||
return value['default']
|
||||
else:
|
||||
return value
|
||||
return value
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = [
|
||||
'get_advisor', 'register_advisor', 'send_trial', 'receive_trial_parameters', 'get_experiment_id',
|
||||
'_advisor' # FIXME: hack to make it importable for tests
|
||||
]
|
||||
|
||||
import warnings
|
||||
from typing import NewType, Any
|
||||
|
||||
import nni
|
||||
from nni.common.version import version_check
|
||||
|
||||
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
|
||||
# because it would induce cycled import
|
||||
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
|
||||
|
||||
_advisor = None # type is RetiariiAdvisor
|
||||
|
||||
|
||||
def get_advisor():
|
||||
# return type: RetiariiAdvisor
|
||||
global _advisor
|
||||
assert _advisor is not None
|
||||
return _advisor
|
||||
|
||||
|
||||
def register_advisor(advisor):
|
||||
# type of advisor: RetiariiAdvisor
|
||||
global _advisor
|
||||
if _advisor is not None:
|
||||
warnings.warn('Advisor is already set.'
|
||||
'You should avoid instantiating RetiariiExperiment twice in one proces.'
|
||||
'If you are running in a Jupyter notebook, please restart the kernel.')
|
||||
_advisor = advisor
|
||||
|
||||
|
||||
def send_trial(parameters: dict, placement_constraint=None) -> int:
|
||||
"""
|
||||
Send a new trial. Executed on tuner end.
|
||||
Return a ID that is the unique identifier for this trial.
|
||||
"""
|
||||
return get_advisor().send_trial(parameters, placement_constraint)
|
||||
|
||||
def receive_trial_parameters() -> dict:
|
||||
"""
|
||||
Received a new trial. Executed on trial end.
|
||||
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
|
||||
"""
|
||||
params = nni.get_next_parameter()
|
||||
|
||||
# version check, optional
|
||||
raw_params = nni.trial._params
|
||||
if raw_params is not None and 'version_info' in raw_params:
|
||||
version_check(raw_params['version_info'])
|
||||
else:
|
||||
warnings.warn('Version check failed because `version_info` is not found.')
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_experiment_id() -> str:
|
||||
return nni.get_experiment_id()
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = ['DefaultListener']
|
||||
|
||||
from .graph import Model, ModelStatus, MetricData
|
||||
from .engine import AbstractGraphListener
|
||||
|
||||
|
||||
class DefaultListener(AbstractGraphListener):
|
||||
|
||||
def on_metric(self, model: Model, metric: MetricData) -> None:
|
||||
model.metric = metric
|
||||
|
||||
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
|
||||
model.intermediate_metrics.append(metric)
|
||||
|
||||
def on_training_end(self, model: Model, success: bool) -> None:
|
||||
if success:
|
||||
model.status = ModelStatus.Trained
|
||||
else:
|
||||
model.status = ModelStatus.Failed
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = ['unpack_if_only_one', 'get_mutation_dict', 'mutation_dict_to_summary', 'get_mutation_summary']
|
||||
|
||||
from typing import Any, List
|
||||
from .graph import Model
|
||||
|
||||
|
||||
def unpack_if_only_one(ele: List[Any]):
|
||||
if len(ele) == 1:
|
||||
return ele[0]
|
||||
return ele
|
||||
|
||||
|
||||
def get_mutation_dict(model: Model):
|
||||
return {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
|
||||
|
||||
|
||||
def mutation_dict_to_summary(mutation: dict) -> dict:
|
||||
mutation_summary = {}
|
||||
for label, samples in mutation.items():
|
||||
# FIXME: this check might be wrong
|
||||
if not isinstance(samples, list):
|
||||
mutation_summary[label] = samples
|
||||
else:
|
||||
for i, sample in enumerate(samples):
|
||||
mutation_summary[f'{label}_{i}'] = sample
|
||||
return mutation_summary
|
||||
|
||||
|
||||
def get_mutation_summary(model: Model) -> dict:
|
||||
mutation = get_mutation_dict(model)
|
||||
return mutation_dict_to_summary(mutation)
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
|
||||
|
||||
from nni.nas.execution.common import Model, receive_trial_parameters, get_mutation_dict
|
||||
from .graph import BaseExecutionEngine
|
||||
|
||||
|
||||
class BenchmarkGraphData:
|
||||
|
||||
SUPPORTED_BENCHMARK_LIST = [
|
||||
'nasbench101',
|
||||
'nasbench201-cifar10',
|
||||
'nasbench201-cifar100',
|
||||
'nasbench201-imagenet16',
|
||||
'nds-cifar10',
|
||||
'nds-imagenet',
|
||||
'nlp'
|
||||
]
|
||||
|
||||
def __init__(self, mutation: Dict[str, Any], benchmark: str,
|
||||
metric_name: Optional[str] = None,
|
||||
db_path: Optional[str] = None) -> None:
|
||||
self.mutation = mutation # mutation dict. e.g., {'layer1': 'conv3x3', ...}
|
||||
self.benchmark = benchmark # e.g., nasbench101, nasbench201, ...
|
||||
self.db_path = db_path # path to directory of database
|
||||
|
||||
def dump(self) -> dict:
|
||||
from nni.nas.benchmarks.constants import DATABASE_DIR
|
||||
return {
|
||||
'mutation': self.mutation,
|
||||
'benchmark': self.benchmark,
|
||||
'db_path': self.db_path or DATABASE_DIR # database path need to be passed from manager to worker
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def load(data) -> 'BenchmarkGraphData':
|
||||
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
|
||||
|
||||
|
||||
class BenchmarkExecutionEngine(BaseExecutionEngine):
|
||||
"""
|
||||
Execution engine that does not actually run any trial, but query the database for results.
|
||||
|
||||
The database query is done on the trial end to make sure intermediate metrics are available.
|
||||
It will also support an accelerated mode that returns metric immediately without even running into NNI manager
|
||||
(not implemented yet).
|
||||
"""
|
||||
|
||||
def __init__(self, benchmark: Union[str, Callable[[BenchmarkGraphData], Tuple[float, List[float]]]], acceleration: bool = False):
|
||||
super().__init__()
|
||||
assert benchmark in BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST, \
|
||||
f'{benchmark} is not one of the supported benchmarks: {BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST}'
|
||||
self.benchmark = benchmark
|
||||
self.acceleration = acceleration
|
||||
|
||||
def pack_model_data(self, model: Model) -> Any:
|
||||
# called when a new model is submitted to backend.
|
||||
# convert a Model into a data that is acceptable by trial end.
|
||||
mutation = get_mutation_dict(model)
|
||||
graph_data = BenchmarkGraphData(mutation, self.benchmark)
|
||||
|
||||
return graph_data
|
||||
|
||||
@classmethod
|
||||
def trial_execute_graph(cls) -> None:
|
||||
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
|
||||
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
|
||||
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
|
||||
final, intermediates = cls.query_in_benchmark(graph_data)
|
||||
|
||||
import nni
|
||||
for i in intermediates:
|
||||
nni.report_intermediate_result(i)
|
||||
nni.report_final_result(final)
|
||||
|
||||
@staticmethod
|
||||
def query_in_benchmark(graph_data: BenchmarkGraphData) -> Tuple[float, List[float]]:
|
||||
if not isinstance(graph_data.benchmark, str):
|
||||
return graph_data.benchmark(graph_data)
|
||||
|
||||
# built-in benchmarks with default query setting
|
||||
if graph_data.benchmark == 'nasbench101':
|
||||
from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
|
||||
arch = None
|
||||
for t in graph_data.mutation.values():
|
||||
if isinstance(t, dict):
|
||||
arch = t
|
||||
if arch is None:
|
||||
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
|
||||
return _convert_to_final_and_intermediates(
|
||||
query_nb101_trial_stats(arch, 108, include_intermediates=True),
|
||||
'valid_acc'
|
||||
)
|
||||
elif graph_data.benchmark.startswith('nasbench201'):
|
||||
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
|
||||
dataset = graph_data.benchmark.split('-')[-1]
|
||||
return _convert_to_final_and_intermediates(
|
||||
query_nb201_trial_stats(_flatten_architecture(graph_data.mutation), 200, dataset, include_intermediates=True),
|
||||
'valid_acc',
|
||||
)
|
||||
elif graph_data.benchmark.startswith('nds'):
|
||||
# FIXME: not tested yet
|
||||
from nni.nas.benchmarks.nds import query_nds_trial_stats
|
||||
dataset = graph_data.benchmark.split('-')[-1]
|
||||
return _convert_to_final_and_intermediates(
|
||||
query_nds_trial_stats(None, None, None, None, _flatten_architecture(graph_data.mutation),
|
||||
dataset, include_intermediates=True),
|
||||
'valid_acc'
|
||||
)
|
||||
elif graph_data.benchmark.startswith('nlp'):
|
||||
# FIXME: not tested yet
|
||||
from nni.nas.benchmarks.nlp import query_nlp_trial_stats
|
||||
# TODO: I'm not sure of the availble datasets in this benchmark. and the docs are missing.
|
||||
return _convert_to_final_and_intermediates(
|
||||
query_nlp_trial_stats(_flatten_architecture(graph_data.mutation), 'ptb', include_intermediates=True),
|
||||
'valid_acc'
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'{graph_data.benchmark} is not a supported benchmark.')
|
||||
|
||||
|
||||
def _flatten_architecture(mutation: Dict[str, Any], benchmark: Optional[str] = None):
|
||||
# STRONG ASSUMPTION HERE!
|
||||
# This assumes that the benchmarked search space is a one-level search space.
|
||||
# This means that it is either ONE cell or ONE network.
|
||||
# Two cell search space like NDS is not supported yet for now.
|
||||
# Some benchmark even needs special handling to pop out invalid keys. I don't think this is a good design.
|
||||
|
||||
# support double underscore to be compatible with naming convention in base engine
|
||||
ret = {k.split('/')[-1].split('__')[-1]: v for k, v in mutation.items()}
|
||||
if benchmark == 'nasbench101':
|
||||
ret = {k: v for k, v in ret.items() if k.startswith('op') or k.startswith('input')}
|
||||
ret = {k: v if k.startswith('op') or isinstance(v, list) else [v] for k, v in ret.items()}
|
||||
return ret
|
||||
|
||||
|
||||
def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_name: str) -> Tuple[float, List[float]]:
|
||||
# convert benchmark results from database to
|
||||
# final result (float) and intermediate results (list of floats)
|
||||
benchmark_result = list(benchmark_result)
|
||||
assert len(benchmark_result) > 0, 'Invalid query. Results from benchmark is empty.'
|
||||
if len(benchmark_result) > 1:
|
||||
benchmark_result = random.choice(benchmark_result)
|
||||
else:
|
||||
benchmark_result = benchmark_result[0]
|
||||
benchmark_result = cast(dict, benchmark_result)
|
||||
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .trainer import PdartsTrainer
|
||||
from .engine import *
|
|
@ -0,0 +1,402 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['CGOExecutionEngine', 'TrialSubmission']
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import threading
|
||||
from typing import Iterable, List, Dict, Tuple, cast
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nni.common.device import GPUDevice, Device
|
||||
from nni.experiment.config.training_services import RemoteConfig
|
||||
from nni.nas import utils
|
||||
from nni.nas.execution.common import (
|
||||
AbstractExecutionEngine, AbstractGraphListener, WorkerInfo,
|
||||
Model, ModelStatus, MetricData, Node,
|
||||
RetiariiAdvisor, send_trial, receive_trial_parameters, get_advisor,
|
||||
)
|
||||
from nni.nas.execution.pytorch import codegen
|
||||
from nni.nas.evaluator.pytorch.lightning import Lightning
|
||||
from nni.nas.evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
|
||||
from nni.nas.execution.pytorch.graph import BaseGraphData
|
||||
|
||||
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
|
||||
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrialSubmission:
|
||||
model: Model
|
||||
placement: Dict[Node, Device]
|
||||
grouped_models: List[Model]
|
||||
|
||||
class CGOExecutionEngine(AbstractExecutionEngine):
|
||||
"""
|
||||
The execution engine with Cross-Graph Optimization (CGO).
|
||||
|
||||
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
|
||||
Otherwise, a model will be submitted independently without any cross-graph optimization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
training_service
|
||||
The remote training service config.
|
||||
max_concurrency
|
||||
The maximum number of trials to run concurrently.
|
||||
batch_waiting_time
|
||||
Seconds to wait for each batch of trial submission.
|
||||
The trials within one batch could apply cross-graph optimization.
|
||||
rest_port
|
||||
The port of the experiment's rest server
|
||||
rest_url_prefix
|
||||
The url prefix of the experiment's rest entry
|
||||
"""
|
||||
|
||||
def __init__(self, training_service: RemoteConfig,
|
||||
max_concurrency: int = None,
|
||||
batch_waiting_time: int = 60,
|
||||
rest_port: int | None = None,
|
||||
rest_url_prefix: str | None = None
|
||||
) -> None:
|
||||
self.port = rest_port
|
||||
self.url_prefix = rest_url_prefix
|
||||
|
||||
self._listeners: List[AbstractGraphListener] = []
|
||||
self._running_models: Dict[int, Model] = dict()
|
||||
self.logical_plan_counter = 0
|
||||
self.available_devices: List[Device] = []
|
||||
self.max_concurrency: int = max_concurrency
|
||||
|
||||
devices = self._construct_devices(training_service)
|
||||
for device in devices:
|
||||
self.available_devices.append(device)
|
||||
self.all_devices = self.available_devices.copy()
|
||||
|
||||
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
|
||||
self._optimizers = [DedupInputOptimizer()]
|
||||
self._original_models = {}
|
||||
self._original_model_to_multi_model = {}
|
||||
self._trial_to_original_models = {}
|
||||
self._trial_used_devices: Dict[int, List[Device]] = {}
|
||||
|
||||
self._history: List[Model] = []
|
||||
|
||||
self._queuing_models: List[Model] = []
|
||||
self._models_to_retry: List[Model] = []
|
||||
self._queue_lock = threading.Lock()
|
||||
|
||||
# register advisor callbacks
|
||||
advisor: RetiariiAdvisor = get_advisor()
|
||||
advisor.register_callbacks({
|
||||
'send_trial': _noop,
|
||||
'request_trial_jobs': _noop,
|
||||
'trial_end': self._trial_end_callback,
|
||||
'intermediate_metric': self._intermediate_metric_callback,
|
||||
'final_metric': self._final_metric_callback
|
||||
})
|
||||
|
||||
self._stopped = False
|
||||
self._consumer_thread = threading.Thread(target=self._consume_models)
|
||||
self._consumer_thread.start()
|
||||
|
||||
def _construct_devices(self, training_service):
|
||||
devices = []
|
||||
if hasattr(training_service, 'machine_list'):
|
||||
for machine in cast(RemoteConfig, training_service).machine_list:
|
||||
assert machine.gpu_indices is not None, \
|
||||
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
|
||||
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
|
||||
for gpu_idx in machine.gpu_indices:
|
||||
devices.append(GPUDevice(machine.host, gpu_idx))
|
||||
return devices
|
||||
|
||||
def join(self):
|
||||
self._stopped = True
|
||||
self._consumer_thread.join()
|
||||
|
||||
def add_optimizer(self, opt):
|
||||
self._optimizers.append(opt)
|
||||
|
||||
def submit_models(self, *models: List[Model]) -> None:
|
||||
curr_time = time.time()
|
||||
_logger.info('%d models are submitted', len(models))
|
||||
self._queue_lock.acquire()
|
||||
self._queuing_models.extend([(curr_time, _) for _ in models])
|
||||
self._queue_lock.release()
|
||||
|
||||
def _submit_retry_models(self, models: List[Model]) -> None:
|
||||
_logger.info('%d models are retried', len(models))
|
||||
self._queue_lock.acquire()
|
||||
self._models_to_retry.extend(models)
|
||||
self._queue_lock.release()
|
||||
|
||||
def _consume_models(self):
|
||||
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
|
||||
while not self._stopped:
|
||||
if len(self._models_to_retry) > 0:
|
||||
self._queue_lock.acquire()
|
||||
# retrying jobs should be first scheduled.
|
||||
for m in self._models_to_retry:
|
||||
if len(self.available_devices) > 0:
|
||||
self._submit_models_in_batch(m) # submit the single model to avoid cross-graph optimization.
|
||||
self._models_to_retry = self._models_to_retry[1:]
|
||||
self._queue_lock.release()
|
||||
|
||||
if len(self._queuing_models) > 0:
|
||||
self._queue_lock.acquire()
|
||||
curr_time = time.time()
|
||||
|
||||
num_models_to_submit = len(self.available_devices)
|
||||
if self.max_concurrency:
|
||||
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
|
||||
|
||||
if curr_time - self._queuing_models[0][0] > self._batch_waiting_time:
|
||||
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
|
||||
if num_models_to_submit > 0:
|
||||
self._submit_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]])
|
||||
self._queuing_models = self._queuing_models[num_models_to_submit:]
|
||||
self._queue_lock.release()
|
||||
time.sleep(1)
|
||||
|
||||
def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
|
||||
unique_gpus = sorted(list(set([e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
|
||||
placement_constraint = None
|
||||
if len(unique_gpus) > 0:
|
||||
placement_constraint = {}
|
||||
placement_constraint['type'] = 'Device'
|
||||
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
|
||||
return placement_constraint
|
||||
|
||||
def _submit_models_in_batch(self, *models: List[Model]) -> None:
|
||||
_logger.info('%d models are submitted in batch', len(models))
|
||||
_logger.debug('model id: %s', str([m.model_id for m in models]))
|
||||
logical = self._build_logical(models)
|
||||
|
||||
for opt in self._optimizers:
|
||||
opt.convert(logical)
|
||||
|
||||
phy_models_and_placements = self._assemble(logical)
|
||||
for model, placement, grouped_models in phy_models_and_placements:
|
||||
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
|
||||
placement_constraint = self._extract_placement_constaint(placement)
|
||||
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
|
||||
# unique non-cpu devices used by the trial
|
||||
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
|
||||
|
||||
# currently, it is impossible for search strategy to submit models more than the number of available devices
|
||||
for used_device in self._trial_used_devices[trial_id]:
|
||||
self.available_devices.remove(used_device) # used_device must be in self.available_devices
|
||||
self._running_models[trial_id] = model
|
||||
|
||||
self._trial_to_original_models[trial_id] = []
|
||||
for m in grouped_models:
|
||||
self._original_models[m.model_id] = m
|
||||
self._original_model_to_multi_model[m.model_id] = model
|
||||
self._trial_to_original_models[trial_id].append(m.model_id)
|
||||
self._history.append(m)
|
||||
|
||||
def list_models(self) -> Iterable[Model]:
|
||||
return self._history
|
||||
|
||||
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
|
||||
"""
|
||||
Return the assembled models as a list of tuple.
|
||||
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
|
||||
"""
|
||||
# try to use the available_devices first so that it can be launched as early as possible
|
||||
# if free devices are not enough to assemble all models in one trial, try all devices
|
||||
if len(self.available_devices) > 0:
|
||||
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
|
||||
|
||||
if len(self.available_devices) == 0 or len(grouped_models) > 1:
|
||||
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
|
||||
|
||||
phy_models_and_placements = []
|
||||
for multi_model in grouped_models:
|
||||
model, model_placement = logical_plan.assemble(multi_model)
|
||||
assert isinstance(model.evaluator, Lightning), \
|
||||
"cross-graph optimization only supports pytorch lighting as evaluator"
|
||||
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
|
||||
"cross-graph optimization only support MultiModelSupervisedLearningModule"
|
||||
|
||||
# replace the module with a new instance whose n_models is set
|
||||
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
|
||||
new_module_init_params = model.evaluator.module.dump_kwargs().copy()
|
||||
|
||||
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
|
||||
new_module_init_params['n_models'] = len(multi_model)
|
||||
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
|
||||
model.evaluator.module = new_module
|
||||
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
|
||||
return phy_models_and_placements
|
||||
|
||||
def _build_logical(self, models: List[Model]) -> LogicalPlan:
|
||||
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
|
||||
for model in models:
|
||||
logical_plan.add_model(model)
|
||||
self.logical_plan_counter += 1
|
||||
return logical_plan
|
||||
|
||||
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
|
||||
self._listeners.append(listener)
|
||||
|
||||
# def _send_trial_callback(self, paramater: dict) -> None:
|
||||
# if len(self.available_devices) == 0:
|
||||
# _logger.warning('There is no available devices, but trial is submitted.')
|
||||
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
|
||||
|
||||
# def _request_trial_jobs_callback(self, num_trials: int) -> None:
|
||||
# self.resources += num_trials
|
||||
# _logger.info('on_resource_available: %d', self.resources)
|
||||
|
||||
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
|
||||
model = self._running_models[trial_id]
|
||||
if success:
|
||||
model.status = ModelStatus.Trained
|
||||
else:
|
||||
model.status = ModelStatus.Failed
|
||||
models_to_retry = []
|
||||
for model_id in self._original_model_to_multi_model:
|
||||
if self._original_model_to_multi_model[model_id] == model:
|
||||
original_model = self._original_models[model_id]
|
||||
if success:
|
||||
original_model.status = ModelStatus.Trained
|
||||
else:
|
||||
original_model.status = ModelStatus.Failed
|
||||
# the failed models in a multi-model will be retried one by one w/o CGO
|
||||
if len(self._trial_to_original_models[trial_id]) > 1:
|
||||
models_to_retry.append(original_model)
|
||||
for listener in self._listeners:
|
||||
listener.on_training_end(original_model, success)
|
||||
|
||||
if len(models_to_retry) > 0:
|
||||
self._submit_retry_models(models_to_retry)
|
||||
|
||||
self.available_devices.extend(self._trial_used_devices[trial_id])
|
||||
self.available_devices = sorted(list(set(self.available_devices)))
|
||||
del self._running_models[trial_id]
|
||||
|
||||
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
|
||||
merged_metrics = {}
|
||||
for idx, _ in enumerate(metrics):
|
||||
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
|
||||
for model_id in merged_metrics:
|
||||
self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
|
||||
for listener in self._listeners:
|
||||
listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
|
||||
|
||||
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
|
||||
_logger.debug(metrics)
|
||||
|
||||
if isinstance(metrics, float):
|
||||
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
|
||||
else:
|
||||
merged_metrics = {}
|
||||
for idx, _ in enumerate(metrics):
|
||||
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
|
||||
for model_id in merged_metrics:
|
||||
self._original_models[model_id].metric = merged_metrics[model_id]
|
||||
for listener in self._listeners:
|
||||
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
|
||||
|
||||
def query_available_resource(self) -> List[WorkerInfo]:
|
||||
# the _queuing_models need to use available_devices first
|
||||
self._queue_lock.acquire()
|
||||
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
|
||||
self._queue_lock.release()
|
||||
return available_for_more_models
|
||||
|
||||
def budget_exhausted(self) -> bool:
|
||||
advisor = get_advisor()
|
||||
return advisor.stopping
|
||||
|
||||
@classmethod
|
||||
def trial_execute_graph(cls) -> None:
|
||||
"""
|
||||
Initialize the model, hand it over to trainer.
|
||||
"""
|
||||
graph_data = BaseGraphData.load(receive_trial_parameters())
|
||||
_logger.info('CGO_ENGINE trial parameters received')
|
||||
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
|
||||
file_name = f'_generated_model/{random_str}.py'
|
||||
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
||||
with open(file_name, 'w') as f:
|
||||
f.write(graph_data.model_script)
|
||||
|
||||
trainer_instance = graph_data.evaluator
|
||||
model_cls = utils.import_(f'_generated_model.{random_str}._model')
|
||||
|
||||
trainer_instance.fit(model_cls())
|
||||
os.remove(file_name)
|
||||
|
||||
|
||||
class AssemblePolicy:
|
||||
@staticmethod
|
||||
def _is_related_node(model: Model, node: Node):
|
||||
if isinstance(node, AbstractLogicalNode):
|
||||
if model in node.related_models:
|
||||
return True
|
||||
else:
|
||||
if model == node.graph.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_graph_connectivity(model: Model,
|
||||
group_model: Dict[Model, Device],
|
||||
logical_plan: LogicalPlan) -> bool:
|
||||
for edge in logical_plan.logical_graph.edges:
|
||||
if AssemblePolicy._is_related_node(model, edge.head) or \
|
||||
AssemblePolicy._is_related_node(model, edge.tail):
|
||||
for grouped_model in group_model:
|
||||
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
|
||||
AssemblePolicy._is_related_node(grouped_model, edge.tail):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
|
||||
if not (isinstance(new_model.evaluator, Lightning)
|
||||
and isinstance(new_model.evaluator.module, _MultiModelSupervisedLearningModule)):
|
||||
return False
|
||||
for m in group_model:
|
||||
if not m.evaluator == new_model.evaluator:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def group(logical_plan, available_devices):
|
||||
# TODO: Packing multiple model in one GPU
|
||||
# Currently, we only support one model per GPU
|
||||
all_grouped_models = []
|
||||
group_model = {}
|
||||
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
|
||||
for idx, m in enumerate(logical_plan.models):
|
||||
# models in one group should
|
||||
# (1) not use more GPUs than available_devices
|
||||
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
|
||||
# (3) use same MultiModelSupervisedLearningModule
|
||||
if len(group_model) > 0 and \
|
||||
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
|
||||
AssemblePolicy._check_evaluator(m, group_model) == False):
|
||||
all_grouped_models.append(group_model)
|
||||
group_model = {}
|
||||
group_model[m] = available_devices[idx % len(available_devices)]
|
||||
if len(group_model) == len(available_devices) or \
|
||||
idx == len(logical_plan.models) - 1:
|
||||
all_grouped_models.append(group_model)
|
||||
group_model = {}
|
||||
return all_grouped_models
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from .logical_plan import LogicalPlan
|
||||
|
||||
|
||||
class AbstractOptimizer(ABC):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def convert(self, logical_plan: LogicalPlan) -> None:
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,336 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
from typing import Dict, Tuple, Any
|
||||
|
||||
from nni.retiarii.utils import uid
|
||||
from nni.common.device import Device, CPUDevice
|
||||
|
||||
from nni.nas.execution.common.graph import Cell, Edge, Graph, Model, Node
|
||||
from nni.nas.execution.common.graph_op import Operation, _IOPseudoOperation
|
||||
|
||||
|
||||
class AbstractLogicalNode(Node):
|
||||
def __init__(self, graph, node_id, name, operation, _internal=False):
|
||||
super().__init__(graph, node_id, name, operation, _internal=_internal)
|
||||
self.related_models = []
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
|
||||
"""
|
||||
Given a set of models to be formed in a physical model and their device placement,
|
||||
this function replaces the logical node with an executable physical node for the physical model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
multi_model_placement : dict
|
||||
a dict of models and device placement.
|
||||
These models will be assembled into the same physical model to run.
|
||||
|
||||
Returns
|
||||
-------
|
||||
node : Node
|
||||
the physical node to replace the logical node in the physical model
|
||||
placement : Device
|
||||
the device placement of the returned physical node
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _fork_to(self, graph: Graph):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LogicalGraph(Graph):
|
||||
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
|
||||
super().__init__(model, graph_id, name='logical_' + name, _internal=_internal)
|
||||
|
||||
def _dump(self) -> Any:
|
||||
nodes_dump = {}
|
||||
for node in self.hidden_nodes:
|
||||
if isinstance(node, OriginNode):
|
||||
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump()
|
||||
else:
|
||||
nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump()
|
||||
|
||||
edges_dump = []
|
||||
for edge in self.edges:
|
||||
if isinstance(edge.head, OriginNode):
|
||||
head_info = f'{edge.head.original_graph.model.model_id}_{edge.head.name}'
|
||||
else:
|
||||
head_info = edge.head.name
|
||||
if isinstance(edge.tail, OriginNode):
|
||||
tail_info = f'{edge.tail.original_graph.model.model_id}_{edge.tail.name}'
|
||||
else:
|
||||
tail_info = edge.tail.name
|
||||
edges_dump.append((head_info, tail_info))
|
||||
return {
|
||||
'inputs': self.input_node.operation.io_names,
|
||||
'outputs': self.output_node.operation.io_names,
|
||||
'nodes': nodes_dump,
|
||||
'edges': edges_dump
|
||||
}
|
||||
|
||||
def _fork_to(self, model: Model) -> Graph:
|
||||
new_graph = Graph(model, self.id, self.name,
|
||||
_internal=True)._register()
|
||||
|
||||
for node in self.hidden_nodes:
|
||||
if isinstance(node, AbstractLogicalNode):
|
||||
node._fork_to(new_graph)
|
||||
else:
|
||||
Node(new_graph, node.id, node.name,
|
||||
node.operation, _internal=True)._register()
|
||||
|
||||
id_to_new_node = {node.__repr__(): node for node in new_graph.nodes}
|
||||
|
||||
for edge in self.edges:
|
||||
new_head = id_to_new_node[edge.head.__repr__()]
|
||||
new_tail = id_to_new_node[edge.tail.__repr__()]
|
||||
Edge((new_head, edge.head_slot),
|
||||
(new_tail, edge.tail_slot), _internal=True)._register()
|
||||
|
||||
return new_graph
|
||||
|
||||
|
||||
class OriginNode(AbstractLogicalNode):
|
||||
"""
|
||||
This is logical node representing the original node without any modification.
|
||||
In assemble, just return the original node along with the physical placement given by multi_model_placement.
|
||||
"""
|
||||
|
||||
def __init__(self, logical_graph: LogicalGraph,
|
||||
original_graph: Graph, original_node: Node,
|
||||
name: str, operation, _internal=False):
|
||||
super().__init__(logical_graph, original_node.id, name, operation)
|
||||
self.original_graph = original_graph
|
||||
self.original_node = original_node
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
|
||||
model_id = self.original_node.graph.model.model_id
|
||||
new_node = Node(self.original_node.graph, self.original_node.id,
|
||||
f"M_{model_id}_" +
|
||||
self.original_node.name,
|
||||
self.original_node.operation)
|
||||
return new_node, multi_model_placement[self.original_node.graph.model]
|
||||
|
||||
def __repr__(self):
|
||||
return f'OriginNode(id={self.id}, name={self.name}, \
|
||||
operation={self.operation}, origin_model_id={self.original_graph.model.model_id})'
|
||||
|
||||
def _fork_to(self, graph: Graph):
|
||||
OriginNode(graph, self.original_graph, self.original_node,
|
||||
self.name, self.operation)._register()
|
||||
|
||||
|
||||
class LogicalPlan:
|
||||
def __init__(self, plan_id=0) -> None:
|
||||
self.lp_model = Model(_internal=True)
|
||||
self.id = plan_id
|
||||
self.logical_graph = LogicalGraph(
|
||||
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
|
||||
self.lp_model._root_graph_name = self.logical_graph.name
|
||||
self.models = []
|
||||
|
||||
def add_model(self, model: Model):
|
||||
self.models.append(model)
|
||||
# Only optimize the root graph.
|
||||
self._merge_graph(model.root_graph)
|
||||
|
||||
def _merge_graph(self, from_graph):
|
||||
to_graph = self.logical_graph
|
||||
id_to_new_node = {} # old node ID -> new node object
|
||||
|
||||
for old_node in from_graph.nodes:
|
||||
new_node = OriginNode(to_graph, old_node.graph,
|
||||
old_node, old_node.name,
|
||||
old_node.operation, _internal=True)._register()
|
||||
id_to_new_node[old_node.id] = new_node
|
||||
|
||||
for edge in from_graph.edges:
|
||||
new_head = id_to_new_node[edge.head.id]
|
||||
new_tail = id_to_new_node[edge.tail.id]
|
||||
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, Device]) \
|
||||
-> Tuple[Model, Dict[Node, Device]]:
|
||||
"""
|
||||
Given a set of models to be formed in a physical model and their device placement,
|
||||
this function replaces all the logical node in this LogicalPlan with executable physical nodes
|
||||
for the physical model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
multi_model_placement : dict
|
||||
a dict of models and device placement.
|
||||
These models will be assembled into the same physical model to run.
|
||||
|
||||
Returns
|
||||
-------
|
||||
phy_model : Model
|
||||
the physical model formed by models in `multi_model_placement`
|
||||
all logical node are replaced by physical nodes
|
||||
node_placements : dict
|
||||
the device placement of the nodes in `phy_model`
|
||||
"""
|
||||
phy_model = Model(_internal=True)
|
||||
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
|
||||
phy_graph._rename_graph(phy_graph.name, "_model")
|
||||
|
||||
# merge sub-graphs
|
||||
for model in multi_model_placement:
|
||||
if phy_model.evaluator is None and model.evaluator is not None:
|
||||
phy_model.evaluator = model.evaluator
|
||||
for graph_name in model.graphs:
|
||||
if graph_name != model._root_graph_name:
|
||||
new_graph = model.graphs[graph_name]._fork_to(
|
||||
phy_model, name_prefix=f'M_{model.model_id}_')
|
||||
|
||||
# prefix of M_ of hidden_nodes name in non-root graphs is added here
|
||||
for new_node in new_graph.hidden_nodes:
|
||||
if isinstance(new_node.operation, Cell):
|
||||
old_cell_name = new_node.operation.cell_name
|
||||
new_node.operation = copy.deepcopy(new_node.operation)
|
||||
new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'
|
||||
|
||||
assert(phy_model.evaluator is not None)
|
||||
|
||||
# When replace logical nodes, merge the training configs when
|
||||
# input/output nodes are replaced.
|
||||
evaluator_slot = {} # Model ID -> Slot ID
|
||||
input_slot_mapping = {}
|
||||
output_slot_mapping = {}
|
||||
# Replace all logical nodes to executable physical nodes
|
||||
hidden_nodes = phy_graph.hidden_nodes.copy()
|
||||
node_placements = {}
|
||||
|
||||
added_models = []
|
||||
|
||||
for node in hidden_nodes:
|
||||
if isinstance(node, OriginNode):
|
||||
model_id = node.original_graph.model.model_id
|
||||
if node.original_graph.model not in multi_model_placement:
|
||||
for edge in node.incoming_edges:
|
||||
edge.remove()
|
||||
for edge in node.outgoing_edges:
|
||||
edge.remove()
|
||||
node.remove()
|
||||
continue
|
||||
|
||||
if isinstance(node, AbstractLogicalNode):
|
||||
new_node, placement = node.assemble(multi_model_placement)
|
||||
if isinstance(new_node.operation, _IOPseudoOperation):
|
||||
model_id = new_node.graph.model.model_id
|
||||
if model_id not in evaluator_slot:
|
||||
added_models.append(model_id)
|
||||
evaluator_slot[model_id] = len(added_models) - 1
|
||||
slot = evaluator_slot[model_id]
|
||||
else:
|
||||
slot = evaluator_slot[model_id]
|
||||
# If a model's inputs/outputs are not used in the multi-model
|
||||
# the codegen and trainer should not generate and use them
|
||||
# "use_input" and "use_output" are used to mark whether
|
||||
# an input/output of a model is used in a multi-model
|
||||
if new_node.operation.type == '_inputs':
|
||||
input_slot_mapping[new_node] = slot
|
||||
if new_node.operation.type == '_outputs':
|
||||
output_slot_mapping[new_node] = slot
|
||||
|
||||
self.node_replace(node, new_node)
|
||||
|
||||
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
|
||||
# FIXME: merge this rename with non-root graph, only do once.
|
||||
if isinstance(new_node.operation, Cell):
|
||||
old_cell_name = new_node.operation.cell_name
|
||||
new_node.operation = copy.deepcopy(new_node.operation)
|
||||
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
|
||||
|
||||
# input should be at CPU, move it to GPU first if necessary
|
||||
if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
|
||||
# hack: only support single_server
|
||||
node_placements[new_node] = CPUDevice(node_id=placement.node_id)
|
||||
else:
|
||||
node_placements[new_node] = placement
|
||||
|
||||
node.remove()
|
||||
|
||||
# If two nodes are placed on different devices, use ToDevice op to copy the node
|
||||
# TODO: when copying one node to multiple devices, broadcast is more efficient than P2P communication
|
||||
existing_edges = phy_graph.edges.copy()
|
||||
# Avoid a node is copied multiple times on the same device
|
||||
copied_op: Dict[Tuple(Node, Device), Node] = {}
|
||||
for edge in existing_edges:
|
||||
head_placement = node_placements[edge.head]
|
||||
tail_placement = node_placements[edge.tail]
|
||||
if head_placement != tail_placement:
|
||||
if head_placement.node_id != tail_placement.node_id:
|
||||
raise ValueError('Cross-server placement is not supported.')
|
||||
# Same server different devices
|
||||
if (edge.head, tail_placement) in copied_op:
|
||||
to_node = copied_op[(edge.head, tail_placement)]
|
||||
else:
|
||||
dst_name = edge.head.name + "_to_" + edge.tail.name
|
||||
to_operation = Operation.new(
|
||||
'ToDevice', {
|
||||
"device": tail_placement, "src": (
|
||||
edge.head.name, edge.head_slot), "dst": dst_name})
|
||||
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
|
||||
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
|
||||
copied_op[(edge.head, tail_placement)] = to_node
|
||||
node_placements[to_node] = head_placement
|
||||
edge.head = to_node
|
||||
edge.head_slot = None
|
||||
|
||||
# merge all input nodes into one with multiple slots
|
||||
input_nodes = []
|
||||
for node in phy_graph.hidden_nodes:
|
||||
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_inputs':
|
||||
input_nodes.append(node)
|
||||
|
||||
for edge in phy_graph.edges:
|
||||
if edge.head in input_nodes:
|
||||
edge.head_slot = input_slot_mapping[edge.head]
|
||||
edge.head = phy_graph.input_node
|
||||
|
||||
# merge all output nodes into one with multiple slots
|
||||
output_nodes = []
|
||||
for node in phy_graph.hidden_nodes:
|
||||
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
|
||||
output_nodes.append(node)
|
||||
|
||||
for edge in phy_graph.edges:
|
||||
if edge.tail in output_nodes:
|
||||
edge.tail_slot = output_slot_mapping[edge.tail]
|
||||
edge.tail = phy_graph.output_node
|
||||
|
||||
for node in input_nodes:
|
||||
node.remove()
|
||||
for node in output_nodes:
|
||||
node.remove()
|
||||
|
||||
return phy_model, node_placements
|
||||
|
||||
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
|
||||
# TODO: currently, only support single input slot and output slot.
|
||||
if input_slot_mapping is not None or output_slot_mapping is not None:
|
||||
raise ValueError('Slot mapping is not supported')
|
||||
|
||||
phy_graph = old_node.graph
|
||||
new_node.graph = phy_graph
|
||||
|
||||
new_node._register()
|
||||
|
||||
for edge in phy_graph.edges:
|
||||
if edge.head == old_node:
|
||||
edge.head = new_node
|
||||
elif edge.tail == old_node:
|
||||
edge.tail = new_node
|
||||
|
||||
# after the replacement, there might be multiple duplicated edges
|
||||
# with the same input and output nodes, which should be de-duplicated
|
||||
self._remove_duplicated_edges()
|
||||
|
||||
def _remove_duplicated_edges(self):
|
||||
# TODO: it does not have duplicated edges if only supporting dedup input
|
||||
# Duplicated edges appear when a chain of prefix nodes are deduplicated
|
||||
pass
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from nni.nas.utils import uid
|
||||
from nni.nas.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
|
||||
from nni.common.device import GPUDevice
|
||||
|
||||
from nni.nas.execution.common.graph import Graph, Model, Node
|
||||
from .interface import AbstractOptimizer
|
||||
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
|
||||
OriginNode)
|
||||
|
||||
|
||||
_supported_evaluators = [MultiModelSupervisedLearningModule]
|
||||
|
||||
|
||||
class DedupInputNode(AbstractLogicalNode):
|
||||
"""
|
||||
This is logical node representing the node for deduplication.
|
||||
In assemble, just return one copy of the original node when multiple models are assembled.
|
||||
These models will share the result of once calculation.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, logical_graph: LogicalGraph, node_id: int,
|
||||
nodes_to_dedup: List[Node], _internal=False):
|
||||
super().__init__(logical_graph, node_id,
|
||||
"Dedup_" + nodes_to_dedup[0].name,
|
||||
nodes_to_dedup[0].operation)
|
||||
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
|
||||
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
|
||||
|
||||
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
|
||||
for node in self.origin_nodes:
|
||||
if node.original_graph.model in multi_model_placement:
|
||||
new_node = Node(node.original_graph, node.id,
|
||||
f'M_{node.original_graph.model.model_id}_{node.name}',
|
||||
node.operation)
|
||||
return new_node, multi_model_placement[node.original_graph.model]
|
||||
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
|
||||
|
||||
def _fork_to(self, graph: Graph):
|
||||
DedupInputNode(graph, self.id, self.origin_nodes)._register()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'DedupNode(id={self.id}, name={self.name}, \
|
||||
len(nodes_to_dedup)={len(self.origin_nodes)}'
|
||||
|
||||
|
||||
class DedupInputOptimizer(AbstractOptimizer):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _check_supported_evaluator(self, evaluator):
|
||||
for e in _supported_evaluators:
|
||||
if isinstance(evaluator, e):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_deduplicate_by_node(self, root_node, node_to_check):
|
||||
if root_node == node_to_check:
|
||||
return True
|
||||
if root_node.operation.type == '_inputs' and \
|
||||
node_to_check.operation.type == '_inputs' and \
|
||||
isinstance(root_node, OriginNode) and \
|
||||
isinstance(node_to_check, OriginNode):
|
||||
if self._check_supported_evaluator(root_node.original_graph.model.evaluator):
|
||||
return False
|
||||
if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def convert(self, logical_plan: LogicalPlan) -> None:
|
||||
nodes_to_skip = set()
|
||||
while True: # repeat until the logical_graph converges
|
||||
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
|
||||
# _PseudoOperation(type_name="_inputs"))
|
||||
root_node = None
|
||||
for node in input_nodes:
|
||||
if node in nodes_to_skip:
|
||||
continue
|
||||
root_node = node
|
||||
break
|
||||
if root_node is None:
|
||||
break # end of convert
|
||||
else:
|
||||
nodes_to_dedup = []
|
||||
for node in input_nodes:
|
||||
if node in nodes_to_skip:
|
||||
continue
|
||||
if self._check_deduplicate_by_node(root_node, node):
|
||||
nodes_to_dedup.append(node)
|
||||
assert(len(nodes_to_dedup) >= 1)
|
||||
if len(nodes_to_dedup) == 1:
|
||||
assert(nodes_to_dedup[0] == root_node)
|
||||
nodes_to_skip.add(root_node)
|
||||
else:
|
||||
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
|
||||
for edge in logical_plan.logical_graph.edges:
|
||||
if edge.head in nodes_to_dedup:
|
||||
edge.head = dedup_node
|
||||
if edge.tail in nodes_to_dedup:
|
||||
edge.tail = dedup_node
|
||||
for node in nodes_to_dedup:
|
||||
node.remove()
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = ['model_to_pytorch_script']
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Any, cast
|
||||
|
||||
from nni.common.device import Device, GPUDevice
|
||||
from nni.nas.execution.common.graph import IllegalGraphError, Edge, Graph, Node, Model
|
||||
from nni.nas.execution.common.graph_op import PyTorchOperation
|
||||
from nni.nas.utils import STATE_DICT_PY_MAPPING
|
||||
|
||||
from .op_def import ToDevice
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def model_to_pytorch_script(model: Model, placement=None) -> str:
|
||||
graphs = []
|
||||
total_pkgs = set()
|
||||
for name, cell in model.graphs.items():
|
||||
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
|
||||
graphs.append(graph_code)
|
||||
total_pkgs.update(import_pkgs)
|
||||
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
|
||||
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
|
||||
|
||||
|
||||
def _sorted_incoming_edges(node: Node) -> List[Edge]:
|
||||
edges = [edge for edge in node.graph.edges if edge.tail is node]
|
||||
_logger.debug('sorted_incoming_edges: %s', str(edges))
|
||||
if not edges:
|
||||
return []
|
||||
_logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
|
||||
if all(edge.tail_slot is None for edge in edges):
|
||||
return edges
|
||||
if all(isinstance(edge.tail_slot, int) for edge in edges):
|
||||
edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
|
||||
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
|
||||
return edges
|
||||
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
|
||||
|
||||
|
||||
def _format_inputs(node: Node, graph_name: str) -> Tuple[List[str], List[Any]]:
|
||||
"""
|
||||
Format the inputs of a given node.
|
||||
Inputs will be formatted with ``_format_variable_name``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node : Node
|
||||
a graph node, get and format its inputs
|
||||
graph_name : str
|
||||
subgraph name, to format variable names
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
the list of input names
|
||||
list
|
||||
the list of input values, if an input is simple type, record its value,
|
||||
otherwise the value is None
|
||||
"""
|
||||
edges = _sorted_incoming_edges(node)
|
||||
inputs = []
|
||||
inputs_value = []
|
||||
for edge in edges:
|
||||
if edge.head.name == '_inputs':
|
||||
assert isinstance(edge.head_slot, int)
|
||||
if edge.head.operation.io_names is not None:
|
||||
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
|
||||
inputs.append(_format_variable_name(edge.head.operation.io_names[edge.head_slot], graph_name))
|
||||
else:
|
||||
# when input has no name, e.g., forward(*_inputs)
|
||||
inputs.append('_inputs[{}]'.format(edge.head_slot))
|
||||
inputs_value.append(None)
|
||||
else:
|
||||
if edge.head_slot is None:
|
||||
# when the input comes from a single-output operator
|
||||
inputs.append(_format_variable_name(edge.head.name, graph_name))
|
||||
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
|
||||
'value' in edge.head.operation.parameters:
|
||||
inputs_value.append(edge.head.operation.parameters['value'])
|
||||
else:
|
||||
inputs_value.append(None)
|
||||
else:
|
||||
# when the input comes from a multi-output operator: needs to know which one it comes from
|
||||
inputs.append('{}[{}]'.format(_format_variable_name(edge.head.name, graph_name), edge.head_slot))
|
||||
inputs_value.append(None)
|
||||
return inputs, inputs_value
|
||||
|
||||
|
||||
def _format_variable_name(name: str, graph_name: str) -> str:
|
||||
"""
|
||||
1. replace invalid characters in node name
|
||||
2. variables name (full name space) is too long, shorten the name by removing the prefix ```graph_name```
|
||||
"""
|
||||
name = name[len(graph_name):] if name.startswith(graph_name) else name
|
||||
name = name.replace('/', '__')
|
||||
|
||||
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
|
||||
name = re.sub(r'\W|^(?=\d)','_', name)
|
||||
|
||||
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
|
||||
# name can't start with double underscore
|
||||
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
|
||||
# but it's actually very common in our generated code
|
||||
name = name[1:]
|
||||
elif name.startswith('_'):
|
||||
# to avoid conflicts between '_' and '__'
|
||||
name = 'i' + name
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
|
||||
'''
|
||||
Since CUDA_VISIBLE_DEVICES will be set to the list of real GPU ID,
|
||||
we need to remap the GPU ID when generating code to match them correctly.
|
||||
For example, when CUDA_VISIBLE_DEVICES="0,3", we need to use "cuda:0", "cuda:1" in the generated code.
|
||||
'''
|
||||
unique_devices = sorted(list(set([e for e in placement.values() if isinstance(e, GPUDevice)])))
|
||||
node_gpu_cnt = {}
|
||||
cuda_remapped_id = {}
|
||||
for d in unique_devices:
|
||||
if d.node_id not in node_gpu_cnt:
|
||||
node_gpu_cnt[d.node_id] = 0
|
||||
node_gpu_cnt[d.node_id] += 1
|
||||
cuda_remapped_id[d] = node_gpu_cnt[d.node_id] - 1
|
||||
|
||||
return cuda_remapped_id
|
||||
|
||||
|
||||
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
|
||||
nodes = graph.topo_sort()
|
||||
|
||||
# handle module node and function node differently
|
||||
# only need to generate code for module here
|
||||
import_pkgs = set()
|
||||
node_codes = []
|
||||
node_python_mappings = {}
|
||||
cuda_remapped_id = None
|
||||
if placement:
|
||||
cuda_remapped_id = generate_cuda_mapping(placement)
|
||||
for node in nodes:
|
||||
if node.operation:
|
||||
if placement and isinstance(node.operation, ToDevice):
|
||||
cuda_remapped_id = cast(dict, cuda_remapped_id)
|
||||
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
|
||||
|
||||
if node.operation.type == 'shared':
|
||||
continue
|
||||
pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
|
||||
if pkg_name is not None:
|
||||
import_pkgs.add(pkg_name)
|
||||
|
||||
py_variable_name = _format_variable_name(node.name, graph_name)
|
||||
node_code = node.operation.to_init_code(py_variable_name)
|
||||
if node_code is not None:
|
||||
if placement and node in placement and len(node_code) > 0:
|
||||
if isinstance(placement[node], GPUDevice):
|
||||
assert cuda_remapped_id is not None
|
||||
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
|
||||
else:
|
||||
device_repr = placement[node].device_repr()
|
||||
node_codes.append(f"{node_code}.to('{device_repr}')")
|
||||
else:
|
||||
node_codes.append(node_code)
|
||||
|
||||
# Map to module hierarchies in original search space python code
|
||||
node_python_mappings[py_variable_name] = node.python_name
|
||||
|
||||
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
|
||||
|
||||
if graph.input_node.operation.io_names is None:
|
||||
input_code = '*_inputs'
|
||||
else:
|
||||
for name in graph.input_node.operation.io_names:
|
||||
assert not name.startswith(graph_name)
|
||||
input_code = ', '.join(graph.input_node.operation.io_names)
|
||||
|
||||
edge_codes = []
|
||||
sorted_nodes = graph.topo_sort()
|
||||
for node in sorted_nodes:
|
||||
if node.operation:
|
||||
inputs, inputs_value = _format_inputs(node, graph_name)
|
||||
node_name = _format_variable_name(node.name, graph_name)
|
||||
submodule_name = node_name
|
||||
if node.operation.type == 'shared':
|
||||
submodule_name = _format_variable_name(node.operation.parameters['reference'], graph_name)
|
||||
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
|
||||
|
||||
output_names, _ = _format_inputs(graph.output_node, graph_name)
|
||||
if not output_names:
|
||||
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
|
||||
output_code = ', '.join(output_names)
|
||||
|
||||
linebreak = '\n '
|
||||
return import_pkgs, _PyTorchModelTemplate.format(
|
||||
graph_name=('Graph' if graph_name == '_graph' else graph_name),
|
||||
inputs=input_code,
|
||||
outputs=output_code,
|
||||
nodes=linebreak.join(node_codes),
|
||||
edges=linebreak.join(edge_codes)
|
||||
)
|
||||
|
||||
|
||||
# TODO: handle imports
|
||||
|
||||
_PyTorchScriptTemplate = '''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
import nni.nas.nn.pytorch
|
||||
|
||||
{}
|
||||
|
||||
{}
|
||||
'''
|
||||
|
||||
_PyTorchModelTemplate = '''
|
||||
class {graph_name}(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
{nodes}
|
||||
|
||||
def forward(self, {inputs}):
|
||||
{edges}
|
||||
return {outputs}
|
||||
'''
|
2
nni/algorithms/nas/pytorch/cream/__init__.py → nni/nas/execution/pytorch/converter/__init__.py
Executable file → Normal file
2
nni/algorithms/nas/pytorch/cream/__init__.py → nni/nas/execution/pytorch/converter/__init__.py
Executable file → Normal file
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .trainer import CreamSupernetTrainer
|
||||
from .graph_gen import convert_to_graph
|
|
@ -0,0 +1,848 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from nni.nas.execution.common import Graph, Model, Node, Cell, Operation
|
||||
from nni.nas.nn.pytorch import InputChoice, Placeholder, LayerChoice
|
||||
from nni.nas.utils import get_init_parameters_or_fail, get_importable_name
|
||||
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
|
||||
from .utils import (
|
||||
_convert_name, build_full_name, _without_shape_info,
|
||||
_extract_info_from_trace_node, get_full_name_by_scope_name,
|
||||
is_layerchoice_node, match_node, build_cand_name,
|
||||
build_python_name
|
||||
)
|
||||
|
||||
|
||||
class GraphConverter:
|
||||
def __init__(self):
|
||||
self.global_seq = 0
|
||||
self.global_graph_id = 0
|
||||
|
||||
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
|
||||
if _input in output_remap:
|
||||
assert output_remap[_input].kind() == 'aten::append'
|
||||
predecessor_node = output_remap[_input]
|
||||
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
|
||||
src_node_idx = None
|
||||
src_node = node_index[predecessor_node]
|
||||
assert isinstance(src_node, Node)
|
||||
elif _input in graph_inputs:
|
||||
idx = graph_inputs.index(_input)
|
||||
src_node = ir_graph.input_node
|
||||
src_node_idx = idx
|
||||
else:
|
||||
predecessor_node = _input.node()
|
||||
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
|
||||
# find out the index of _input in the outputs of predecessor_node
|
||||
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
|
||||
if len(predecessor_outputs) == 1:
|
||||
idx = None
|
||||
else:
|
||||
idx = predecessor_outputs.index(_input)
|
||||
ir_predecessor_node = node_index[predecessor_node]
|
||||
src_node_idx = idx
|
||||
assert isinstance(ir_predecessor_node, Node)
|
||||
src_node = ir_predecessor_node
|
||||
return src_node, src_node_idx
|
||||
|
||||
def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
ir_graph : Graph
|
||||
node : torch._C.Node
|
||||
graph_inputs : List[torch._C.Value]
|
||||
a list of a script graph's inputs
|
||||
node_index : Dict
|
||||
new_node : Node
|
||||
newly created ir node corresponding to `node`
|
||||
output_remap : Dict
|
||||
ignore_first : bool
|
||||
if it is true, skip the first input
|
||||
"""
|
||||
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
|
||||
new_node_input_idx = 0
|
||||
for _input in node.inputs():
|
||||
if ignore_first:
|
||||
ignore_first = False
|
||||
continue
|
||||
# handle source node
|
||||
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
|
||||
# handle destination node
|
||||
dst_node = new_node
|
||||
if is_single_input:
|
||||
dst_node_idx = None
|
||||
else:
|
||||
dst_node_idx = new_node_input_idx
|
||||
# create edge
|
||||
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
|
||||
|
||||
new_node_input_idx += 1
|
||||
|
||||
def create_prim_constant_node(self, ir_graph, node, module_name):
|
||||
# NOTE: compare with string not type, because the type is defined in pytorch C code.
|
||||
# `.kind()` can also be used here
|
||||
if node.outputsAt(0).type().str() == 'None':
|
||||
attrs = {'type': 'None'}
|
||||
else:
|
||||
attrs = {'type': node.outputsAt(0).type().str(), 'value': node.outputsAt(0).toIValue()}
|
||||
self.global_seq += 1
|
||||
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
|
||||
node.kind(), attrs)
|
||||
return new_node
|
||||
|
||||
def handle_prim_attr_node(self, node, module):
|
||||
assert node.hasAttribute('name')
|
||||
value = None
|
||||
if node.inputsAt(0).debugName() == 'self':
|
||||
_val = getattr(module, node.s('name'))
|
||||
# TODO: serialize complex data type, and output proper error message
|
||||
if isinstance(_val, (int, float, str, bool)):
|
||||
value = _val
|
||||
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName(), 'value': value}
|
||||
return node.kind(), attrs
|
||||
|
||||
def _remove_mangle(self, module_type_str):
|
||||
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
|
||||
|
||||
def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
ir_graph : Graph
|
||||
our ir graph representation
|
||||
targeted_type : str
|
||||
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
|
||||
```None``` means removing all the nodes whose fanout is 0.
|
||||
"""
|
||||
# build index of outputs of Node(s)
|
||||
node_fanout = set()
|
||||
for edge in ir_graph.edges:
|
||||
if edge.head.id not in node_fanout:
|
||||
node_fanout.add(edge.head.id)
|
||||
|
||||
to_removes = []
|
||||
for hidden_node in ir_graph.hidden_nodes:
|
||||
if hidden_node.id not in node_fanout:
|
||||
assert isinstance(hidden_node, Node)
|
||||
if targeted_type is None:
|
||||
to_removes.append(hidden_node)
|
||||
elif hidden_node.operation.type == targeted_type:
|
||||
to_removes.append(hidden_node)
|
||||
|
||||
for hidden_node in to_removes:
|
||||
hidden_node.remove()
|
||||
|
||||
def handle_graph_nodes(self, script_module, sm_graph,
|
||||
module, module_name, module_python_name,
|
||||
ir_model, ir_graph,
|
||||
shared_module_index=None):
|
||||
"""
|
||||
Convert torch script node to our node ir, and build our graph ir
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script_module : torch.jit.RecursiveScriptModule
|
||||
the torch script of ```module```
|
||||
sm_graph : torch._C.Graph
|
||||
the graph in torch script
|
||||
module : nn.Module
|
||||
the targeted pytorch module
|
||||
module_name : str
|
||||
```module```'s name
|
||||
ir_model : Model
|
||||
the whole graph ir
|
||||
ir_graph : Graph
|
||||
the graph ir of ```module```
|
||||
shared_module_index : dict
|
||||
it is used for knowing which module has been created an ir node,
|
||||
if created and invoked again, then the new ir node can simply reference that ir node.
|
||||
this way we can identify shared modules (i.e., one module invoked multiple times in `forward` function)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the mapping from graph node to our graph ir node
|
||||
"""
|
||||
# handle inputs
|
||||
graph_inputs = []
|
||||
for _input in sm_graph.inputs():
|
||||
if _input.debugName() == 'self':
|
||||
assert _input.unique() == 0
|
||||
continue
|
||||
graph_inputs.append(_input)
|
||||
# TODO: add scope name
|
||||
ir_graph._add_input(_convert_name(_input.debugName()))
|
||||
|
||||
node_index = {} # graph node to graph ir node
|
||||
if shared_module_index is None:
|
||||
shared_module_index = {}
|
||||
|
||||
# some node does not have output but it modifies a variable, for example aten::append
|
||||
# %17 : Tensor[] = aten::append(%out.1, %16)
|
||||
# %out.1 is updated, and %17 is None
|
||||
# we add output to this type of node and connect it to the following node which uses %out.1
|
||||
# key: tensor (%out.1), value: node (this node)
|
||||
output_remap = {}
|
||||
|
||||
# ===================handle control flow: if===================
|
||||
def handle_if_condition(cond_tensor):
|
||||
"""
|
||||
to calculate the condition, we only deal with the following op types by tracing back
|
||||
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
|
||||
|
||||
generate the expression using recursive calls
|
||||
|
||||
NOTE: do not support dynamic graph
|
||||
"""
|
||||
def _generate_expr(tensor):
|
||||
if tensor.node().kind() == 'prim::GetAttr':
|
||||
return f'({getattr(module, tensor.node().s("name"))})'
|
||||
elif tensor.node().kind() == 'aten::__getitem__':
|
||||
t = _generate_expr(tensor.node().inputsAt(0))
|
||||
idx = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({t}[{idx}])'
|
||||
elif tensor.node().kind() == 'prim::Constant':
|
||||
return f'{tensor.toIValue()}'
|
||||
elif tensor.node().kind() == 'aten::eq':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} == {right})'
|
||||
elif tensor.node().kind() == 'aten::le':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} <= {right})'
|
||||
elif tensor.node().kind() == 'aten::ge':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} >= {right})'
|
||||
elif tensor.node().kind() == 'aten::__not__':
|
||||
value = _generate_expr(tensor.node().inputsAt(0))
|
||||
return f'(not {value})'
|
||||
elif tensor.node().kind() == 'aten::Bool':
|
||||
value = _generate_expr(tensor.node().inputsAt(0))
|
||||
return f'bool({value})'
|
||||
elif tensor.node().kind() == 'aten::__is__':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} is {right})'
|
||||
elif tensor.node().kind() == 'aten::__isnot__':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} is not {right})'
|
||||
elif tensor.node().kind() == 'aten::ne':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} != {right})'
|
||||
elif tensor.node().kind() == 'aten::gt':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} > {right})'
|
||||
elif tensor.node().kind() == 'aten::lt':
|
||||
left = _generate_expr(tensor.node().inputsAt(0))
|
||||
right = _generate_expr(tensor.node().inputsAt(1))
|
||||
return f'({left} < {right})'
|
||||
elif tensor.node().kind() == 'prim::If':
|
||||
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
|
||||
elif tensor.node().kind() == 'aten::abs':
|
||||
value = _generate_expr(tensor.node().inputsAt(0))
|
||||
return f'(torch.abs({value}))'
|
||||
elif tensor.node().kind() == 'aten::sum':
|
||||
value = _generate_expr(tensor.node().inputsAt(0))
|
||||
return f'(torch.sum({value}))'
|
||||
elif tensor.node().kind() == 'aten::item':
|
||||
value = _generate_expr(tensor.node().inputsAt(0))
|
||||
return f'({value}.item())'
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
|
||||
'you are suggested to decorate the corresponding class with "@basic_unit".')
|
||||
expr = _generate_expr(cond_tensor)
|
||||
return eval(expr)
|
||||
|
||||
def handle_if_node(node):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
node : torch._C.Node
|
||||
the node from TorchScript graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
Node
|
||||
the created node ir
|
||||
"""
|
||||
# only deal with input of prim::If is constant or attribute for now
|
||||
# will support constant expression in future
|
||||
inputs = [i for i in node.inputs()]
|
||||
assert len(inputs) == 1
|
||||
cond = handle_if_condition(inputs[0])
|
||||
chosen_block = 0 if cond else 1
|
||||
blocks = [block for block in node.blocks()]
|
||||
assert len(blocks) == 2
|
||||
last_block_node = None
|
||||
for node in blocks[chosen_block].nodes():
|
||||
last_block_node = handle_single_node(node)
|
||||
self.global_seq += 1
|
||||
new_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
|
||||
self._add_edge(ir_graph, blocks[chosen_block].returnNode(), graph_inputs, node_index, new_node, output_remap)
|
||||
last_block_node = new_node
|
||||
return last_block_node
|
||||
|
||||
# ===================handle function call===================
|
||||
def handle_function_callmethod(node):
|
||||
# get and handle the first input, which should be an nn.Module
|
||||
assert node.hasAttribute('name')
|
||||
# NOTE: "forward__0" is hacky, LSTM instance is parsed to call forward__0 in torchscript
|
||||
if node.s('name') in ['forward', 'forward__0']:
|
||||
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
|
||||
submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
|
||||
submodule = node.inputsAt(0).node()
|
||||
assert submodule.kind() == 'prim::GetAttr'
|
||||
assert submodule.hasAttribute('name')
|
||||
submodule_name = submodule.s('name')
|
||||
|
||||
if submodule.inputsAt(0).debugName() == 'self':
|
||||
# module is usually instantiated in __init__.
|
||||
# when calling a module in forward,
|
||||
# prim::GetAttr is used to obtain the module in torch script.
|
||||
# therefore, we do this check for a module. example below:
|
||||
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
|
||||
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
|
||||
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
|
||||
submodule_name, script_module._modules.keys())
|
||||
|
||||
submodule_full_name = build_full_name(module_name, submodule_name)
|
||||
submodule_python_name = build_python_name(module_python_name, submodule_name)
|
||||
submodule_obj = getattr(module, submodule_name)
|
||||
subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
|
||||
submodule_obj,
|
||||
submodule_full_name, submodule_python_name,
|
||||
ir_model)
|
||||
else:
|
||||
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
|
||||
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
|
||||
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
|
||||
if submodule.inputsAt(0).type().name() == 'ModuleList':
|
||||
# handle ModuleList
|
||||
predecessor = submodule.inputsAt(0).node()
|
||||
module_name_space = [submodule_name]
|
||||
while predecessor.inputsAt(0).debugName() != 'self':
|
||||
# this is for dealing with nested ModuleList. below is an example
|
||||
# %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
|
||||
# %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
|
||||
# %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
|
||||
# %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
|
||||
# %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
|
||||
# %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
|
||||
# %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
|
||||
# %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
|
||||
# %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
|
||||
assert predecessor.kind() == 'prim::GetAttr'
|
||||
module_name_space.append(predecessor.s('name'))
|
||||
predecessor = predecessor.inputsAt(0).node()
|
||||
assert predecessor.kind() == 'prim::GetAttr'
|
||||
assert predecessor.hasAttribute('name')
|
||||
module_name_space.append(predecessor.s('name'))
|
||||
submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
|
||||
submodule_python_name = build_python_name(module_python_name, list(reversed(module_name_space)))
|
||||
submodule_obj = module
|
||||
script_submodule = script_module
|
||||
for each_name in list(reversed(module_name_space)):
|
||||
submodule_obj = getattr(submodule_obj, each_name)
|
||||
script_submodule = script_submodule._modules[each_name]
|
||||
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name,
|
||||
submodule_python_name, ir_model)
|
||||
else:
|
||||
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
|
||||
|
||||
if submodule_full_name in shared_module_index:
|
||||
# this module is invoked more than once, the ir node has already been created
|
||||
# create a reference node for it.
|
||||
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
|
||||
self.global_seq += 1
|
||||
shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
|
||||
shared_node_python_name = build_python_name(submodule_python_name, self.global_seq)
|
||||
shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
|
||||
subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
|
||||
subcell.python_name = shared_node_python_name
|
||||
else:
|
||||
# this module is processed for the first time, build cell for it
|
||||
if subgraph is None:
|
||||
# if we do not parse this module's graph, we create Node for this module
|
||||
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
|
||||
subcell.python_name = submodule_python_name
|
||||
if isinstance(submodule_obj, Placeholder):
|
||||
subcell.update_label(submodule_obj.label)
|
||||
elif isinstance(submodule_obj, InputChoice):
|
||||
subcell.update_label(sub_m_attrs['label'])
|
||||
else:
|
||||
# Graph already created, create Cell for it
|
||||
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
|
||||
subcell = ir_graph.add_node(submodule_full_name, new_cell)
|
||||
subcell.python_name = submodule_python_name
|
||||
shared_module_index[submodule_full_name] = subcell
|
||||
node_index[node] = subcell
|
||||
# connect the cell into graph
|
||||
self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
|
||||
else:
|
||||
# handle normal member function
|
||||
assert hasattr(script_module, node.s('name'))
|
||||
# TODO: support non member functions
|
||||
assert node.inputsAt(0).debugName() == 'self'
|
||||
script_method = getattr(script_module, node.s('name')) # <class 'torch._C.ScriptMethod'>
|
||||
|
||||
# step #1: generate graph ir for this method
|
||||
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
|
||||
self.handle_graph_nodes(script_module, script_method.graph, module,
|
||||
module_name, module_python_name, ir_model, method_ir_graph, shared_module_index)
|
||||
self.refine_graph(method_ir_graph)
|
||||
|
||||
# step #2: merge this graph to its module graph
|
||||
for h_node in method_ir_graph.hidden_nodes:
|
||||
h_node.graph = ir_graph
|
||||
ir_graph.hidden_nodes.append(h_node)
|
||||
for edge in method_ir_graph.edges:
|
||||
edge.graph = ir_graph
|
||||
if edge.head == method_ir_graph.input_node:
|
||||
# this is a member method, 'self' is the first argument, thus +1
|
||||
assert edge.head_slot is not None
|
||||
_input = node.inputsAt(edge.head_slot + 1)
|
||||
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
|
||||
edge.head = src_node
|
||||
edge.head_slot = src_node_idx
|
||||
if edge.tail == method_ir_graph.output_node:
|
||||
# since the following nodes have not been created, skip this edge
|
||||
# edge.head is the output node of this method
|
||||
# TODO: check whether there could be multiple output nodes???
|
||||
node_index[node] = edge.head
|
||||
continue
|
||||
ir_graph.edges.append(edge)
|
||||
|
||||
# ===================handle each single node===================
|
||||
def handle_single_node(node):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
node : torch._C.Node
|
||||
the node from TorchScript graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
Node
|
||||
the created node ir
|
||||
"""
|
||||
if node.kind() == 'prim::CallMethod':
|
||||
handle_function_callmethod(node)
|
||||
elif node.kind() == 'prim::CallFunction':
|
||||
func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
|
||||
func = node.inputsAt(0).node()
|
||||
assert func.kind() == 'prim::Constant'
|
||||
assert func.hasAttribute('name')
|
||||
func_name = func.s('name')
|
||||
# create node for func
|
||||
self.global_seq += 1
|
||||
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
|
||||
'{}.{}'.format(func_type_str, func_name))
|
||||
func_python_name = build_python_name(module_python_name, func_name)
|
||||
func_node.python_name = func_python_name
|
||||
node_index[node] = func_node
|
||||
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
|
||||
elif node.kind() == 'prim::Constant':
|
||||
new_node = self.create_prim_constant_node(ir_graph, node, module_name)
|
||||
node_index[node] = new_node
|
||||
elif node.kind() in ['prim::ListConstruct', 'prim::ListUnpack', 'prim::TupleConstruct', 'prim::TupleUnpack']:
|
||||
self.global_seq += 1
|
||||
prim_op_name = node.kind().split('::')[-1]
|
||||
new_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
|
||||
node_index[node] = new_node
|
||||
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
|
||||
elif node.kind() == 'prim::GetAttr':
|
||||
node_type, attrs = self.handle_prim_attr_node(node, module)
|
||||
self.global_seq += 1
|
||||
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
|
||||
node_type, attrs)
|
||||
node_index[node] = new_node
|
||||
elif node.kind() == 'prim::If':
|
||||
last_block_node = handle_if_node(node)
|
||||
# last_block_node is None means no node in the branch block
|
||||
node_index[node] = last_block_node
|
||||
elif node.kind() == 'prim::Loop':
|
||||
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
|
||||
raise RuntimeError('Loop has not been supported yet!')
|
||||
elif node.kind().startswith('prim::'):
|
||||
self.global_seq += 1
|
||||
prim_op_name = node.kind().replace('::', '__')
|
||||
prim_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
|
||||
node_index[node] = prim_node
|
||||
self._add_edge(ir_graph, node, graph_inputs, node_index, prim_node, output_remap)
|
||||
elif node.kind() == 'aten::append':
|
||||
self.global_seq += 1
|
||||
aten_op_name = node.kind().replace('::', '__')
|
||||
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
|
||||
node_index[node] = aten_node
|
||||
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
|
||||
output_remap[node.inputsAt(0)] = node
|
||||
elif node.kind().startswith('aten::'):
|
||||
# handle aten::XXX
|
||||
self.global_seq += 1
|
||||
aten_op_name = node.kind().replace('::', '__')
|
||||
aten_op_python_name = node.kind().replace('aten::', '')
|
||||
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
|
||||
aten_python_name = build_python_name(module_python_name, aten_op_python_name)
|
||||
aten_node.python_name = aten_python_name
|
||||
node_index[node] = aten_node
|
||||
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
|
||||
else:
|
||||
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
|
||||
|
||||
return node_index[node]
|
||||
|
||||
for node in sm_graph.nodes():
|
||||
handle_single_node(node)
|
||||
|
||||
if node_index != {}:
|
||||
for _output in sm_graph.outputs():
|
||||
ir_graph._add_output(_convert_name(_output.debugName()))
|
||||
predecessor_node_outputs = [o for o in _output.node().outputs()]
|
||||
if len(predecessor_node_outputs) == 1:
|
||||
src_node_idx = None
|
||||
else:
|
||||
src_node_idx = predecessor_node_outputs.index(_output)
|
||||
|
||||
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
|
||||
tail=(ir_graph.output_node, None))
|
||||
else:
|
||||
# here is an example that the ir_graph and node_index is empty
|
||||
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
|
||||
# %x.1 : Tensor): return (%x.1)
|
||||
# add an edge from head to tail to handle this situation
|
||||
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ir_graph.output_node, None))
|
||||
|
||||
|
||||
def merge_aten_slices(self, ir_graph):
|
||||
"""
|
||||
if there is aten::slice node, merge the consecutive ones together.
|
||||
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
|
||||
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
|
||||
"""
|
||||
head_slice_nodes = []
|
||||
has_slice_node = False
|
||||
for node in ir_graph.hidden_nodes:
|
||||
if node.operation.type == 'aten::slice':
|
||||
has_slice_node = True
|
||||
for pred in node.predecessors:
|
||||
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
|
||||
head_slice_nodes.append(node)
|
||||
break
|
||||
if has_slice_node:
|
||||
assert head_slice_nodes
|
||||
|
||||
for head_node in head_slice_nodes:
|
||||
slot = 0
|
||||
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
|
||||
if len(head_node.incoming_edges) == 4:
|
||||
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
|
||||
for edge in head_node.incoming_edges:
|
||||
edge.tail = new_slice_node
|
||||
for edge in head_node.outgoing_edges:
|
||||
edge.head = new_slice_node
|
||||
ir_graph.hidden_nodes.remove(head_node)
|
||||
break
|
||||
assert len(head_node.incoming_edges) == 5
|
||||
for edge in head_node.incoming_edges:
|
||||
edge.tail = new_slice_node
|
||||
slot += 5
|
||||
node = head_node
|
||||
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
|
||||
suc_node = node.successors[0]
|
||||
assert len(suc_node.incoming_edges) == 5
|
||||
for edge in suc_node.incoming_edges:
|
||||
if edge.tail_slot == 0:
|
||||
edge.remove()
|
||||
else:
|
||||
edge.tail = new_slice_node
|
||||
edge.tail_slot = slot + edge.tail_slot - 1
|
||||
slot += 4
|
||||
ir_graph.hidden_nodes.remove(node)
|
||||
node = suc_node
|
||||
|
||||
for edge in node.outgoing_edges:
|
||||
edge.head = new_slice_node
|
||||
ir_graph.hidden_nodes.remove(node)
|
||||
|
||||
def refine_graph(self, ir_graph):
|
||||
"""
|
||||
Do the following process to simplify graph:
|
||||
1. remove unconnected constant node
|
||||
2. remove unconnected getattr node
|
||||
"""
|
||||
# some constant is not used, for example, function name as prim::Constant
|
||||
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
|
||||
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
|
||||
self.merge_aten_slices(ir_graph)
|
||||
|
||||
def _handle_inputchoice(self, module):
|
||||
return {
|
||||
'n_candidates': module.n_candidates,
|
||||
'n_chosen': module.n_chosen,
|
||||
'reduction': module.reduction,
|
||||
'label': module.label
|
||||
}
|
||||
|
||||
def _handle_valuechoice(self, module):
|
||||
return {
|
||||
'candidates': module.candidates,
|
||||
'label': module.label,
|
||||
}
|
||||
|
||||
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
|
||||
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
|
||||
# also has LayerChoice or InputChoice or ValueChoice
|
||||
original_type_name = script_module.original_name
|
||||
m_attrs = None
|
||||
if original_type_name == OpTypeName.LayerChoice:
|
||||
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
|
||||
graph.python_name = module_python_name
|
||||
candidate_name_list = []
|
||||
for cand_name in module.names:
|
||||
cand = module[cand_name]
|
||||
script_cand = script_module._modules[cand_name]
|
||||
cand_full_name = build_cand_name(cand_name, module.label)
|
||||
cand_python_name = build_python_name(module_python_name, cand_name)
|
||||
candidate_name_list.append(cand_full_name)
|
||||
subgraph, attrs = self._convert_module(script_cand, cand, cand_full_name, cand_python_name, ir_model)
|
||||
if subgraph is not None:
|
||||
cand_node = graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
|
||||
cand_node.python_name = cand_python_name
|
||||
else:
|
||||
cand_type = '__torch__.' + get_importable_name(cand.__class__)
|
||||
cand_node = graph.add_node(cand_full_name, cand_type, attrs)
|
||||
cand_node.python_name = cand_python_name
|
||||
graph._register()
|
||||
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
|
||||
elif original_type_name == OpTypeName.InputChoice:
|
||||
m_attrs = self._handle_inputchoice(module)
|
||||
elif original_type_name == OpTypeName.ValueChoice:
|
||||
m_attrs = self._handle_valuechoice(module)
|
||||
elif original_type_name == OpTypeName.Placeholder:
|
||||
m_attrs = get_init_parameters_or_fail(module)
|
||||
elif module.__class__.__module__.startswith('torch.nn') and \
|
||||
original_type_name in torch.nn.__dict__ and \
|
||||
original_type_name not in MODULE_EXCEPT_LIST:
|
||||
# this is a basic module from pytorch, no need to parse its graph
|
||||
m_attrs = get_init_parameters_or_fail(module)
|
||||
elif getattr(module, '_nni_basic_unit', False):
|
||||
# this module is marked as serialize, won't continue to parse
|
||||
m_attrs = get_init_parameters_or_fail(module)
|
||||
if m_attrs is not None:
|
||||
return None, m_attrs
|
||||
|
||||
# handle TorchScript graph
|
||||
sm_graph = script_module.graph
|
||||
self.global_graph_id += 1
|
||||
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
|
||||
ir_graph.python_name = module_python_name
|
||||
|
||||
# handle graph nodes
|
||||
self.handle_graph_nodes(script_module, sm_graph, module,
|
||||
module_name, module_python_name, ir_model, ir_graph)
|
||||
self.refine_graph(ir_graph)
|
||||
|
||||
ir_graph._register()
|
||||
|
||||
# add mutation signal for special modules
|
||||
if original_type_name == OpTypeName.Repeat:
|
||||
attrs = {
|
||||
'mutation': 'repeat',
|
||||
'label': module.label,
|
||||
'depth': module.depth_choice,
|
||||
'max_depth': module.max_depth,
|
||||
'min_depth': module.min_depth,
|
||||
}
|
||||
return ir_graph, attrs
|
||||
|
||||
return ir_graph, {}
|
||||
|
||||
def convert_module(self, script_module, module, module_name, ir_model):
|
||||
"""
|
||||
Convert a module to its graph ir (i.e., Graph) along with its input arguments
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script_module : torch.jit.RecursiveScriptModule
|
||||
the script module of ```module``` obtained with torch.jit.script
|
||||
module : nn.Module
|
||||
the targeted module instance
|
||||
module_name : str
|
||||
the constructed name space of ```module```
|
||||
ir_model : Model
|
||||
the whole graph ir
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
the built graph ir from module, ```None``` means do not further parse the module
|
||||
dict
|
||||
the input arguments of this module
|
||||
"""
|
||||
return self._convert_module(script_module, module, module_name, None, ir_model)
|
||||
|
||||
|
||||
class GraphConverterWithShape(GraphConverter):
|
||||
"""
|
||||
Convert a pytorch model to nni ir along with input/output shape info.
|
||||
Based ir acquired through ``torch.jit.script``
|
||||
and shape info acquired through ``torch.jit.trace``.
|
||||
|
||||
.. warning::
|
||||
|
||||
Known issues:
|
||||
|
||||
1. ``InputChoice`` and ``ValueChoice`` not supported yet.
|
||||
2. Currently random inputs are fed while tracing layerchoice.
|
||||
If forward path of candidates depends on input data, then wrong path will be traced.
|
||||
This will result in incomplete shape info.
|
||||
"""
|
||||
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
|
||||
module.eval()
|
||||
|
||||
ir_graph, attrs = self._convert_module(script_module, module, module_name, None, ir_model)
|
||||
self.remove_dummy_nodes(ir_model)
|
||||
self._initialize_parameters(ir_model)
|
||||
self._trace_module(module, module_name, ir_model, dummy_input)
|
||||
return ir_graph, attrs
|
||||
|
||||
def _initialize_parameters(self, ir_model: 'Model'):
|
||||
for ir_node in ir_model.get_nodes():
|
||||
if ir_node.operation.parameters is None:
|
||||
ir_node.operation.parameters = {}
|
||||
ir_node.operation.attributes.setdefault('input_shape', [])
|
||||
ir_node.operation.attributes.setdefault('output_shape', [])
|
||||
|
||||
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
|
||||
# First, trace the whole graph
|
||||
tm_graph = self._trace(module, dummy_input)
|
||||
|
||||
for node in tm_graph.nodes():
|
||||
shape_parameters, parameters = _extract_info_from_trace_node(node)
|
||||
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
|
||||
ir_node = match_node(ir_model, node, module_name)
|
||||
if ir_node is not None:
|
||||
ir_node.operation.attributes.update(shape_parameters)
|
||||
if parameters:
|
||||
ir_node.operation.parameters.update(parameters)
|
||||
|
||||
self.propagate_shape(ir_model)
|
||||
|
||||
# trace each layerchoice
|
||||
for name, submodule in module.named_modules():
|
||||
# TODO: support InputChoice and ValueChoice
|
||||
if isinstance(submodule, LayerChoice):
|
||||
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
|
||||
lc_node = ir_model.get_node_by_name(full_name)
|
||||
assert lc_node is not None, f'Cannot find a node with name {full_name}'
|
||||
|
||||
for cand_name in submodule.names:
|
||||
cand = submodule[cand_name]
|
||||
cand_name = build_cand_name(cand_name, submodule.label)
|
||||
# TODO: Feed the exact input tensor if user provides input,
|
||||
# in case the path changes according to input data.
|
||||
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.attributes['input_shape']]
|
||||
self._trace_module(cand, cand_name, ir_model, lc_inputs)
|
||||
|
||||
def propagate_shape(self, ir_model: 'Model'):
|
||||
|
||||
def propagate_shape_for_graph(graph: 'Graph'):
|
||||
if graph == ir_model.root_graph:
|
||||
return
|
||||
|
||||
graph_node = ir_model.get_node_by_name(graph.name)
|
||||
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
|
||||
if not _without_shape_info(graph_node):
|
||||
return
|
||||
|
||||
if is_layerchoice_node(graph_node):
|
||||
cand_name = graph_node.operation.parameters['candidates'][0]
|
||||
cand_node = ir_model.get_node_by_name(cand_name)
|
||||
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
|
||||
if _without_shape_info(cand_node):
|
||||
propagate_shape_for_graph(ir_model.graphs[cand_name])
|
||||
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
|
||||
graph_node.operation.attributes['output_shape'] = cand_node.operation.attributes['output_shape']
|
||||
else:
|
||||
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
|
||||
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
|
||||
for edge in graph.input_node.outgoing_edges:
|
||||
node = edge.tail
|
||||
if _without_shape_info(node):
|
||||
if node.name in ir_model.graphs:
|
||||
propagate_shape_for_graph(ir_model.graphs[node.name])
|
||||
if node.operation.attributes['input_shape']:
|
||||
input_shape[edge.head_slot or 0] = node.operation.attributes['input_shape'][edge.tail_slot or 0]
|
||||
graph_node.operation.attributes['input_shape'] = input_shape
|
||||
for edge in graph.output_node.incoming_edges:
|
||||
node = edge.head
|
||||
if _without_shape_info(node):
|
||||
if node.name in ir_model.graphs:
|
||||
propagate_shape_for_graph(ir_model.graphs[node.name])
|
||||
if node.operation.attributes['output_shape']:
|
||||
output_shape[edge.tail_slot or 0] = node.operation.attributes['output_shape'][edge.head_slot or 0]
|
||||
graph_node.operation.attributes['output_shape'] = output_shape
|
||||
|
||||
propagate_shape_for_graph(graph_node.graph)
|
||||
|
||||
# propagate from node to graph
|
||||
for node in ir_model.get_nodes():
|
||||
propagate_shape_for_graph(node.graph)
|
||||
|
||||
def _trace(self, module, dummy_input):
|
||||
traced_module = torch.jit.trace(module, dummy_input)
|
||||
torch._C._jit_pass_inline(traced_module.graph)
|
||||
return traced_module.graph
|
||||
|
||||
def remove_dummy_nodes(self, ir_model: 'Model'):
|
||||
# remove identity nodes
|
||||
for node in ir_model.get_nodes_by_type('noop_identity'):
|
||||
graph = node.graph
|
||||
for in_edge in node.incoming_edges:
|
||||
for out_edge in node.outgoing_edges:
|
||||
if in_edge.tail_slot == out_edge.head_slot:
|
||||
graph.add_edge(head=(in_edge.head, in_edge.head_slot), tail=(out_edge.tail, out_edge.tail_slot))
|
||||
graph.del_edge(in_edge)
|
||||
graph.del_edge(out_edge)
|
||||
break
|
||||
node.remove()
|
||||
|
||||
|
||||
def convert_to_graph(script_module, module, converter=None, **kwargs):
|
||||
"""
|
||||
Convert module to our graph ir, i.e., build a :class:`Model` type
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script_module : torch.jit.RecursiveScriptModule
|
||||
the script module obtained with torch.jit.script
|
||||
module : nn.Module
|
||||
the targeted module instance
|
||||
converter : `TorchConverter`
|
||||
default `GraphConverter` is used
|
||||
kwargs:
|
||||
will be passed to `converter.convert_module()`
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
the constructed IR model
|
||||
"""
|
||||
|
||||
model = Model(_internal=True)
|
||||
module_name = '_model'
|
||||
if converter is None:
|
||||
converter = GraphConverter()
|
||||
converter.convert_module(script_module, module, module_name, model, **kwargs)
|
||||
|
||||
return model
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
# except the special case which can not treat as a basic module from pytorch
|
||||
MODULE_EXCEPT_LIST = ['Sequential']
|
||||
|
||||
|
||||
class OpTypeName(str, Enum):
|
||||
"""
|
||||
op type to its type name str
|
||||
"""
|
||||
Attr = 'Attr'
|
||||
Constant = 'Constant'
|
||||
LayerChoice = 'LayerChoice'
|
||||
InputChoice = 'InputChoice'
|
||||
ValueChoice = 'ValueChoice'
|
||||
Placeholder = 'Placeholder'
|
||||
MergedSlice = 'MergedSlice'
|
||||
Repeat = 'Repeat'
|
||||
Cell = 'Cell'
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from nni.nas.execution.common import Cell, Model, Graph, Node, Edge
|
||||
|
||||
|
||||
def build_full_name(prefix, name, seq=None):
|
||||
if isinstance(name, list):
|
||||
name = '__'.join(name)
|
||||
if seq is None:
|
||||
return '{}__{}'.format(prefix, name)
|
||||
else:
|
||||
return '{}__{}{}'.format(prefix, name, str(seq))
|
||||
|
||||
|
||||
def build_python_name(prefix, name):
|
||||
if isinstance(name, list):
|
||||
name = '.'.join(name)
|
||||
if prefix:
|
||||
return '{}.{}'.format(prefix, name)
|
||||
else: # predix could be None
|
||||
return name
|
||||
|
||||
|
||||
def build_cand_name(name, label):
|
||||
return f'layerchoice_{label}_{name}'
|
||||
|
||||
|
||||
def _convert_name(name: str) -> str:
|
||||
"""
|
||||
Convert the names using separator '.' to valid variable name in code
|
||||
"""
|
||||
return name.replace('.', '__')
|
||||
|
||||
|
||||
def _extract_info_from_trace_node(trace_node):
|
||||
"""
|
||||
Extract parameters from a trace node.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trace_node: torch._C.Value
|
||||
"""
|
||||
input_shape = []
|
||||
output_shape = []
|
||||
|
||||
inputs = list(trace_node.inputs())
|
||||
|
||||
# cat input tensors are in a strange place
|
||||
if trace_node.kind() == 'aten::cat':
|
||||
input_shape = [input.type().sizes() for input in inputs[0].node().inputs()]
|
||||
else:
|
||||
for _input in inputs:
|
||||
input_type = _input.type()
|
||||
if input_type.kind() == 'TensorType':
|
||||
shape = input_type.sizes()
|
||||
if shape:
|
||||
input_shape.append(shape)
|
||||
|
||||
for _output in trace_node.outputs():
|
||||
output_type = _output.type()
|
||||
if output_type.kind() == 'TensorType':
|
||||
shape = output_type.sizes()
|
||||
if shape:
|
||||
output_shape.append(shape)
|
||||
|
||||
shape_parameters = {
|
||||
'input_shape': input_shape,
|
||||
'output_shape': output_shape,
|
||||
}
|
||||
|
||||
if trace_node.kind() == 'aten::cat':
|
||||
parameters = {'dim': inputs[1].toIValue()}
|
||||
return shape_parameters, parameters
|
||||
else:
|
||||
return shape_parameters, None
|
||||
|
||||
|
||||
def is_layerchoice_node(ir_node: Optional[Node]) -> TypeGuard[Node]:
|
||||
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_full_name_by_scope_name(ir_model: Model, scope_names, prefix=''):
|
||||
full_name = prefix
|
||||
|
||||
for last_scope in range(len(scope_names)):
|
||||
ir_node = ir_model.get_node_by_name(full_name)
|
||||
# check if it's layerchoice
|
||||
if is_layerchoice_node(ir_node):
|
||||
full_name = f'layerchoice_{ir_node.operation.parameters["label"]}_{scope_names[last_scope]}'
|
||||
else:
|
||||
full_name = build_full_name(full_name, scope_names[last_scope])
|
||||
|
||||
return full_name
|
||||
|
||||
|
||||
def match_node(ir_model: Model, torch_node, prefix=''):
|
||||
"""
|
||||
Match the corresponding node of a torch._C.Value
|
||||
"""
|
||||
scope_names = torch_node.scopeName().split('/')[-1].split('.')[1:]
|
||||
full_name = get_full_name_by_scope_name(ir_model, scope_names, prefix)
|
||||
# handle the case when node is not nn.Module, but directly used in forward()
|
||||
# Because name can't be directly matched, so I use a hacky way.
|
||||
# I match the first unshaped node of that kind
|
||||
graph = ir_model.graphs.get(full_name)
|
||||
if graph is not None:
|
||||
for node in graph.get_nodes_by_type(torch_node.kind()):
|
||||
if not node.operation.attributes['input_shape']:
|
||||
return node
|
||||
return None
|
||||
else:
|
||||
return ir_model.get_node_by_name(full_name)
|
||||
|
||||
|
||||
def _without_shape_info(node: Node):
|
||||
return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape']
|
||||
|
||||
|
||||
def flatten_model_graph(ir_model: Model):
|
||||
"""
|
||||
Flatten the subgraph into root graph.
|
||||
"""
|
||||
def _flatten(graph: Graph):
|
||||
"""
|
||||
flatten this graph
|
||||
"""
|
||||
model = graph.model
|
||||
node_to_remove = []
|
||||
|
||||
for node in graph.hidden_nodes:
|
||||
node_graph = model.graphs.get(node.name)
|
||||
if node_graph is not None:
|
||||
_flatten(node_graph)
|
||||
|
||||
# flatten node graph into this graph
|
||||
id_to_new_node = {}
|
||||
for node_graph_node in node_graph.hidden_nodes:
|
||||
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
|
||||
new_node.update_label(node_graph_node.label)
|
||||
new_node._register()
|
||||
id_to_new_node[new_node.id] = new_node
|
||||
|
||||
# reconnect node edges
|
||||
for in_edge in node.incoming_edges:
|
||||
graph.del_edge(in_edge)
|
||||
for input_node_edge in node_graph.input_node.outgoing_edges:
|
||||
if input_node_edge.head_slot == in_edge.tail_slot:
|
||||
graph.add_edge(
|
||||
head=(in_edge.head, in_edge.head_slot),
|
||||
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
|
||||
|
||||
for out_edge in node.outgoing_edges:
|
||||
graph.del_edge(out_edge)
|
||||
for output_node_edge in node_graph.output_node.incoming_edges:
|
||||
if output_node_edge.head_slot == out_edge.tail_slot:
|
||||
graph.add_edge(
|
||||
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
|
||||
tail=(out_edge.tail, out_edge.tail_slot))
|
||||
|
||||
for edge in node_graph.edges:
|
||||
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
|
||||
continue
|
||||
new_head = id_to_new_node[edge.head.id]
|
||||
new_tail = id_to_new_node[edge.tail.id]
|
||||
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
|
||||
|
||||
node_to_remove.append(node)
|
||||
del model.graphs[node.name]
|
||||
|
||||
for node in node_to_remove:
|
||||
node.remove()
|
||||
|
||||
new_ir_model = ir_model.fork()
|
||||
_flatten(new_ir_model.root_graph)
|
||||
|
||||
# remove subgraphs
|
||||
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
|
||||
return new_ir_model
|
||||
|
||||
|
||||
def flatten_model_graph_without_layerchoice(ir_model: Model):
|
||||
"""
|
||||
Flatten the subgraph into root graph and jump all layerchoice
|
||||
"""
|
||||
def _flatten_without_layerchoice(graph: Graph):
|
||||
"""
|
||||
flatten this graph
|
||||
"""
|
||||
model = graph.model
|
||||
node_to_remove = []
|
||||
|
||||
for node in graph.hidden_nodes:
|
||||
if is_layerchoice_node(node):
|
||||
for in_edge in node.incoming_edges:
|
||||
graph.del_edge(in_edge)
|
||||
for out_edge in node.outgoing_edges:
|
||||
graph.del_edge(out_edge)
|
||||
del model.graphs[node.name]
|
||||
node.remove()
|
||||
return
|
||||
|
||||
node_graph = model.graphs.get(node.name)
|
||||
if node_graph is not None:
|
||||
_flatten_without_layerchoice(node_graph)
|
||||
|
||||
# flatten node graph into this graph
|
||||
id_to_new_node = {}
|
||||
for node_graph_node in node_graph.hidden_nodes:
|
||||
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
|
||||
new_node.update_label(node_graph_node.label)
|
||||
new_node._register()
|
||||
id_to_new_node[new_node.id] = new_node
|
||||
|
||||
# reconnect node edges
|
||||
for in_edge in node.incoming_edges:
|
||||
graph.del_edge(in_edge)
|
||||
for input_node_edge in node_graph.input_node.outgoing_edges:
|
||||
if input_node_edge.head_slot == in_edge.tail_slot:
|
||||
graph.add_edge(
|
||||
head=(in_edge.head, in_edge.head_slot),
|
||||
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
|
||||
|
||||
for out_edge in node.outgoing_edges:
|
||||
graph.del_edge(out_edge)
|
||||
for output_node_edge in node_graph.output_node.incoming_edges:
|
||||
if output_node_edge.head_slot == out_edge.tail_slot:
|
||||
graph.add_edge(
|
||||
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
|
||||
tail=(out_edge.tail, out_edge.tail_slot))
|
||||
|
||||
|
||||
for edge in node_graph.edges:
|
||||
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
|
||||
continue
|
||||
new_head = id_to_new_node[edge.head.id]
|
||||
new_tail = id_to_new_node[edge.tail.id]
|
||||
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
|
||||
|
||||
node_to_remove.append(node)
|
||||
del model.graphs[node.name]
|
||||
|
||||
for node in node_to_remove:
|
||||
node.remove()
|
||||
|
||||
new_ir_model = ir_model.fork()
|
||||
_flatten_without_layerchoice(new_ir_model.root_graph)
|
||||
|
||||
# remove subgraphs
|
||||
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
|
||||
return new_ir_model
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import graphviz
|
||||
|
||||
|
||||
def convert_to_visualize(graph_ir, vgraph):
|
||||
for name, graph in graph_ir.items():
|
||||
if name == '_evaluator':
|
||||
continue
|
||||
with vgraph.subgraph(name='cluster'+name) as subgraph:
|
||||
subgraph.attr(color='blue')
|
||||
cell_node = {}
|
||||
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
|
||||
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
|
||||
subgraph.node(ioput['_inputs'])
|
||||
subgraph.node(ioput['_outputs'])
|
||||
for node_name, node_value in graph['nodes'].items():
|
||||
value = node_value['operation']
|
||||
if value['type'] == '_cell':
|
||||
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
|
||||
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
|
||||
cell_node[node_name] = (cell_input_name, cell_output_name)
|
||||
print('cell: ', node_name, cell_input_name, cell_output_name)
|
||||
else:
|
||||
subgraph.node(node_name)
|
||||
for edge in graph['edges']:
|
||||
src = edge['head'][0]
|
||||
if src == '_inputs':
|
||||
src = ioput['_inputs']
|
||||
elif src in cell_node:
|
||||
src = cell_node[src][1]
|
||||
dst = edge['tail'][0]
|
||||
if dst == '_outputs':
|
||||
dst = ioput['_outputs']
|
||||
elif dst in cell_node:
|
||||
dst = cell_node[dst][0]
|
||||
subgraph.edge(src, dst)
|
||||
|
||||
|
||||
def visualize_model(graph_ir):
|
||||
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
|
||||
convert_to_visualize(graph_ir, vgraph)
|
||||
vgraph.render()
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['BaseGraphData', 'BaseExecutionEngine']
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from typing import Any, Dict, Iterable, List
|
||||
|
||||
from nni.experiment import rest
|
||||
|
||||
from nni.nas.execution.common import (
|
||||
AbstractExecutionEngine, AbstractGraphListener, RetiariiAdvisor, get_mutation_summary,
|
||||
Model, ModelStatus, MetricData, Evaluator,
|
||||
send_trial, receive_trial_parameters, get_advisor
|
||||
)
|
||||
from nni.nas.utils import import_
|
||||
from .codegen import model_to_pytorch_script
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseGraphData:
|
||||
"""
|
||||
Data sent between strategy and trial, in graph-based execution engine.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
model_script
|
||||
code of an instantiated PyTorch model
|
||||
evaluator
|
||||
training approach for model_script
|
||||
mutation_summary
|
||||
a dict of all the choices during mutations in the HPO search space format
|
||||
"""
|
||||
def __init__(self, model_script: str, evaluator: Evaluator, mutation_summary: dict) -> None:
|
||||
self.model_script = model_script
|
||||
self.evaluator = evaluator
|
||||
self.mutation_summary = mutation_summary
|
||||
|
||||
def dump(self) -> dict:
|
||||
return {
|
||||
'model_script': self.model_script,
|
||||
# engine needs to call dump here,
|
||||
# otherwise, evaluator will become binary
|
||||
# also, evaluator can be none in tests
|
||||
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
|
||||
'mutation_summary': self.mutation_summary
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def load(data) -> 'BaseGraphData':
|
||||
return BaseGraphData(data['model_script'], Evaluator._load(data['evaluator']), data['mutation_summary'])
|
||||
|
||||
|
||||
class BaseExecutionEngine(AbstractExecutionEngine):
|
||||
"""
|
||||
The execution engine with no optimization at all.
|
||||
Resource management is implemented in this class.
|
||||
"""
|
||||
|
||||
def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
|
||||
"""
|
||||
Upon initialization, advisor callbacks need to be registered.
|
||||
Advisor will call the callbacks when the corresponding event has been triggered.
|
||||
Base execution engine will get those callbacks and broadcast them to graph listener.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rest_port
|
||||
The port of the experiment's rest server
|
||||
rest_url_prefix
|
||||
The url prefix of the experiment's rest entry
|
||||
"""
|
||||
self.port = rest_port
|
||||
self.url_prefix = rest_url_prefix
|
||||
|
||||
self._listeners: List[AbstractGraphListener] = []
|
||||
self._running_models: Dict[int, Model] = dict()
|
||||
self._history: List[Model] = []
|
||||
|
||||
self.resources = 0
|
||||
|
||||
# register advisor callbacks
|
||||
advisor: RetiariiAdvisor = get_advisor()
|
||||
advisor.register_callbacks({
|
||||
'send_trial': self._send_trial_callback,
|
||||
'request_trial_jobs': self._request_trial_jobs_callback,
|
||||
'trial_end': self._trial_end_callback,
|
||||
'intermediate_metric': self._intermediate_metric_callback,
|
||||
'final_metric': self._final_metric_callback
|
||||
})
|
||||
|
||||
def submit_models(self, *models: Model) -> None:
|
||||
for model in models:
|
||||
data = self.pack_model_data(model)
|
||||
self._running_models[send_trial(data.dump())] = model
|
||||
self._history.append(model)
|
||||
|
||||
def list_models(self) -> Iterable[Model]:
|
||||
return self._history
|
||||
|
||||
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
|
||||
self._listeners.append(listener)
|
||||
|
||||
def _send_trial_callback(self, paramater: dict) -> None:
|
||||
if self.resources <= 0:
|
||||
# FIXME: should be a warning message here
|
||||
_logger.debug('There is no available resource, but trial is submitted.')
|
||||
self.resources -= 1
|
||||
_logger.debug('Resource used. Remaining: %d', self.resources)
|
||||
|
||||
def _request_trial_jobs_callback(self, num_trials: int) -> None:
|
||||
self.resources += num_trials
|
||||
_logger.debug('New resource available. Remaining: %d', self.resources)
|
||||
|
||||
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
|
||||
model = self._running_models[trial_id]
|
||||
if success:
|
||||
model.status = ModelStatus.Trained
|
||||
else:
|
||||
model.status = ModelStatus.Failed
|
||||
for listener in self._listeners:
|
||||
listener.on_training_end(model, success)
|
||||
|
||||
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
|
||||
model = self._running_models[trial_id]
|
||||
model.intermediate_metrics.append(metrics)
|
||||
for listener in self._listeners:
|
||||
listener.on_intermediate_metric(model, metrics)
|
||||
|
||||
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
|
||||
model = self._running_models[trial_id]
|
||||
model.metric = metrics
|
||||
for listener in self._listeners:
|
||||
listener.on_metric(model, metrics)
|
||||
|
||||
def query_available_resource(self) -> int:
|
||||
return self.resources
|
||||
|
||||
def budget_exhausted(self) -> bool:
|
||||
resp = rest.get(self.port, '/check-status', self.url_prefix)
|
||||
return resp['status'] == 'DONE'
|
||||
|
||||
@classmethod
|
||||
def pack_model_data(cls, model: Model) -> Any:
|
||||
mutation_summary = get_mutation_summary(model)
|
||||
assert model.evaluator is not None, 'Model evaluator can not be None'
|
||||
return BaseGraphData(model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def trial_execute_graph(cls) -> None:
|
||||
"""
|
||||
Initialize the model, hand it over to trainer.
|
||||
"""
|
||||
graph_data = BaseGraphData.load(receive_trial_parameters())
|
||||
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
|
||||
file_name = f'_generated_model/{random_str}.py'
|
||||
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
||||
with open(file_name, 'w') as f:
|
||||
f.write(graph_data.model_script)
|
||||
model_cls = import_(f'_generated_model.{random_str}._model')
|
||||
graph_data.evaluator._execute(model_cls)
|
||||
os.remove(file_name)
|
|
@ -0,0 +1,548 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (Any, Dict, List)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as nn_functional
|
||||
|
||||
from nni.nas.execution.common import PyTorchOperation
|
||||
|
||||
|
||||
mem_format = [
|
||||
'torch.contiguous_format', # 0
|
||||
'torch.preserve_format', # 1
|
||||
'torch.channels_last', # 2
|
||||
]
|
||||
|
||||
# this snippet is copied from torch/onnx/symbolic_helper.py,
|
||||
# the original definition is in c10/core/ScalarType.h
|
||||
# This indicates each scalar type's corresponding
|
||||
scalar_type_to_pytorch_type = [
|
||||
'torch.uint8', # 0
|
||||
'torch.int8', # 1
|
||||
'torch.short', # 2
|
||||
'torch.int', # 3
|
||||
'torch.int64', # 4
|
||||
'torch.half', # 5
|
||||
'torch.float', # 6
|
||||
'torch.double', # 7
|
||||
'torch.complex32', # 8
|
||||
'torch.complex64', # 9
|
||||
'torch.complex128', # 10
|
||||
'torch.bool', # 11
|
||||
]
|
||||
|
||||
|
||||
class NoOpIdentity(PyTorchOperation):
|
||||
"""
|
||||
this operator type is added by us
|
||||
"""
|
||||
_ori_type_name = ['noop_identity']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {", ".join(inputs)}'
|
||||
|
||||
|
||||
class ModuleOperator(PyTorchOperation):
|
||||
_ori_type_name = ['ModuleOperator', 'shared']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = self.{field}({", ".join(inputs)})'
|
||||
|
||||
|
||||
class FunctionalOperator(PyTorchOperation):
|
||||
_ori_type_name = ['FunctionalOperator']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
func_name = self.type[len('Function.'):]
|
||||
if not hasattr(nn_functional, func_name):
|
||||
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
|
||||
f'{func_name} is not in it.')
|
||||
return f'{output} = F.{func_name}({", ".join(inputs)})'
|
||||
|
||||
|
||||
class PrimConstant(PyTorchOperation):
|
||||
_ori_type_name = ['prim::Constant']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
|
||||
# TODO: deal with all the types
|
||||
if self.parameters['type'] in ['None', 'NoneType']:
|
||||
return f'{output} = None'
|
||||
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): # 'Long()' ???
|
||||
return f'{output} = {self.parameters["value"]}'
|
||||
elif self.parameters['type'] == 'str':
|
||||
str_val = self.parameters["value"]
|
||||
return f'{output} = "{str_val}"'
|
||||
elif self.parameters['type'] == 'Device':
|
||||
value = self.parameters['value']
|
||||
return f'{output} = torch.device("{value}")'
|
||||
elif self.parameters['type'] in ('dict', 'list', 'tuple'):
|
||||
# TODO: prim::TupleIndex is not supported yet
|
||||
return f'{output} = {repr(self.parameters["value"])}'
|
||||
else:
|
||||
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
|
||||
|
||||
|
||||
class PrimListConstruct(PyTorchOperation):
|
||||
_ori_type_name = ['prim::ListConstruct']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = [{", ".join(inputs)}]'
|
||||
|
||||
|
||||
class PrimListUnpack(PyTorchOperation):
|
||||
_ori_type_name = ['prim::ListUnpack']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]}'
|
||||
|
||||
|
||||
class PrimTupleConstruct(PyTorchOperation):
|
||||
_ori_type_name = ['prim::TupleConstruct']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = ({", ".join(inputs)})'
|
||||
|
||||
|
||||
class PrimTupleUnpack(PyTorchOperation):
|
||||
_ori_type_name = ['prim::TupleUnpack']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# have single output here, because the following code uses index to access the unpacked values
|
||||
assert len(inputs) == 1
|
||||
return f'{output} = {inputs[0]}'
|
||||
|
||||
|
||||
class PrimGetAttr(PyTorchOperation):
|
||||
_ori_type_name = ['prim::GetAttr']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
if self.parameters['value'] is not None:
|
||||
return f"{output} = {self.parameters['value']}"
|
||||
else:
|
||||
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
|
||||
|
||||
|
||||
class PrimUncheckedCast(PyTorchOperation):
|
||||
_ori_type_name = ['prim::unchecked_cast']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]}'
|
||||
|
||||
|
||||
class SimpleMember(PyTorchOperation):
|
||||
_ori_type_name = ['prim::is_cuda', 'prim::data']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
member_name = self.type.split('::')[-1]
|
||||
return f'{output} = {inputs[0]}.{member_name}'
|
||||
|
||||
|
||||
class AtenContiguous(PyTorchOperation):
|
||||
_ori_type_name = ['aten::contiguous']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# defined in pytorch/c10/core/MemoryFormat.h
|
||||
assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
|
||||
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
|
||||
|
||||
|
||||
class AtenGetitem(PyTorchOperation):
|
||||
_ori_type_name = ['aten::__getitem__']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
assert len(inputs) == 2
|
||||
return f'{output} = {inputs[0]}[{inputs[1]}]'
|
||||
|
||||
|
||||
class AtenAppend(PyTorchOperation):
|
||||
_ori_type_name = ['aten::append']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
assert len(inputs) == 2
|
||||
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
|
||||
|
||||
|
||||
class MergedSlice(PyTorchOperation):
|
||||
_ori_type_name = ['MergedSlice']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
if (len(inputs) - 1) % 4 == 0:
|
||||
slices = []
|
||||
dim = int((len(inputs) - 1) / 4)
|
||||
for i in range(dim):
|
||||
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
|
||||
slice_str = ','.join(slices)
|
||||
return f'{output} = {inputs[0]}[{slice_str}]'
|
||||
elif len(inputs) == 4:
|
||||
# this case is for simple list
|
||||
return f'{output} = {inputs[0]}[{inputs[1]}:{inputs[2]}:{inputs[3]}]'
|
||||
else:
|
||||
raise RuntimeError('Unsupported slice pattern')
|
||||
|
||||
# the following Aten classes means these aten ops are not in torch.Tensor
|
||||
|
||||
|
||||
class AtenBool(PyTorchOperation):
|
||||
_ori_type_name = ['aten::Bool']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = bool({inputs[0]})'
|
||||
|
||||
|
||||
class AtenNot(PyTorchOperation):
|
||||
_ori_type_name = ['aten::__not__']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = not {inputs[0]}'
|
||||
|
||||
|
||||
class AtenCat(PyTorchOperation):
|
||||
_ori_type_name = ['aten::cat']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
assert len(inputs) == 2
|
||||
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
|
||||
|
||||
# ====================================
|
||||
|
||||
|
||||
class AtenTensors(PyTorchOperation):
|
||||
_ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like',
|
||||
'aten::ones_like', 'aten::zeros_like', 'aten::rand',
|
||||
'aten::randn', 'aten::scalar_tensor', 'aten::new_full',
|
||||
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
|
||||
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
schemas = torch._C._jit_get_schemas_for_operator(self.type)
|
||||
# match number of inputs
|
||||
overloaded_defs = [len(s.arguments) for s in schemas]
|
||||
matched = overloaded_defs.index(len(inputs))
|
||||
args_list = []
|
||||
for idx, arg in enumerate(schemas[matched].arguments):
|
||||
if arg.name == 'dtype':
|
||||
arg_str = f'dtype={scalar_type_to_pytorch_type[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
|
||||
elif arg.name == 'layout':
|
||||
if inputs_value[idx] is not None:
|
||||
arg_str = f'layout=torch.strided'
|
||||
print('Warning: only support `torch.strided` for now!!!')
|
||||
else:
|
||||
arg_str = ''
|
||||
elif arg.name == 'device':
|
||||
arg_str = f'device=torch.device({inputs[idx]})' if inputs_value[idx] is not None else ''
|
||||
elif arg.name == 'memory_format':
|
||||
arg_str = f'memory_format={mem_format[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
|
||||
elif arg.name == 'pin_memory':
|
||||
# TODO: deal with this argument
|
||||
continue
|
||||
elif arg.name == 'requires_grad':
|
||||
arg_str = f'requires_grad={inputs[idx]}' if inputs_value[idx] else ''
|
||||
elif str(arg.type).startswith('Optional['):
|
||||
arg_str = f'{arg.name}={inputs[idx]}'
|
||||
else:
|
||||
arg_str = f'{inputs[idx]}'
|
||||
if arg_str != '':
|
||||
args_list.append(arg_str)
|
||||
op_name = self.type.split('::')[-1]
|
||||
if hasattr(torch, op_name):
|
||||
return f'{output} = torch.{op_name}({", ".join(args_list)})'
|
||||
else:
|
||||
return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})'
|
||||
|
||||
# ====================================
|
||||
|
||||
|
||||
class AtenFloordiv(PyTorchOperation):
|
||||
_ori_type_name = ['aten::floordiv']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]} // {inputs[1]}'
|
||||
|
||||
|
||||
class AtenMul(PyTorchOperation):
|
||||
_ori_type_name = ['aten::mul']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]} * {inputs[1]}'
|
||||
|
||||
|
||||
class AtenLen(PyTorchOperation):
|
||||
_ori_type_name = ['aten::len']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = len({inputs[0]})'
|
||||
|
||||
|
||||
class AtenIntImplicit(PyTorchOperation):
|
||||
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
if self.type.endswith('Implicit'):
|
||||
return f'{output} = {inputs[0]}'
|
||||
elif self.type == 'aten::Int':
|
||||
return f'{output} = int({inputs[0]})'
|
||||
elif self.type == 'aten::Float':
|
||||
return f'{output} = float({inputs[0]})'
|
||||
raise TypeError(f'Unexpected type: {self.type}')
|
||||
|
||||
|
||||
class AtenIndex(PyTorchOperation):
|
||||
_ori_type_name = ['aten::index']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = {inputs[0]}[{inputs[1]}]'
|
||||
|
||||
|
||||
ManuallyChooseDef = {
|
||||
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
|
||||
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
|
||||
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
|
||||
'aten::view': [('size', 'List[int]', 'None')],
|
||||
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
|
||||
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
|
||||
# torch.std(input, unbiased) Tensor
|
||||
'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')]
|
||||
}
|
||||
|
||||
TensorOpExceptions = {
|
||||
'aten::sub': lambda output, inputs: f'{output} = {inputs[0]} - {inputs[1]}', # example: x.size(1) - 3
|
||||
'aten::add': lambda output, inputs: f'{output} = {inputs[0]} + {inputs[1]}' # example: input.shape[0] + 5
|
||||
}
|
||||
|
||||
TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device',
|
||||
'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm',
|
||||
'aten::save', 'aten::tensor', 'aten::wait'
|
||||
]
|
||||
|
||||
|
||||
def _hidden(name):
|
||||
return name.startswith('_') and not name.startswith('__')
|
||||
|
||||
|
||||
def _emit_args(args):
|
||||
# filter out the `out` argument here
|
||||
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
|
||||
|
||||
|
||||
def _get_tensor_ops():
|
||||
def is_tensor_method(schema):
|
||||
if len(schema.arguments) == 0:
|
||||
return False
|
||||
self = schema.arguments[0]
|
||||
if self.name != 'self':
|
||||
return False
|
||||
if not self.type.isSubtypeOf(torch._C.TensorType.get()):
|
||||
return False
|
||||
return True
|
||||
|
||||
op_args = {}
|
||||
# discover methods
|
||||
for elem in dir(torch.Tensor):
|
||||
if not _hidden(elem):
|
||||
schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
|
||||
for schema in schemas:
|
||||
if is_tensor_method(schema):
|
||||
op_name = 'aten::' + elem
|
||||
args = _emit_args(schema.arguments[1:])
|
||||
if op_name in op_args:
|
||||
op_args[op_name].append(args)
|
||||
else:
|
||||
op_args[op_name] = [args]
|
||||
|
||||
return op_args.keys(), op_args
|
||||
|
||||
|
||||
def _get_torch_ops():
|
||||
torch_op_args = {}
|
||||
for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
|
||||
name = mod.__name__
|
||||
if name == 'torch._C._nn':
|
||||
continue
|
||||
# only process 'torch.XXX'
|
||||
for elem in dir(mod):
|
||||
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
|
||||
if builtin is not None:
|
||||
schemas = torch._C._jit_get_schemas_for_operator(builtin)
|
||||
for schema in schemas:
|
||||
# remove _tan but not __and__
|
||||
if not _hidden(elem):
|
||||
op_name = 'aten::' + elem
|
||||
if len(schema.arguments) > 0 and schema.arguments[0].name == 'self':
|
||||
continue
|
||||
args = _emit_args(schema.arguments)
|
||||
if op_name in torch_op_args:
|
||||
torch_op_args[op_name].append(args)
|
||||
else:
|
||||
torch_op_args[op_name] = [args]
|
||||
|
||||
return torch_op_args.keys(), torch_op_args
|
||||
|
||||
|
||||
def _get_torch_ops_exclude_tensor_ops():
|
||||
tensor_op_names, _ = _get_tensor_ops()
|
||||
torch_op_names, torch_ops = _get_torch_ops()
|
||||
|
||||
torch_exclude_ops = {}
|
||||
for name in torch_op_names:
|
||||
if name not in tensor_op_names:
|
||||
if name not in TorchOpExclude:
|
||||
# exclude the ops that are not in
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
torch_exclude_ops[name] = torch_ops[name]
|
||||
|
||||
return torch_exclude_ops.keys(), torch_exclude_ops
|
||||
|
||||
|
||||
class TensorOps(PyTorchOperation):
|
||||
"""
|
||||
corresponding to _get_tensor_ops in torch.jit.supported_ops
|
||||
"""
|
||||
_ori_type_name, _op_args = _get_tensor_ops()
|
||||
|
||||
comparison_ops = {'aten::eq': '==', 'aten::ne': '!=', 'aten::le': '<=', 'aten::ge': '>=', 'aten::lt': '<', 'aten::gt': '>'}
|
||||
|
||||
@staticmethod
|
||||
def _get_matched_args(_type, inputs):
|
||||
def has_same_arg_name(matched):
|
||||
concated_names = []
|
||||
for i, each in enumerate(matched):
|
||||
name = ','.join([arg[0] for arg in each])
|
||||
concated_names.append(name)
|
||||
for i in range(len(concated_names) - 1):
|
||||
if concated_names[i] != concated_names[i + 1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
overloaded_defs = TensorOps._op_args[_type]
|
||||
matched = []
|
||||
for each in overloaded_defs:
|
||||
# plus 1 because we skip the first argument when generating tensor op def
|
||||
if len(each) + 1 == len(inputs):
|
||||
matched.append(each)
|
||||
if len(matched) == 1:
|
||||
return matched[0]
|
||||
elif len(matched) > 1:
|
||||
# TODO: match with arg's type. manually choose for now
|
||||
if has_same_arg_name(matched):
|
||||
# return any one is okay
|
||||
return matched[0]
|
||||
elif _type in ManuallyChooseDef:
|
||||
return ManuallyChooseDef[_type]
|
||||
else:
|
||||
raise RuntimeError(f'tensor op type {_type} has more than one matched: {matched}')
|
||||
else:
|
||||
if _type in TensorOpExceptions:
|
||||
return None
|
||||
raise RuntimeError(f'tensor op type {_type} has no matched')
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
# TODO: deal with conditional ops
|
||||
if self.type in TensorOps.comparison_ops:
|
||||
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
|
||||
matched_args = TensorOps._get_matched_args(self.type, inputs)
|
||||
if matched_args is None:
|
||||
return TensorOpExceptions[self.type](output, inputs)
|
||||
op_name = self.type.split('::')[-1]
|
||||
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
|
||||
return f'{output} = {inputs[0]}.{op_name}({args_str})'
|
||||
|
||||
|
||||
class TorchOps(PyTorchOperation):
|
||||
"""
|
||||
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
|
||||
"""
|
||||
_ori_type_name, _op_args = _get_torch_ops_exclude_tensor_ops()
|
||||
# add 'aten::pixel_shuffle'
|
||||
_op_args['aten::pixel_shuffle'] = [[('input', 'Tensor', 'None'), ('upscale_factor', 'Optional[int]', 'None')]]
|
||||
_ori_type_name = _op_args.keys()
|
||||
|
||||
@staticmethod
|
||||
def _get_matched_args(_type, inputs):
|
||||
def has_same_arg_name(matched):
|
||||
concated_names = []
|
||||
for i, each in enumerate(matched):
|
||||
name = ','.join([arg[0] for arg in each])
|
||||
concated_names.append(name)
|
||||
for i in range(len(concated_names) - 1):
|
||||
if concated_names[i] != concated_names[i + 1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
overloaded_defs = TorchOps._op_args[_type]
|
||||
matched = []
|
||||
for each in overloaded_defs:
|
||||
if len(each) == len(inputs):
|
||||
matched.append(each)
|
||||
if len(matched) == 1:
|
||||
return matched[0]
|
||||
elif len(matched) > 1:
|
||||
# TODO: match with arg's type. manually choose for now
|
||||
if has_same_arg_name(matched):
|
||||
# return any one is okay
|
||||
return matched[0]
|
||||
else:
|
||||
raise RuntimeError(f'torch op type {_type} has more than one matched: {matched}')
|
||||
else:
|
||||
raise RuntimeError(f'torch op type {_type} has no matched')
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
matched_args = TorchOps._get_matched_args(self.type, inputs)
|
||||
op_name = self.type.split('::')[-1]
|
||||
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
|
||||
for i, (name, t, default) in enumerate(matched_args)])
|
||||
return f'{output} = torch.{op_name}({args_str})'
|
||||
|
||||
|
||||
class AtenAvgpool2d(PyTorchOperation):
|
||||
# NOTE: it is not included in the above aten ops for unkown reason
|
||||
_ori_type_name = ['aten::avg_pool2d']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
|
||||
|
||||
|
||||
class ToDevice(PyTorchOperation):
|
||||
_artificial_op_name = "ToDevice"
|
||||
|
||||
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
|
||||
attributes: Dict[str, Any] = {}):
|
||||
self.type = "ToDevice"
|
||||
self.device = parameters['device']
|
||||
self.overridden_device_repr = None
|
||||
self.src = parameters['src']
|
||||
self.dst = parameters['dst']
|
||||
|
||||
def override_device_repr(self, device_repr):
|
||||
# CUDA GPUDevice may remap GPU physical ID to CUDA ID. The device repr is different from GPUDevice.device_repr()
|
||||
# override_device_repr will be called in pytorch.graph_to_pytorch_model to replace device_repr with the correct
|
||||
# CUDA ID, e.g., when a job uses Physical GPU-1,2, its CUDA ID should be "cuda:0" and "cuda:1".
|
||||
# self.device.device_repr() would return "cuda:1" and "cuda:2", but override_device_repr should be "cuda:0" and
|
||||
# "cuda:1"
|
||||
self.overridden_device_repr = device_repr
|
||||
|
||||
def __repr__(self):
|
||||
if self.overridden_device_repr is None:
|
||||
return f'to("{self.device.device_repr()}")'
|
||||
else:
|
||||
return f'to("{self.overridden_device_repr}")'
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
if self.overridden_device_repr is None:
|
||||
forward_code = f'{output} = {inputs[0]}.to("{self.device.device_repr()}")'
|
||||
else:
|
||||
forward_code = f'{output} = {inputs[0]}.to("{self.overridden_device_repr}")'
|
||||
return forward_code
|
||||
|
||||
|
||||
class AtenDet(PyTorchOperation):
|
||||
# for torch 1.9
|
||||
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
|
||||
_ori_type_name = ['aten::linalg_det']
|
||||
|
||||
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
|
||||
return f'{output} = torch.det({inputs[0]})'
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict, Any, Type, cast
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.nas.execution.common import (
|
||||
Model, receive_trial_parameters,
|
||||
get_mutation_dict, mutation_dict_to_summary
|
||||
)
|
||||
from nni.nas.evaluator import Evaluator
|
||||
from nni.nas.utils import ContextStack
|
||||
from .graph import BaseExecutionEngine
|
||||
|
||||
|
||||
class PythonGraphData:
|
||||
def __init__(self, class_: Type[nn.Module], init_parameters: Dict[str, Any],
|
||||
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
|
||||
self.class_ = class_
|
||||
self.init_parameters = init_parameters
|
||||
self.mutation = mutation
|
||||
self.evaluator = evaluator
|
||||
self.mutation_summary = mutation_dict_to_summary(mutation)
|
||||
|
||||
def dump(self) -> dict:
|
||||
return {
|
||||
'class': self.class_,
|
||||
'init_parameters': self.init_parameters,
|
||||
'mutation': self.mutation,
|
||||
# engine needs to call dump here,
|
||||
# otherwise, evaluator will become binary
|
||||
# also, evaluator can be none in tests
|
||||
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
|
||||
'mutation_summary': self.mutation_summary
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def load(data) -> 'PythonGraphData':
|
||||
return PythonGraphData(data['class'], data['init_parameters'], data['mutation'], Evaluator._load(data['evaluator']))
|
||||
|
||||
|
||||
class PurePythonExecutionEngine(BaseExecutionEngine):
|
||||
"""
|
||||
This is the execution engine that doesn't rely on Python-IR converter.
|
||||
|
||||
We didn't explicitly state this independency for now. Front-end needs to decide which converter / no converter
|
||||
to use depending on the execution type. In the future, that logic may be moved into this execution engine.
|
||||
|
||||
The execution engine needs to store the class path of base model, and init parameters to re-initialize the model
|
||||
with the mutation dict in the context, so that the mutable modules are created to be the fixed instance on the fly.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def pack_model_data(cls, model: Model) -> Any:
|
||||
mutation = get_mutation_dict(model)
|
||||
assert model.evaluator is not None, 'Model evaluator is not available.'
|
||||
graph_data = PythonGraphData(
|
||||
cast(Type[nn.Module], model.python_class),
|
||||
model.python_init_params or {}, mutation, model.evaluator
|
||||
)
|
||||
return graph_data
|
||||
|
||||
@classmethod
|
||||
def trial_execute_graph(cls) -> None:
|
||||
graph_data = PythonGraphData.load(receive_trial_parameters())
|
||||
|
||||
def _model():
|
||||
return graph_data.class_(**graph_data.init_parameters)
|
||||
|
||||
with ContextStack('fixed', graph_data.mutation):
|
||||
graph_data.evaluator._execute(_model)
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from nni.nas.execution.common import TensorFlowOperation
|
||||
|
||||
|
||||
class Conv2D(TensorFlowOperation):
|
||||
def __init__(self, type_name, parameters, _internal, attributes=None):
|
||||
if 'padding' not in parameters:
|
||||
parameters['padding'] = 'same'
|
||||
super().__init__(type_name, parameters, _internal)
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Entrypoint for trials.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark'])
|
||||
args = parser.parse_args()
|
||||
if args.exec == 'base':
|
||||
from .pytorch.graph import BaseExecutionEngine
|
||||
engine = BaseExecutionEngine
|
||||
elif args.exec == 'cgo':
|
||||
from .pytorch.cgo import CGOExecutionEngine
|
||||
engine = CGOExecutionEngine
|
||||
elif args.exec == 'py':
|
||||
from .pytorch.simplified import PurePythonExecutionEngine
|
||||
engine = PurePythonExecutionEngine
|
||||
elif args.exec == 'benchmark':
|
||||
from .pytorch.benchmark import BenchmarkExecutionEngine
|
||||
engine = BenchmarkExecutionEngine
|
||||
else:
|
||||
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
|
||||
engine.trial_execute_graph()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from nni.common.framework import shortcut_framework
|
||||
|
||||
shortcut_framework(__name__)
|
||||
|
||||
del shortcut_framework
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import get_and_apply_next_architecture
|
||||
from .experiment_config import *
|
||||
from .engine_config import *
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
|
||||
from nni.experiment.config.base import ConfigBase
|
||||
|
||||
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
|
||||
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
|
||||
|
||||
@dataclass(init=False)
|
||||
class ExecutionEngineConfig(ConfigBase):
|
||||
name: str
|
||||
|
||||
@dataclass(init=False)
|
||||
class PyEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'py'
|
||||
|
||||
@dataclass(init=False)
|
||||
class OneshotEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'oneshot'
|
||||
|
||||
@dataclass(init=False)
|
||||
class BaseEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'base'
|
||||
# input used in GraphConverterWithShape. Currently support shape tuple only.
|
||||
dummy_input: Optional[List[int]] = None
|
||||
|
||||
@dataclass(init=False)
|
||||
class CgoEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'cgo'
|
||||
max_concurrency_cgo: Optional[int] = None
|
||||
batch_waiting_time: Optional[int] = None
|
||||
# input used in GraphConverterWithShape. Currently support shape tuple only.
|
||||
dummy_input: Optional[List[int]] = None
|
||||
|
||||
@dataclass(init=False)
|
||||
class BenchmarkEngineConfig(ExecutionEngineConfig):
|
||||
name: str = 'benchmark'
|
||||
benchmark: Optional[str] = None
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from nni.experiment.config import utils, ExperimentConfig
|
||||
|
||||
from .engine_config import ExecutionEngineConfig
|
||||
|
||||
__all__ = ['RetiariiExeConfig']
|
||||
|
||||
def execution_engine_config_factory(engine_name):
|
||||
# FIXME: may move this function to experiment utils in future
|
||||
cls = _get_ee_config_class(engine_name)
|
||||
if cls is None:
|
||||
raise ValueError(f'Invalid execution engine name: {engine_name}')
|
||||
return cls()
|
||||
|
||||
def _get_ee_config_class(engine_name):
|
||||
for cls in ExecutionEngineConfig.__subclasses__():
|
||||
if cls.name == engine_name:
|
||||
return cls
|
||||
return None
|
||||
|
||||
@dataclass(init=False)
|
||||
class RetiariiExeConfig(ExperimentConfig):
|
||||
# FIXME: refactor this class to inherit from a new common base class with HPO config
|
||||
search_space: Any = ''
|
||||
trial_code_directory: utils.PathLike = '.'
|
||||
trial_command: str = '_reserved'
|
||||
# new config field for NAS
|
||||
execution_engine: Union[str, ExecutionEngineConfig]
|
||||
|
||||
# Internal: to support customized fields in trial command
|
||||
# Useful when customized python / environment variables are needed
|
||||
_trial_command_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __init__(self, training_service_platform: Union[str, None] = None,
|
||||
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
|
||||
**kwargs):
|
||||
super().__init__(training_service_platform, **kwargs)
|
||||
self.execution_engine = execution_engine
|
||||
|
||||
def _canonicalize(self, _parents):
|
||||
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
|
||||
if self.search_space != '':
|
||||
raise ValueError(msg.format('search_space', self.search_space))
|
||||
# TODO: maybe we should also allow users to specify trial_code_directory
|
||||
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
|
||||
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
|
||||
|
||||
trial_command_tmpl = '{envs} {python} -m nni.retiarii.trial_entry {execution_engine}'
|
||||
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
|
||||
raise ValueError(msg.format('trial_command', self.trial_command))
|
||||
|
||||
if isinstance(self.execution_engine, str):
|
||||
self.execution_engine = execution_engine_config_factory(self.execution_engine)
|
||||
|
||||
_trial_command_params = {
|
||||
# Default variables
|
||||
'envs': '',
|
||||
# TODO: maybe use sys.executable rendered in trial side (e.g., trial_runner)
|
||||
'python': sys.executable,
|
||||
'execution_engine': self.execution_engine.name,
|
||||
|
||||
# This should override the parameters above.
|
||||
**(self._trial_command_params or {})
|
||||
}
|
||||
|
||||
self.trial_command = trial_command_tmpl.format(**_trial_command_params).strip()
|
||||
|
||||
super()._canonicalize([self])
|
|
@ -0,0 +1,361 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model']
|
||||
|
||||
import logging
|
||||
|
||||
import warnings
|
||||
from threading import Thread
|
||||
from typing import Any, List, cast
|
||||
|
||||
import colorama
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from nni.experiment import Experiment, RunMode
|
||||
from nni.experiment.config.training_services import RemoteConfig
|
||||
|
||||
from nni.nas.execution import list_models, set_execution_engine
|
||||
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict
|
||||
from nni.nas.execution.pytorch.codegen import model_to_pytorch_script
|
||||
from nni.nas.execution.pytorch.converter import convert_to_graph
|
||||
from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape
|
||||
from nni.nas.evaluator import Evaluator
|
||||
from nni.nas.mutable import Mutator
|
||||
from nni.nas.nn.pytorch.mutator import (
|
||||
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
|
||||
)
|
||||
from nni.nas.utils import is_model_wrapped
|
||||
from nni.nas.strategy import BaseStrategy
|
||||
from nni.nas.strategy.utils import dry_run_for_formatted_search_space
|
||||
from .config import (
|
||||
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
|
||||
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
|
||||
# TODO: this logic might need to be refactored into execution engine
|
||||
if oneshot:
|
||||
base_model_ir, mutators = process_oneshot_mutations(base_model, evaluator)
|
||||
elif full_ir:
|
||||
try:
|
||||
script_module = torch.jit.script(base_model)
|
||||
except Exception as e:
|
||||
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
|
||||
raise e
|
||||
if dummy_input is not None:
|
||||
# FIXME: this is a workaround as full tensor is not supported in configs
|
||||
dummy_input = torch.randn(*dummy_input)
|
||||
converter = GraphConverterWithShape()
|
||||
base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
|
||||
else:
|
||||
base_model_ir = convert_to_graph(script_module, base_model)
|
||||
# handle inline mutations
|
||||
mutators = process_inline_mutation(base_model_ir)
|
||||
else:
|
||||
base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
|
||||
base_model_ir.evaluator = evaluator
|
||||
|
||||
if mutators is not None and applied_mutators:
|
||||
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
|
||||
'do not use mutators when you use LayerChoice/InputChoice')
|
||||
if mutators is not None:
|
||||
applied_mutators = mutators
|
||||
|
||||
# Add mutations on evaluators
|
||||
applied_mutators += process_evaluator_mutations(evaluator, applied_mutators)
|
||||
|
||||
return base_model_ir, applied_mutators
|
||||
|
||||
|
||||
def debug_mutated_model(base_model, evaluator, applied_mutators):
|
||||
"""
|
||||
Locally run only one trial without launching an experiment for debug purpose, then exit.
|
||||
For example, it can be used to quickly check shape mismatch.
|
||||
|
||||
Specifically, it applies mutators (default to choose the first candidate for the choices)
|
||||
to generate a new model, then run this model locally.
|
||||
|
||||
The model will be parsed with graph execution engine.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_model : nni.retiarii.nn.pytorch.nn.Module
|
||||
the base model
|
||||
evaluator : nni.retiarii.graph.Evaluator
|
||||
the training class of the generated models
|
||||
applied_mutators : list
|
||||
a list of mutators that will be applied on the base model for generating a new model
|
||||
"""
|
||||
base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
|
||||
from nni.nas.strategy.debug import _LocalDebugStrategy
|
||||
strategy = _LocalDebugStrategy()
|
||||
strategy.run(base_model_ir, applied_mutators)
|
||||
_logger.info('local debug completed!')
|
||||
|
||||
|
||||
class RetiariiExperiment(Experiment):
|
||||
"""
|
||||
The entry for a NAS experiment.
|
||||
Users can use this class to start/stop or inspect an experiment, like exporting the results.
|
||||
|
||||
Experiment is a sub-class of :class:`nni.experiment.Experiment`, there are many similarities such as
|
||||
configurable training service to distributed running the experiment on remote server.
|
||||
But unlike :class:`nni.experiment.Experiment`, RetiariiExperiment doesn't support configure:
|
||||
|
||||
- ``trial_code_directory``, which can only be current working directory.
|
||||
- ``search_space``, which is auto-generated in NAS.
|
||||
- ``trial_command``, which must be ``python -m nni.retiarii.trial_entry`` to launch the modulized trial code.
|
||||
|
||||
RetiariiExperiment also doesn't have tuner/assessor/advisor, because they are also implemented in strategy.
|
||||
|
||||
Also, unlike :class:`nni.experiment.Experiment` which is bounded to a node server,
|
||||
RetiariiExperiment optionally starts a node server to schedule the trials, when the strategy is a multi-trial strategy.
|
||||
When the strategy is one-shot, the step of launching node server is omitted, and the experiment is run locally by default.
|
||||
|
||||
Configurations of experiments, such as execution engine, number of GPUs allocated,
|
||||
should be put into a :class:`RetiariiExeConfig` and used as an argument of :meth:`RetiariiExperiment.run`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_model : nn.Module
|
||||
The model defining the search space / base skeleton without mutation.
|
||||
It should be wrapped by decorator ``nni.retiarii.model_wrapper``.
|
||||
evaluator : nni.retiarii.Evaluator, default = None
|
||||
Evaluator for the experiment.
|
||||
If you are using a one-shot trainer, it should be placed here, although this usage is deprecated.
|
||||
applied_mutators : list of nni.retiarii.Mutator, default = None
|
||||
Mutators os mutate the base model. If none, mutators are skipped.
|
||||
Note that when ``base_model`` uses inline mutations (e.g., LayerChoice), ``applied_mutators`` must be empty / none.
|
||||
strategy : nni.retiarii.strategy.BaseStrategy, default = None
|
||||
Exploration strategy. Can be multi-trial or one-shot.
|
||||
trainer : BaseOneShotTrainer
|
||||
Kept for compatibility purposes.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Multi-trial NAS:
|
||||
|
||||
>>> base_model = Net()
|
||||
>>> search_strategy = strategy.Random()
|
||||
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
|
||||
>>> exp = RetiariiExperiment(base_model, model_evaluator, [], search_strategy)
|
||||
>>> exp_config = RetiariiExeConfig('local')
|
||||
>>> exp_config.trial_concurrency = 2
|
||||
>>> exp_config.max_trial_number = 20
|
||||
>>> exp_config.training_service.use_active_gpu = False
|
||||
>>> exp.run(exp_config, 8081)
|
||||
|
||||
One-shot NAS:
|
||||
|
||||
>>> base_model = Net()
|
||||
>>> search_strategy = strategy.DARTS()
|
||||
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
|
||||
>>> exp = RetiariiExperiment(base_model, evaluator, [], search_strategy)
|
||||
>>> exp_config = RetiariiExeConfig()
|
||||
>>> exp_config.execution_engine = 'oneshot' # must be set of one-shot strategy
|
||||
>>> exp.run(exp_config)
|
||||
|
||||
Export top models:
|
||||
|
||||
>>> for model_dict in exp.export_top_models(formatter='dict'):
|
||||
... print(model_dict)
|
||||
>>> with nni.retarii.fixed_arch(model_dict):
|
||||
... final_model = Net()
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module,
|
||||
evaluator: Evaluator = cast(Evaluator, None),
|
||||
applied_mutators: List[Mutator] = cast(List[Mutator], None),
|
||||
strategy: BaseStrategy = cast(BaseStrategy, None),
|
||||
trainer: Any = None):
|
||||
super().__init__(None)
|
||||
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
|
||||
|
||||
if trainer is not None:
|
||||
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
|
||||
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
|
||||
evaluator = trainer
|
||||
|
||||
if evaluator is None:
|
||||
raise ValueError('Evaluator should not be none.')
|
||||
|
||||
self.base_model = base_model
|
||||
self.evaluator: Evaluator = evaluator
|
||||
self.applied_mutators = applied_mutators
|
||||
self.strategy = strategy
|
||||
|
||||
self._dispatcher = None
|
||||
self._dispatcher_thread = None
|
||||
|
||||
# check for sanity
|
||||
if not is_model_wrapped(base_model):
|
||||
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
|
||||
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
|
||||
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
|
||||
RuntimeWarning)
|
||||
|
||||
def _run_strategy(self, config: RetiariiExeConfig):
|
||||
base_model_ir, self.applied_mutators = preprocess_model(
|
||||
self.base_model, self.evaluator, self.applied_mutators,
|
||||
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
|
||||
dummy_input=config.execution_engine.dummy_input
|
||||
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
|
||||
)
|
||||
|
||||
_logger.info('Start strategy...')
|
||||
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
|
||||
self.update_search_space(search_space)
|
||||
self.strategy.run(base_model_ir, self.applied_mutators)
|
||||
_logger.info('Strategy exit')
|
||||
# TODO: find out a proper way to show no more trial message on WebUI
|
||||
|
||||
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
|
||||
#TODO: we will probably need a execution engine factory to make this clean and elegant
|
||||
if isinstance(config.execution_engine, BaseEngineConfig):
|
||||
from nni.nas.execution.pytorch.graph import BaseExecutionEngine
|
||||
engine = BaseExecutionEngine(self.port, self.url_prefix)
|
||||
elif isinstance(config.execution_engine, CgoEngineConfig):
|
||||
from nni.nas.execution.pytorch.cgo import CGOExecutionEngine
|
||||
|
||||
assert not isinstance(config.training_service, list) \
|
||||
and config.training_service.platform == 'remote', \
|
||||
"CGO execution engine currently only supports remote training service"
|
||||
assert config.execution_engine.batch_waiting_time is not None \
|
||||
and config.execution_engine.max_concurrency_cgo is not None
|
||||
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
|
||||
max_concurrency=config.execution_engine.max_concurrency_cgo,
|
||||
batch_waiting_time=config.execution_engine.batch_waiting_time,
|
||||
rest_port=self.port,
|
||||
rest_url_prefix=self.url_prefix)
|
||||
elif isinstance(config.execution_engine, PyEngineConfig):
|
||||
from nni.nas.execution.pytorch.simplified import PurePythonExecutionEngine
|
||||
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
|
||||
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
|
||||
from nni.nas.execution.pytorch.benchmark import BenchmarkExecutionEngine
|
||||
assert config.execution_engine.benchmark is not None, \
|
||||
'"benchmark" must be set when benchmark execution engine is used.'
|
||||
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
|
||||
else:
|
||||
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
|
||||
set_execution_engine(engine)
|
||||
|
||||
def start(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
By design, the only different between `start` and `run` is that `start` is asynchronous,
|
||||
while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
|
||||
to complete as strategy runs in foreground.
|
||||
"""
|
||||
raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method')
|
||||
|
||||
def run(self,
|
||||
config: RetiariiExeConfig | None = None,
|
||||
port: int = 8080,
|
||||
debug: bool = False) -> None:
|
||||
"""
|
||||
Run the experiment.
|
||||
This function will block until experiment finish or error.
|
||||
"""
|
||||
|
||||
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
|
||||
if isinstance(self.evaluator, BaseOneShotTrainer):
|
||||
warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
|
||||
'We will try to convert this trainer to our new implementation to run the algorithm. '
|
||||
'In case you want to stick to the old implementation, '
|
||||
'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
|
||||
self.evaluator.fit()
|
||||
return
|
||||
|
||||
if config is None:
|
||||
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
|
||||
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
|
||||
self.config = RetiariiExeConfig()
|
||||
self.config.execution_engine = OneshotEngineConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
if isinstance(self.config.execution_engine, OneshotEngineConfig) \
|
||||
or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'):
|
||||
# this is hacky, will be refactored when oneshot can run on training services
|
||||
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
|
||||
self.strategy.run(base_model_ir, self.applied_mutators)
|
||||
else:
|
||||
ws_url = f'ws://localhost:{port}/tuner'
|
||||
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
|
||||
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
|
||||
self._dispatcher = RetiariiAdvisor(ws_url)
|
||||
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
|
||||
self._dispatcher_thread.start()
|
||||
# FIXME: engine cannot be created twice
|
||||
self._create_execution_engine(canonicalized_config)
|
||||
try:
|
||||
self._run_strategy(canonicalized_config)
|
||||
# FIXME: move this logic to strategy with a new API provided by execution engine
|
||||
self._wait_completion()
|
||||
except KeyboardInterrupt:
|
||||
_logger.warning('KeyboardInterrupt detected')
|
||||
self.stop()
|
||||
_logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.')
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stop background experiment.
|
||||
"""
|
||||
_logger.info('Stopping experiment, please wait...')
|
||||
self._stop_impl()
|
||||
if self._dispatcher_thread:
|
||||
self._dispatcher_thread.join()
|
||||
self._dispatcher = cast(RetiariiAdvisor, None)
|
||||
self._dispatcher_thread = None
|
||||
_logger.info('Experiment stopped')
|
||||
|
||||
def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'dict') -> Any:
|
||||
"""
|
||||
Export several top performing models.
|
||||
|
||||
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
|
||||
available for customization.
|
||||
|
||||
The concrete behavior of export depends on each strategy.
|
||||
See the documentation of each strategy for detailed specifications.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
top_k : int
|
||||
How many models are intended to be exported.
|
||||
optimize_mode : str
|
||||
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
|
||||
``optimize_mode`` is likely to be removed and defined in strategy in future.
|
||||
formatter : str
|
||||
Support ``code`` and ``dict``. Not supported by one-shot algorithms.
|
||||
If ``code``, the python code of model will be returned.
|
||||
If ``dict``, the mutation history will be returned.
|
||||
"""
|
||||
# TODO: the base class may also need this method
|
||||
if formatter == 'code':
|
||||
config = self.config.canonical_copy()
|
||||
assert not isinstance(config.execution_engine, PyEngineConfig), \
|
||||
'You should use `dict` formatter when using Python execution engine.'
|
||||
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
|
||||
if isinstance(self.evaluator, BaseOneShotTrainer):
|
||||
assert top_k == 1, 'Only support top_k is 1 for now.'
|
||||
return self.evaluator.export()
|
||||
try:
|
||||
# this currently works for one-shot algorithms
|
||||
return self.strategy.export_top_models(top_k=top_k)
|
||||
except NotImplementedError:
|
||||
# when strategy hasn't implemented its own export logic
|
||||
all_models = filter(lambda m: m.metric is not None, list_models())
|
||||
assert optimize_mode in ['maximize', 'minimize']
|
||||
all_models = sorted(all_models, key=lambda m: cast(float, m.metric), reverse=optimize_mode == 'maximize')
|
||||
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
|
||||
if formatter == 'code':
|
||||
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
|
||||
elif formatter == 'dict':
|
||||
return [get_mutation_dict(model) for model in all_models[:top_k]]
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union, Dict, Any
|
||||
|
||||
from .utils import ContextStack
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
|
||||
"""
|
||||
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with fixed_arch('/path/to/export.json'):
|
||||
model = Model(3, 224, 224)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fixed_arc : str, Path or dict
|
||||
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
|
||||
verbose : bool
|
||||
Print log messages if set to True
|
||||
|
||||
Returns
|
||||
-------
|
||||
ContextStack
|
||||
Context manager that provides a fixed architecture when creates the model.
|
||||
"""
|
||||
|
||||
if isinstance(fixed_arch, (str, Path)):
|
||||
with open(fixed_arch) as f:
|
||||
fixed_arch = json.load(f)
|
||||
|
||||
if verbose:
|
||||
_logger.info(f'Fixed architecture: %s', fixed_arch)
|
||||
|
||||
return ContextStack('fixed', fixed_arch)
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from nni.common.framework import shortcut_framework
|
||||
|
||||
shortcut_framework(__name__)
|
||||
|
||||
del shortcut_framework
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mobilenetv3 import MobileNetV3Space
|
||||
from .nasbench101 import NasBench101
|
||||
from .nasbench201 import NasBench201
|
||||
from .nasnet import NDS, NASNet, ENAS, AmoebaNet, PNAS, DARTS
|
||||
from .proxylessnas import ProxylessNAS
|
||||
from .shufflenet import ShuffleNetSpace
|
||||
from .autoformer import AutoformerSpace
|
|
@ -0,0 +1,469 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional, Tuple, cast, Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from timm.models.layers import trunc_normal_, DropPath
|
||||
|
||||
import nni.nas.nn.pytorch as nn
|
||||
from nni.nas import model_wrapper, basic_unit
|
||||
from nni.nas.nn.pytorch.choice import ValueChoiceX
|
||||
from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
|
||||
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
|
||||
from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W
|
||||
|
||||
from .utils.fixed import FixedFactory
|
||||
from .utils.pretrained import load_pretrained_weight
|
||||
|
||||
|
||||
class RelativePosition2D(nn.Module):
|
||||
def __init__(self, head_embed_dim, length=14,) -> None:
|
||||
super().__init__()
|
||||
self.head_embed_dim = head_embed_dim
|
||||
self.legnth = length
|
||||
self.embeddings_table_v = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
|
||||
self.embeddings_table_h = nn.Parameter(torch.randn(length * 2 + 2, head_embed_dim))
|
||||
|
||||
trunc_normal_(self.embeddings_table_v, std=.02)
|
||||
trunc_normal_(self.embeddings_table_h, std=.02)
|
||||
|
||||
def forward(self, length_q, length_k):
|
||||
# remove the first cls token distance computation
|
||||
length_q = length_q - 1
|
||||
length_k = length_k - 1
|
||||
# init in the device directly, rather than move to device
|
||||
range_vec_q = torch.arange(length_q, device=self.embeddings_table_v.device)
|
||||
range_vec_k = torch.arange(length_k, device=self.embeddings_table_v.device)
|
||||
# compute the row and column distance
|
||||
length_q_sqrt = int(length_q ** 0.5)
|
||||
distance_mat_v = (range_vec_k[None, :] // length_q_sqrt - range_vec_q[:, None] // length_q_sqrt)
|
||||
distance_mat_h = (range_vec_k[None, :] % length_q_sqrt - range_vec_q[:, None] % length_q_sqrt)
|
||||
# clip the distance to the range of [-legnth, legnth]
|
||||
distance_mat_clipped_v = torch.clamp(distance_mat_v, - self.legnth, self.legnth)
|
||||
distance_mat_clipped_h = torch.clamp(distance_mat_h, - self.legnth, self.legnth)
|
||||
|
||||
# translate the distance from [1, 2 * legnth + 1], 0 is for the cls token
|
||||
final_mat_v = distance_mat_clipped_v + self.legnth + 1
|
||||
final_mat_h = distance_mat_clipped_h + self.legnth + 1
|
||||
# pad the 0 which represent the cls token
|
||||
final_mat_v = F.pad(final_mat_v, (1, 0, 1, 0), "constant", 0)
|
||||
final_mat_h = F.pad(final_mat_h, (1, 0, 1, 0), "constant", 0)
|
||||
|
||||
final_mat_v = final_mat_v.long()
|
||||
final_mat_h = final_mat_h.long()
|
||||
# get the embeddings with the corresponding distance
|
||||
embeddings = self.embeddings_table_v[final_mat_v] + self.embeddings_table_h[final_mat_h]
|
||||
|
||||
return embeddings
|
||||
|
||||
class RelativePositionAttention(nn.Module):
|
||||
"""
|
||||
This class is designed to support the relative position in attention.
|
||||
The pytorch built-in nn.MultiheadAttention() does not support relative position embedding.
|
||||
Different from the absolute position embedding, the relative position embedding considers
|
||||
encode the relative distance between input tokens and learn the pairwise relations of them.
|
||||
It is commonly calculated via a look-up table with learnable parameters interacting with queries
|
||||
and keys in self-attention modules.
|
||||
"""
|
||||
def __init__(
|
||||
self, embed_dim, num_heads,
|
||||
attn_drop=0., proj_drop=0.,
|
||||
qkv_bias=False, qk_scale=None,
|
||||
rpe_length=14, rpe=False,
|
||||
head_dim=64):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
# head_dim is fixed 64 in official autoformer. set head_dim = None to use flex head dim.
|
||||
self.head_dim = head_dim or (embed_dim // num_heads)
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
# Please refer to MixedMultiheadAttention for details.
|
||||
self.q = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
|
||||
self.k = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
|
||||
self.v = nn.Linear(embed_dim, head_dim * num_heads, bias = qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(head_dim * num_heads, embed_dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.rpe = rpe
|
||||
if rpe:
|
||||
self.rel_pos_embed_k = RelativePosition2D(head_dim, rpe_length)
|
||||
self.rel_pos_embed_v = RelativePosition2D(head_dim, rpe_length)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
head_dim = self.head_dim
|
||||
# num_heads can not get from self.num_heads directly,
|
||||
# use -1 to compute implicitly.
|
||||
num_heads = -1
|
||||
q = self.q(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
|
||||
k = self.k(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
|
||||
v = self.v(x).reshape(B, N, num_heads, head_dim).permute(0, 2, 1, 3)
|
||||
num_heads = q.size(1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if self.rpe:
|
||||
r_p_k = self.rel_pos_embed_k(N, N)
|
||||
attn = attn + (
|
||||
q.permute(2, 0, 1, 3).reshape(N, num_heads * B, head_dim) @ r_p_k.transpose(2, 1)
|
||||
).transpose(1, 0).reshape(B, num_heads, N, N) * self.scale
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, num_heads * head_dim)
|
||||
|
||||
if self.rpe:
|
||||
attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * num_heads, N)
|
||||
r_p_v = self.rel_pos_embed_v(N, N)
|
||||
# The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with
|
||||
# the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the
|
||||
# same size as x (B, num_heads, N, hidden_dim)
|
||||
x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(B, num_heads, N, head_dim).transpose(2, 1).reshape(B, N, num_heads * head_dim)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
This class is designed to support the RelativePositionAttention().
|
||||
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
|
||||
"""
|
||||
def __init__(
|
||||
self, embed_dim, num_heads, mlp_ratio=4.,
|
||||
qkv_bias=False, qk_scale=None, rpe=False,
|
||||
drop_rate=0., attn_drop=0., proj_drop=0., drop_path=0.,
|
||||
pre_norm=True, rpe_length=14, head_dim=64
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.normalize_before = pre_norm
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.dropout = drop_rate
|
||||
self.attn = RelativePositionAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
rpe=rpe,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
rpe_length=rpe_length,
|
||||
head_dim=head_dim
|
||||
)
|
||||
|
||||
self.attn_layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.ffn_layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.activation_fn = nn.GELU()
|
||||
|
||||
self.fc1 = nn.Linear(
|
||||
cast(int, embed_dim),
|
||||
cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio))
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
cast(int, nn.ValueChoice.to_int(embed_dim * mlp_ratio)),
|
||||
cast(int, embed_dim)
|
||||
)
|
||||
|
||||
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
||||
assert before ^ after
|
||||
if after ^ self.normalize_before:
|
||||
return layer_norm(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)`
|
||||
Returns:
|
||||
encoded output of shape `(batch, patch_num, sample_embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.attn_layer_norm, x, before=True)
|
||||
x = self.attn(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.drop_path(x)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.attn_layer_norm, x, after=True)
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.ffn_layer_norm, x, before=True)
|
||||
x = self.fc1(x)
|
||||
x = self.activation_fn(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.drop_path(x)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.ffn_layer_norm, x, after=True)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@basic_unit
|
||||
class ClsToken(nn.Module):
|
||||
""" Concat class token with dim=embed_dim before patch embedding.
|
||||
"""
|
||||
def __init__(self, embed_dim: int):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
|
||||
|
||||
class MixedClsToken(MixedOperation, ClsToken):
|
||||
""" Mixed class token concat operation.
|
||||
|
||||
Supported arguments are:
|
||||
|
||||
- ``embed_dim``
|
||||
|
||||
Prefix of cls_token will be sliced.
|
||||
"""
|
||||
bound_type = ClsToken
|
||||
argument_list = ['embed_dim']
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
|
||||
return max(traverse_all_options(value_choice))
|
||||
|
||||
def forward_with_args(self, embed_dim,
|
||||
inputs: torch.Tensor) -> torch.Tensor:
|
||||
embed_dim_ = _W(embed_dim)
|
||||
cls_token = _S(self.cls_token)[..., :embed_dim_]
|
||||
|
||||
return torch.cat((cls_token.expand(inputs.shape[0], -1, -1), inputs), dim=1)
|
||||
|
||||
|
||||
@basic_unit
|
||||
class AbsPosEmbed(nn.Module):
|
||||
""" Add absolute position embedding on patch embedding.
|
||||
"""
|
||||
def __init__(self, length: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, length, embed_dim))
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pos_embed
|
||||
|
||||
|
||||
class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed):
|
||||
""" Mixed absolute position embedding add operation.
|
||||
|
||||
Supported arguments are:
|
||||
|
||||
- ``embed_dim``
|
||||
|
||||
Prefix of pos_embed will be sliced.
|
||||
"""
|
||||
bound_type = AbsPosEmbed
|
||||
argument_list = ['embed_dim']
|
||||
|
||||
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
|
||||
return max(traverse_all_options(value_choice))
|
||||
|
||||
def forward_with_args(self, embed_dim,
|
||||
inputs: torch.Tensor) -> torch.Tensor:
|
||||
embed_dim_ = _W(embed_dim)
|
||||
pos_embed = _S(self.pos_embed)[..., :embed_dim_]
|
||||
|
||||
return inputs + pos_embed
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class AutoformerSpace(nn.Module):
|
||||
"""
|
||||
The search space that is proposed in `Autoformer <https://arxiv.org/abs/2107.00651>`__.
|
||||
There are four searchable variables: depth, embedding dimension, heads number and MLP ratio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
search_embed_dim : list of int
|
||||
The search space of embedding dimension.
|
||||
search_mlp_ratio : list of float
|
||||
The search space of MLP ratio.
|
||||
search_num_heads : list of int
|
||||
The search space of number of heads.
|
||||
search_depth: list of int
|
||||
The search space of depth.
|
||||
img_size : int
|
||||
Size of input image.
|
||||
patch_size : int
|
||||
Size of image patch.
|
||||
in_chans : int
|
||||
Number of channels of the input image.
|
||||
num_classes : int
|
||||
Number of classes for classifier.
|
||||
qkv_bias : bool
|
||||
Whether to use bias item in the qkv embedding.
|
||||
drop_rate : float
|
||||
Drop rate of the MLP projection in MSA and FFN.
|
||||
attn_drop_rate : float
|
||||
Drop rate of attention.
|
||||
drop_path_rate : float
|
||||
Drop path rate.
|
||||
pre_norm : bool
|
||||
Whether to use pre_norm. Otherwise post_norm is used.
|
||||
global_pool : bool
|
||||
Whether to use global pooling to generate the image representation. Otherwise the cls_token is used.
|
||||
abs_pos : bool
|
||||
Whether to use absolute positional embeddings.
|
||||
qk_scale : float
|
||||
The scaler on score map in self-attention.
|
||||
rpe : bool
|
||||
Whether to use relative position encoding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
search_embed_dim: Tuple[int, ...] = (192, 216, 240),
|
||||
search_mlp_ratio: Tuple[float, ...] = (3.0, 3.5, 4.0),
|
||||
search_num_heads: Tuple[int, ...] = (3, 4),
|
||||
search_depth: Tuple[int, ...] = (12, 13, 14),
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
qkv_bias: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
pre_norm: bool = True,
|
||||
global_pool: bool = False,
|
||||
abs_pos: bool = True,
|
||||
qk_scale: Optional[float] = None,
|
||||
rpe: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
# define search space parameters
|
||||
embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim")
|
||||
depth = nn.ValueChoice(list(search_depth), label="depth")
|
||||
mlp_ratios = [nn.ValueChoice(list(search_mlp_ratio), label=f"mlp_ratio_{i}") for i in range(max(search_depth))]
|
||||
num_heads = [nn.ValueChoice(list(search_num_heads), label=f"num_head_{i}") for i in range(max(search_depth))]
|
||||
|
||||
self.patch_embed = nn.Conv2d(
|
||||
in_chans, cast(int, embed_dim),
|
||||
kernel_size = patch_size,
|
||||
stride = patch_size
|
||||
)
|
||||
self.patches_num = int((img_size // patch_size) ** 2)
|
||||
self.global_pool = global_pool
|
||||
|
||||
self.cls_token = ClsToken(cast(int, embed_dim))
|
||||
self.pos_embed = AbsPosEmbed(self.patches_num+1, cast(int, embed_dim)) if abs_pos else nn.Identity()
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, max(search_depth))] # stochastic depth decay rule
|
||||
|
||||
self.blocks = nn.Repeat(
|
||||
lambda index: TransformerEncoderLayer(
|
||||
embed_dim = embed_dim, num_heads = num_heads[index], mlp_ratio=mlp_ratios[index],
|
||||
qkv_bias = qkv_bias, drop_rate = drop_rate, attn_drop = attn_drop_rate, drop_path=dpr[index],
|
||||
rpe_length=img_size // patch_size, qk_scale=qk_scale, rpe=rpe, pre_norm=pre_norm, head_dim = 64
|
||||
), depth
|
||||
)
|
||||
|
||||
self.norm = nn.LayerNorm(cast(int, embed_dim)) if pre_norm else nn.Identity()
|
||||
self.head = nn.Linear(cast(int, embed_dim), num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
@classmethod
|
||||
def get_extra_mutation_hooks(cls):
|
||||
return [MixedAbsPosEmbed.mutate, MixedClsToken.mutate]
|
||||
|
||||
@classmethod
|
||||
def load_searched_model(
|
||||
cls, name: str,
|
||||
pretrained: bool = False, download: bool = False, progress: bool = True
|
||||
) -> nn.Module:
|
||||
|
||||
init_kwargs = {'qkv_bias': True, 'drop_rate': 0.0, 'drop_path_rate': 0.1, 'global_pool': True, 'num_classes': 1000}
|
||||
if name == 'autoformer-tiny':
|
||||
mlp_ratio = [3.5, 3.5, 3.0, 3.5, 3.0, 3.0, 4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5] + [3.0]
|
||||
num_head = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3] + [3]
|
||||
arch: Dict[str, Any] = {
|
||||
'embed_dim': 192,
|
||||
'depth': 13
|
||||
}
|
||||
for i in range(14):
|
||||
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
|
||||
arch[f'num_head_{i}'] = num_head[i]
|
||||
|
||||
init_kwargs.update({
|
||||
'search_embed_dim': (240, 216, 192),
|
||||
'search_mlp_ratio': (4.0, 3.5, 3.0),
|
||||
'search_num_heads': (4, 3),
|
||||
'search_depth': (14, 13, 12),
|
||||
})
|
||||
elif name == 'autoformer-small':
|
||||
mlp_ratio = [3.0, 3.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.5, 4.0] + [3.0]
|
||||
num_head = [6, 6, 5, 7, 5, 5, 5, 6, 6, 7, 7, 6, 7] + [5]
|
||||
arch: Dict[str, Any] = {
|
||||
'embed_dim': 384,
|
||||
'depth': 13
|
||||
}
|
||||
for i in range(14):
|
||||
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
|
||||
arch[f'num_head_{i}'] = num_head[i]
|
||||
|
||||
init_kwargs.update({
|
||||
'search_embed_dim': (448, 384, 320),
|
||||
'search_mlp_ratio': (4.0, 3.5, 3.0),
|
||||
'search_num_heads': (7, 6, 5),
|
||||
'search_depth': (14, 13, 12),
|
||||
})
|
||||
|
||||
elif name == 'autoformer-base':
|
||||
mlp_ratio = [3.5, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.5] + [3.0, 3.0]
|
||||
num_head = [9, 9, 9, 9, 9, 10, 9, 9, 10, 9, 10, 9, 9, 10] + [8, 8]
|
||||
arch: Dict[str, Any] = {
|
||||
'embed_dim': 576,
|
||||
'depth': 14
|
||||
}
|
||||
for i in range(16):
|
||||
arch[f'mlp_ratio_{i}'] = mlp_ratio[i]
|
||||
arch[f'num_head_{i}'] = num_head[i]
|
||||
|
||||
init_kwargs.update({
|
||||
'search_embed_dim': (624, 576, 528),
|
||||
'search_mlp_ratio': (4.0, 3.5, 3.0),
|
||||
'search_num_heads': (10, 9, 8),
|
||||
'search_depth': (16, 15, 14),
|
||||
})
|
||||
else:
|
||||
raise ValueError(f'Unsupported architecture with name: {name}')
|
||||
|
||||
model_factory = FixedFactory(cls, arch)
|
||||
model = model_factory(**init_kwargs)
|
||||
|
||||
if pretrained:
|
||||
weight_file = load_pretrained_weight(name, download=download, progress=progress)
|
||||
pretrained_weights = torch.load(weight_file)
|
||||
model.load_state_dict(pretrained_weights)
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
x = x.permute(0, 2, 3, 1).view(B, self.patches_num, -1)
|
||||
x = self.cls_token(x)
|
||||
x = self.pos_embed(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if self.global_pool:
|
||||
x = torch.mean(x[:, 1:], dim=1)
|
||||
else:
|
||||
x = x[:, 0]
|
||||
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,664 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, Optional, Callable, Union, List, Type, cast
|
||||
|
||||
import torch
|
||||
import nni.nas.nn.pytorch as nn
|
||||
from nni.nas import model_wrapper
|
||||
from nni.typehint import Literal
|
||||
|
||||
from .proxylessnas import ConvBNReLU, InvertedResidual, DepthwiseSeparableConv, make_divisible, reset_parameters
|
||||
from .utils.fixed import FixedFactory
|
||||
from .utils.pretrained import load_pretrained_weight
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
"""Squeeze-and-excite layer.
|
||||
|
||||
We can't use the op from ``torchvision.ops`` because it's not (yet) properly wrapped,
|
||||
and ValueChoice couldn't be processed.
|
||||
|
||||
Reference:
|
||||
|
||||
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L26
|
||||
- https://github.com/d-li14/mobilenetv3.pytorch/blob/3e6938cedcbbc5ee5bc50780ea18e644702d85fc/mobilenetv3.py#L53
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
reduction_ratio: float = 0.25,
|
||||
gate_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
activation_layer: Optional[Callable[..., nn.Module]] = None):
|
||||
super().__init__()
|
||||
|
||||
rd_channels = make_divisible(channels * reduction_ratio, 8)
|
||||
gate_layer = gate_layer or nn.Hardsigmoid
|
||||
activation_layer = activation_layer or nn.ReLU
|
||||
self.conv_reduce = nn.Conv2d(channels, rd_channels, 1, bias=True)
|
||||
self.act1 = activation_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(rd_channels, channels, 1, bias=True)
|
||||
self.gate = gate_layer()
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
return x * self.gate(x_se)
|
||||
|
||||
|
||||
def _se_or_skip(hidden_ch: int, input_ch: int, optional: bool, se_from_exp: bool, label: str) -> nn.Module:
|
||||
ch = hidden_ch if se_from_exp else input_ch
|
||||
if optional:
|
||||
return nn.LayerChoice({
|
||||
'identity': nn.Identity(),
|
||||
'se': SqueezeExcite(ch)
|
||||
}, label=label)
|
||||
else:
|
||||
return SqueezeExcite(ch)
|
||||
|
||||
|
||||
def _act_fn(act_alias: Literal['hswish', 'swish', 'relu']) -> Type[nn.Module]:
|
||||
if act_alias == 'hswish':
|
||||
return nn.Hardswish
|
||||
elif act_alias == 'swish':
|
||||
return nn.SiLU
|
||||
elif act_alias == 'relu':
|
||||
return nn.ReLU
|
||||
else:
|
||||
raise ValueError(f'Unsupported act alias: {act_alias}')
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class MobileNetV3Space(nn.Module):
|
||||
"""
|
||||
MobileNetV3Space implements the largest search space in `TuNAS <https://arxiv.org/abs/2008.06120>`__.
|
||||
|
||||
The search dimensions include widths, expand ratios, kernel sizes, SE ratio.
|
||||
Some of them can be turned off via arguments to narrow down the search space.
|
||||
|
||||
Different from ProxylessNAS search space, this space is implemented with :class:`nn.ValueChoice`.
|
||||
|
||||
We use the following snipppet as reference.
|
||||
https://github.com/google-research/google-research/blob/20736344591f774f4b1570af64624ed1e18d2867/tunas/mobile_search_space_v3.py#L728
|
||||
|
||||
We have ``num_blocks`` which equals to the length of ``self.blocks`` (the main body of the network).
|
||||
For simplicity, the following parameter specification assumes ``num_blocks`` equals 8 (body + head).
|
||||
If a shallower body is intended, arrays including ``base_widths``, ``squeeze_excite``, ``depth_range``,
|
||||
``stride``, ``activation`` should also be shortened accordingly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_labels
|
||||
Dimensions for classification head.
|
||||
base_widths
|
||||
Widths of each stage, from stem, to body, to head.
|
||||
Length should be 9, i.e., ``num_blocks + 1`` (because there is a stem width in front).
|
||||
width_multipliers
|
||||
A range of widths multiplier to choose from. The choice is independent for each stage.
|
||||
Or it can be a fixed float. This will be applied on ``base_widths``,
|
||||
and we would also make sure that widths can be divided by 8.
|
||||
expand_ratios
|
||||
A list of expand ratios to choose from. Independent for every **block**.
|
||||
squeeze_excite
|
||||
Indicating whether the current stage can have an optional SE layer.
|
||||
Expect array of length 6 for stage 0 to 5. Each element can be one of ``force``, ``optional``, ``none``.
|
||||
depth_range
|
||||
A range (e.g., ``(1, 4)``),
|
||||
or a list of range (e.g., ``[(1, 3), (1, 4), (1, 4), (1, 3), (0, 2)]``).
|
||||
If a list, the length should be 5. The depth are specified for stage 1 to 5.
|
||||
stride
|
||||
Stride for all stages (including stem and head). Length should be same as ``base_widths``.
|
||||
activation
|
||||
Activation (class) for all stages. Length is same as ``base_widths``.
|
||||
se_from_exp
|
||||
Calculate SE channel reduction from expanded (mid) channels.
|
||||
dropout_rate
|
||||
Dropout rate at classification head.
|
||||
bn_eps
|
||||
Epsilon of batch normalization.
|
||||
bn_momentum
|
||||
Momentum of batch normalization.
|
||||
"""
|
||||
|
||||
widths: List[Union[nn.ChoiceOf[int], int]]
|
||||
depth_range: List[Tuple[int, int]]
|
||||
|
||||
def __init__(
|
||||
self, num_labels: int = 1000,
|
||||
base_widths: Tuple[int, ...] = (16, 16, 16, 32, 64, 128, 256, 512, 1024),
|
||||
width_multipliers: Union[Tuple[float, ...], float] = (0.5, 0.625, 0.75, 1.0, 1.25, 1.5, 2.0),
|
||||
expand_ratios: Tuple[float, ...] = (1., 2., 3., 4., 5., 6.),
|
||||
squeeze_excite: Tuple[Literal['force', 'optional', 'none'], ...] = (
|
||||
'none', 'none', 'optional', 'none', 'optional', 'optional'
|
||||
),
|
||||
depth_range: Union[List[Tuple[int, int]], Tuple[int, int]] = (1, 4),
|
||||
stride: Tuple[int, ...] = (2, 1, 2, 2, 2, 1, 2, 1, 1),
|
||||
activation: Tuple[Literal['hswish', 'swish', 'relu'], ...] = (
|
||||
'hswish', 'relu', 'relu', 'relu', 'hswish', 'hswish', 'hswish', 'hswish', 'hswish'
|
||||
),
|
||||
se_from_exp: bool = True,
|
||||
dropout_rate: float = 0.2,
|
||||
bn_eps: float = 1e-3,
|
||||
bn_momentum: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_blocks = len(base_widths) - 1 # without stem, equal to len(self.blocks)
|
||||
assert self.num_blocks >= 4
|
||||
|
||||
assert len(base_widths) == len(stride) == len(activation) == self.num_blocks + 1
|
||||
|
||||
# The final two blocks can't have SE
|
||||
assert len(squeeze_excite) == self.num_blocks - 2 and all(se in ['force', 'optional', 'none'] for se in squeeze_excite)
|
||||
|
||||
# The first and final two blocks can't have variational depth
|
||||
if isinstance(depth_range[0], int):
|
||||
depth_range = cast(Tuple[int, int], depth_range)
|
||||
assert len(depth_range) == 2 and depth_range[1] >= depth_range[0] >= 1
|
||||
self.depth_range = [depth_range] * (self.num_blocks - 3)
|
||||
else:
|
||||
assert len(depth_range) == self.num_blocks - 3
|
||||
self.depth_range = cast(List[Tuple[int, int]], depth_range)
|
||||
for d in self.depth_range:
|
||||
d = cast(Tuple[int, int], d)
|
||||
# pylint: disable=unsubscriptable-object
|
||||
assert len(d) == 2 and d[1] >= d[0] >= 1, f'{d} does not satisfy depth constraints'
|
||||
|
||||
self.widths = []
|
||||
for i, base_width in enumerate(base_widths):
|
||||
if isinstance(width_multipliers, float):
|
||||
self.widths.append(make_divisible(base_width * width_multipliers, 8))
|
||||
else:
|
||||
self.widths.append(
|
||||
# According to tunas, stem and stage 0 share one width multiplier
|
||||
# https://github.com/google-research/google-research/blob/20736344/tunas/mobile_search_space_v3.py#L791
|
||||
make_divisible(
|
||||
nn.ValueChoice(list(width_multipliers), label=f's{max(i - 1, 0)}_width_mult') * base_width, 8
|
||||
)
|
||||
)
|
||||
|
||||
self.expand_ratios = expand_ratios
|
||||
self.se_from_exp = se_from_exp
|
||||
|
||||
# NOTE: The built-in hardswish produces slightly different output from 3rd-party implementation
|
||||
# But I guess it doesn't really matter.
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/layers/activations.py#L79
|
||||
|
||||
self.stem = ConvBNReLU(
|
||||
3, self.widths[0],
|
||||
nn.ValueChoice([3, 5], label=f'stem_ks'),
|
||||
stride=stride[0], activation_layer=_act_fn(activation[0])
|
||||
)
|
||||
|
||||
blocks: List[nn.Module] = [
|
||||
# Stage 0
|
||||
# FIXME: this should be an optional layer.
|
||||
# https://github.com/google-research/google-research/blob/20736344/tunas/mobile_search_space_v3.py#L791
|
||||
DepthwiseSeparableConv(
|
||||
self.widths[0], self.widths[1],
|
||||
nn.ValueChoice([3, 5, 7], label=f's0_i0_ks'),
|
||||
stride=stride[1],
|
||||
squeeze_excite=cast(Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module], partial(
|
||||
_se_or_skip, optional=squeeze_excite[0] == 'optional', se_from_exp=self.se_from_exp, label=f's0_i0_se'
|
||||
)) if squeeze_excite[0] != 'none' else None,
|
||||
activation_layer=_act_fn(activation[1])
|
||||
),
|
||||
]
|
||||
|
||||
blocks += [
|
||||
# Stage 1-5 (by default)
|
||||
self._make_stage(i, self.widths[i], self.widths[i + 1], squeeze_excite[i], stride[i + 1], _act_fn(activation[i + 1]))
|
||||
for i in range(1, self.num_blocks - 2)
|
||||
]
|
||||
|
||||
# Head
|
||||
blocks += [
|
||||
ConvBNReLU(
|
||||
self.widths[self.num_blocks - 2],
|
||||
self.widths[self.num_blocks - 1],
|
||||
kernel_size=1,
|
||||
stride=stride[self.num_blocks - 1],
|
||||
activation_layer=_act_fn(activation[self.num_blocks - 1])
|
||||
),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
|
||||
# In some implementation, this is a linear instead.
|
||||
# Should be equivalent.
|
||||
ConvBNReLU(
|
||||
self.widths[self.num_blocks - 1],
|
||||
self.widths[self.num_blocks],
|
||||
kernel_size=1,
|
||||
stride=stride[self.num_blocks],
|
||||
norm_layer=nn.Identity,
|
||||
activation_layer=_act_fn(activation[self.num_blocks])
|
||||
)
|
||||
]
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(cast(int, self.widths[self.num_blocks]), num_labels),
|
||||
)
|
||||
|
||||
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def _make_stage(self, stage_idx, inp, oup, se, stride, act):
|
||||
def layer_builder(idx):
|
||||
exp = nn.ValueChoice(list(self.expand_ratios), label=f's{stage_idx}_i{idx}_exp')
|
||||
ks = nn.ValueChoice([3, 5, 7], label=f's{stage_idx}_i{idx}_ks')
|
||||
# if SE is true, assign a layer choice to SE
|
||||
se_or_skip = cast(Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module], partial(
|
||||
_se_or_skip, optional=se == 'optional', se_from_exp=self.se_from_exp, label=f's{stage_idx}_i{idx}_se'
|
||||
)) if se != 'none' else None
|
||||
return InvertedResidual(
|
||||
inp if idx == 0 else oup,
|
||||
oup, exp, ks,
|
||||
stride=stride if idx == 0 else 1, # only the first layer in each stage can have stride > 1
|
||||
squeeze_excite=se_or_skip,
|
||||
activation_layer=act,
|
||||
)
|
||||
|
||||
# mutable depth
|
||||
min_depth, max_depth = self.depth_range[stage_idx - 1]
|
||||
if stride != 1:
|
||||
min_depth = max(min_depth, 1)
|
||||
return nn.Repeat(layer_builder, depth=(min_depth, max_depth), label=f's{stage_idx}_depth')
|
||||
|
||||
@classmethod
|
||||
def fixed_arch(cls, arch: dict) -> FixedFactory:
|
||||
return FixedFactory(cls, arch)
|
||||
|
||||
@classmethod
|
||||
def load_searched_model(
|
||||
cls, name: str,
|
||||
pretrained: bool = False, download: bool = False, progress: bool = True
|
||||
) -> nn.Module:
|
||||
|
||||
init_kwargs = {} # all default
|
||||
|
||||
if name == 'mobilenetv3-large-100':
|
||||
# NOTE: Use bicsubic interpolation to evaluate this
|
||||
# With default interpolation, it yields top-1 75.722
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 2,
|
||||
's1_i0_exp': 4,
|
||||
's1_i0_ks': 3,
|
||||
's1_i1_exp': 3,
|
||||
's1_i1_ks': 3,
|
||||
's2_depth': 3,
|
||||
's2_i0_exp': 3,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 3,
|
||||
's2_i1_ks': 5,
|
||||
's2_i2_exp': 3,
|
||||
's2_i2_ks': 5,
|
||||
's3_depth': 4,
|
||||
's3_i0_exp': 6,
|
||||
's3_i0_ks': 3,
|
||||
's3_i1_exp': 2.5,
|
||||
's3_i1_ks': 3,
|
||||
's3_i2_exp': 2.3,
|
||||
's3_i2_ks': 3,
|
||||
's3_i3_exp': 2.3,
|
||||
's3_i3_ks': 3,
|
||||
's4_depth': 2,
|
||||
's4_i0_exp': 6,
|
||||
's4_i0_ks': 3,
|
||||
's4_i1_exp': 6,
|
||||
's4_i1_ks': 3,
|
||||
's5_depth': 3,
|
||||
's5_i0_exp': 6,
|
||||
's5_i0_ks': 5,
|
||||
's5_i1_exp': 6,
|
||||
's5_i1_ks': 5,
|
||||
's5_i2_exp': 6,
|
||||
's5_i2_ks': 5,
|
||||
}
|
||||
|
||||
init_kwargs.update(
|
||||
base_widths=[16, 16, 24, 40, 80, 112, 160, 960, 1280],
|
||||
expand_ratios=[1.0, 2.0, 2.3, 2.5, 3.0, 4.0, 6.0],
|
||||
bn_eps=1e-5,
|
||||
bn_momentum=0.1,
|
||||
width_multipliers=1.0,
|
||||
squeeze_excite=['none', 'none', 'force', 'none', 'force', 'force']
|
||||
)
|
||||
|
||||
elif name.startswith('mobilenetv3-small-'):
|
||||
# Evaluate with bicubic interpolation
|
||||
multiplier = int(name.split('-')[-1]) / 100
|
||||
widths = [16, 16, 24, 40, 48, 96, 576, 1024]
|
||||
for i in range(7):
|
||||
if i > 0 or multiplier >= 0.75:
|
||||
# fix_stem = True when multiplier < 0.75
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/mobilenetv3.py#L421
|
||||
widths[i] = make_divisible(widths[i] * multiplier, 8)
|
||||
init_kwargs.update(
|
||||
base_widths=widths,
|
||||
width_multipliers=1.0,
|
||||
expand_ratios=[3.0, 3.67, 4.0, 4.5, 6.0],
|
||||
bn_eps=1e-05,
|
||||
bn_momentum=0.1,
|
||||
squeeze_excite=['force', 'none', 'force', 'force', 'force'],
|
||||
activation=['hswish', 'relu', 'relu', 'hswish', 'hswish', 'hswish', 'hswish', 'hswish'],
|
||||
stride=[2, 2, 2, 2, 1, 2, 1, 1],
|
||||
depth_range=(1, 2),
|
||||
)
|
||||
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 2,
|
||||
's1_i0_exp': 4.5,
|
||||
's1_i0_ks': 3,
|
||||
's1_i1_exp': 3.67,
|
||||
's1_i1_ks': 3,
|
||||
's2_depth': 3,
|
||||
's2_i0_exp': 4.0,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 6.0,
|
||||
's2_i1_ks': 5,
|
||||
's2_i2_exp': 6.0,
|
||||
's2_i2_ks': 5,
|
||||
's3_depth': 2,
|
||||
's3_i0_exp': 3.0,
|
||||
's3_i0_ks': 5,
|
||||
's3_i1_exp': 3.0,
|
||||
's3_i1_ks': 5,
|
||||
's4_depth': 3,
|
||||
's4_i0_exp': 6.0,
|
||||
's4_i0_ks': 5,
|
||||
's4_i1_exp': 6.0,
|
||||
's4_i1_ks': 5,
|
||||
's4_i2_exp': 6.0,
|
||||
's4_i2_ks': 5
|
||||
}
|
||||
|
||||
elif name.startswith('cream'):
|
||||
# https://github.com/microsoft/Cream/tree/main/Cream
|
||||
# bilinear interpolation
|
||||
|
||||
level = name.split('-')[-1]
|
||||
|
||||
# region cream arch specification
|
||||
if level == '014':
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_depth': 1,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 1,
|
||||
's1_i0_exp': 4.0,
|
||||
's1_i0_ks': 3,
|
||||
's2_depth': 2,
|
||||
's2_i0_exp': 6.0,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 6.0,
|
||||
's2_i1_ks': 5,
|
||||
's3_depth': 2,
|
||||
's3_i0_exp': 6.0,
|
||||
's3_i0_ks': 5,
|
||||
's3_i1_exp': 6.0,
|
||||
's3_i1_ks': 5,
|
||||
's4_depth': 1,
|
||||
's4_i0_exp': 6.0,
|
||||
's4_i0_ks': 3,
|
||||
's5_depth': 1,
|
||||
's5_i0_exp': 6.0,
|
||||
's5_i0_ks': 5
|
||||
}
|
||||
elif level == '043':
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_depth': 1,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 1,
|
||||
's1_i0_exp': 4.0,
|
||||
's1_i0_ks': 3,
|
||||
's2_depth': 2,
|
||||
's2_i0_exp': 6.0,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 6.0,
|
||||
's2_i1_ks': 3,
|
||||
's3_depth': 2,
|
||||
's3_i0_exp': 6.0,
|
||||
's3_i0_ks': 5,
|
||||
's3_i1_exp': 6.0,
|
||||
's3_i1_ks': 3,
|
||||
's4_depth': 3,
|
||||
's4_i0_exp': 6.0,
|
||||
's4_i0_ks': 5,
|
||||
's4_i1_exp': 6.0,
|
||||
's4_i1_ks': 5,
|
||||
's4_i2_exp': 6.0,
|
||||
's4_i2_ks': 5,
|
||||
's5_depth': 2,
|
||||
's5_i0_exp': 6.0,
|
||||
's5_i0_ks': 5,
|
||||
's5_i1_exp': 6.0,
|
||||
's5_i1_ks': 5
|
||||
}
|
||||
elif level == '114':
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_depth': 1,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 1,
|
||||
's1_i0_exp': 4.0,
|
||||
's1_i0_ks': 3,
|
||||
's2_depth': 2,
|
||||
's2_i0_exp': 6.0,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 6.0,
|
||||
's2_i1_ks': 5,
|
||||
's3_depth': 2,
|
||||
's3_i0_exp': 6.0,
|
||||
's3_i0_ks': 5,
|
||||
's3_i1_exp': 6.0,
|
||||
's3_i1_ks': 5,
|
||||
's4_depth': 3,
|
||||
's4_i0_exp': 6.0,
|
||||
's4_i0_ks': 5,
|
||||
's4_i1_exp': 6.0,
|
||||
's4_i1_ks': 5,
|
||||
's4_i2_exp': 6.0,
|
||||
's4_i2_ks': 5,
|
||||
's5_depth': 2,
|
||||
's5_i0_exp': 6.0,
|
||||
's5_i0_ks': 5,
|
||||
's5_i1_exp': 6.0,
|
||||
's5_i1_ks': 5
|
||||
}
|
||||
elif level == '287':
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_depth': 1,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 1,
|
||||
's1_i0_exp': 4.0,
|
||||
's1_i0_ks': 3,
|
||||
's2_depth': 2,
|
||||
's2_i0_exp': 6.0,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 6.0,
|
||||
's2_i1_ks': 5,
|
||||
's3_depth': 3,
|
||||
's3_i0_exp': 6.0,
|
||||
's3_i0_ks': 5,
|
||||
's3_i1_exp': 6.0,
|
||||
's3_i1_ks': 3,
|
||||
's3_i2_exp': 6.0,
|
||||
's3_i2_ks': 5,
|
||||
's4_depth': 4,
|
||||
's4_i0_exp': 6.0,
|
||||
's4_i0_ks': 5,
|
||||
's4_i1_exp': 6.0,
|
||||
's4_i1_ks': 5,
|
||||
's4_i2_exp': 6.0,
|
||||
's4_i2_ks': 5,
|
||||
's4_i3_exp': 6.0,
|
||||
's4_i3_ks': 5,
|
||||
's5_depth': 3,
|
||||
's5_i0_exp': 6.0,
|
||||
's5_i0_ks': 5,
|
||||
's5_i1_exp': 6.0,
|
||||
's5_i1_ks': 5,
|
||||
's5_i2_exp': 6.0,
|
||||
's5_i2_ks': 5
|
||||
}
|
||||
elif level == '481':
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_depth': 1,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 4,
|
||||
's1_i0_exp': 6.0,
|
||||
's1_i0_ks': 5,
|
||||
's1_i1_exp': 4.0,
|
||||
's1_i1_ks': 7,
|
||||
's1_i2_exp': 6.0,
|
||||
's1_i2_ks': 5,
|
||||
's1_i3_exp': 6.0,
|
||||
's1_i3_ks': 3,
|
||||
's2_depth': 4,
|
||||
's2_i0_exp': 6.0,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 4.0,
|
||||
's2_i1_ks': 5,
|
||||
's2_i2_exp': 6.0,
|
||||
's2_i2_ks': 5,
|
||||
's2_i3_exp': 4.0,
|
||||
's2_i3_ks': 3,
|
||||
's3_depth': 5,
|
||||
's3_i0_exp': 6.0,
|
||||
's3_i0_ks': 5,
|
||||
's3_i1_exp': 6.0,
|
||||
's3_i1_ks': 5,
|
||||
's3_i2_exp': 6.0,
|
||||
's3_i2_ks': 5,
|
||||
's3_i3_exp': 6.0,
|
||||
's3_i3_ks': 3,
|
||||
's3_i4_exp': 6.0,
|
||||
's3_i4_ks': 3,
|
||||
's4_depth': 4,
|
||||
's4_i0_exp': 6.0,
|
||||
's4_i0_ks': 5,
|
||||
's4_i1_exp': 6.0,
|
||||
's4_i1_ks': 5,
|
||||
's4_i2_exp': 6.0,
|
||||
's4_i2_ks': 5,
|
||||
's4_i3_exp': 6.0,
|
||||
's4_i3_ks': 5,
|
||||
's5_depth': 4,
|
||||
's5_i0_exp': 6.0,
|
||||
's5_i0_ks': 5,
|
||||
's5_i1_exp': 6.0,
|
||||
's5_i1_ks': 5,
|
||||
's5_i2_exp': 6.0,
|
||||
's5_i2_ks': 5,
|
||||
's5_i3_exp': 6.0,
|
||||
's5_i3_ks': 5
|
||||
}
|
||||
elif level == '604':
|
||||
arch = {
|
||||
'stem_ks': 3,
|
||||
's0_depth': 1,
|
||||
's0_i0_ks': 3,
|
||||
's1_depth': 5,
|
||||
's1_i0_exp': 6.0,
|
||||
's1_i0_ks': 5,
|
||||
's1_i1_exp': 6.0,
|
||||
's1_i1_ks': 5,
|
||||
's1_i2_exp': 4.0,
|
||||
's1_i2_ks': 5,
|
||||
's1_i3_exp': 6.0,
|
||||
's1_i3_ks': 5,
|
||||
's1_i4_exp': 6.0,
|
||||
's1_i4_ks': 5,
|
||||
's2_depth': 5,
|
||||
's2_i0_exp': 6.0,
|
||||
's2_i0_ks': 5,
|
||||
's2_i1_exp': 4.0,
|
||||
's2_i1_ks': 5,
|
||||
's2_i2_exp': 6.0,
|
||||
's2_i2_ks': 5,
|
||||
's2_i3_exp': 4.0,
|
||||
's2_i3_ks': 5,
|
||||
's2_i4_exp': 6.0,
|
||||
's2_i4_ks': 5,
|
||||
's3_depth': 5,
|
||||
's3_i0_exp': 6.0,
|
||||
's3_i0_ks': 5,
|
||||
's3_i1_exp': 4.0,
|
||||
's3_i1_ks': 5,
|
||||
's3_i2_exp': 6.0,
|
||||
's3_i2_ks': 5,
|
||||
's3_i3_exp': 4.0,
|
||||
's3_i3_ks': 5,
|
||||
's3_i4_exp': 6.0,
|
||||
's3_i4_ks': 5,
|
||||
's4_depth': 6,
|
||||
's4_i0_exp': 6.0,
|
||||
's4_i0_ks': 5,
|
||||
's4_i1_exp': 6.0,
|
||||
's4_i1_ks': 5,
|
||||
's4_i2_exp': 4.0,
|
||||
's4_i2_ks': 5,
|
||||
's4_i3_exp': 4.0,
|
||||
's4_i3_ks': 5,
|
||||
's4_i4_exp': 6.0,
|
||||
's4_i4_ks': 5,
|
||||
's4_i5_exp': 6.0,
|
||||
's4_i5_ks': 5,
|
||||
's5_depth': 6,
|
||||
's5_i0_exp': 6.0,
|
||||
's5_i0_ks': 5,
|
||||
's5_i1_exp': 6.0,
|
||||
's5_i1_ks': 5,
|
||||
's5_i2_exp': 4.0,
|
||||
's5_i2_ks': 5,
|
||||
's5_i3_exp': 6.0,
|
||||
's5_i3_ks': 5,
|
||||
's5_i4_exp': 6.0,
|
||||
's5_i4_ks': 5,
|
||||
's5_i5_exp': 6.0,
|
||||
's5_i5_ks': 5
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Unsupported cream model level: {level}')
|
||||
# endregion
|
||||
|
||||
init_kwargs.update(
|
||||
base_widths=[16, 16, 24, 40, 80, 96, 192, 320, 1280],
|
||||
width_multipliers=1.0,
|
||||
expand_ratios=[4.0, 6.0],
|
||||
bn_eps=1e-5,
|
||||
bn_momentum=0.1,
|
||||
squeeze_excite=['force'] * 6,
|
||||
activation=['swish'] * 9
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unsupported architecture with name: {name}')
|
||||
|
||||
model_factory = cls.fixed_arch(arch)
|
||||
model = model_factory(**init_kwargs)
|
||||
|
||||
if pretrained:
|
||||
weight_file = load_pretrained_weight(name, download=download, progress=progress)
|
||||
pretrained_weights = torch.load(weight_file)
|
||||
model.load_state_dict(pretrained_weights)
|
||||
|
||||
return model
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Famous building blocks of search spaces."""
|
||||
|
||||
from .autoactivation import *
|
||||
from .nasbench101 import *
|
||||
from .nasbench201 import *
|
|
@ -0,0 +1,259 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.nas.utils import basic_unit
|
||||
|
||||
from nni.nas.nn.pytorch import LayerChoice
|
||||
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
|
||||
|
||||
__all__ = ['AutoActivation']
|
||||
|
||||
TorchVersion = '1.5.0'
|
||||
|
||||
# ============== unary function modules ==============
|
||||
|
||||
@basic_unit
|
||||
class UnaryIdentity(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
@basic_unit
|
||||
class UnaryNegative(nn.Module):
|
||||
def forward(self, x):
|
||||
return -x
|
||||
|
||||
@basic_unit
|
||||
class UnaryAbs(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.abs(x)
|
||||
|
||||
@basic_unit
|
||||
class UnarySquare(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.square(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryPow(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.pow(x, 3)
|
||||
|
||||
@basic_unit
|
||||
class UnarySqrt(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryMul(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# element-wise for now, will change to per-channel trainable parameter
|
||||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
|
||||
def forward(self, x):
|
||||
return x * self.beta
|
||||
|
||||
@basic_unit
|
||||
class UnaryAdd(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# element-wise for now, will change to per-channel trainable parameter
|
||||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
|
||||
def forward(self, x):
|
||||
return x + self.beta
|
||||
|
||||
@basic_unit
|
||||
class UnaryLogAbs(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.log(torch.abs(x) + 1e-7)
|
||||
|
||||
@basic_unit
|
||||
class UnaryExp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.exp(x)
|
||||
|
||||
@basic_unit
|
||||
class UnarySin(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sin(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryCos(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.cos(x)
|
||||
|
||||
@basic_unit
|
||||
class UnarySinh(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sinh(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryCosh(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.cosh(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryTanh(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x)
|
||||
|
||||
if not Version(torch.__version__) >= Version(TorchVersion):
|
||||
@basic_unit
|
||||
class UnaryAsinh(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.asinh(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryAtan(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.atan(x)
|
||||
|
||||
if not Version(torch.__version__) >= Version(TorchVersion):
|
||||
@basic_unit
|
||||
class UnarySinc(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sinc(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryMax(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.max(x, torch.zeros_like(x))
|
||||
|
||||
@basic_unit
|
||||
class UnaryMin(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.min(x, torch.zeros_like(x))
|
||||
|
||||
@basic_unit
|
||||
class UnarySigmoid(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sigmoid(x)
|
||||
|
||||
@basic_unit
|
||||
class UnaryLogExp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.log(1 + torch.exp(x))
|
||||
|
||||
@basic_unit
|
||||
class UnaryExpSquare(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.exp(-torch.square(x))
|
||||
|
||||
@basic_unit
|
||||
class UnaryErf(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.erf(x)
|
||||
|
||||
unary_modules = ['UnaryIdentity', 'UnaryNegative', 'UnaryAbs', 'UnarySquare', 'UnaryPow',
|
||||
'UnarySqrt', 'UnaryMul', 'UnaryAdd', 'UnaryLogAbs', 'UnaryExp', 'UnarySin', 'UnaryCos',
|
||||
'UnarySinh', 'UnaryCosh', 'UnaryTanh', 'UnaryAtan', 'UnaryMax',
|
||||
'UnaryMin', 'UnarySigmoid', 'UnaryLogExp', 'UnaryExpSquare', 'UnaryErf']
|
||||
|
||||
if not Version(torch.__version__) >= Version(TorchVersion):
|
||||
unary_modules.append('UnaryAsinh')
|
||||
unary_modules.append('UnarySinc')
|
||||
|
||||
# ============== binary function modules ==============
|
||||
|
||||
@basic_unit
|
||||
class BinaryAdd(nn.Module):
|
||||
def forward(self, x):
|
||||
return x[0] + x[1]
|
||||
|
||||
@basic_unit
|
||||
class BinaryMul(nn.Module):
|
||||
def forward(self, x):
|
||||
return x[0] * x[1]
|
||||
|
||||
@basic_unit
|
||||
class BinaryMinus(nn.Module):
|
||||
def forward(self, x):
|
||||
return x[0] - x[1]
|
||||
|
||||
@basic_unit
|
||||
class BinaryDivide(nn.Module):
|
||||
def forward(self, x):
|
||||
return x[0] / (x[1] + 1e-7)
|
||||
|
||||
@basic_unit
|
||||
class BinaryMax(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.max(x[0], x[1])
|
||||
|
||||
@basic_unit
|
||||
class BinaryMin(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.min(x[0], x[1])
|
||||
|
||||
@basic_unit
|
||||
class BinarySigmoid(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.sigmoid(x[0]) * x[1]
|
||||
|
||||
@basic_unit
|
||||
class BinaryExpSquare(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
|
||||
def forward(self, x):
|
||||
return torch.exp(-self.beta * torch.square(x[0] - x[1]))
|
||||
|
||||
@basic_unit
|
||||
class BinaryExpAbs(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
|
||||
def forward(self, x):
|
||||
return torch.exp(-self.beta * torch.abs(x[0] - x[1]))
|
||||
|
||||
@basic_unit
|
||||
class BinaryParamAdd(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
|
||||
def forward(self, x):
|
||||
return self.beta * x[0] + (1 - self.beta) * x[1]
|
||||
|
||||
binary_modules = ['BinaryAdd', 'BinaryMul', 'BinaryMinus', 'BinaryDivide', 'BinaryMax',
|
||||
'BinaryMin', 'BinarySigmoid', 'BinaryExpSquare', 'BinaryExpAbs', 'BinaryParamAdd']
|
||||
|
||||
|
||||
class AutoActivation(nn.Module):
|
||||
"""
|
||||
This module is an implementation of the paper `Searching for Activation Functions <https://arxiv.org/abs/1710.05941>`__.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
unit_num : int
|
||||
the number of core units
|
||||
|
||||
Notes
|
||||
-----
|
||||
Current `beta` is not per-channel parameter.
|
||||
"""
|
||||
def __init__(self, unit_num: int = 1, label: str | None = None):
|
||||
super().__init__()
|
||||
self._label = generate_new_label(label)
|
||||
self.unaries = nn.ModuleList()
|
||||
self.binaries = nn.ModuleList()
|
||||
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_0')
|
||||
for i in range(unit_num):
|
||||
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_{i+1}')
|
||||
self.unaries.append(one_unary)
|
||||
for i in range(unit_num):
|
||||
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules], label = f'{self.label}__binary_{i}')
|
||||
self.binaries.append(one_binary)
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._label
|
||||
|
||||
def forward(self, x):
|
||||
out = self.first_unary(x)
|
||||
for unary, binary in zip(self.unaries, self.binaries):
|
||||
out = binary(torch.stack([out, unary(x)]))
|
||||
return out
|
|
@ -0,0 +1,418 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = ['NasBench101Cell', 'NasBench101Mutator']
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, List, Optional, Union, Dict, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.nas.mutable import InvalidMutation, Mutator
|
||||
from nni.nas.execution.common import Model
|
||||
from nni.nas.nn.pytorch import InputChoice, ValueChoice, LayerChoice
|
||||
from nni.nas.nn.pytorch.mutation_utils import Mutable, generate_new_label, get_fixed_dict
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_vertex_channels(input_channels, output_channels, matrix):
|
||||
"""
|
||||
This is (almost) copied from the original NAS-Bench-101 implementation.
|
||||
|
||||
Computes the number of channels at every vertex.
|
||||
|
||||
Given the input channels and output channels, this calculates the number of channels at each interior vertex.
|
||||
Interior vertices have the same number of channels as the max of the channels of the vertices it feeds into.
|
||||
The output channels are divided amongst the vertices that are directly connected to it.
|
||||
When the division is not even, some vertices may receive an extra channel to compensate.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
input channels count.
|
||||
output_channels : int
|
||||
output channel count.
|
||||
matrix : np.ndarray
|
||||
adjacency matrix for the module (pruned by model_spec).
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of int
|
||||
list of channel counts, in order of the vertices.
|
||||
"""
|
||||
|
||||
num_vertices = np.shape(matrix)[0]
|
||||
|
||||
vertex_channels = [0] * num_vertices
|
||||
vertex_channels[0] = input_channels
|
||||
vertex_channels[num_vertices - 1] = output_channels
|
||||
|
||||
if num_vertices == 2:
|
||||
# Edge case where module only has input and output vertices
|
||||
return vertex_channels
|
||||
|
||||
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
|
||||
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
|
||||
in_degree = np.sum(matrix[1:], axis=0)
|
||||
interior_channels = output_channels // in_degree[num_vertices - 1]
|
||||
correction = output_channels % in_degree[num_vertices - 1] # Remainder to add
|
||||
|
||||
# Set channels of vertices that flow directly to output
|
||||
for v in range(1, num_vertices - 1):
|
||||
if matrix[v, num_vertices - 1]:
|
||||
vertex_channels[v] = interior_channels
|
||||
if correction:
|
||||
vertex_channels[v] += 1
|
||||
correction -= 1
|
||||
|
||||
# Set channels for all other vertices to the max of the out edges, going backwards.
|
||||
# (num_vertices - 2) index skipped because it only connects to output.
|
||||
for v in range(num_vertices - 3, 0, -1):
|
||||
if not matrix[v, num_vertices - 1]:
|
||||
for dst in range(v + 1, num_vertices - 1):
|
||||
if matrix[v, dst]:
|
||||
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
|
||||
assert vertex_channels[v] > 0
|
||||
|
||||
_logger.debug('vertex_channels: %s', str(vertex_channels))
|
||||
|
||||
# Sanity check, verify that channels never increase and final channels add up.
|
||||
final_fan_in = 0
|
||||
for v in range(1, num_vertices - 1):
|
||||
if matrix[v, num_vertices - 1]:
|
||||
final_fan_in += vertex_channels[v]
|
||||
for dst in range(v + 1, num_vertices - 1):
|
||||
if matrix[v, dst]:
|
||||
assert vertex_channels[v] >= vertex_channels[dst]
|
||||
assert final_fan_in == output_channels or num_vertices == 2
|
||||
# num_vertices == 2 means only input/output nodes, so 0 fan-in
|
||||
|
||||
return vertex_channels
|
||||
|
||||
|
||||
def prune(matrix, ops) -> Tuple[np.ndarray, List[Union[str, Callable[[int], nn.Module]]]]:
|
||||
"""
|
||||
Prune the extraneous parts of the graph.
|
||||
|
||||
General procedure:
|
||||
|
||||
1. Remove parts of graph not connected to input.
|
||||
2. Remove parts of graph not connected to output.
|
||||
3. Reorder the vertices so that they are consecutive after steps 1 and 2.
|
||||
|
||||
These 3 steps can be combined by deleting the rows and columns of the
|
||||
vertices that are not reachable from both the input and output (in reverse).
|
||||
"""
|
||||
num_vertices = np.shape(matrix)[0]
|
||||
|
||||
# calculate the connection matrix within V number of steps.
|
||||
connections = np.linalg.matrix_power(matrix + np.eye(num_vertices), num_vertices)
|
||||
|
||||
visited_from_input = set([i for i in range(num_vertices) if connections[0, i]])
|
||||
visited_from_output = set([i for i in range(num_vertices) if connections[i, -1]])
|
||||
|
||||
# Any vertex that isn't connected to both input and output is extraneous to the computation graph.
|
||||
extraneous = set(range(num_vertices)).difference(
|
||||
visited_from_input.intersection(visited_from_output))
|
||||
|
||||
if len(extraneous) > num_vertices - 2:
|
||||
raise InvalidMutation('Non-extraneous graph is less than 2 vertices, '
|
||||
'the input is not connected to the output and the spec is invalid.')
|
||||
|
||||
matrix = np.delete(matrix, list(extraneous), axis=0)
|
||||
matrix = np.delete(matrix, list(extraneous), axis=1)
|
||||
for index in sorted(extraneous, reverse=True):
|
||||
del ops[index]
|
||||
return matrix, ops
|
||||
|
||||
|
||||
def truncate(inputs, channels):
|
||||
input_channels = inputs.size(1)
|
||||
if input_channels < channels:
|
||||
raise ValueError('input channel < output channels for truncate')
|
||||
elif input_channels == channels:
|
||||
return inputs # No truncation necessary
|
||||
else:
|
||||
# Truncation should only be necessary when channel division leads to
|
||||
# vertices with +1 channels. The input vertex should always be projected to
|
||||
# the minimum channel count.
|
||||
assert input_channels - channels == 1
|
||||
return inputs[:, :channels]
|
||||
|
||||
|
||||
class _NasBench101CellFixed(nn.Module):
|
||||
"""
|
||||
The fixed version of NAS-Bench-101 Cell, used in python-version execution engine.
|
||||
"""
|
||||
|
||||
def __init__(self, operations: List[Callable[[int], nn.Module]],
|
||||
adjacency_list: List[List[int]],
|
||||
in_features: int, out_features: int, num_nodes: int,
|
||||
projection: Callable[[int, int], nn.Module]):
|
||||
super().__init__()
|
||||
|
||||
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
|
||||
|
||||
raw_operations: List[Union[str, Callable[[int], nn.Module]]] = list(operations)
|
||||
del operations # operations is no longer needed. Delete it to avoid misuse
|
||||
|
||||
# add psuedo nodes
|
||||
raw_operations.insert(0, 'IN')
|
||||
raw_operations.append('OUT')
|
||||
|
||||
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
|
||||
del num_nodes # raw number of nodes is no longer used
|
||||
|
||||
self.connection_matrix, self.operations = prune(self.connection_matrix, raw_operations)
|
||||
|
||||
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
|
||||
|
||||
self.num_nodes = len(self.connection_matrix)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
_logger.info('Prund number of nodes: %d', self.num_nodes)
|
||||
_logger.info('Pruned connection matrix: %s', str(self.connection_matrix))
|
||||
|
||||
self.projections = nn.ModuleList([nn.Identity()])
|
||||
self.ops = nn.ModuleList([nn.Identity()])
|
||||
for i in range(1, self.num_nodes):
|
||||
self.projections.append(projection(in_features, self.hidden_features[i]))
|
||||
|
||||
for i in range(1, self.num_nodes - 1):
|
||||
operation = cast(Callable[[int], nn.Module], self.operations[i])
|
||||
self.ops.append(operation(self.hidden_features[i]))
|
||||
|
||||
@staticmethod
|
||||
def build_connection_matrix(adjacency_list, num_nodes):
|
||||
adjacency_list = [[]] + adjacency_list # add adjacency for first node
|
||||
connections = np.zeros((num_nodes, num_nodes), dtype='int')
|
||||
for i, lst in enumerate(adjacency_list):
|
||||
assert all([0 <= k < i for k in lst])
|
||||
for k in lst:
|
||||
connections[k, i] = 1
|
||||
return connections
|
||||
|
||||
def forward(self, inputs):
|
||||
tensors = [inputs]
|
||||
for t in range(1, self.num_nodes - 1):
|
||||
|
||||
# Create interior connections, truncating if necessary
|
||||
add_in = [truncate(tensors[src], self.hidden_features[t])
|
||||
for src in range(1, t) if self.connection_matrix[src, t]]
|
||||
|
||||
# Create add connection from projected input
|
||||
if self.connection_matrix[0, t]:
|
||||
add_in.append(self.projections[t](tensors[0]))
|
||||
|
||||
if len(add_in) == 1:
|
||||
vertex_input = add_in[0]
|
||||
else:
|
||||
vertex_input = sum(add_in)
|
||||
|
||||
# Perform op at vertex t
|
||||
vertex_out = self.ops[t](vertex_input)
|
||||
tensors.append(vertex_out)
|
||||
|
||||
# Construct final output tensor by concating all fan-in and adding input.
|
||||
if np.sum(self.connection_matrix[:, -1]) == 1:
|
||||
src = np.where(self.connection_matrix[:, -1] == 1)[0][0]
|
||||
return self.projections[-1](tensors[0]) if src == 0 else tensors[src]
|
||||
|
||||
outputs = torch.cat([tensors[src] for src in range(1, self.num_nodes - 1) if self.connection_matrix[src, -1]], 1)
|
||||
if self.connection_matrix[0, -1]:
|
||||
outputs += self.projections[-1](tensors[0])
|
||||
assert outputs.size(1) == self.out_features
|
||||
return outputs
|
||||
|
||||
|
||||
class NasBench101Cell(Mutable):
|
||||
"""
|
||||
Cell structure that is proposed in NAS-Bench-101.
|
||||
|
||||
Proposed by `NAS-Bench-101: Towards Reproducible Neural Architecture Search <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__.
|
||||
|
||||
This cell is usually used in evaluation of NAS algorithms because there is a "comprehensive analysis" of this search space
|
||||
available, which includes a full architecture-dataset that "maps 423k unique architectures to metrics
|
||||
including run time and accuracy". You can also use the space in your own space design, in which scenario it should be possible
|
||||
to leverage results in the benchmark to narrow the huge space down to a few efficient architectures.
|
||||
|
||||
The space of this cell architecture consists of all possible directed acyclic graphs on no more than ``max_num_nodes`` nodes,
|
||||
where each possible node (other than IN and OUT) has one of ``op_candidates``, representing the corresponding operation.
|
||||
Edges connecting the nodes can be no more than ``max_num_edges``.
|
||||
To align with the paper settings, two vertices specially labeled as operation IN and OUT, are also counted into
|
||||
``max_num_nodes`` in our implementaion, the default value of ``max_num_nodes`` is 7 and ``max_num_edges`` is 9.
|
||||
|
||||
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. The shape
|
||||
of each hidden nodes will be first automatically computed, depending on the cell structure. Each of the ``op_candidates``
|
||||
should be a callable that accepts computed ``num_features`` and returns a ``Module``. For example,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def conv_bn_relu(num_features):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(num_features, num_features, 1),
|
||||
nn.BatchNorm2d(num_features),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
The output of each node is the sum of its input node feed into its operation, except for the last node (output node),
|
||||
which is the concatenation of its input *hidden* nodes, adding the *IN* node (if IN and OUT are connected).
|
||||
|
||||
When input tensor is added with any other tensor, there could be shape mismatch. Therefore, a projection transformation
|
||||
is needed to transform the input tensor. In paper, this is simply a Conv1x1 followed by BN and ReLU. The ``projection``
|
||||
parameters accepts ``in_features`` and ``out_features``, returns a ``Module``. This parameter has no default value,
|
||||
as we hold no assumption that users are dealing with images. An example for this parameter is,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def projection_fn(in_features, out_features):
|
||||
return nn.Conv2d(in_features, out_features, 1)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op_candidates : list of callable
|
||||
Operation candidates. Each should be a function accepts number of feature, returning nn.Module.
|
||||
in_features : int
|
||||
Input dimension of cell.
|
||||
out_features : int
|
||||
Output dimension of cell.
|
||||
projection : callable
|
||||
Projection module that is used to preprocess the input tensor of the whole cell.
|
||||
A callable that accept input feature and output feature, returning nn.Module.
|
||||
max_num_nodes : int
|
||||
Maximum number of nodes in the cell, input and output included. At least 2. Default: 7.
|
||||
max_num_edges : int
|
||||
Maximum number of edges in the cell. Default: 9.
|
||||
label : str
|
||||
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
|
||||
|
||||
Warnings
|
||||
--------
|
||||
:class:`NasBench101Cell` is not supported in :ref:`graph-based execution engine <graph-based-execution-engine>`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_dict(x):
|
||||
if isinstance(x, list):
|
||||
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
|
||||
return OrderedDict(x)
|
||||
|
||||
@classmethod
|
||||
def create_fixed_module(cls, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
|
||||
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
|
||||
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
|
||||
def make_list(x): return x if isinstance(x, list) else [x]
|
||||
|
||||
label, selected = get_fixed_dict(label)
|
||||
op_candidates = cls._make_dict(op_candidates)
|
||||
num_nodes = selected[f'{label}/num_nodes']
|
||||
adjacency_list = [make_list(selected[f'{label}/input{i}']) for i in range(1, num_nodes)]
|
||||
if sum([len(e) for e in adjacency_list]) > max_num_edges:
|
||||
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
|
||||
return _NasBench101CellFixed(
|
||||
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
|
||||
adjacency_list, in_features, out_features, num_nodes, projection)
|
||||
|
||||
# FIXME: weight inheritance on nasbench101 is not supported yet
|
||||
|
||||
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
|
||||
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
|
||||
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
|
||||
|
||||
super().__init__()
|
||||
self._label = generate_new_label(label)
|
||||
num_vertices_prior = [2 ** i for i in range(2, max_num_nodes + 1)]
|
||||
num_vertices_prior = (np.array(num_vertices_prior) / sum(num_vertices_prior)).tolist()
|
||||
self.num_nodes = ValueChoice(list(range(2, max_num_nodes + 1)),
|
||||
prior=num_vertices_prior,
|
||||
label=f'{self._label}/num_nodes')
|
||||
self.max_num_nodes = max_num_nodes
|
||||
self.max_num_edges = max_num_edges
|
||||
|
||||
op_candidates = self._make_dict(op_candidates)
|
||||
|
||||
# this is only for input validation and instantiating enough layer choice and input choice
|
||||
self.hidden_features = out_features
|
||||
|
||||
self.projections = nn.ModuleList([nn.Identity()])
|
||||
self.ops = nn.ModuleList([nn.Identity()])
|
||||
self.inputs = nn.ModuleList([nn.Identity()])
|
||||
for _ in range(1, max_num_nodes):
|
||||
self.projections.append(projection(in_features, self.hidden_features))
|
||||
for i in range(1, max_num_nodes):
|
||||
if i < max_num_nodes - 1:
|
||||
self.ops.append(LayerChoice(OrderedDict([(k, op(self.hidden_features)) for k, op in op_candidates.items()]),
|
||||
label=f'{self._label}/op{i}'))
|
||||
self.inputs.append(InputChoice(i, None, label=f'{self._label}/input{i}'))
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return self._label
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
The forward of input choice is simply selecting first on all choices.
|
||||
It shouldn't be called directly by users in most cases.
|
||||
"""
|
||||
tensors = [x]
|
||||
for i in range(1, self.max_num_nodes):
|
||||
node_input = self.inputs[i]([self.projections[i](tensors[0])] + [t for t in tensors[1:]])
|
||||
if i < self.max_num_nodes - 1:
|
||||
node_output = self.ops[i](node_input)
|
||||
else:
|
||||
node_output = node_input
|
||||
tensors.append(node_output)
|
||||
return tensors[-1]
|
||||
|
||||
|
||||
class NasBench101Mutator(Mutator):
|
||||
# for validation purposes
|
||||
# for python execution engine
|
||||
|
||||
def __init__(self, label: str):
|
||||
super().__init__(label=label)
|
||||
|
||||
@staticmethod
|
||||
def candidates(node):
|
||||
if 'n_candidates' in node.operation.parameters:
|
||||
return list(range(node.operation.parameters['n_candidates']))
|
||||
else:
|
||||
return node.operation.parameters['candidates']
|
||||
|
||||
@staticmethod
|
||||
def number_of_chosen(node):
|
||||
if 'n_chosen' in node.operation.parameters:
|
||||
return node.operation.parameters['n_chosen']
|
||||
return 1
|
||||
|
||||
def mutate(self, model: Model):
|
||||
max_num_edges = cast(int, None)
|
||||
for node in model.get_nodes_by_label(self.label):
|
||||
max_num_edges = node.operation.parameters['max_num_edges']
|
||||
break
|
||||
assert max_num_edges is not None
|
||||
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
|
||||
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
|
||||
adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]
|
||||
if sum([len(e) for e in adjacency_list]) > max_num_edges:
|
||||
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
|
||||
matrix = _NasBench101CellFixed.build_connection_matrix(adjacency_list, num_nodes)
|
||||
|
||||
operations = ['IN'] + [mutation_dict[f'{self.label}/op{i}'][0] for i in range(1, num_nodes - 1)] + ['OUT']
|
||||
assert len(operations) == len(matrix)
|
||||
matrix, operations = prune(matrix, operations) # possible to raise InvalidMutation inside
|
||||
|
||||
# NOTE: a hack to maintain a clean copy of what nasbench101 cell looks like
|
||||
self._cur_samples = {}
|
||||
for i in range(1, len(matrix)):
|
||||
if i + 1 < len(matrix):
|
||||
self._cur_samples[f'op{i}'] = operations[i]
|
||||
self._cur_samples[f'input{i}'] = [k for k in range(i) if matrix[k, i]]
|
||||
self._cur_samples = [self._cur_samples] # by design, _cur_samples is a list of samples
|
||||
|
||||
def dry_run(self, model):
|
||||
return [], model
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
__all__ = ['NasBench201Cell']
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, List, Dict, Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.nas.nn.pytorch import LayerChoice
|
||||
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
|
||||
|
||||
|
||||
class NasBench201Cell(nn.Module):
|
||||
"""
|
||||
Cell structure that is proposed in NAS-Bench-201.
|
||||
|
||||
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
|
||||
|
||||
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
|
||||
For every i < j, there is an edge from i-th node to j-th node.
|
||||
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
|
||||
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
|
||||
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
|
||||
and returns a ``Module``.
|
||||
|
||||
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
|
||||
|
||||
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
|
||||
and :math:`N` is defined by ``num_tensors``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op_candidates : list of callable
|
||||
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
|
||||
in_features : int
|
||||
Input dimension of cell.
|
||||
out_features : int
|
||||
Output dimension of cell.
|
||||
num_tensors : int
|
||||
Number of tensors in the cell (input included). Default: 4
|
||||
label : str
|
||||
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_dict(x):
|
||||
if isinstance(x, list):
|
||||
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
|
||||
return OrderedDict(x)
|
||||
|
||||
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
|
||||
in_features: int, out_features: int, num_tensors: int = 4,
|
||||
label: Optional[str] = None):
|
||||
super().__init__()
|
||||
self._label = generate_new_label(label)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.num_tensors = num_tensors
|
||||
|
||||
op_candidates = self._make_dict(op_candidates)
|
||||
|
||||
for tid in range(1, num_tensors):
|
||||
node_ops = nn.ModuleList()
|
||||
for j in range(tid):
|
||||
inp = in_features if j == 0 else out_features
|
||||
op_choices = OrderedDict([(key, cls(inp, out_features))
|
||||
for key, cls in op_candidates.items()])
|
||||
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
|
||||
self.layers.append(node_ops)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
The forward of input choice is simply selecting first on all choices.
|
||||
It shouldn't be called directly by users in most cases.
|
||||
"""
|
||||
tensors: List[torch.Tensor] = [inputs]
|
||||
for layer in self.layers:
|
||||
current_tensor: List[torch.Tensor] = []
|
||||
for i, op in enumerate(layer): # type: ignore
|
||||
current_tensor.append(op(tensors[i])) # type: ignore
|
||||
tensors.append(torch.sum(torch.stack(current_tensor), 0))
|
||||
return tensors[-1]
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.nas import model_wrapper
|
||||
from .modules.nasbench101 import NasBench101Cell
|
||||
|
||||
|
||||
__all__ = ['NasBench101']
|
||||
|
||||
|
||||
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
|
||||
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
|
||||
size = tensor.shape
|
||||
tmp = tensor.new_empty(size + (4,)).normal_()
|
||||
valid = (tmp < 2) & (tmp > -2)
|
||||
ind = valid.max(-1, keepdim=True)[1]
|
||||
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
|
||||
tensor.data.mul_(std).add_(mean)
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv_bn_relu = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
|
||||
truncated_normal_(m.weight.data, mean=0., std=math.sqrt(1. / fan_in))
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_bn_relu(x)
|
||||
|
||||
|
||||
class Conv3x3BNReLU(ConvBNReLU):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Conv3x3BNReLU, self).__init__(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
|
||||
class Conv1x1BNReLU(ConvBNReLU):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Conv1x1BNReLU, self).__init__(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
|
||||
Projection = Conv1x1BNReLU
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class NasBench101(nn.Module):
|
||||
"""The full search space, proposed by `NAS-Bench-101 <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__.
|
||||
|
||||
It's simply a stack of :class:`NasBench101Cell`. Operations are conv3x3, conv1x1 and maxpool respectively.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
stem_out_channels: int = 128,
|
||||
num_stacks: int = 3,
|
||||
num_modules_per_stack: int = 3,
|
||||
max_num_vertices: int = 7,
|
||||
max_num_edges: int = 9,
|
||||
num_labels: int = 10,
|
||||
bn_eps: float = 1e-5,
|
||||
bn_momentum: float = 0.003):
|
||||
super().__init__()
|
||||
|
||||
op_candidates = {
|
||||
'conv3x3-bn-relu': lambda num_features: Conv3x3BNReLU(num_features, num_features),
|
||||
'conv1x1-bn-relu': lambda num_features: Conv1x1BNReLU(num_features, num_features),
|
||||
'maxpool3x3': lambda num_features: nn.MaxPool2d(3, 1, 1)
|
||||
}
|
||||
|
||||
# initial stem convolution
|
||||
self.stem_conv = Conv3x3BNReLU(3, stem_out_channels)
|
||||
|
||||
layers = []
|
||||
in_channels = out_channels = stem_out_channels
|
||||
for stack_num in range(num_stacks):
|
||||
if stack_num > 0:
|
||||
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
layers.append(downsample)
|
||||
out_channels *= 2
|
||||
for _ in range(num_modules_per_stack):
|
||||
cell = NasBench101Cell(op_candidates, in_channels, out_channels,
|
||||
lambda cin, cout: Projection(cin, cout),
|
||||
max_num_vertices, max_num_edges, label='cell')
|
||||
layers.append(cell)
|
||||
in_channels = out_channels
|
||||
|
||||
self.features = nn.ModuleList(layers)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(out_channels, num_labels)
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.BatchNorm2d):
|
||||
module.eps = bn_eps
|
||||
module.momentum = bn_momentum
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.size(0)
|
||||
out = self.stem_conv(x)
|
||||
for layer in self.features:
|
||||
out = layer(out)
|
||||
out = self.gap(out).view(bs, -1)
|
||||
out = self.classifier(out)
|
||||
return out
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nni.nas import model_wrapper
|
||||
from .modules.nasbench201 import NasBench201Cell
|
||||
|
||||
|
||||
__all__ = ['NasBench201']
|
||||
|
||||
|
||||
OPS_WITH_STRIDE = {
|
||||
'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
|
||||
'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'),
|
||||
'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'),
|
||||
'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)),
|
||||
'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)),
|
||||
'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out
|
||||
else FactorizedReduce(C_in, C_out, stride),
|
||||
}
|
||||
|
||||
PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(C_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Pooling(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride, mode):
|
||||
super(Pooling, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1)
|
||||
if mode == 'avg':
|
||||
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max':
|
||||
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
else:
|
||||
raise ValueError('Invalid mode={:} in Pooling'.format(mode))
|
||||
|
||||
def forward(self, x):
|
||||
if self.preprocess:
|
||||
x = self.preprocess(x)
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.stride = stride
|
||||
self.is_zero = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.C_in == self.C_out:
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
else:
|
||||
return x[:, :, ::self.stride, ::self.stride].mul(0.)
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[1] = self.C_out
|
||||
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return zeros
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
self.bn = nn.BatchNorm2d(C_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
|
||||
self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1)
|
||||
if stride == 2:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
inputs = self.downsample(inputs) # residual
|
||||
return inputs + basicblock
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class NasBench201(nn.Module):
|
||||
"""The full search space proposed by `NAS-Bench-201 <https://arxiv.org/abs/2001.00326>`__.
|
||||
|
||||
It's a stack of :class:`NasBench201Cell`.
|
||||
"""
|
||||
def __init__(self,
|
||||
stem_out_channels: int = 16,
|
||||
num_modules_per_stack: int = 5,
|
||||
num_labels: int = 10):
|
||||
super().__init__()
|
||||
self.channels = C = stem_out_channels
|
||||
self.num_modules = N = num_modules_per_stack
|
||||
self.num_labels = num_labels
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev = C
|
||||
self.cells = nn.ModuleList()
|
||||
for C_curr, reduction in zip(layer_channels, layer_reductions):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
ops: Dict[str, Callable[[int, int], nn.Module]] = {
|
||||
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
|
||||
}
|
||||
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
|
||||
self.cells.append(cell)
|
||||
C_prev = C_curr
|
||||
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(C_prev),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, self.num_labels)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for cell in self.cells:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return logits
|
|
@ -0,0 +1,862 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""File containing NASNet-series search space.
|
||||
|
||||
The implementation is based on NDS.
|
||||
It's called ``nasnet.py`` simply because NASNet is the first to propose such structure.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
|
||||
import nni.nas.nn.pytorch as nn
|
||||
from nni.nas import model_wrapper
|
||||
|
||||
from nni.nas.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
|
||||
from nni.nas.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
|
||||
|
||||
from .utils.fixed import FixedFactory
|
||||
from .utils.pretrained import load_pretrained_weight
|
||||
|
||||
|
||||
# the following are NAS operations from
|
||||
# https://github.com/facebookresearch/unnas/blob/main/pycls/models/nas/operations.py
|
||||
|
||||
OPS = {
|
||||
'none': lambda C, stride, affine:
|
||||
Zero(stride),
|
||||
'avg_pool_2x2': lambda C, stride, affine:
|
||||
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
|
||||
'avg_pool_3x3': lambda C, stride, affine:
|
||||
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
'avg_pool_5x5': lambda C, stride, affine:
|
||||
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
|
||||
'max_pool_2x2': lambda C, stride, affine:
|
||||
nn.MaxPool2d(2, stride=stride, padding=0),
|
||||
'max_pool_3x3': lambda C, stride, affine:
|
||||
nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
'max_pool_5x5': lambda C, stride, affine:
|
||||
nn.MaxPool2d(5, stride=stride, padding=2),
|
||||
'max_pool_7x7': lambda C, stride, affine:
|
||||
nn.MaxPool2d(7, stride=stride, padding=3),
|
||||
'skip_connect': lambda C, stride, affine:
|
||||
nn.Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'conv_1x1': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_3x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'sep_conv_3x3': lambda C, stride, affine:
|
||||
SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5': lambda C, stride, affine:
|
||||
SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7': lambda C, stride, affine:
|
||||
SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3': lambda C, stride, affine:
|
||||
DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5': lambda C, stride, affine:
|
||||
DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'dil_sep_conv_3x3': lambda C, stride, affine:
|
||||
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'conv_3x1_1x3': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'conv_7x1_1x7': lambda C, stride, affine:
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Sequential):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super().__init__(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_out, kernel_size, stride=stride,
|
||||
padding=padding, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
|
||||
class DilConv(nn.Sequential):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super().__init__(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
|
||||
class SepConv(nn.Sequential):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super().__init__(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
|
||||
class DilSepConv(nn.Sequential):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super().__init__(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in, C_in, kernel_size=kernel_size, stride=1,
|
||||
padding=padding, dilation=dilation, groups=C_in, bias=False
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:, :, ::self.stride, ::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super().__init__()
|
||||
if isinstance(C_out, int):
|
||||
assert C_out % 2 == 0
|
||||
else: # is a value choice
|
||||
assert all(c % 2 == 0 for c in C_out.all_options())
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(y[:, :, 1:, 1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath_(nn.Module):
|
||||
# https://github.com/khanrc/pt.darts/blob/0.1/models/ops.py
|
||||
def __init__(self, drop_prob=0.):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.drop_prob > 0.:
|
||||
keep_prob = 1. - self.drop_prob
|
||||
mask = torch.zeros((x.size(0), 1, 1, 1), dtype=torch.float, device=x.device).bernoulli_(keep_prob)
|
||||
return x.div(keep_prob).mul(mask)
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHead(nn.Module):
|
||||
def __init__(self, C: int, num_labels: int, dataset: Literal['imagenet', 'cifar']):
|
||||
super().__init__()
|
||||
if dataset == 'imagenet':
|
||||
# assuming input size 14x14
|
||||
stride = 2
|
||||
elif dataset == 'cifar':
|
||||
stride = 3
|
||||
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=stride, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_labels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class SequentialBreakdown(nn.Sequential):
|
||||
"""Return all layers of a sequential."""
|
||||
|
||||
def __init__(self, sequential: nn.Sequential):
|
||||
super().__init__(OrderedDict(sequential.named_children()))
|
||||
|
||||
def forward(self, inputs):
|
||||
result = []
|
||||
for module in self:
|
||||
inputs = module(inputs)
|
||||
result.append(inputs)
|
||||
return result
|
||||
|
||||
|
||||
class CellPreprocessor(nn.Module):
|
||||
"""
|
||||
Aligning the shape of predecessors.
|
||||
|
||||
If the last cell is a reduction cell, ``pre0`` should be ``FactorizedReduce`` instead of ``ReLUConvBN``.
|
||||
See :class:`CellBuilder` on how to calculate those channel numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
|
||||
super().__init__()
|
||||
|
||||
if last_cell_reduce:
|
||||
self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
|
||||
else:
|
||||
self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
|
||||
self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
|
||||
|
||||
def forward(self, cells):
|
||||
assert len(cells) == 2
|
||||
pprev, prev = cells
|
||||
pprev = self.pre0(pprev)
|
||||
prev = self.pre1(prev)
|
||||
|
||||
return [pprev, prev]
|
||||
|
||||
|
||||
class CellPostprocessor(nn.Module):
|
||||
"""
|
||||
The cell outputs previous cell + this cell, so that cells can be directly chained.
|
||||
"""
|
||||
|
||||
def forward(self, this_cell, previous_cells):
|
||||
return [previous_cells[-1], this_cell]
|
||||
|
||||
|
||||
class CellBuilder:
|
||||
"""The cell builder is used in Repeat.
|
||||
Builds an cell each time it's "called".
|
||||
Note that the builder is ephemeral, it can only be called once for every index.
|
||||
"""
|
||||
|
||||
def __init__(self, op_candidates: List[str],
|
||||
C_prev_in: nn.MaybeChoice[int],
|
||||
C_in: nn.MaybeChoice[int],
|
||||
C: nn.MaybeChoice[int],
|
||||
num_nodes: int,
|
||||
merge_op: Literal['all', 'loose_end'],
|
||||
first_cell_reduce: bool, last_cell_reduce: bool):
|
||||
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
|
||||
self.C_in = C_in # This is the out channesl of last cell.
|
||||
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
|
||||
self.op_candidates = op_candidates
|
||||
self.num_nodes = num_nodes
|
||||
self.merge_op: Literal['all', 'loose_end'] = merge_op
|
||||
self.first_cell_reduce = first_cell_reduce
|
||||
self.last_cell_reduce = last_cell_reduce
|
||||
self._expect_idx = 0
|
||||
|
||||
# It takes an index that is the index in the repeat.
|
||||
# Number of predecessors for each cell is fixed to 2.
|
||||
self.num_predecessors = 2
|
||||
|
||||
# Number of ops per node is fixed to 2.
|
||||
self.num_ops_per_node = 2
|
||||
|
||||
def op_factory(self, node_index: int, op_index: int, input_index: Optional[int], *,
|
||||
op: str, channels: int, is_reduction_cell: bool):
|
||||
if is_reduction_cell and (
|
||||
input_index is None or input_index < self.num_predecessors
|
||||
): # could be none when constructing search sapce
|
||||
stride = 2
|
||||
else:
|
||||
stride = 1
|
||||
return OPS[op](channels, stride, True)
|
||||
|
||||
def __call__(self, repeat_idx: int):
|
||||
if self._expect_idx != repeat_idx:
|
||||
raise ValueError(f'Expect index {self._expect_idx}, found {repeat_idx}')
|
||||
|
||||
# Reduction cell means stride = 2 and channel multiplied by 2.
|
||||
is_reduction_cell = repeat_idx == 0 and self.first_cell_reduce
|
||||
|
||||
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
|
||||
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
|
||||
|
||||
ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {}
|
||||
for op in self.op_candidates:
|
||||
ops_factory[op] = partial(self.op_factory, op=op, channels=cast(int, self.C), is_reduction_cell=is_reduction_cell)
|
||||
|
||||
cell = nn.Cell(ops_factory, self.num_nodes, self.num_ops_per_node, self.num_predecessors, self.merge_op,
|
||||
preprocessor=preprocessor, postprocessor=CellPostprocessor(),
|
||||
label='reduce' if is_reduction_cell else 'normal')
|
||||
|
||||
# update state
|
||||
self.C_prev_in = self.C_in
|
||||
self.C_in = self.C * len(cell.output_node_indices)
|
||||
self.last_cell_reduce = is_reduction_cell
|
||||
self._expect_idx += 1
|
||||
|
||||
return cell
|
||||
|
||||
|
||||
class NDSStage(nn.Repeat):
|
||||
"""This class defines NDSStage, a special type of Repeat, for isinstance check, and shape alignment.
|
||||
|
||||
In NDS, we can't simply use Repeat to stack the blocks,
|
||||
because the output shape of each stacked block can be different.
|
||||
This is a problem for one-shot strategy because they assume every possible candidate
|
||||
should return values of the same shape.
|
||||
|
||||
Therefore, we need :class:`NDSStagePathSampling` and :class:`NDSStageDifferentiable`
|
||||
to manually align the shapes -- specifically, to transform the first block in each stage.
|
||||
|
||||
This is not required though, when depth is not changing, or the mutable depth causes no problem
|
||||
(e.g., when the minimum depth is large enough).
|
||||
|
||||
.. attention::
|
||||
|
||||
Assumption: Loose end is treated as all in ``merge_op`` (the case in one-shot),
|
||||
which enforces reduction cell and normal cells in the same stage to have the exact same output shape.
|
||||
"""
|
||||
|
||||
estimated_out_channels_prev: int
|
||||
"""Output channels of cells in last stage."""
|
||||
|
||||
estimated_out_channels: int
|
||||
"""Output channels of this stage. It's **estimated** because it assumes ``all`` as ``merge_op``."""
|
||||
|
||||
downsampling: bool
|
||||
"""This stage has downsampling"""
|
||||
|
||||
def first_cell_transformation_factory(self) -> Optional[nn.Module]:
|
||||
"""To make the "previous cell" in first cell's output have the same shape as cells in this stage."""
|
||||
if self.downsampling:
|
||||
return FactorizedReduce(self.estimated_out_channels_prev, self.estimated_out_channels)
|
||||
elif self.estimated_out_channels_prev is not self.estimated_out_channels:
|
||||
# Can't use != here, ValueChoice doesn't support
|
||||
return ReLUConvBN(self.estimated_out_channels_prev, self.estimated_out_channels, 1, 1, 0)
|
||||
return None
|
||||
|
||||
|
||||
class NDSStagePathSampling(PathSamplingRepeat):
|
||||
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
|
||||
return cls(
|
||||
module.first_cell_transformation_factory(),
|
||||
cast(List[nn.Module], module.blocks),
|
||||
module.depth_choice
|
||||
)
|
||||
|
||||
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.first_cell_transformation = first_cell_transformation
|
||||
|
||||
def reduction(self, items: List[Tuple[torch.Tensor, torch.Tensor]], sampled: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if 1 not in sampled or self.first_cell_transformation is None:
|
||||
return super().reduction(items, sampled)
|
||||
# items[0] must be the result of first cell
|
||||
assert len(items[0]) == 2
|
||||
# Only apply the transformation on "prev" output.
|
||||
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
|
||||
return super().reduction(items, sampled)
|
||||
|
||||
|
||||
class NDSStageDifferentiable(DifferentiableMixedRepeat):
|
||||
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
|
||||
@classmethod
|
||||
def mutate(cls, module, name, memo, mutate_kwargs):
|
||||
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
|
||||
# Only interesting when depth is mutable
|
||||
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
|
||||
return cls(
|
||||
module.first_cell_transformation_factory(),
|
||||
cast(List[nn.Module], module.blocks),
|
||||
module.depth_choice,
|
||||
softmax,
|
||||
memo
|
||||
)
|
||||
|
||||
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.first_cell_transformation = first_cell_transformation
|
||||
|
||||
def reduction(
|
||||
self, items: List[Tuple[torch.Tensor, torch.Tensor]], weights: List[float], depths: List[int]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if 1 not in depths or self.first_cell_transformation is None:
|
||||
return super().reduction(items, weights, depths)
|
||||
# Same as NDSStagePathSampling
|
||||
assert len(items[0]) == 2
|
||||
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
|
||||
return super().reduction(items, weights, depths)
|
||||
|
||||
|
||||
_INIT_PARAMETER_DOCS = """
|
||||
|
||||
Parameters
|
||||
----------
|
||||
width : int or tuple of int
|
||||
A fixed initial width or a tuple of widths to choose from.
|
||||
num_cells : int or tuple of int
|
||||
A fixed number of cells (depths) to stack, or a tuple of depths to choose from.
|
||||
dataset : "cifar" | "imagenet"
|
||||
The essential differences are in "stem" cells, i.e., how they process the raw image input.
|
||||
Choosing "imagenet" means more downsampling at the beginning of the network.
|
||||
auxiliary_loss : bool
|
||||
If true, another auxiliary classification head will produce the another prediction.
|
||||
This makes the output of network two logits in the training phase.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class NDS(nn.Module):
|
||||
__doc__ = """
|
||||
The unified version of NASNet search space.
|
||||
|
||||
We follow the implementation in
|
||||
`unnas <https://github.com/facebookresearch/unnas/blob/main/pycls/models/nas/nas.py>`__.
|
||||
See `On Network Design Spaces for Visual Recognition <https://arxiv.org/abs/1905.13214>`__ for details.
|
||||
|
||||
Different NAS papers usually differ in the way that they specify ``op_candidates`` and ``merge_op``.
|
||||
``dataset`` here is to give a hint about input resolution, so as to create reasonable stem and auxiliary heads.
|
||||
|
||||
NDS has a speciality that it has mutable depths/widths.
|
||||
This is implemented by accepting a list of int as ``num_cells`` / ``width``.
|
||||
""" + _INIT_PARAMETER_DOCS + """
|
||||
op_candidates : list of str
|
||||
List of operator candidates. Must be from ``OPS``.
|
||||
merge_op : ``all`` or ``loose_end``
|
||||
See :class:`~nni.retiarii.nn.pytorch.Cell`.
|
||||
num_nodes_per_cell : int
|
||||
See :class:`~nni.retiarii.nn.pytorch.Cell`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
op_candidates: List[str],
|
||||
merge_op: Literal['all', 'loose_end'] = 'all',
|
||||
num_nodes_per_cell: int = 4,
|
||||
width: Union[Tuple[int, ...], int] = 16,
|
||||
num_cells: Union[Tuple[int, ...], int] = 20,
|
||||
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.dataset = dataset
|
||||
self.num_labels = 10 if dataset == 'cifar' else 1000
|
||||
self.auxiliary_loss = auxiliary_loss
|
||||
|
||||
# preprocess the specified width and depth
|
||||
if isinstance(width, Iterable):
|
||||
C = nn.ValueChoice(list(width), label='width')
|
||||
else:
|
||||
C = width
|
||||
|
||||
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
|
||||
if isinstance(num_cells, Iterable):
|
||||
self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
|
||||
num_cells_per_stage = [(i + 1) * self.num_cells // 3 - i * self.num_cells // 3 for i in range(3)]
|
||||
|
||||
# auxiliary head is different for network targetted at different datasets
|
||||
if dataset == 'imagenet':
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(cast(int, C // 2)),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
C_pprev = C_prev = C_curr = C
|
||||
last_cell_reduce = True
|
||||
elif dataset == 'cifar':
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(cast(int, 3 * C))
|
||||
)
|
||||
C_pprev = C_prev = 3 * C
|
||||
C_curr = C
|
||||
last_cell_reduce = False
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset: {dataset}')
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
for stage_idx in range(3):
|
||||
if stage_idx > 0:
|
||||
C_curr *= 2
|
||||
# For a stage, we get C_in, C_curr, and C_out.
|
||||
# C_in is only used in the first cell.
|
||||
# C_curr is number of channels for each operator in current stage.
|
||||
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
|
||||
cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell,
|
||||
merge_op, stage_idx > 0, last_cell_reduce)
|
||||
stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx])
|
||||
|
||||
if isinstance(stage, NDSStage):
|
||||
stage.estimated_out_channels_prev = cast(int, C_prev)
|
||||
stage.estimated_out_channels = cast(int, C_curr * num_nodes_per_cell)
|
||||
stage.downsampling = stage_idx > 0
|
||||
|
||||
self.stages.append(stage)
|
||||
|
||||
# NOTE: output_node_indices will be computed on-the-fly in trial code.
|
||||
# When constructing model space, it's just all the nodes in the cell,
|
||||
# which happens to be the case of one-shot supernet.
|
||||
|
||||
# C_pprev is output channel number of last second cell among all the cells already built.
|
||||
if len(stage) > 1:
|
||||
# Contains more than one cell
|
||||
C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
|
||||
else:
|
||||
# Look up in the out channels of last stage.
|
||||
C_pprev = C_prev
|
||||
|
||||
# This was originally,
|
||||
# C_prev = num_nodes_per_cell * C_curr.
|
||||
# but due to loose end, it becomes,
|
||||
C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
|
||||
|
||||
# Useful in aligning the pprev and prev cell.
|
||||
last_cell_reduce = cell_builder.last_cell_reduce
|
||||
|
||||
if stage_idx == 2:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary_loss:
|
||||
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
|
||||
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
|
||||
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.dataset == 'imagenet':
|
||||
s0 = self.stem0(inputs)
|
||||
s1 = self.stem1(s0)
|
||||
else:
|
||||
s0 = s1 = self.stem(inputs)
|
||||
|
||||
for stage_idx, stage in enumerate(self.stages):
|
||||
if stage_idx == 2 and self.auxiliary_loss:
|
||||
s = list(stage([s0, s1]).values())
|
||||
s0, s1 = s[-1]
|
||||
if self.training:
|
||||
# auxiliary loss is attached to the first cell of the last stage.
|
||||
logits_aux = self.auxiliary_head(s[0][1])
|
||||
else:
|
||||
s0, s1 = stage([s0, s1])
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
if self.training and self.auxiliary_loss:
|
||||
return logits, logits_aux # type: ignore
|
||||
else:
|
||||
return logits
|
||||
|
||||
def set_drop_path_prob(self, drop_prob):
|
||||
"""
|
||||
Set the drop probability of Drop-path in the network.
|
||||
Reference: `FractalNet: Ultra-Deep Neural Networks without Residuals <https://arxiv.org/pdf/1605.07648v4.pdf>`__.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, DropPath_):
|
||||
module.drop_prob = drop_prob
|
||||
|
||||
@classmethod
|
||||
def fixed_arch(cls, arch: dict) -> FixedFactory:
|
||||
return FixedFactory(cls, arch)
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class NASNet(NDS):
|
||||
__doc__ = """
|
||||
Search space proposed in `Learning Transferable Architectures for Scalable Image Recognition <https://arxiv.org/abs/1707.07012>`__.
|
||||
|
||||
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
|
||||
Its operator candidates are :attribute:`~NASNet.NASNET_OPS`.
|
||||
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
|
||||
""" + _INIT_PARAMETER_DOCS
|
||||
|
||||
NASNET_OPS = [
|
||||
'skip_connect',
|
||||
'conv_3x1_1x3',
|
||||
'conv_7x1_1x7',
|
||||
'dil_conv_3x3',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'max_pool_5x5',
|
||||
'max_pool_7x7',
|
||||
'conv_1x1',
|
||||
'conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.NASNET_OPS,
|
||||
merge_op='loose_end',
|
||||
num_nodes_per_cell=5,
|
||||
width=width,
|
||||
num_cells=num_cells,
|
||||
dataset=dataset,
|
||||
auxiliary_loss=auxiliary_loss)
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class ENAS(NDS):
|
||||
__doc__ = """Search space proposed in `Efficient neural architecture search via parameter sharing <https://arxiv.org/abs/1802.03268>`__.
|
||||
|
||||
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
|
||||
Its operator candidates are :attribute:`~ENAS.ENAS_OPS`.
|
||||
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
|
||||
""" + _INIT_PARAMETER_DOCS
|
||||
|
||||
ENAS_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.ENAS_OPS,
|
||||
merge_op='loose_end',
|
||||
num_nodes_per_cell=5,
|
||||
width=width,
|
||||
num_cells=num_cells,
|
||||
dataset=dataset,
|
||||
auxiliary_loss=auxiliary_loss)
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class AmoebaNet(NDS):
|
||||
__doc__ = """Search space proposed in
|
||||
`Regularized evolution for image classifier architecture search <https://arxiv.org/abs/1802.01548>`__.
|
||||
|
||||
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
|
||||
Its operator candidates are :attribute:`~AmoebaNet.AMOEBA_OPS`.
|
||||
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
|
||||
""" + _INIT_PARAMETER_DOCS
|
||||
|
||||
AMOEBA_OPS = [
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_sep_conv_3x3',
|
||||
'conv_7x1_1x7',
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
|
||||
super().__init__(self.AMOEBA_OPS,
|
||||
merge_op='loose_end',
|
||||
num_nodes_per_cell=5,
|
||||
width=width,
|
||||
num_cells=num_cells,
|
||||
dataset=dataset,
|
||||
auxiliary_loss=auxiliary_loss)
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class PNAS(NDS):
|
||||
__doc__ = """Search space proposed in
|
||||
`Progressive neural architecture search <https://arxiv.org/abs/1712.00559>`__.
|
||||
|
||||
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
|
||||
Its operator candidates are :attribute:`~PNAS.PNAS_OPS`.
|
||||
It has 5 nodes per cell, and the output is concatenation of all nodes in the cell.
|
||||
""" + _INIT_PARAMETER_DOCS
|
||||
|
||||
PNAS_OPS = [
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'skip_connect',
|
||||
'avg_pool_3x3',
|
||||
'max_pool_3x3',
|
||||
'dil_conv_3x3',
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.PNAS_OPS,
|
||||
merge_op='all',
|
||||
num_nodes_per_cell=5,
|
||||
width=width,
|
||||
num_cells=num_cells,
|
||||
dataset=dataset,
|
||||
auxiliary_loss=auxiliary_loss)
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class DARTS(NDS):
|
||||
__doc__ = """Search space proposed in `Darts: Differentiable architecture search <https://arxiv.org/abs/1806.09055>`__.
|
||||
|
||||
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
|
||||
Its operator candidates are :attribute:`~DARTS.DARTS_OPS`.
|
||||
It has 4 nodes per cell, and the output is concatenation of all nodes in the cell.
|
||||
""" + _INIT_PARAMETER_DOCS
|
||||
|
||||
DARTS_OPS = [
|
||||
'none',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
width: Union[Tuple[int, ...], int] = (16, 24, 32),
|
||||
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
|
||||
dataset: Literal['cifar', 'imagenet'] = 'cifar',
|
||||
auxiliary_loss: bool = False):
|
||||
super().__init__(self.DARTS_OPS,
|
||||
merge_op='all',
|
||||
num_nodes_per_cell=4,
|
||||
width=width,
|
||||
num_cells=num_cells,
|
||||
dataset=dataset,
|
||||
auxiliary_loss=auxiliary_loss)
|
||||
|
||||
@classmethod
|
||||
def load_searched_model(
|
||||
cls, name: str,
|
||||
pretrained: bool = False, download: bool = False, progress: bool = True
|
||||
) -> nn.Module:
|
||||
|
||||
init_kwargs = {} # all default
|
||||
|
||||
if name == 'darts-v2':
|
||||
init_kwargs.update(
|
||||
num_cells=20,
|
||||
width=36,
|
||||
)
|
||||
arch = {
|
||||
'normal/op_2_0': 'sep_conv_3x3',
|
||||
'normal/op_2_1': 'sep_conv_3x3',
|
||||
'normal/input_2_0': 0,
|
||||
'normal/input_2_1': 1,
|
||||
'normal/op_3_0': 'sep_conv_3x3',
|
||||
'normal/op_3_1': 'sep_conv_3x3',
|
||||
'normal/input_3_0': 0,
|
||||
'normal/input_3_1': 1,
|
||||
'normal/op_4_0': 'sep_conv_3x3',
|
||||
'normal/op_4_1': 'skip_connect',
|
||||
'normal/input_4_0': 1,
|
||||
'normal/input_4_1': 0,
|
||||
'normal/op_5_0': 'skip_connect',
|
||||
'normal/op_5_1': 'dil_conv_3x3',
|
||||
'normal/input_5_0': 0,
|
||||
'normal/input_5_1': 2,
|
||||
'reduce/op_2_0': 'max_pool_3x3',
|
||||
'reduce/op_2_1': 'max_pool_3x3',
|
||||
'reduce/input_2_0': 0,
|
||||
'reduce/input_2_1': 1,
|
||||
'reduce/op_3_0': 'skip_connect',
|
||||
'reduce/op_3_1': 'max_pool_3x3',
|
||||
'reduce/input_3_0': 2,
|
||||
'reduce/input_3_1': 1,
|
||||
'reduce/op_4_0': 'max_pool_3x3',
|
||||
'reduce/op_4_1': 'skip_connect',
|
||||
'reduce/input_4_0': 0,
|
||||
'reduce/input_4_1': 2,
|
||||
'reduce/op_5_0': 'skip_connect',
|
||||
'reduce/op_5_1': 'max_pool_3x3',
|
||||
'reduce/input_5_0': 2,
|
||||
'reduce/input_5_1': 1
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unsupported architecture with name: {name}')
|
||||
|
||||
model_factory = cls.fixed_arch(arch)
|
||||
model = model_factory(**init_kwargs)
|
||||
|
||||
if pretrained:
|
||||
weight_file = load_pretrained_weight(name, download=download, progress=progress)
|
||||
pretrained_weights = torch.load(weight_file)
|
||||
model.load_state_dict(pretrained_weights)
|
||||
|
||||
return model
|
|
@ -0,0 +1,538 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
|
||||
|
||||
import torch
|
||||
import nni.nas.nn.pytorch as nn
|
||||
from nni.nas import model_wrapper
|
||||
|
||||
from .utils.fixed import FixedFactory
|
||||
from .utils.pretrained import load_pretrained_weight
|
||||
|
||||
|
||||
@overload
|
||||
def make_divisible(v: Union[int, float], divisor, min_val=None) -> int:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float]], divisor, min_val=None) -> nn.ChoiceOf[int]:
|
||||
...
|
||||
|
||||
|
||||
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float], int, float], divisor, min_val=None) -> nn.MaybeChoice[int]:
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
"""
|
||||
if min_val is None:
|
||||
min_val = divisor
|
||||
# This should work for both value choices and constants.
|
||||
new_v = nn.ValueChoice.max(min_val, round(v + divisor // 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
return nn.ValueChoice.condition(new_v < 0.9 * v, new_v + divisor, new_v)
|
||||
|
||||
|
||||
def simplify_sequential(sequentials: List[nn.Module]) -> Iterator[nn.Module]:
|
||||
"""
|
||||
Flatten the sequential blocks so that the hierarchy looks better.
|
||||
Eliminate identity modules automatically.
|
||||
"""
|
||||
for module in sequentials:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
# no recursive expansion
|
||||
if not isinstance(submodule, nn.Identity):
|
||||
yield submodule
|
||||
else:
|
||||
if not isinstance(module, nn.Identity):
|
||||
yield module
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
"""
|
||||
The template for a conv-bn-relu block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: nn.MaybeChoice[int],
|
||||
out_channels: nn.MaybeChoice[int],
|
||||
kernel_size: nn.MaybeChoice[int] = 3,
|
||||
stride: int = 1,
|
||||
groups: nn.MaybeChoice[int] = 1,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
activation_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
dilation: int = 1,
|
||||
) -> None:
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if activation_layer is None:
|
||||
activation_layer = nn.ReLU6
|
||||
# If no normalization is used, set bias to True
|
||||
# https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L194
|
||||
norm = norm_layer(cast(int, out_channels))
|
||||
no_normalization = isinstance(norm, nn.Identity)
|
||||
blocks: List[nn.Module] = [
|
||||
nn.Conv2d(
|
||||
cast(int, in_channels),
|
||||
cast(int, out_channels),
|
||||
cast(int, kernel_size),
|
||||
stride,
|
||||
cast(int, padding),
|
||||
dilation=dilation,
|
||||
groups=cast(int, groups),
|
||||
bias=no_normalization
|
||||
),
|
||||
# Normalization, regardless of batchnorm or identity
|
||||
norm,
|
||||
# One pytorch implementation as an SE here, to faithfully reproduce paper
|
||||
# We follow a more accepted approach to put SE outside
|
||||
# Reference: https://github.com/d-li14/mobilenetv3.pytorch/issues/18
|
||||
activation_layer(inplace=True)
|
||||
]
|
||||
|
||||
super().__init__(*simplify_sequential(blocks))
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Sequential):
|
||||
"""
|
||||
In the original MobileNetV2 implementation, this is InvertedResidual when expand ratio = 1.
|
||||
Residual connection is added if input and output shape are the same.
|
||||
|
||||
References:
|
||||
|
||||
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L90
|
||||
- https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L433
|
||||
- https://github.com/ultmaster/AceNAS/blob/46c8895f/searchspace/proxylessnas/utils.py#L100
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: nn.MaybeChoice[int],
|
||||
out_channels: nn.MaybeChoice[int],
|
||||
kernel_size: nn.MaybeChoice[int] = 3,
|
||||
stride: int = 1,
|
||||
squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
activation_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
) -> None:
|
||||
blocks = [
|
||||
# dw
|
||||
ConvBNReLU(in_channels, in_channels, stride=stride, kernel_size=kernel_size, groups=in_channels,
|
||||
norm_layer=norm_layer, activation_layer=activation_layer),
|
||||
# optional se
|
||||
squeeze_excite(in_channels, in_channels) if squeeze_excite else nn.Identity(),
|
||||
# pw-linear
|
||||
ConvBNReLU(in_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)
|
||||
]
|
||||
super().__init__(*simplify_sequential(blocks))
|
||||
# NOTE: "is" is used here instead of "==" to avoid creating a new value choice.
|
||||
self.has_skip = stride == 1 and in_channels is out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.has_skip:
|
||||
return x + super().forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Sequential):
|
||||
"""
|
||||
An Inverted Residual Block, sometimes called an MBConv Block, is a type of residual block used for image models
|
||||
that uses an inverted structure for efficiency reasons.
|
||||
|
||||
It was originally proposed for the `MobileNetV2 <https://arxiv.org/abs/1801.04381>`__ CNN architecture.
|
||||
It has since been reused for several mobile-optimized CNNs.
|
||||
It follows a narrow -> wide -> narrow approach, hence the inversion.
|
||||
It first widens with a 1x1 convolution, then uses a 3x3 depthwise convolution (which greatly reduces the number of parameters),
|
||||
then a 1x1 convolution is used to reduce the number of channels so input and output can be added.
|
||||
|
||||
This implementation is sort of a mixture between:
|
||||
|
||||
- https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L453
|
||||
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L134
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: nn.MaybeChoice[int],
|
||||
out_channels: nn.MaybeChoice[int],
|
||||
expand_ratio: nn.MaybeChoice[float],
|
||||
kernel_size: nn.MaybeChoice[int] = 3,
|
||||
stride: int = 1,
|
||||
squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
activation_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = out_channels
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_ch = cast(int, make_divisible(in_channels * expand_ratio, 8))
|
||||
|
||||
# NOTE: this equivalence check (==) does NOT work for ValueChoice, need to use "is"
|
||||
self.has_skip = stride == 1 and in_channels is out_channels
|
||||
|
||||
layers: List[nn.Module] = [
|
||||
# point-wise convolution
|
||||
# NOTE: some paper omit this point-wise convolution when stride = 1.
|
||||
# In our implementation, if this pw convolution is intended to be omitted,
|
||||
# please use SepConv instead.
|
||||
ConvBNReLU(in_channels, hidden_ch, kernel_size=1,
|
||||
norm_layer=norm_layer, activation_layer=activation_layer),
|
||||
# depth-wise
|
||||
ConvBNReLU(hidden_ch, hidden_ch, stride=stride, kernel_size=kernel_size, groups=hidden_ch,
|
||||
norm_layer=norm_layer, activation_layer=activation_layer),
|
||||
# SE
|
||||
squeeze_excite(
|
||||
cast(int, hidden_ch),
|
||||
cast(int, in_channels)
|
||||
) if squeeze_excite is not None else nn.Identity(),
|
||||
# pw-linear
|
||||
ConvBNReLU(hidden_ch, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity),
|
||||
]
|
||||
|
||||
super().__init__(*simplify_sequential(layers))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.has_skip:
|
||||
return x + super().forward(x)
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
def inverted_residual_choice_builder(
|
||||
expand_ratios: List[int],
|
||||
kernel_sizes: List[int],
|
||||
downsample: bool,
|
||||
stage_input_width: int,
|
||||
stage_output_width: int,
|
||||
label: str
|
||||
):
|
||||
def builder(index):
|
||||
stride = 1
|
||||
inp = stage_output_width
|
||||
|
||||
if index == 0:
|
||||
# first layer in stage
|
||||
# do downsample and width reshape
|
||||
inp = stage_input_width
|
||||
if downsample:
|
||||
stride = 2
|
||||
|
||||
oup = stage_output_width
|
||||
|
||||
op_choices = {}
|
||||
for exp_ratio in expand_ratios:
|
||||
for kernel_size in kernel_sizes:
|
||||
op_choices[f'k{kernel_size}e{exp_ratio}'] = InvertedResidual(inp, oup, exp_ratio, kernel_size, stride)
|
||||
|
||||
# It can be implemented with ValueChoice, but we use LayerChoice here
|
||||
# to be aligned with the intention of the original ProxylessNAS.
|
||||
return nn.LayerChoice(op_choices, label=f'{label}_i{index}')
|
||||
|
||||
return builder
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class ProxylessNAS(nn.Module):
|
||||
"""
|
||||
The search space proposed by `ProxylessNAS <https://arxiv.org/abs/1812.00332>`__.
|
||||
|
||||
Following the official implementation, the inverted residual with kernel size / expand ratio variations in each layer
|
||||
is implemented with a :class:`nn.LayerChoice` with all-combination candidates. That means,
|
||||
when used in weight sharing, these candidates will be treated as separate layers, and won't be fine-grained shared.
|
||||
We note that :class:`MobileNetV3Space` is different in this perspective.
|
||||
|
||||
This space can be implemented as part of :class:`MobileNetV3Space`, but we separate those following conventions.
|
||||
"""
|
||||
|
||||
def __init__(self, num_labels: int = 1000,
|
||||
base_widths: Tuple[int, ...] = (32, 16, 32, 40, 80, 96, 192, 320, 1280),
|
||||
dropout_rate: float = 0.,
|
||||
width_mult: float = 1.0,
|
||||
bn_eps: float = 1e-3,
|
||||
bn_momentum: float = 0.1):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert len(base_widths) == 9
|
||||
# include the last stage info widths here
|
||||
widths = [make_divisible(width * width_mult, 8) for width in base_widths]
|
||||
downsamples = [True, False, True, True, True, False, True, False]
|
||||
|
||||
self.num_labels = num_labels
|
||||
self.dropout_rate = dropout_rate
|
||||
self.bn_eps = bn_eps
|
||||
self.bn_momentum = bn_momentum
|
||||
|
||||
self.stem = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
|
||||
|
||||
blocks: List[nn.Module] = [
|
||||
# first stage is fixed
|
||||
DepthwiseSeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
|
||||
]
|
||||
|
||||
# https://github.com/ultmaster/AceNAS/blob/46c8895fd8a05ffbc61a6b44f1e813f64b4f66b7/searchspace/proxylessnas/__init__.py#L21
|
||||
for stage in range(2, 8):
|
||||
# Rather than returning a fixed module here,
|
||||
# we return a builder that dynamically creates module for different `repeat_idx`.
|
||||
builder = inverted_residual_choice_builder(
|
||||
[3, 6], [3, 5, 7], downsamples[stage], widths[stage - 1], widths[stage], f's{stage}')
|
||||
if stage < 7:
|
||||
blocks.append(nn.Repeat(builder, (1, 4), label=f's{stage}_depth'))
|
||||
else:
|
||||
# No mutation for depth in the last stage.
|
||||
# Directly call builder to initiate one block
|
||||
blocks.append(builder(0))
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
# final layers
|
||||
self.feature_mix_layer = ConvBNReLU(widths[7], widths[8], kernel_size=1, norm_layer=nn.BatchNorm2d)
|
||||
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.dropout_layer = nn.Dropout(dropout_rate)
|
||||
self.classifier = nn.Linear(widths[-1], num_labels)
|
||||
|
||||
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
x = self.feature_mix_layer(x)
|
||||
x = self.global_avg_pooling(x)
|
||||
x = x.view(x.size(0), -1) # flatten
|
||||
x = self.dropout_layer(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def no_weight_decay(self):
|
||||
# this is useful for timm optimizer
|
||||
# no regularizer to linear layer
|
||||
if hasattr(self, 'classifier'):
|
||||
return {'classifier.weight', 'classifier.bias'}
|
||||
return set()
|
||||
|
||||
@classmethod
|
||||
def fixed_arch(cls, arch: dict) -> FixedFactory:
|
||||
return FixedFactory(cls, arch)
|
||||
|
||||
@classmethod
|
||||
def load_searched_model(
|
||||
cls, name: str,
|
||||
pretrained: bool = False, download: bool = False, progress: bool = True
|
||||
) -> nn.Module:
|
||||
|
||||
init_kwargs = {} # all default
|
||||
|
||||
if name == 'acenas-m1':
|
||||
arch = {
|
||||
's2_depth': 2,
|
||||
's2_i0': 'k3e6',
|
||||
's2_i1': 'k3e3',
|
||||
's3_depth': 3,
|
||||
's3_i0': 'k5e3',
|
||||
's3_i1': 'k3e3',
|
||||
's3_i2': 'k5e3',
|
||||
's4_depth': 2,
|
||||
's4_i0': 'k3e6',
|
||||
's4_i1': 'k5e3',
|
||||
's5_depth': 4,
|
||||
's5_i0': 'k7e6',
|
||||
's5_i1': 'k3e6',
|
||||
's5_i2': 'k3e6',
|
||||
's5_i3': 'k7e3',
|
||||
's6_depth': 4,
|
||||
's6_i0': 'k7e6',
|
||||
's6_i1': 'k7e6',
|
||||
's6_i2': 'k7e3',
|
||||
's6_i3': 'k7e3',
|
||||
's7_depth': 1,
|
||||
's7_i0': 'k7e6'
|
||||
}
|
||||
|
||||
elif name == 'acenas-m2':
|
||||
arch = {
|
||||
's2_depth': 1,
|
||||
's2_i0': 'k5e3',
|
||||
's3_depth': 3,
|
||||
's3_i0': 'k3e6',
|
||||
's3_i1': 'k3e3',
|
||||
's3_i2': 'k5e3',
|
||||
's4_depth': 2,
|
||||
's4_i0': 'k7e6',
|
||||
's4_i1': 'k5e6',
|
||||
's5_depth': 4,
|
||||
's5_i0': 'k5e6',
|
||||
's5_i1': 'k5e3',
|
||||
's5_i2': 'k5e6',
|
||||
's5_i3': 'k3e6',
|
||||
's6_depth': 4,
|
||||
's6_i0': 'k7e6',
|
||||
's6_i1': 'k5e6',
|
||||
's6_i2': 'k5e3',
|
||||
's6_i3': 'k5e6',
|
||||
's7_depth': 1,
|
||||
's7_i0': 'k7e6'
|
||||
}
|
||||
|
||||
elif name == 'acenas-m3':
|
||||
arch = {
|
||||
's2_depth': 2,
|
||||
's2_i0': 'k3e3',
|
||||
's2_i1': 'k3e6',
|
||||
's3_depth': 2,
|
||||
's3_i0': 'k5e3',
|
||||
's3_i1': 'k3e3',
|
||||
's4_depth': 3,
|
||||
's4_i0': 'k5e6',
|
||||
's4_i1': 'k7e6',
|
||||
's4_i2': 'k3e6',
|
||||
's5_depth': 4,
|
||||
's5_i0': 'k7e6',
|
||||
's5_i1': 'k7e3',
|
||||
's5_i2': 'k7e3',
|
||||
's5_i3': 'k5e3',
|
||||
's6_depth': 4,
|
||||
's6_i0': 'k7e6',
|
||||
's6_i1': 'k7e3',
|
||||
's6_i2': 'k7e6',
|
||||
's6_i3': 'k3e3',
|
||||
's7_depth': 1,
|
||||
's7_i0': 'k5e6'
|
||||
}
|
||||
|
||||
elif name == 'proxyless-cpu':
|
||||
arch = {
|
||||
's2_depth': 4,
|
||||
's2_i0': 'k3e6',
|
||||
's2_i1': 'k3e3',
|
||||
's2_i2': 'k3e3',
|
||||
's2_i3': 'k3e3',
|
||||
's3_depth': 4,
|
||||
's3_i0': 'k3e6',
|
||||
's3_i1': 'k3e3',
|
||||
's3_i2': 'k3e3',
|
||||
's3_i3': 'k5e3',
|
||||
's4_depth': 2,
|
||||
's4_i0': 'k3e6',
|
||||
's4_i1': 'k3e3',
|
||||
's5_depth': 4,
|
||||
's5_i0': 'k5e6',
|
||||
's5_i1': 'k3e3',
|
||||
's5_i2': 'k3e3',
|
||||
's5_i3': 'k3e3',
|
||||
's6_depth': 4,
|
||||
's6_i0': 'k5e6',
|
||||
's6_i1': 'k5e3',
|
||||
's6_i2': 'k5e3',
|
||||
's6_i3': 'k3e3',
|
||||
's7_depth': 1,
|
||||
's7_i0': 'k5e6'
|
||||
}
|
||||
|
||||
init_kwargs['base_widths'] = [40, 24, 32, 48, 88, 104, 216, 360, 1432]
|
||||
|
||||
elif name == 'proxyless-gpu':
|
||||
arch = {
|
||||
's2_depth': 1,
|
||||
's2_i0': 'k5e3',
|
||||
's3_depth': 2,
|
||||
's3_i0': 'k7e3',
|
||||
's3_i1': 'k3e3',
|
||||
's4_depth': 2,
|
||||
's4_i0': 'k7e6',
|
||||
's4_i1': 'k5e3',
|
||||
's5_depth': 3,
|
||||
's5_i0': 'k5e6',
|
||||
's5_i1': 'k3e3',
|
||||
's5_i2': 'k5e3',
|
||||
's6_depth': 4,
|
||||
's6_i0': 'k7e6',
|
||||
's6_i1': 'k7e6',
|
||||
's6_i2': 'k7e6',
|
||||
's6_i3': 'k5e6',
|
||||
's7_depth': 1,
|
||||
's7_i0': 'k7e6'
|
||||
}
|
||||
|
||||
init_kwargs['base_widths'] = [40, 24, 32, 56, 112, 128, 256, 432, 1728]
|
||||
|
||||
elif name == 'proxyless-mobile':
|
||||
arch = {
|
||||
's2_depth': 2,
|
||||
's2_i0': 'k5e3',
|
||||
's2_i1': 'k3e3',
|
||||
's3_depth': 4,
|
||||
's3_i0': 'k7e3',
|
||||
's3_i1': 'k3e3',
|
||||
's3_i2': 'k5e3',
|
||||
's3_i3': 'k5e3',
|
||||
's4_depth': 4,
|
||||
's4_i0': 'k7e6',
|
||||
's4_i1': 'k5e3',
|
||||
's4_i2': 'k5e3',
|
||||
's4_i3': 'k5e3',
|
||||
's5_depth': 4,
|
||||
's5_i0': 'k5e6',
|
||||
's5_i1': 'k5e3',
|
||||
's5_i2': 'k5e3',
|
||||
's5_i3': 'k5e3',
|
||||
's6_depth': 4,
|
||||
's6_i0': 'k7e6',
|
||||
's6_i1': 'k7e6',
|
||||
's6_i2': 'k7e3',
|
||||
's6_i3': 'k7e3',
|
||||
's7_depth': 1,
|
||||
's7_i0': 'k7e6'
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unsupported architecture with name: {name}')
|
||||
|
||||
model_factory = cls.fixed_arch(arch)
|
||||
model = model_factory(**init_kwargs)
|
||||
|
||||
if pretrained:
|
||||
weight_file = load_pretrained_weight(name, download=download, progress=progress)
|
||||
pretrained_weights = torch.load(weight_file)
|
||||
model.load_state_dict(pretrained_weights)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def reset_parameters(model, model_init='he_fout', init_div_groups=False,
|
||||
bn_momentum=0.1, bn_eps=1e-5):
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if model_init == 'he_fout':
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
if init_div_groups:
|
||||
n /= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif model_init == 'he_fin':
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
|
||||
if init_div_groups:
|
||||
n /= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
m.momentum = bn_momentum
|
||||
m.eps = bn_eps
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
|
@ -0,0 +1,297 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import nni.nas.nn.pytorch as nn
|
||||
from nni.nas import model_wrapper
|
||||
|
||||
from .utils.fixed import FixedFactory
|
||||
from .utils.pretrained import load_pretrained_weight
|
||||
|
||||
|
||||
class ShuffleNetBlock(nn.Module):
|
||||
"""
|
||||
Describe the basic building block of shuffle net, as described in
|
||||
`ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices <https://arxiv.org/pdf/1707.01083.pdf>`__.
|
||||
|
||||
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
|
||||
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
|
||||
super().__init__()
|
||||
assert stride in [1, 2]
|
||||
assert kernel_size in [3, 5, 7]
|
||||
self.channels = in_channels // 2 if stride == 1 else in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.pad = kernel_size // 2
|
||||
self.oup_main = out_channels - self.channels
|
||||
self.affine = affine
|
||||
assert self.oup_main > 0
|
||||
|
||||
self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
|
||||
|
||||
if stride == 2:
|
||||
self.branch_proj = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(self.channels, self.channels, kernel_size, stride, self.pad,
|
||||
groups=self.channels, bias=False),
|
||||
nn.BatchNorm2d(self.channels, affine=affine),
|
||||
# pw-linear
|
||||
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(self.channels, affine=affine),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
else:
|
||||
# empty block to be compatible with torchscript
|
||||
self.branch_proj = nn.Sequential()
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 2:
|
||||
x_proj, x = self.branch_proj(x), x
|
||||
else:
|
||||
x_proj, x = self._channel_shuffle(x)
|
||||
return torch.cat((x_proj, self.branch_main(x)), 1)
|
||||
|
||||
def _decode_point_depth_conv(self, sequence):
|
||||
result = []
|
||||
first_depth = first_point = True
|
||||
pc: int = self.channels
|
||||
c: int = self.channels
|
||||
for i, token in enumerate(sequence):
|
||||
# compute output channels of this conv
|
||||
if i + 1 == len(sequence):
|
||||
assert token == "p", "Last conv must be point-wise conv."
|
||||
c = self.oup_main
|
||||
elif token == "p" and first_point:
|
||||
c = cast(int, self.mid_channels)
|
||||
if token == "d":
|
||||
# depth-wise conv
|
||||
if isinstance(pc, int) and isinstance(c, int):
|
||||
# check can only be done for static channels
|
||||
assert pc == c, "Depth-wise conv must not change channels."
|
||||
result.append(nn.Conv2d(pc, c, self.kernel_size, self.stride if first_depth else 1, self.pad,
|
||||
groups=c, bias=False))
|
||||
result.append(nn.BatchNorm2d(c, affine=self.affine))
|
||||
first_depth = False
|
||||
elif token == "p":
|
||||
# point-wise conv
|
||||
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
|
||||
result.append(nn.BatchNorm2d(c, affine=self.affine))
|
||||
result.append(nn.ReLU(inplace=True))
|
||||
first_point = False
|
||||
else:
|
||||
raise ValueError("Conv sequence must be d and p.")
|
||||
pc = c
|
||||
return result
|
||||
|
||||
def _channel_shuffle(self, x):
|
||||
bs, num_channels, height, width = x.size()
|
||||
# NOTE: this line is commented for torchscript
|
||||
# assert (num_channels % 4 == 0)
|
||||
x = x.reshape(bs * num_channels // 2, 2, height * width)
|
||||
x = x.permute(1, 0, 2)
|
||||
x = x.reshape(2, -1, num_channels // 2, height, width)
|
||||
return x[0], x[1]
|
||||
|
||||
|
||||
class ShuffleXceptionBlock(ShuffleNetBlock):
|
||||
"""
|
||||
The ``choice_x`` version of shuffle net block, described in
|
||||
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
|
||||
super().__init__(in_channels, out_channels, mid_channels,
|
||||
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
|
||||
|
||||
|
||||
@model_wrapper
|
||||
class ShuffleNetSpace(nn.Module):
|
||||
"""
|
||||
The search space proposed in `Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
|
||||
|
||||
The basic building block design is inspired by a state-of-the-art manually-designed network --
|
||||
`ShuffleNetV2 <https://openaccess.thecvf.com/content_ECCV_2018/html/Ningning_Light-weight_CNN_Architecture_ECCV_2018_paper.html>`__.
|
||||
There are 20 choice blocks in total. Each choice block has 4 candidates, namely ``choice 3``, ``choice 5``,
|
||||
``choice_7`` and ``choice_x`` respectively. They differ in kernel sizes and the number of depthwise convolutions.
|
||||
The size of the search space is :math:`4^{20}`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_labels : int
|
||||
Number of classes for the classification head. Default: 1000.
|
||||
channel_search : bool
|
||||
If true, for each building block, the number of ``mid_channels``
|
||||
(output channels of the first 1x1 conv in each building block) varies from 0.2x to 1.6x (quantized to multiple of 0.2).
|
||||
Here, "k-x" means k times the number of default channels.
|
||||
Otherwise, 1.0x is used by default. Default: false.
|
||||
affine : bool
|
||||
Apply affine to all batch norm. Default: true.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_labels: int = 1000,
|
||||
channel_search: bool = False,
|
||||
affine: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.num_labels = num_labels
|
||||
self.channel_search = channel_search
|
||||
self.affine = affine
|
||||
|
||||
# the block number in each stage. 4 stages in total. 20 blocks in total.
|
||||
self.stage_repeats = [4, 4, 8, 4]
|
||||
|
||||
# output channels for all stages, including the very first layer and the very last layer
|
||||
self.stage_out_channels = [-1, 16, 64, 160, 320, 640, 1024]
|
||||
|
||||
# building first layer
|
||||
out_channels = self.stage_out_channels[1]
|
||||
self.first_conv = nn.Sequential(
|
||||
nn.Conv2d(3, out_channels, 3, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
feature_blocks = []
|
||||
|
||||
global_block_idx = 0
|
||||
for stage_idx, num_repeat in enumerate(self.stage_repeats):
|
||||
for block_idx in range(num_repeat):
|
||||
# count global index to give names to choices
|
||||
global_block_idx += 1
|
||||
|
||||
# get ready for input and output
|
||||
in_channels = out_channels
|
||||
out_channels = self.stage_out_channels[stage_idx + 2]
|
||||
stride = 2 if block_idx == 0 else 1
|
||||
|
||||
# mid channels can be searched
|
||||
base_mid_channels = out_channels // 2
|
||||
if self.channel_search:
|
||||
k_choice_list = [int(base_mid_channels * (.2 * k)) for k in range(1, 9)]
|
||||
mid_channels = nn.ValueChoice(k_choice_list, label=f'channel_{global_block_idx}')
|
||||
else:
|
||||
mid_channels = int(base_mid_channels)
|
||||
|
||||
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
|
||||
|
||||
choice_block = nn.LayerChoice(dict(
|
||||
k3=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
|
||||
k5=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
|
||||
k7=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
|
||||
xcep=ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
|
||||
), label=f'layer_{global_block_idx}')
|
||||
feature_blocks.append(choice_block)
|
||||
|
||||
self.features = nn.Sequential(*feature_blocks)
|
||||
|
||||
# final layers
|
||||
last_conv_channels = self.stage_out_channels[-1]
|
||||
self.conv_last = nn.Sequential(
|
||||
nn.Conv2d(out_channels, last_conv_channels, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(last_conv_channels, affine=affine),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.globalpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(last_conv_channels, num_labels, bias=False),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_conv(x)
|
||||
x = self.features(x)
|
||||
x = self.conv_last(x)
|
||||
|
||||
x = self.globalpool(x)
|
||||
|
||||
x = self.dropout(x)
|
||||
x = x.contiguous().view(-1, self.stage_out_channels[-1])
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'first' in name:
|
||||
torch.nn.init.normal_(m.weight, 0, 0.01)
|
||||
else:
|
||||
torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
|
||||
if m.bias is not None:
|
||||
torch.nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
if m.weight is not None:
|
||||
torch.nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
torch.nn.init.constant_(m.bias, 0.0001)
|
||||
if m.running_mean is not None:
|
||||
torch.nn.init.constant_(m.running_mean, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
if m.weight is not None:
|
||||
torch.nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
torch.nn.init.constant_(m.bias, 0.0001)
|
||||
if m.running_mean is not None:
|
||||
torch.nn.init.constant_(m.running_mean, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
torch.nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
torch.nn.init.constant_(m.bias, 0)
|
||||
|
||||
@classmethod
|
||||
def fixed_arch(cls, arch: dict) -> FixedFactory:
|
||||
return FixedFactory(cls, arch)
|
||||
|
||||
@classmethod
|
||||
def load_searched_model(
|
||||
cls, name: str,
|
||||
pretrained: bool = False, download: bool = False, progress: bool = True
|
||||
) -> nn.Module:
|
||||
if name == 'spos':
|
||||
# NOTE: Need BGR tensor, with no normalization
|
||||
# https://github.com/ultmaster/spacehub-conversion/blob/371a4fd6646b4e11eda3f61187f7c9a1d484b1ca/cutils.py#L63
|
||||
arch = {
|
||||
'layer_1': 'k7',
|
||||
'layer_2': 'k5',
|
||||
'layer_3': 'k3',
|
||||
'layer_4': 'k5',
|
||||
'layer_5': 'k7',
|
||||
'layer_6': 'k3',
|
||||
'layer_7': 'k7',
|
||||
'layer_8': 'k3',
|
||||
'layer_9': 'k7',
|
||||
'layer_10': 'k3',
|
||||
'layer_11': 'k7',
|
||||
'layer_12': 'xcep',
|
||||
'layer_13': 'k3',
|
||||
'layer_14': 'k3',
|
||||
'layer_15': 'k3',
|
||||
'layer_16': 'k3',
|
||||
'layer_17': 'xcep',
|
||||
'layer_18': 'k7',
|
||||
'layer_19': 'xcep',
|
||||
'layer_20': 'xcep'
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unsupported architecture with name: {name}')
|
||||
|
||||
model_factory = cls.fixed_arch(arch)
|
||||
model = model_factory()
|
||||
|
||||
if pretrained:
|
||||
weight_file = load_pretrained_weight(name, download=download, progress=progress)
|
||||
pretrained_weights = torch.load(weight_file)
|
||||
model.load_state_dict(pretrained_weights)
|
||||
|
||||
return model
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""This file should be merged to nni/nas/fixed.py"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from nni.nas.utils import ContextStack
|
||||
|
||||
|
||||
class FixedFactory:
|
||||
"""Make a model space ready to create a fixed model.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> factory = FixedFactory(ModelSpaceClass, {"choice1": 3})
|
||||
>>> model = factory(channels=16, classes=10)
|
||||
"""
|
||||
|
||||
# TODO: mutations on ``init_args`` and ``init_kwargs`` themselves are not supported.
|
||||
|
||||
def __init__(self, cls: Type, arch: dict):
|
||||
self.cls = cls
|
||||
self.arch = arch
|
||||
|
||||
def __call__(self, *init_args, **init_kwargs):
|
||||
with ContextStack('fixed', self.arch):
|
||||
return self.cls(*init_args, **init_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f'FixedFactory(class={self.cls}, arch={self.arch})'
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Weights available in this file are processed with scripts in https://github.com/ultmaster/spacehub-conversion,
|
||||
and uploaded with :func:`nni.common.blob_utils.upload_file`.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from nni.common.blob_utils import NNI_BLOB, nni_cache_home, load_or_download_file
|
||||
|
||||
|
||||
PRETRAINED_WEIGHT_URLS = {
|
||||
# proxylessnas
|
||||
'acenas-m1': f'{NNI_BLOB}/nashub/acenas-m1-e215f1b8.pth',
|
||||
'acenas-m2': f'{NNI_BLOB}/nashub/acenas-m2-a8ee9e8f.pth',
|
||||
'acenas-m3': f'{NNI_BLOB}/nashub/acenas-m3-66a5ed7b.pth',
|
||||
'proxyless-cpu': f'{NNI_BLOB}/nashub/proxyless-cpu-2df03430.pth',
|
||||
'proxyless-gpu': f'{NNI_BLOB}/nashub/proxyless-gpu-dbe6dd15.pth',
|
||||
'proxyless-mobile': f'{NNI_BLOB}/nashub/proxyless-mobile-8668a978.pth',
|
||||
|
||||
# mobilenetv3
|
||||
'mobilenetv3-large-100': f'{NNI_BLOB}/nashub/mobilenetv3-large-100-420e040a.pth',
|
||||
'mobilenetv3-small-050': f'{NNI_BLOB}/nashub/mobilenetv3-small-050-05cb7a80.pth',
|
||||
'mobilenetv3-small-075': f'{NNI_BLOB}/nashub/mobilenetv3-small-075-c87d8acb.pth',
|
||||
'mobilenetv3-small-100': f'{NNI_BLOB}/nashub/mobilenetv3-small-100-8332faac.pth',
|
||||
'cream-014': f'{NNI_BLOB}/nashub/cream-014-060aea24.pth',
|
||||
'cream-043': f'{NNI_BLOB}/nashub/cream-043-bec949e1.pth',
|
||||
'cream-114': f'{NNI_BLOB}/nashub/cream-114-fc272590.pth',
|
||||
'cream-287': f'{NNI_BLOB}/nashub/cream-287-a0fcba33.pth',
|
||||
'cream-481': f'{NNI_BLOB}/nashub/cream-481-d85779b6.pth',
|
||||
'cream-604': f'{NNI_BLOB}/nashub/cream-604-9ee425f7.pth',
|
||||
|
||||
# nasnet
|
||||
'darts-v2': f'{NNI_BLOB}/nashub/darts-v2-5465b0d2.pth',
|
||||
|
||||
# spos
|
||||
'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth',
|
||||
|
||||
# autoformer
|
||||
'autoformer-tiny': f'{NNI_BLOB}/nashub/autoformer-searched-tiny-1e90ebc1.pth',
|
||||
'autoformer-small': f'{NNI_BLOB}/nashub/autoformer-searched-small-4bc5d4e5.pth',
|
||||
'autoformer-base': f'{NNI_BLOB}/nashub/autoformer-searched-base-c417590a.pth'
|
||||
}
|
||||
|
||||
|
||||
def load_pretrained_weight(name: str, **kwargs) -> str:
|
||||
if name not in PRETRAINED_WEIGHT_URLS:
|
||||
raise ValueError(f'"{name}" do not have a valid pretrained weight file.')
|
||||
url = PRETRAINED_WEIGHT_URLS[name]
|
||||
|
||||
local_path = os.path.join(nni_cache_home(), 'nashub', url.split('/')[-1])
|
||||
load_or_download_file(local_path, url, **kwargs)
|
||||
return local_path
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .mutator import *
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import warnings
|
||||
from typing import (Any, Iterable, List, Optional, Tuple, cast)
|
||||
|
||||
from nni.nas.execution import Model, Mutation, ModelStatus
|
||||
|
||||
|
||||
__all__ = ['Sampler', 'Mutator', 'InvalidMutation']
|
||||
|
||||
|
||||
Choice = Any
|
||||
|
||||
|
||||
class Sampler:
|
||||
"""
|
||||
Handles `Mutator.choice()` calls.
|
||||
"""
|
||||
|
||||
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
|
||||
pass
|
||||
|
||||
def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class Mutator:
|
||||
"""
|
||||
Mutates graphs in model to generate new model.
|
||||
`Mutator` class will be used in two places:
|
||||
|
||||
1. Inherit `Mutator` to implement graph mutation logic.
|
||||
2. Use `Mutator` subclass to implement NAS strategy.
|
||||
|
||||
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
|
||||
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
|
||||
and then use `Mutator.apply()` to mutate model.
|
||||
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
|
||||
# Method names are open for discussion.
|
||||
|
||||
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
|
||||
self.sampler: Optional[Sampler] = sampler
|
||||
if label is None:
|
||||
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
|
||||
self.label: str = label
|
||||
self._cur_model: Optional[Model] = None
|
||||
self._cur_choice_idx: Optional[int] = None
|
||||
|
||||
def bind_sampler(self, sampler: Sampler) -> 'Mutator':
|
||||
"""
|
||||
Set the sampler which will handle `Mutator.choice` calls.
|
||||
"""
|
||||
self.sampler = sampler
|
||||
return self
|
||||
|
||||
def apply(self, model: Model) -> Model:
|
||||
"""
|
||||
Apply this mutator on a model.
|
||||
Returns mutated model.
|
||||
The model will be copied before mutation and the original model will not be modified.
|
||||
"""
|
||||
assert self.sampler is not None
|
||||
copy = model.fork()
|
||||
self._cur_model = copy
|
||||
self._cur_choice_idx = 0
|
||||
self._cur_samples = []
|
||||
self.sampler.mutation_start(self, copy)
|
||||
self.mutate(copy)
|
||||
self.sampler.mutation_end(self, copy)
|
||||
copy.history.append(Mutation(self, self._cur_samples, model, copy))
|
||||
copy.status = ModelStatus.Frozen
|
||||
self._cur_model = None
|
||||
self._cur_choice_idx = None
|
||||
return copy
|
||||
|
||||
def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]:
|
||||
"""
|
||||
Dry run mutator on a model to collect choice candidates.
|
||||
If you invoke this method multiple times on same or different models,
|
||||
it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`.
|
||||
"""
|
||||
sampler_backup = self.sampler
|
||||
recorder = _RecorderSampler()
|
||||
self.sampler = recorder
|
||||
new_model = self.apply(model)
|
||||
self.sampler = sampler_backup
|
||||
return recorder.recorded_candidates, new_model
|
||||
|
||||
def mutate(self, model: Model) -> None:
|
||||
"""
|
||||
Abstract method to be implemented by subclass.
|
||||
Mutate a model in place.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def choice(self, candidates: Iterable[Choice]) -> Choice:
|
||||
"""
|
||||
Ask sampler to make a choice.
|
||||
"""
|
||||
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
|
||||
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
|
||||
self._cur_samples.append(ret)
|
||||
self._cur_choice_idx += 1
|
||||
return ret
|
||||
|
||||
|
||||
class _RecorderSampler(Sampler):
|
||||
def __init__(self):
|
||||
self.recorded_candidates: List[List[Choice]] = []
|
||||
|
||||
def choice(self, candidates: List[Choice], *args) -> Choice:
|
||||
self.recorded_candidates.append(candidates)
|
||||
return candidates[0]
|
||||
|
||||
|
||||
class InvalidMutation(Exception):
|
||||
pass
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from nni.common.framework import shortcut_framework
|
||||
|
||||
shortcut_framework(__name__)
|
||||
|
||||
del shortcut_framework
|
|
@ -0,0 +1 @@
|
|||
_layers.py
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче