зеркало из https://github.com/microsoft/nni.git
[Compression V2] Movement pruning (#4308)
This commit is contained in:
Родитель
40fc466743
Коммит
1eced0a7bf
|
@ -19,6 +19,7 @@ and how to schedule sparsity in each iteration are implemented as iterative prun
|
|||
* `Activation Mean Rank Pruner <#activation-mean-rank-pruner>`__
|
||||
* `Taylor FO Weight Pruner <#taylor-fo-weight-pruner>`__
|
||||
* `ADMM Pruner <#admm-pruner>`__
|
||||
* `Movement Pruner <#movement-pruner>`__
|
||||
|
||||
**Iterative Pruner**
|
||||
|
||||
|
@ -292,6 +293,58 @@ User configuration for ADMM Pruner
|
|||
|
||||
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.ADMMPruner
|
||||
|
||||
Movement Pruner
|
||||
---------------
|
||||
|
||||
Movement pruner is an implementation of movement pruning.
|
||||
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
|
||||
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
|
||||
This means the weight elements moving towards zero will accumulate negative scores, the weight elements moving away from zero will accumulate positive scores.
|
||||
The weight elements with low scores will be masked during inference.
|
||||
|
||||
The following figure from the paper shows the weight pruning by movement pruning.
|
||||
|
||||
.. image:: ../../img/movement_pruning.png
|
||||
:target: ../../img/movement_pruning.png
|
||||
:alt:
|
||||
|
||||
For more details, please refer to `Movement Pruning: Adaptive Sparsity by Fine-Tuning <https://arxiv.org/abs/2005.07683>`__.
|
||||
|
||||
Usage
|
||||
^^^^^^
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner
|
||||
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}]
|
||||
pruner = MovementPruner(model, config_list, p_trainer, optimizer, criterion, 10, 3000, 27000)
|
||||
masked_model, masks = pruner.compress()
|
||||
|
||||
User configuration for Movement Pruner
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
**PyTorch**
|
||||
|
||||
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.MovementPruner
|
||||
|
||||
Reproduced Experiment
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
:widths: auto
|
||||
|
||||
* - Model
|
||||
- Dataset
|
||||
- Remaining Weights
|
||||
- MaP acc.(paper/ours)
|
||||
- MvP acc.(paper/ours)
|
||||
* - Bert base
|
||||
- MNLI - Dev
|
||||
- 10%
|
||||
- 77.8% / 73.6%
|
||||
- 79.3% / 78.8%
|
||||
|
||||
Linear Pruner
|
||||
-------------
|
||||
|
||||
|
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 106 KiB |
|
@ -0,0 +1,122 @@
|
|||
import functools
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets import load_metric, load_dataset
|
||||
from transformers import (
|
||||
BertForSequenceClassification,
|
||||
BertTokenizerFast,
|
||||
DataCollatorWithPadding,
|
||||
set_seed
|
||||
)
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner
|
||||
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
"mnli": ("premise", "hypothesis"),
|
||||
"mrpc": ("sentence1", "sentence2"),
|
||||
"qnli": ("question", "sentence"),
|
||||
"qqp": ("question1", "question2"),
|
||||
"rte": ("sentence1", "sentence2"),
|
||||
"sst2": ("sentence", None),
|
||||
"stsb": ("sentence1", "sentence2"),
|
||||
"wnli": ("sentence1", "sentence2"),
|
||||
}
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
gradient_accumulation_steps = 16
|
||||
|
||||
# a fake criterion because huggingface output already has loss
|
||||
def criterion(input, target):
|
||||
return input.loss
|
||||
|
||||
def trainer(model, optimizer, criterion, train_dataloader):
|
||||
model.train()
|
||||
counter = 0
|
||||
for batch in tqdm(train_dataloader):
|
||||
counter += 1
|
||||
batch.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
# pruner may wrap the criterion, for example, loss = origin_loss + norm(weight), so call criterion to get loss here
|
||||
loss = criterion(outputs, None)
|
||||
loss = loss / gradient_accumulation_steps
|
||||
loss.backward()
|
||||
if counter % gradient_accumulation_steps == 0 or counter == len(train_dataloader):
|
||||
optimizer.step()
|
||||
if counter % 16000 == 0:
|
||||
print('Step {}: {}'.format(counter // gradient_accumulation_steps, evaluator(model, metric, is_regression, validate_dataloader)))
|
||||
|
||||
def evaluator(model, metric, is_regression, eval_dataloader):
|
||||
model.eval()
|
||||
for batch in tqdm(eval_dataloader):
|
||||
batch.to(device)
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=batch["labels"],
|
||||
)
|
||||
return metric.compute()
|
||||
|
||||
if __name__ == '__main__':
|
||||
task_name = 'mnli'
|
||||
is_regression = False
|
||||
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
|
||||
train_batch_size = 8
|
||||
eval_batch_size = 8
|
||||
|
||||
set_seed(1024)
|
||||
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
|
||||
sentence1_key, sentence2_key = task_to_keys[task_name]
|
||||
|
||||
# used to preprocess the raw data
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
|
||||
|
||||
if "label" in examples:
|
||||
# In all cases, rename the column to labels because the model will expect that.
|
||||
result["labels"] = examples["label"]
|
||||
return result
|
||||
|
||||
raw_datasets = load_dataset('glue', task_name, cache_dir='./data')
|
||||
processed_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names)
|
||||
|
||||
train_dataset = processed_datasets['train']
|
||||
validate_dataset = processed_datasets['validation_matched' if task_name == "mnli" else 'validation']
|
||||
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
|
||||
validate_dataloader = DataLoader(validate_dataset, collate_fn=data_collator, batch_size=eval_batch_size)
|
||||
|
||||
metric = load_metric("glue", task_name)
|
||||
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=num_labels).to(device)
|
||||
|
||||
print('Initial: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))
|
||||
|
||||
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}]
|
||||
p_trainer = functools.partial(trainer, train_dataloader=train_dataloader)
|
||||
optimizer = Adam(model.parameters(), lr=2e-5)
|
||||
pruner = MovementPruner(model, config_list, p_trainer, optimizer, criterion, training_epochs=10,
|
||||
warm_up_step=3000, cool_down_beginning_step=27000)
|
||||
|
||||
_, masks = pruner.compress()
|
||||
pruner.show_pruned_weights()
|
||||
|
||||
print('Final: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=2e-5)
|
||||
trainer(model, optimizer, criterion, train_dataloader)
|
||||
print('After 1 epoch finetuning: {}'.format(evaluator(model, metric, is_regression, validate_dataloader)))
|
|
@ -18,7 +18,7 @@ __all__ = ['Pruner']
|
|||
class PrunerModuleWrapper(Module):
|
||||
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor):
|
||||
"""
|
||||
Wrap an module to enable data parallel, forward method customization and buffer registeration.
|
||||
Wrap a module to enable data parallel, forward method customization and buffer registeration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .basic_pruner import *
|
||||
from .basic_scheduler import PruningScheduler
|
||||
from .iterative_pruner import *
|
||||
from .movement_pruner import MovementPruner
|
||||
from .auto_compress_pruner import AutoCompressPruner
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from torch import autograd, Tensor
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer, Adam
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.base.compressor import Compressor, _setattr, LayerInfo
|
||||
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
|
||||
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema
|
||||
|
||||
from .tools.base import TrainerBasedDataCollector
|
||||
|
||||
from .tools import (
|
||||
StraightMetricsCalculator,
|
||||
NormalSparsityAllocator
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PrunerScoredModuleWrapper(Module):
|
||||
"""
|
||||
Wrap a module to enable data parallel, forward method customization and buffer registeration.
|
||||
Different from `PrunerModuleWrapper`, `PrunerScoredModuleWrapper` will record the gradient.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module
|
||||
The module user wants to compress.
|
||||
config
|
||||
The configurations that users specify for compression.
|
||||
module_name
|
||||
The name of the module to compress, wrapper module shares same name.
|
||||
pruner
|
||||
The pruner used to calculate mask.
|
||||
"""
|
||||
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor):
|
||||
super().__init__()
|
||||
# origin layer information
|
||||
self.module = module
|
||||
self.name = module_name
|
||||
# config and pruner
|
||||
self.config = config
|
||||
self.pruner = pruner
|
||||
|
||||
self.weight = Parameter(torch.empty(self.module.weight.size()))
|
||||
self.weight.data = self.module.weight.data
|
||||
|
||||
self.weight_score = Parameter(torch.empty(self.weight.size()))
|
||||
torch.nn.init.constant_(self.weight_score, val=0.0)
|
||||
|
||||
# register buffer for mask
|
||||
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
|
||||
self.bias = Parameter(torch.empty(self.module.bias.size()))
|
||||
self.bias.data = self.module.bias.data
|
||||
else:
|
||||
self.register_buffer("bias_mask", None)
|
||||
|
||||
def _weight2buffer(self):
|
||||
"""
|
||||
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
|
||||
The best place to call this function is in `Pruner._wrap_model()`.
|
||||
"""
|
||||
delattr(self.module, 'weight')
|
||||
self.module.register_buffer('weight', self.weight.data)
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
delattr(self.module, 'bias')
|
||||
self.module.register_buffer('bias', self.bias.data)
|
||||
|
||||
def _weight2parameter(self):
|
||||
"""
|
||||
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
|
||||
The best place to call this function is in `Pruner._unwrap_model()`.
|
||||
"""
|
||||
delattr(self.module, 'weight')
|
||||
self.module.weight = Parameter(torch.empty(self.weight.size()))
|
||||
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
delattr(self.module, 'bias')
|
||||
self.module.bias = Parameter(torch.empty(self.bias.size()))
|
||||
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
|
||||
|
||||
def forward(self, *inputs):
|
||||
# apply mask to weight, bias
|
||||
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask))
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
self.module.bias = torch.mul(self.bias, self.bias_mask)
|
||||
return self.module(*inputs)
|
||||
|
||||
|
||||
class _StraightThrough(autograd.Function):
|
||||
"""
|
||||
Straight through the gradient to the score, then the score = initial_score + sum(-lr * grad(weight) * weight).
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(self, score, masks):
|
||||
return masks
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradOutput):
|
||||
return gradOutput, None
|
||||
|
||||
|
||||
class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
|
||||
"""
|
||||
Collect all weight_score in wrappers as data used to calculate metrics.
|
||||
"""
|
||||
def _reset_optimizer(self):
|
||||
"""
|
||||
Weed out the weight_score from the parameters passed to optimizer, guaranteed to load the optimizer state dict.
|
||||
"""
|
||||
if self._origin_optimizer_cls is not None:
|
||||
optimizer_grouped_parameters = [{
|
||||
"params": [p for n, p in self.compressor.bound_model.named_parameters() if "weight_score" not in n and p.requires_grad]
|
||||
}]
|
||||
if self._origin_optimizer_cls.__name__ == 'SGD':
|
||||
self.optimizer = self._origin_optimizer_cls(optimizer_grouped_parameters, lr=0.001)
|
||||
else:
|
||||
self.optimizer = self._origin_optimizer_cls(optimizer_grouped_parameters)
|
||||
self.optimizer.load_state_dict(self._origin_optimizer_state_dict)
|
||||
else:
|
||||
self.optimizer = None
|
||||
|
||||
def collect(self) -> Dict[str, Tensor]:
|
||||
for _ in range(self.training_epochs):
|
||||
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
|
||||
|
||||
data = {}
|
||||
for _, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
data[wrapper.name] = wrapper.weight_score.data.clone().detach()
|
||||
return data
|
||||
|
||||
|
||||
class MovementPruner(BasicPruner):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : torch.nn.Module
|
||||
Model to be pruned.
|
||||
config_list : List[Dict]
|
||||
Supported keys:
|
||||
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
|
||||
- sparsity_per_layer : Equals to sparsity.
|
||||
- op_types : Operation types to prune.
|
||||
- op_names : Operation names to prune.
|
||||
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
|
||||
trainer : Callable[[Module, Optimizer, Callable]
|
||||
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
|
||||
The model will be trained or inferenced `training_epochs` epochs.
|
||||
|
||||
Example::
|
||||
|
||||
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
|
||||
training = model.training
|
||||
model.train(mode=True)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
|
||||
optimizer.step()
|
||||
model.train(mode=training)
|
||||
optimizer : torch.optim.Optimizer
|
||||
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
|
||||
so do not use this optimizer in other places.
|
||||
criterion : Callable[[Tensor, Tensor], Tensor]
|
||||
The criterion function used in trainer. Take model output and target value as input, and return the loss.
|
||||
training_epochs : int
|
||||
The total epoch number for training the model.
|
||||
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
|
||||
warm_up_step : int
|
||||
The total `optimizer.step()` number before start pruning for warm up.
|
||||
Make sure `warm_up_step` is smaller than `cool_down_beginning_step`.
|
||||
cool_down_beginning_step: int
|
||||
The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed.
|
||||
The sparsity after each `optimizer.step()` is:
|
||||
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
|
||||
"""
|
||||
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
|
||||
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
|
||||
cool_down_beginning_step: int):
|
||||
self.trainer = trainer
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.training_epochs = training_epochs
|
||||
self.warm_up_step = warm_up_step
|
||||
self.cool_down_beginning_step = cool_down_beginning_step
|
||||
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
|
||||
super().__init__(model, config_list)
|
||||
|
||||
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
|
||||
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
|
||||
schema = CompressorSchema(schema_list, model, _logger)
|
||||
schema.validate(config_list)
|
||||
|
||||
def cubic_schedule(self, current_step: int):
|
||||
if self.warm_up_step < current_step <= self.cool_down_beginning_step:
|
||||
wrapper_dict = self.get_modules_wrapper()
|
||||
for config in self.config_list:
|
||||
current_sparsity = config['total_sparsity'] * (1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3)
|
||||
for op_name in config['op_names']:
|
||||
wrapper_dict[op_name].config['total_sparsity'] = current_sparsity
|
||||
|
||||
def reset_tools(self):
|
||||
if self.metrics_calculator is None:
|
||||
self.metrics_calculator = StraightMetricsCalculator()
|
||||
if self.sparsity_allocator is None:
|
||||
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
|
||||
|
||||
# use Adam to update the weight_score
|
||||
params = [{"params": [p for n, p in self.bound_model.named_parameters() if "weight_score" in n and p.requires_grad]}]
|
||||
optimizer = Adam(params, 1e-2)
|
||||
self.step_counter = 0
|
||||
|
||||
# update the masks after each optimzier step
|
||||
def _optimizer_patch():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
self.step_counter += 1
|
||||
if self.step_counter > self.warm_up_step:
|
||||
self.cubic_schedule(self.step_counter)
|
||||
data = {}
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
data[wrapper.name] = wrapper.weight_score.data
|
||||
metrics = self.metrics_calculator.calculate_metrics(data)
|
||||
masks = self.sparsity_allocator.generate_sparsity(metrics)
|
||||
self.load_masks(masks)
|
||||
|
||||
if self.data_collector is None:
|
||||
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch])
|
||||
else:
|
||||
self.data_collector.reset()
|
||||
|
||||
def _wrap_model(self):
|
||||
"""
|
||||
Wrap all modules that needed to be compressed.
|
||||
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
|
||||
"""
|
||||
if not self.is_wrapped:
|
||||
for _, wrapper in reversed(self.get_modules_wrapper().items()):
|
||||
_setattr(self.bound_model, wrapper.name, wrapper)
|
||||
wrapper._weight2buffer()
|
||||
self.is_wrapped = True
|
||||
|
||||
def _unwrap_model(self):
|
||||
"""
|
||||
Unwrap all modules that needed to be compressed.
|
||||
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
|
||||
"""
|
||||
if self.is_wrapped:
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
_setattr(self.bound_model, wrapper.name, wrapper.module)
|
||||
wrapper._weight2parameter()
|
||||
self.is_wrapped = False
|
||||
|
||||
def _wrap_modules(self, layer: LayerInfo, config: Dict):
|
||||
"""
|
||||
Create a wrapper module to replace the original one.
|
||||
Different from the parent function, use `PrunerScoredModuleWrapper` instead of `PrunerModuleWrapper`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
layer
|
||||
The layer to instrument the mask.
|
||||
config
|
||||
The configuration for generating the mask.
|
||||
"""
|
||||
_logger.debug("Module detected to compress : %s.", layer.name)
|
||||
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config, self)
|
||||
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
|
||||
# move newly registered buffers to the same device of weight
|
||||
wrapper.to(layer.module.weight.device)
|
||||
return wrapper
|
||||
|
||||
def compress(self) -> Tuple[Module, Dict]:
|
||||
# sparsity grow from 0
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
wrapper.config['total_sparsity'] = 0
|
||||
result = super().compress()
|
||||
# del weight_score
|
||||
for _, wrapper in self.get_modules_wrapper().items():
|
||||
wrapper.weight_score = None
|
||||
return result
|
|
@ -11,6 +11,7 @@ from .data_collector import (
|
|||
SingleHookTrainerBasedDataCollector
|
||||
)
|
||||
from .metrics_calculator import (
|
||||
StraightMetricsCalculator,
|
||||
NormMetricsCalculator,
|
||||
MultiDataNormMetricsCalculator,
|
||||
DistMetricsCalculator,
|
||||
|
|
|
@ -133,7 +133,8 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
super().__init__(compressor)
|
||||
self.trainer = trainer
|
||||
self.training_epochs = training_epochs
|
||||
self._origin_optimizer = optimizer
|
||||
self._origin_optimizer_cls = optimizer.__class__ if optimizer is not None else None
|
||||
self._origin_optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
self._origin_criterion = criterion
|
||||
self._opt_before_tasks = opt_before_tasks
|
||||
self._opt_after_tasks = opt_after_tasks
|
||||
|
@ -146,22 +147,12 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
|
||||
def reset(self):
|
||||
# refresh optimizer and criterion
|
||||
self.compressor._unwrap_model()
|
||||
if self._origin_optimizer is not None:
|
||||
optimizer_cls = self._origin_optimizer.__class__
|
||||
if optimizer_cls.__name__ == 'SGD':
|
||||
self.optimizer = optimizer_cls(self.compressor.bound_model.parameters(), lr=0.001)
|
||||
else:
|
||||
self.optimizer = optimizer_cls(self.compressor.bound_model.parameters())
|
||||
self.optimizer.load_state_dict(self._origin_optimizer.state_dict())
|
||||
else:
|
||||
self.optimizer = None
|
||||
self._reset_optimizer()
|
||||
|
||||
if self._criterion_patch is not None:
|
||||
self.criterion = self._criterion_patch(self._origin_criterion)
|
||||
else:
|
||||
self.criterion = self._origin_criterion
|
||||
self.compressor._wrap_model()
|
||||
|
||||
# patch optimizer
|
||||
self._patch_optimizer()
|
||||
|
@ -173,6 +164,18 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
self._hook_buffer = {}
|
||||
self._add_all_hook()
|
||||
|
||||
def _reset_optimizer(self):
|
||||
self.compressor._unwrap_model()
|
||||
if self._origin_optimizer_cls is not None:
|
||||
if self._origin_optimizer_cls.__name__ == 'SGD':
|
||||
self.optimizer = self._origin_optimizer_cls(self.compressor.bound_model.parameters(), lr=0.001)
|
||||
else:
|
||||
self.optimizer = self._origin_optimizer_cls(self.compressor.bound_model.parameters())
|
||||
self.optimizer.load_state_dict(self._origin_optimizer_state_dict)
|
||||
else:
|
||||
self.optimizer = None
|
||||
self.compressor._wrap_model()
|
||||
|
||||
def _patch_optimizer(self):
|
||||
def patch_step(old_step):
|
||||
def new_step(_, *args, **kwargs):
|
||||
|
@ -315,7 +318,7 @@ class SparsityAllocator:
|
|||
"""
|
||||
|
||||
def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None,
|
||||
block_sparse_size: Optional[Union[int, List[int]]] = None):
|
||||
block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -339,6 +342,8 @@ class SparsityAllocator:
|
|||
Example:
|
||||
|
||||
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
|
||||
continuous_mask
|
||||
Inherit the mask already in the wrapper if set True.
|
||||
"""
|
||||
self.pruner = pruner
|
||||
self.dim = dim if not isinstance(dim, int) else [dim]
|
||||
|
@ -350,6 +355,7 @@ class SparsityAllocator:
|
|||
if self.dim is not None:
|
||||
assert all(i >= 0 for i in self.dim)
|
||||
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
|
||||
self.continuous_mask = continuous_mask
|
||||
|
||||
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
|
||||
"""
|
||||
|
|
|
@ -9,7 +9,18 @@ from torch import Tensor
|
|||
from .base import MetricsCalculator
|
||||
|
||||
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
|
||||
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator']
|
||||
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
|
||||
|
||||
|
||||
class StraightMetricsCalculator(MetricsCalculator):
|
||||
"""
|
||||
This metrics calculator directly returns a copy of data as metrics.
|
||||
"""
|
||||
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
metrics = {}
|
||||
for name, tensor in data.items():
|
||||
metrics[name] = tensor.clone().detach()
|
||||
return metrics
|
||||
|
||||
|
||||
class NormMetricsCalculator(MetricsCalculator):
|
||||
|
|
|
@ -24,7 +24,9 @@ class NormalSparsityAllocator(SparsityAllocator):
|
|||
sparsity_rate = wrapper.config['total_sparsity']
|
||||
|
||||
assert name in metrics, 'Metric of %s is not calculated.'
|
||||
metric = metrics[name] * self._compress_mask(wrapper.weight_mask)
|
||||
metric = metrics[name]
|
||||
if self.continuous_mask:
|
||||
metric *= self._compress_mask(wrapper.weight_mask)
|
||||
prune_num = int(sparsity_rate * metric.numel())
|
||||
if prune_num == 0:
|
||||
threshold = metric.min() - 1
|
||||
|
@ -64,7 +66,8 @@ class GlobalSparsityAllocator(SparsityAllocator):
|
|||
|
||||
for name, metric in group_metric_dict.items():
|
||||
wrapper = self.pruner.get_modules_wrapper()[name]
|
||||
metric = metric * self._compress_mask(wrapper.weight_mask)
|
||||
if self.continuous_mask:
|
||||
metric = metric * self._compress_mask(wrapper.weight_mask)
|
||||
layer_weight_num = wrapper.module.weight.data.numel()
|
||||
total_weight_num += layer_weight_num
|
||||
expend_times = int(layer_weight_num / metric.numel())
|
||||
|
@ -113,7 +116,10 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
|
|||
masks = {}
|
||||
grouped_metrics = {}
|
||||
for idx, names in enumerate(self.channel_depen):
|
||||
grouped_metric = {name: metrics[name] * self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) for name in names if name in metrics}
|
||||
grouped_metric = {name: metrics[name] for name in names if name in metrics}
|
||||
if self.continuous_mask:
|
||||
for name, metric in grouped_metric.items():
|
||||
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask)
|
||||
if len(grouped_metric) > 0:
|
||||
grouped_metrics[idx] = grouped_metric
|
||||
for _, group_metric_dict in grouped_metrics.items():
|
||||
|
|
|
@ -281,7 +281,7 @@ class Compressor:
|
|||
class PrunerModuleWrapper(torch.nn.Module):
|
||||
def __init__(self, module, module_name, module_type, config, pruner):
|
||||
"""
|
||||
Wrap an module to enable data parallel, forward method customization and buffer registeration.
|
||||
Wrap a module to enable data parallel, forward method customization and buffer registeration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -495,7 +495,7 @@ class Pruner(Compressor):
|
|||
class QuantizerModuleWrapper(torch.nn.Module):
|
||||
def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None):
|
||||
"""
|
||||
Wrap an module to enable data parallel, forward method customization and buffer registeration.
|
||||
Wrap a module to enable data parallel, forward method customization and buffer registeration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
@ -15,7 +15,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
|
|||
ActivationAPoZRankPruner,
|
||||
ActivationMeanRankPruner,
|
||||
TaylorFOWeightPruner,
|
||||
ADMMPruner
|
||||
ADMMPruner,
|
||||
MovementPruner
|
||||
)
|
||||
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact
|
||||
|
||||
|
@ -67,7 +68,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_l1_norm_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -77,7 +78,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_l2_norm_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -87,7 +88,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_fpgm_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -97,7 +98,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_slim_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -107,7 +108,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_activation_apoz_rank_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -119,7 +120,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_activation_mean_rank_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -131,7 +132,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_taylor_fo_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -142,7 +143,7 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_admm_pruner(self):
|
||||
model = TorchModel()
|
||||
|
@ -152,7 +153,18 @@ class PrunerTestCase(unittest.TestCase):
|
|||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
def test_movement_pruner(self):
|
||||
model = TorchModel()
|
||||
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
|
||||
pruner = MovementPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model),
|
||||
criterion=criterion, training_epochs=5, warm_up_step=0, cool_down_beginning_step=4)
|
||||
pruned_model, masks = pruner.compress()
|
||||
pruner._unwrap_model()
|
||||
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
|
||||
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче