[Compression V2] Movement pruning (#4308)

This commit is contained in:
J-shang 2021-11-29 11:01:39 +08:00 коммит произвёл GitHub
Родитель 40fc466743
Коммит 1eced0a7bf
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 536 добавлений и 30 удалений

Просмотреть файл

@ -19,6 +19,7 @@ and how to schedule sparsity in each iteration are implemented as iterative prun
* `Activation Mean Rank Pruner <#activation-mean-rank-pruner>`__
* `Taylor FO Weight Pruner <#taylor-fo-weight-pruner>`__
* `ADMM Pruner <#admm-pruner>`__
* `Movement Pruner <#movement-pruner>`__
**Iterative Pruner**
@ -292,6 +293,58 @@ User configuration for ADMM Pruner
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.ADMMPruner
Movement Pruner
---------------
Movement pruner is an implementation of movement pruning.
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
This means the weight elements moving towards zero will accumulate negative scores, the weight elements moving away from zero will accumulate positive scores.
The weight elements with low scores will be masked during inference.
The following figure from the paper shows the weight pruning by movement pruning.
.. image:: ../../img/movement_pruning.png
:target: ../../img/movement_pruning.png
:alt:
For more details, please refer to `Movement Pruning: Adaptive Sparsity by Fine-Tuning <https://arxiv.org/abs/2005.07683>`__.
Usage
^^^^^^
.. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}]
pruner = MovementPruner(model, config_list, p_trainer, optimizer, criterion, 10, 3000, 27000)
masked_model, masks = pruner.compress()
User configuration for Movement Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch**
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.MovementPruner
Reproduced Experiment
^^^^^^^^^^^^^^^^^^^^^
.. list-table::
:header-rows: 1
:widths: auto
* - Model
- Dataset
- Remaining Weights
- MaP acc.(paper/ours)
- MvP acc.(paper/ours)
* - Bert base
- MNLI - Dev
- 10%
- 77.8% / 73.6%
- 79.3% / 78.8%
Linear Pruner
-------------

Двоичные данные
docs/img/movement_pruning.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 106 KiB

Просмотреть файл

@ -0,0 +1,122 @@
import functools
from tqdm import tqdm
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from datasets import load_metric, load_dataset
from transformers import (
BertForSequenceClassification,
BertTokenizerFast,
DataCollatorWithPadding,
set_seed
)
from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gradient_accumulation_steps = 16
# a fake criterion because huggingface output already has loss
def criterion(input, target):
return input.loss
def trainer(model, optimizer, criterion, train_dataloader):
model.train()
counter = 0
for batch in tqdm(train_dataloader):
counter += 1
batch.to(device)
optimizer.zero_grad()
outputs = model(**batch)
# pruner may wrap the criterion, for example, loss = origin_loss + norm(weight), so call criterion to get loss here
loss = criterion(outputs, None)
loss = loss / gradient_accumulation_steps
loss.backward()
if counter % gradient_accumulation_steps == 0 or counter == len(train_dataloader):
optimizer.step()
if counter % 16000 == 0:
print('Step {}: {}'.format(counter // gradient_accumulation_steps, evaluator(model, metric, is_regression, validate_dataloader)))
def evaluator(model, metric, is_regression, eval_dataloader):
model.eval()
for batch in tqdm(eval_dataloader):
batch.to(device)
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
metric.add_batch(
predictions=predictions,
references=batch["labels"],
)
return metric.compute()
if __name__ == '__main__':
task_name = 'mnli'
is_regression = False
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
train_batch_size = 8
eval_batch_size = 8
set_seed(1024)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
sentence1_key, sentence2_key = task_to_keys[task_name]
# used to preprocess the raw data
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
if "label" in examples:
# In all cases, rename the column to labels because the model will expect that.
result["labels"] = examples["label"]
return result
raw_datasets = load_dataset('glue', task_name, cache_dir='./data')
processed_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names)
train_dataset = processed_datasets['train']
validate_dataset = processed_datasets['validation_matched' if task_name == "mnli" else 'validation']
data_collator = DataCollatorWithPadding(tokenizer)
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
validate_dataloader = DataLoader(validate_dataset, collate_fn=data_collator, batch_size=eval_batch_size)
metric = load_metric("glue", task_name)
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=num_labels).to(device)
print('Initial: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}]
p_trainer = functools.partial(trainer, train_dataloader=train_dataloader)
optimizer = Adam(model.parameters(), lr=2e-5)
pruner = MovementPruner(model, config_list, p_trainer, optimizer, criterion, training_epochs=10,
warm_up_step=3000, cool_down_beginning_step=27000)
_, masks = pruner.compress()
pruner.show_pruned_weights()
print('Final: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))
optimizer = Adam(model.parameters(), lr=2e-5)
trainer(model, optimizer, criterion, train_dataloader)
print('After 1 epoch finetuning: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))

Просмотреть файл

@ -18,7 +18,7 @@ __all__ = ['Pruner']
class PrunerModuleWrapper(Module):
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------

Просмотреть файл

@ -1,4 +1,5 @@
from .basic_pruner import *
from .basic_scheduler import PruningScheduler
from .iterative_pruner import *
from .movement_pruner import MovementPruner
from .auto_compress_pruner import AutoCompressPruner

Просмотреть файл

@ -0,0 +1,294 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
import logging
from typing import Dict, List, Tuple, Callable
import torch
from torch import autograd, Tensor
from torch.nn import Module, Parameter
from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base.compressor import Compressor, _setattr, LayerInfo
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema
from .tools.base import TrainerBasedDataCollector
from .tools import (
StraightMetricsCalculator,
NormalSparsityAllocator
)
_logger = logging.getLogger(__name__)
class PrunerScoredModuleWrapper(Module):
"""
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Different from `PrunerModuleWrapper`, `PrunerScoredModuleWrapper` will record the gradient.
Parameters
----------
module
The module user wants to compress.
config
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor):
super().__init__()
# origin layer information
self.module = module
self.name = module_name
# config and pruner
self.config = config
self.pruner = pruner
self.weight = Parameter(torch.empty(self.module.weight.size()))
self.weight.data = self.module.weight.data
self.weight_score = Parameter(torch.empty(self.weight.size()))
torch.nn.init.constant_(self.weight_score, val=0.0)
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
self.bias = Parameter(torch.empty(self.module.bias.size()))
self.bias.data = self.module.bias.data
else:
self.register_buffer("bias_mask", None)
def _weight2buffer(self):
"""
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`.
"""
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.weight.data)
if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias')
self.module.register_buffer('bias', self.bias.data)
def _weight2parameter(self):
"""
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr(self.module, 'weight')
self.module.weight = Parameter(torch.empty(self.weight.size()))
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias')
self.module.bias = Parameter(torch.empty(self.bias.size()))
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
def forward(self, *inputs):
# apply mask to weight, bias
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask)
return self.module(*inputs)
class _StraightThrough(autograd.Function):
"""
Straight through the gradient to the score, then the score = initial_score + sum(-lr * grad(weight) * weight).
"""
@staticmethod
def forward(self, score, masks):
return masks
@staticmethod
def backward(ctx, gradOutput):
return gradOutput, None
class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Collect all weight_score in wrappers as data used to calculate metrics.
"""
def _reset_optimizer(self):
"""
Weed out the weight_score from the parameters passed to optimizer, guaranteed to load the optimizer state dict.
"""
if self._origin_optimizer_cls is not None:
optimizer_grouped_parameters = [{
"params": [p for n, p in self.compressor.bound_model.named_parameters() if "weight_score" not in n and p.requires_grad]
}]
if self._origin_optimizer_cls.__name__ == 'SGD':
self.optimizer = self._origin_optimizer_cls(optimizer_grouped_parameters, lr=0.001)
else:
self.optimizer = self._origin_optimizer_cls(optimizer_grouped_parameters)
self.optimizer.load_state_dict(self._origin_optimizer_state_dict)
else:
self.optimizer = None
def collect(self) -> Dict[str, Tensor]:
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data.clone().detach()
return data
class MovementPruner(BasicPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : List[Dict]
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Operation types to prune.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer : torch.optim.Optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs : int
The total epoch number for training the model.
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
warm_up_step : int
The total `optimizer.step()` number before start pruning for warm up.
Make sure `warm_up_step` is smaller than `cool_down_beginning_step`.
cool_down_beginning_step: int
The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed.
The sparsity after each `optimizer.step()` is:
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
"""
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
cool_down_beginning_step: int):
self.trainer = trainer
self.optimizer = optimizer
self.criterion = criterion
self.training_epochs = training_epochs
self.warm_up_step = warm_up_step
self.cool_down_beginning_step = cool_down_beginning_step
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
super().__init__(model, config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
def cubic_schedule(self, current_step: int):
if self.warm_up_step < current_step <= self.cool_down_beginning_step:
wrapper_dict = self.get_modules_wrapper()
for config in self.config_list:
current_sparsity = config['total_sparsity'] * (1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3)
for op_name in config['op_names']:
wrapper_dict[op_name].config['total_sparsity'] = current_sparsity
def reset_tools(self):
if self.metrics_calculator is None:
self.metrics_calculator = StraightMetricsCalculator()
if self.sparsity_allocator is None:
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
# use Adam to update the weight_score
params = [{"params": [p for n, p in self.bound_model.named_parameters() if "weight_score" in n and p.requires_grad]}]
optimizer = Adam(params, 1e-2)
self.step_counter = 0
# update the masks after each optimzier step
def _optimizer_patch():
optimizer.step()
optimizer.zero_grad()
self.step_counter += 1
if self.step_counter > self.warm_up_step:
self.cubic_schedule(self.step_counter)
data = {}
for _, wrapper in self.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data
metrics = self.metrics_calculator.calculate_metrics(data)
masks = self.sparsity_allocator.generate_sparsity(metrics)
self.load_masks(masks)
if self.data_collector is None:
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
def _wrap_model(self):
"""
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
if not self.is_wrapped:
for _, wrapper in reversed(self.get_modules_wrapper().items()):
_setattr(self.bound_model, wrapper.name, wrapper)
wrapper._weight2buffer()
self.is_wrapped = True
def _unwrap_model(self):
"""
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
if self.is_wrapped:
for _, wrapper in self.get_modules_wrapper().items():
_setattr(self.bound_model, wrapper.name, wrapper.module)
wrapper._weight2parameter()
self.is_wrapped = False
def _wrap_modules(self, layer: LayerInfo, config: Dict):
"""
Create a wrapper module to replace the original one.
Different from the parent function, use `PrunerScoredModuleWrapper` instead of `PrunerModuleWrapper`.
Parameters
----------
layer
The layer to instrument the mask.
config
The configuration for generating the mask.
"""
_logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config, self)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device)
return wrapper
def compress(self) -> Tuple[Module, Dict]:
# sparsity grow from 0
for _, wrapper in self.get_modules_wrapper().items():
wrapper.config['total_sparsity'] = 0
result = super().compress()
# del weight_score
for _, wrapper in self.get_modules_wrapper().items():
wrapper.weight_score = None
return result

Просмотреть файл

@ -11,6 +11,7 @@ from .data_collector import (
SingleHookTrainerBasedDataCollector
)
from .metrics_calculator import (
StraightMetricsCalculator,
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
DistMetricsCalculator,

Просмотреть файл

@ -133,7 +133,8 @@ class TrainerBasedDataCollector(DataCollector):
super().__init__(compressor)
self.trainer = trainer
self.training_epochs = training_epochs
self._origin_optimizer = optimizer
self._origin_optimizer_cls = optimizer.__class__ if optimizer is not None else None
self._origin_optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None
self._origin_criterion = criterion
self._opt_before_tasks = opt_before_tasks
self._opt_after_tasks = opt_after_tasks
@ -146,22 +147,12 @@ class TrainerBasedDataCollector(DataCollector):
def reset(self):
# refresh optimizer and criterion
self.compressor._unwrap_model()
if self._origin_optimizer is not None:
optimizer_cls = self._origin_optimizer.__class__
if optimizer_cls.__name__ == 'SGD':
self.optimizer = optimizer_cls(self.compressor.bound_model.parameters(), lr=0.001)
else:
self.optimizer = optimizer_cls(self.compressor.bound_model.parameters())
self.optimizer.load_state_dict(self._origin_optimizer.state_dict())
else:
self.optimizer = None
self._reset_optimizer()
if self._criterion_patch is not None:
self.criterion = self._criterion_patch(self._origin_criterion)
else:
self.criterion = self._origin_criterion
self.compressor._wrap_model()
# patch optimizer
self._patch_optimizer()
@ -173,6 +164,18 @@ class TrainerBasedDataCollector(DataCollector):
self._hook_buffer = {}
self._add_all_hook()
def _reset_optimizer(self):
self.compressor._unwrap_model()
if self._origin_optimizer_cls is not None:
if self._origin_optimizer_cls.__name__ == 'SGD':
self.optimizer = self._origin_optimizer_cls(self.compressor.bound_model.parameters(), lr=0.001)
else:
self.optimizer = self._origin_optimizer_cls(self.compressor.bound_model.parameters())
self.optimizer.load_state_dict(self._origin_optimizer_state_dict)
else:
self.optimizer = None
self.compressor._wrap_model()
def _patch_optimizer(self):
def patch_step(old_step):
def new_step(_, *args, **kwargs):
@ -315,7 +318,7 @@ class SparsityAllocator:
"""
def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None):
block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True):
"""
Parameters
----------
@ -339,6 +342,8 @@ class SparsityAllocator:
Example:
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
continuous_mask
Inherit the mask already in the wrapper if set True.
"""
self.pruner = pruner
self.dim = dim if not isinstance(dim, int) else [dim]
@ -350,6 +355,7 @@ class SparsityAllocator:
if self.dim is not None:
assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
self.continuous_mask = continuous_mask
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
"""

Просмотреть файл

@ -9,7 +9,18 @@ from torch import Tensor
from .base import MetricsCalculator
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator']
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
class StraightMetricsCalculator(MetricsCalculator):
"""
This metrics calculator directly returns a copy of data as metrics.
"""
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
metrics = {}
for name, tensor in data.items():
metrics[name] = tensor.clone().detach()
return metrics
class NormMetricsCalculator(MetricsCalculator):

Просмотреть файл

@ -24,7 +24,9 @@ class NormalSparsityAllocator(SparsityAllocator):
sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of %s is not calculated.'
metric = metrics[name] * self._compress_mask(wrapper.weight_mask)
metric = metrics[name]
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask)
prune_num = int(sparsity_rate * metric.numel())
if prune_num == 0:
threshold = metric.min() - 1
@ -64,7 +66,8 @@ class GlobalSparsityAllocator(SparsityAllocator):
for name, metric in group_metric_dict.items():
wrapper = self.pruner.get_modules_wrapper()[name]
metric = metric * self._compress_mask(wrapper.weight_mask)
if self.continuous_mask:
metric = metric * self._compress_mask(wrapper.weight_mask)
layer_weight_num = wrapper.module.weight.data.numel()
total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel())
@ -113,7 +116,10 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
masks = {}
grouped_metrics = {}
for idx, names in enumerate(self.channel_depen):
grouped_metric = {name: metrics[name] * self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) for name in names if name in metrics}
grouped_metric = {name: metrics[name] for name in names if name in metrics}
if self.continuous_mask:
for name, metric in grouped_metric.items():
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask)
if len(grouped_metric) > 0:
grouped_metrics[idx] = grouped_metric
for _, group_metric_dict in grouped_metrics.items():

Просмотреть файл

@ -281,7 +281,7 @@ class Compressor:
class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
@ -495,7 +495,7 @@ class Pruner(Compressor):
class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------

Просмотреть файл

@ -15,7 +15,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
ActivationAPoZRankPruner,
ActivationMeanRankPruner,
TaylorFOWeightPruner,
ADMMPruner
ADMMPruner,
MovementPruner
)
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact
@ -67,7 +68,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_l1_norm_pruner(self):
model = TorchModel()
@ -77,7 +78,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_l2_norm_pruner(self):
model = TorchModel()
@ -87,7 +88,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_fpgm_pruner(self):
model = TorchModel()
@ -97,7 +98,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_slim_pruner(self):
model = TorchModel()
@ -107,7 +108,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_activation_apoz_rank_pruner(self):
model = TorchModel()
@ -119,7 +120,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_activation_mean_rank_pruner(self):
model = TorchModel()
@ -131,7 +132,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_taylor_fo_pruner(self):
model = TorchModel()
@ -142,7 +143,7 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_admm_pruner(self):
model = TorchModel()
@ -152,7 +153,18 @@ class PrunerTestCase(unittest.TestCase):
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_movement_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = MovementPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model),
criterion=criterion, training_epochs=5, warm_up_step=0, cool_down_beginning_step=4)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
if __name__ == '__main__':
unittest.main()