[Compression] evaluator - step2 (#4992)

This commit is contained in:
J-shang 2022-07-28 12:13:13 +08:00 коммит произвёл GitHub
Родитель a689e619c4
Коммит ed455174db
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
39 изменённых файлов: 2402 добавлений и 1255 удалений

Просмотреть файл

@ -5,3 +5,9 @@ Quickstart
PyTorch </tutorials/hpo_quickstart_pytorch/main>
TensorFlow </tutorials/hpo_quickstart_tensorflow/main>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/index
/tutorials/hpo_quickstart_tensorflow/index

Просмотреть файл

@ -0,0 +1,16 @@
Evaluator
=========
.. _compression-torch-evaluator:
TorchEvaluator
--------------
.. autoclass:: nni.compression.pytorch.TorchEvaluator
.. _compression-lightning-evaluator:
LightningEvaluator
------------------
.. autoclass:: nni.compression.pytorch.LightningEvaluator

Просмотреть файл

@ -8,5 +8,6 @@ Compression API Reference
Quantizer <quantizer>
Pruning Speedup <pruning_speedup>
Quantization Speedup <quantization_speedup>
Evaluator <evaluator>
Compression Utilities <utils>
Framework Related <framework>

Просмотреть файл

Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import LightningEvaluator, TorchEvaluator

Просмотреть файл

@ -119,7 +119,8 @@ class Compressor:
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self._modules_to_compress is None:
self._modules_to_compress = []
@ -146,7 +147,8 @@ class Compressor:
Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed.
"""
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
ret = None
for config in self.config_list:
@ -240,8 +242,10 @@ class Compressor:
Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
self._unwrap_model()
module_groups = {}
@ -323,6 +327,8 @@ class Compressor:
torch.nn.Module
model with specified modules compressed.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
return self.bound_model

Просмотреть файл

@ -43,8 +43,8 @@ class PrunerModuleWrapper(Module):
pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
pruning_target = getattr(self.module, pruning_target_name, None)
if hasattr(self.module, pruning_target_name) and pruning_target is not None:
setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape)))
self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape))
setattr(self, pruning_target_name, Parameter(torch.empty_like(pruning_target)))
self.register_buffer(pruning_target_mask_name, torch.ones_like(pruning_target))
else:
self.register_buffer(pruning_target_mask_name, None)
@ -67,11 +67,11 @@ class PrunerModuleWrapper(Module):
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr(self.module, 'weight')
self.module.weight = Parameter(torch.empty(self.weight.size()))
self.module.weight = Parameter(torch.empty_like(self.weight))
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias')
self.module.bias = Parameter(torch.empty(self.bias.size()))
self.module.bias = Parameter(torch.empty_like(self.bias))
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
def forward(self, *inputs):
@ -130,7 +130,8 @@ class Pruner(Compressor):
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if not self.is_wrapped:
for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
@ -143,7 +144,8 @@ class Pruner(Compressor):
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self.is_wrapped:
for wrapper in self.get_modules_wrapper().values():
@ -165,8 +167,10 @@ class Pruner(Compressor):
self._unwrap_model()
parameter_name_map = {}
for name, param in self.bound_model.named_parameters():
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap.
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap.
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`,
# the name will not change after wrap.
# If the parameter name in under wrapped module is others,
# the name `xxx.param` will change to `xxx.module.param` after wrap.
parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name
self._wrap_model()
return parameter_name_map
@ -183,14 +187,12 @@ class Pruner(Compressor):
The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
"""
wrappers = self.get_modules_wrapper()
for name, layer_mask in masks.items():
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name)
if layer_mask.get('weight') is not None:
assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.'
setattr(wrappers[name], 'weight_mask', layer_mask.get('weight'))
if layer_mask.get('bias') is not None:
assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.'
setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))
for module_name, target_masks in masks.items():
assert module_name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(module_name)
for target_name, target_mask in target_masks.items():
assert hasattr(wrappers[module_name], f'{target_name}_mask'), f'There is no attribute {target_name}_mask in wrapper.'
target: Tensor = getattr(self.get_modules_wrapper()[module_name], target_name)
setattr(wrappers[module_name], f'{target_name}_mask', target_mask.to(target.device))
def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
"""

Просмотреть файл

@ -1,9 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Callable, Optional, cast
from typing import Dict, List, Callable, Optional, cast, overload
import json_tricks
import torch
@ -11,12 +13,13 @@ from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity, config_list_canonical
from nni.compression.pytorch.utils import count_flops_params
from .iterative_pruner import IterativePruner, PRUNER_DICT
from .tools import TaskGenerator
from .tools.rl_env import DDPG, AMCEnv
from ..utils import LightningEvaluator, TorchEvaluator, compute_sparsity, config_list_canonical
from ..utils.docstring import _EVALUATOR_DOCSTRING
class AMCTaskGenerator(TaskGenerator):
@ -41,8 +44,8 @@ class AMCTaskGenerator(TaskGenerator):
ddpg_params
The ddpg agent parameters.
target : str
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse,
but in AMC, you can choose flops sparse. This parameter is used to explain what the sparsity setting in config_list refers to.
"""
def __init__(self, total_episode: int, dummy_input: Tensor, origin_model: Module, origin_config_list: List[Dict],
@ -56,7 +59,7 @@ class AMCTaskGenerator(TaskGenerator):
self.config_list_copy = deepcopy(origin_config_list)
super().__init__(origin_model=origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
@ -82,6 +85,8 @@ class AMCTaskGenerator(TaskGenerator):
return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
self.temp_config_list = self.temp_config_list if hasattr(self, 'temp_config_list') else []
# append experience & update agent policy
if self.action is not None:
action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
@ -106,7 +111,8 @@ class AMCTaskGenerator(TaskGenerator):
origin_model = torch.load(self._origin_model_path)
compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.temp_config_list)
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
self.temp_config_list) # type: ignore
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy)
self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity
@ -162,7 +168,7 @@ class AMCTaskGenerator(TaskGenerator):
class AMCPruner(IterativePruner):
r"""
__doc__ = r"""
AMC pruner leverages reinforcement learning to provide the model compression policy.
According to the author, this learning-based compression policy outperforms conventional rule-based compression policy by having a higher compression ratio,
better preserving the accuracy and freeing human labor.
@ -186,10 +192,11 @@ class AMCPruner(IterativePruner):
- op_names : Operation name to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
dummy_input : torch.Tensor
`dummy_input` is required for speedup and tracing the model in RL environment.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
pruning_algorithm : str
Supported pruning algorithm ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
@ -197,8 +204,6 @@ class AMCPruner(IterativePruner):
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
ddpg_params : Dict
Configuration dict to configure the DDPG agent, any key unset will be set to default implicitly.
- hidden1: hidden num of first fully connect layer. Default: 300
@ -223,23 +228,42 @@ class AMCPruner(IterativePruner):
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
Examples
--------
>>> from nni.compression.pytorch.pruning import AMCPruner
>>> config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
>>> dummy_input = torch.rand(...).to(device)
>>> evaluator = ...
>>> finetuner = ...
>>> pruner = AMCPruner(400, model, config_list, dummy_input, evaluator, finetuner=finetuner)
>>> pruner.compress()
Notes
-----
The full script can be found :githublink:`here <examples/model_compress/pruning/amc_pruning_torch.py>`.
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], dummy_input: Tensor,
evaluator: Callable[[Module], float], pruning_algorithm: str = 'l1', log_dir: str = '.',
keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], *args, **kwargs):
new_api = ['evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'ddpg_params', 'pruning_params', 'target']
new_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
old_api = ['dummy_input', 'evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'finetuner', 'ddpg_params',
'pruning_params', 'target']
old_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
pruning_algorithm = init_kwargs['pruning_algorithm']
log_dir = init_kwargs['log_dir']
keep_intermediate_result = init_kwargs['keep_intermediate_result']
ddpg_params = init_kwargs['ddpg_params']
pruning_params = init_kwargs['pruning_params']
target = init_kwargs['target']
dummy_input = self.dummy_input if not self.using_evaluator else self.evaluator.get_dummy_input()
assert pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'], \
"Only support pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']"
task_generator = AMCTaskGenerator(total_episode=total_episode,
@ -251,5 +275,9 @@ class AMCPruner(IterativePruner):
ddpg_params=ddpg_params,
target=target)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=True, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=True, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=True, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore

Просмотреть файл

@ -1,18 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional
from typing import Dict, List, Callable, Optional, overload
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator
from ..utils import LightningEvaluator, TorchEvaluator, OptimizerConstructHelper
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
@ -21,10 +23,7 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.',
keep_intermediate_result: bool = False):
self.iterative_pruner = SimulatedAnnealingPruner(model=None,
config_list=None,
log_dir=Path(log_dir, 'SA'),
**sa_params)
self._sa_params = sa_params
super().__init__(total_iteration=total_iteration,
origin_model=origin_model,
origin_config_list=origin_config_list,
@ -36,12 +35,20 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
return super().reset(model, config_list, masks)
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
if not hasattr(self, 'iterative_pruner'):
self.iterative_pruner = SimulatedAnnealingPruner(model=model,
config_list=config_list,
log_dir=Path(self._log_dir_root, 'SA'),
**self._sa_params)
else:
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
self._iterative_pruner_reset(model, new_config_list, masks)
@ -53,8 +60,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
class AutoCompressPruner(IterativePruner):
r"""
__doc__ = r"""
For total iteration number :math:`N`, AutoCompressPruner prune the model that survive the previous iteration for a fixed sparsity ratio (e.g., :math:`1-{(1-0.8)}^{(1/N)}`) to achieve the overall sparsity (e.g., :math:`0.8`):
""" + r"""
.. code-block:: bash
@ -65,35 +73,27 @@ class AutoCompressPruner(IterativePruner):
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
total_iteration : int
total_iteration
The total iteration number.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
admm_params : Dict
admm_params
The parameters passed to the ADMMPruner.
- trainer : Callable[[Module, Optimizer, Callable].
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
- traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
- criterion : Callable[[Tensor, Tensor], Tensor].
The criterion function used in trainer. Take model output and target value as input, and return the loss.
- evaluator : LightningEvaluator or TorchEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- iterations : int.
The total iteration number in admm pruning algorithm.
- training_epochs : int.
The epoch number for training model in each iteration.
sa_params : Dict
sa_params
The parameters passed to the SimulatedAnnealingPruner.
- evaluator : Callable[[Module], float]. Required.
Evaluate the pruned model and give a score.
- evaluator : LightningEvaluator or TorchEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process.
- stop_temperature : float. Default: `20`.
@ -104,54 +104,50 @@ class AutoCompressPruner(IterativePruner):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
- pruning_algorithm : str. Default: `'level'`.
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
- pruning_params : Dict. Default: `{}`.
- pruning_params : Dict. Default: dict().
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str
log_dir
The log directory used to save the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handles all finetune logic, takes a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import AutoCompressPruner
>>> model = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> evaluator = ...
>>> finetuner = ...
>>> admm_params = {
>>> 'trainer': trainer,
>>> 'traced_optimizer': traced_optimizer,
>>> 'criterion': criterion,
>>> 'iterations': 10,
>>> 'training_epochs': 1
>>> }
>>> sa_params = {
>>> 'evaluator': evaluator
>>> }
>>> pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
The full script can be found :githublink:`here <examples/model_compress/pruning/auto_compress_pruner.py>`.
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
...
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
*args, **kwargs):
new_api = ['evaluator', 'speedup']
new_init_kwargs = {'evaluator': None, 'speedup': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
@ -175,6 +171,10 @@ class AutoCompressPruner(IterativePruner):
else:
admm_params['granularity'] = 'fine-grained'
pruner = ADMMPruner(None, None, **admm_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
pruner = ADMMPruner(None, None, **admm_params) # type: ignore
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,8 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional, Union
import logging
from typing import Any, Dict, List, Tuple, Callable, Optional, Union, overload
import torch
from torch import Tensor
@ -12,9 +15,63 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
from nni.compression.pytorch.speedup import ModelSpeedup
from .tools import TaskGenerator
from ..utils import Evaluator, LightningEvaluator, TorchEvaluator
_logger = logging.getLogger(__name__)
_LEGACY_FINETUNER = Callable[[Module], None]
_LEGACY_EVALUATOR = Callable[[Module], float]
class PruningScheduler(BasePruningScheduler):
# TODO: remove in nni v3.0.
class EvaluatorBasedPruningScheduler(BasePruningScheduler):
evaluator: LightningEvaluator | TorchEvaluator
using_evaluator: bool
finetuner: _LEGACY_FINETUNER
_evaluator: _LEGACY_EVALUATOR
dummy_input: Any
def _init_evaluator(self, model: Module, new_api: List[str], new_init_kwargs: Dict, old_api: List[str],
old_init_kwargs: Dict, args: Tuple, kwargs: Dict) -> Dict:
# for fake __init__ overload, parsing args and kwargs,
# initializing evaluator or [finetuner, evaluator, dummy_input], return the remaining arguments.
if (len(args) > 0 and isinstance(args[0], Evaluator)) or \
(len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True
else:
init_kwargs = self._parse_args(old_api, args, kwargs, old_init_kwargs)
self.finetuner: _LEGACY_FINETUNER = init_kwargs.pop('finetuner')
self._evaluator: _LEGACY_EVALUATOR = init_kwargs.pop('evaluator')
self.dummy_input = init_kwargs.pop('dummy_input')
self.using_evaluator = False
warn_msg = f'The old API ...{",".join(old_api)} will be deprecated after NNI v3.0,' +\
f'please using the new one ...{",".join(new_api)}'
_logger.warning(warn_msg)
return init_kwargs
def _parse_args(self, arg_names: List, args: Tuple, kwargs: Dict, def_kwargs: Dict) -> Dict:
merged_kwargs = {arg_names[idx]: arg for idx, arg in enumerate(args)}
for key, value in kwargs.items():
if key in merged_kwargs:
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
merged_kwargs[key] = value
for key, value in def_kwargs.items():
if key not in merged_kwargs:
merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names)
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
return merged_kwargs
class PruningScheduler(EvaluatorBasedPruningScheduler):
"""
Parameters
----------
@ -25,7 +82,8 @@ class PruningScheduler(BasePruningScheduler):
Used to generate task for each iteration.
finetuner
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
It will be called at the end of each iteration if reset_weight is False,
will be called at the beginning of each iteration otherwise.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input
@ -36,16 +94,30 @@ class PruningScheduler(BasePruningScheduler):
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
"""
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None,
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: LightningEvaluator | TorchEvaluator,
speedup: bool = False, reset_weight: bool = False):
...
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: _LEGACY_FINETUNER | None = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: _LEGACY_EVALUATOR | None = None,
reset_weight: bool = False):
...
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, *args, **kwargs) -> None:
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'speedup', 'reset_weight']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'reset_weight': False}
init_kwargs = self._init_evaluator(None, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs) # type: ignore
self.pruner = pruner
self.task_generator = task_generator
self.finetuner = finetuner
self.speedup = speedup
self.dummy_input = dummy_input
self.evaluator = evaluator
self.reset_weight = reset_weight
self.speedup = init_kwargs['speedup']
self.reset_weight = init_kwargs['reset_weight']
def reset(self, model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}):
self.task_generator.reset(model, config_list, masks)
@ -61,6 +133,7 @@ class PruningScheduler(BasePruningScheduler):
generate masks -> speedup -> finetune -> evaluate
"""
model, masks, config_list = task.load_data()
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
@ -74,28 +147,58 @@ class PruningScheduler(BasePruningScheduler):
# speedup
if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
if self.using_evaluator:
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# finetune
if self.finetuner is not None and task.finetune:
if self.speedup:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
if self.using_evaluator:
if task.finetune:
self.evaluator.bind_model(compact_model) # type: ignore
if self.speedup:
self.evaluator.finetune()
else:
self.pruner._wrap_model()
self.evaluator.finetune()
self.pruner._unwrap_model()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
if self.speedup:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
# evaluate
if self.evaluator is not None and task.evaluate:
if self.speedup:
score = self.evaluator(compact_model)
if self.using_evaluator:
if task.evaluate:
self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
score = None
else:
score = None
if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
@ -107,13 +210,20 @@ class PruningScheduler(BasePruningScheduler):
finetune -> generate masks -> reset weight -> speedup -> evaluate
"""
model, masks, config_list = task.load_data()
checkpoint = deepcopy(model.state_dict())
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
# finetune
if self.finetuner is not None and task.finetune:
self.finetuner(model)
if self.using_evaluator:
if task.finetune:
self.evaluator.bind_model(model) # type: ignore
self.evaluator.finetune()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
self.finetuner(model)
# pruning model
compact_model, pruner_generated_masks = self.pruner.compress()
@ -128,19 +238,38 @@ class PruningScheduler(BasePruningScheduler):
# speedup
if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
if self.using_evaluator:
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# evaluate
if self.evaluator is not None and task.evaluate:
if self.speedup:
score = self.evaluator(compact_model)
if self.using_evaluator:
if task.evaluate:
self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
score = None
else:
score = None
if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()

Просмотреть файл

@ -1,15 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional, Union
from typing import Any, Dict, List, Optional, Union, overload
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import (
LevelPruner,
L1NormPruner,
@ -21,13 +21,19 @@ from .basic_pruner import (
TaylorFOWeightPruner,
ADMMPruner
)
from .basic_scheduler import PruningScheduler
from .basic_scheduler import PruningScheduler, _LEGACY_FINETUNER, _LEGACY_EVALUATOR
from .tools import (
LinearTaskGenerator,
AGPTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
)
from ..utils import (
OptimizerConstructHelper,
LightningEvaluator,
TorchEvaluator
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
@ -71,55 +77,67 @@ class IterativePruner(PruningScheduler):
class LinearPruner(IterativePruner):
r"""
__doc__ = r"""
Linear pruner is an iterative pruner, it will increase sparsity evenly from scratch during each iteration.
For example, the final sparsity is set as 0.5, and the iteration number is 5, then the sparsity used in each iteration are ``[0, 0.1, 0.2, 0.3, 0.4, 0.5]``.
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
total_iteration
The total iteration number.
log_dir : str
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
Examples
--------
>>> from nni.compression.pytorch.pruning import LinearPruner
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> finetuner = ...
>>> pruner = LinearPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Any | None = None,
evaluator: _LEGACY_EVALUATOR | None = None, pruning_params: Dict = {}):
...
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
*args, **kwargs):
new_api = ['evaluator', 'speedup', 'pruning_params']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'pruning_params': {}}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'pruning_params']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'pruning_params': {}}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
pruning_params = init_kwargs['pruning_params']
task_generator = LinearTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
@ -128,63 +146,80 @@ class LinearPruner(IterativePruner):
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
class AGPPruner(IterativePruner):
r"""
__doc__ = r"""
This is an iterative pruner, which the sparsity is increased from an initial sparsity value :math:`s_{i}` (usually 0) to a final sparsity value :math:`s_{f}` over a span of :math:`n` pruning iterations,
starting at training step :math:`t_{0}` and with pruning frequency :math:`\Delta t`:
:math:`s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n \Delta t}\right)^{3} \text { for } t \in\left\{t_{0}, t_{0}+\Delta t, \ldots, t_{0} + n \Delta t\right\}`
""" + r"""
For more details please refer to `To prune, or not to prune: exploring the efficacy of pruning for model compression <https://arxiv.org/abs/1710.01878>`__\.
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
total_iteration
The total iteration number.
log_dir : str
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
Examples
--------
>>> from nni.compression.pytorch.pruning import AGPPruner
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> finetuner = ...
>>> pruner = AGPPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Any | None = None,
evaluator: _LEGACY_EVALUATOR | None = None, pruning_params: Dict = {}):
...
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
*args, **kwargs):
new_api = ['evaluator', 'speedup', 'pruning_params']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'pruning_params': {}}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'pruning_params']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'pruning_params': {}}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
pruning_params = init_kwargs['pruning_params']
task_generator = AGPTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
@ -193,12 +228,16 @@ class AGPPruner(IterativePruner):
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
class LotteryTicketPruner(IterativePruner):
r"""
__doc__ = r"""
`The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks <https://arxiv.org/abs/1803.03635>`__\ ,
authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis,
and articulate the *lottery ticket hypothesis*\ : dense, randomly-initialized, feed-forward networks contain subnetworks (*winning tickets*\ ) that
@ -216,55 +255,69 @@ class LotteryTicketPruner(IterativePruner):
If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning,
each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.
""" + r"""
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
total_iteration
The total iteration number.
log_dir : str
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight : bool
reset_weight
If set True, the model weight will reset to the original model weight at the end of each iteration step.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
Examples
--------
>>> from nni.compression.pytorch.pruning import LotteryTicketPruner
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> finetuner = ...
>>> pruner = LotteryTicketPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner, reset_weight=True)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
"""
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
reset_weight: bool = True, pruning_params: Dict = {}):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: _LEGACY_EVALUATOR | None = None, reset_weight: bool = True,
pruning_params: Dict = {}):
...
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True,
pruning_params: Dict = {}):
*args, **kwargs):
new_api = ['evaluator', 'speedup', 'reset_weight', 'pruning_params']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': True, 'pruning_params': {}}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight', 'pruning_params']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False,
'reset_weight': True, 'pruning_params': {}}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
reset_weight = init_kwargs['reset_weight']
pruning_params = init_kwargs['pruning_params']
task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
@ -273,12 +326,16 @@ class LotteryTicketPruner(IterativePruner):
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=reset_weight)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=reset_weight)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=reset_weight) # type: ignore
class SimulatedAnnealingPruner(IterativePruner):
"""
__doc__ = r"""
We implement a guided heuristic search method, Simulated Annealing (SA) algorithm. As mentioned in the paper, this method is enhanced on guided search based on prior experience.
The enhanced SA technique is based on the observation that a DNN layer with more number of weights often has a higher degree of model compression with less impact on overall accuracy.
@ -294,54 +351,81 @@ class SimulatedAnnealingPruner(IterativePruner):
Parameters
----------
model : Optional[Module]
model
The origin unwrapped pytorch model to be pruned.
config_list : Optional[List[Dict]]
config_list
The origin config list provided by the user.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
start_temperature : float
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
start_temperature
Start temperature of the simulated annealing process.
stop_temperature : float
stop_temperature
Stop temperature of the simulated annealing process.
cool_down_rate : float
cool_down_rate
Cool down rate of the temperature.
perturbation_magnitude : float
perturbation_magnitude
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : Union[str, Path]
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speedup : bool
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
Examples
--------
>>> from nni.compression.pytorch.pruning import SimulatedAnnealingPruner
>>> model = ...
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> evaluator = ...
>>> finetuner = ...
>>> pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm='l1', evaluator=evaluator, cool_down_rate=0.9, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/simulated_anealing_pruning_torch.py <examples/model_compress/pruning/simulated_anealing_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False, speedup: bool = False):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: _LEGACY_EVALUATOR,
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None):
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
new_api = ['evaluator', 'start_temperature', 'stop_temperature', 'cool_down_rate', 'perturbation_magnitude',
'pruning_algorithm', 'pruning_params', 'log_dir', 'keep_intermediate_result', 'speedup']
new_init_kwargs = {'start_temperature': 100, 'stop_temperature': 20, 'cool_down_rate': 0.9,
'perturbation_magnitude': 0.35, 'pruning_algorithm': 'level', 'pruning_params': {},
'log_dir': '.', 'keep_intermediate_result': False, 'speedup': False}
old_api = ['evaluator', 'start_temperature', 'stop_temperature', 'cool_down_rate', 'perturbation_magnitude',
'pruning_algorithm', 'pruning_params', 'log_dir', 'keep_intermediate_result', 'finetuner',
'speedup', 'dummy_input']
old_init_kwargs = {'start_temperature': 100, 'stop_temperature': 20, 'cool_down_rate': 0.9,
'perturbation_magnitude': 0.35, 'pruning_algorithm': 'level', 'pruning_params': {},
'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None, 'speedup': False,
'dummy_input': None}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
start_temperature = init_kwargs['start_temperature']
stop_temperature = init_kwargs['stop_temperature']
cool_down_rate = init_kwargs['cool_down_rate']
perturbation_magnitude = init_kwargs['perturbation_magnitude']
pruning_algorithm = init_kwargs['pruning_algorithm']
pruning_params = init_kwargs['pruning_params']
log_dir = init_kwargs['log_dir']
keep_intermediate_result = init_kwargs['keep_intermediate_result']
speedup = init_kwargs['speedup']
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], evaluator: Callable[[Module], float], start_temperature: float = 100,
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list,
start_temperature=start_temperature,
@ -351,7 +435,12 @@ class SimulatedAnnealingPruner(IterativePruner):
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
pruning_params['traced_optimizer'] = \
OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup,
dummy_input=self.dummy_input, evaluator=self._evaluator, reset_weight=False) # type: ignore

Просмотреть файл

@ -1,9 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
import logging
from typing import Dict, List, Tuple, Callable
from typing import Dict, List, Tuple, Callable, overload
import torch
from torch import autograd, Tensor
@ -12,17 +14,23 @@ from torch.nn.parameter import Parameter
from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper
from nni.common.serializer import Traceable
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import EvaluatorBasedPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema
from .tools.base import TrainerBasedDataCollector
from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
from .tools import (
StraightMetricsCalculator,
NormalSparsityAllocator
NormalSparsityAllocator,
StraightMetricsCalculator
)
from ..utils import (
LightningEvaluator,
TorchEvaluator
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
@ -47,8 +55,7 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
def forward(self, *inputs):
# apply mask to weight, bias
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()`
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # type: ignore
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
return self.module(*inputs)
@ -77,13 +84,30 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
target_name = 'weight'
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data # type: ignore
data[wrapper.name] = {target_name: wrapper.weight_score.data} # type: ignore
return data
class MovementPruner(BasicPruner):
r"""
class EvaluatorBasedScoreDataCollector(EvaluatorBasedDataCollector):
"""
Collect all weight_score in wrappers as data used to calculate metrics.
"""
def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target_score: Tensor = getattr(wrapper, f'{target_name}_score')
data[module_name] = {target_name: target_score.data.clone()}
return data
class MovementPruner(EvaluatorBasedPruner):
__doc__ = r"""
Movement pruner is an implementation of movement pruning.
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
@ -110,30 +134,12 @@ class MovementPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_epochs : int
The total epoch number for training the model.
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
@ -145,33 +151,31 @@ class MovementPruner(BasicPruner):
The sparsity after each `optimizer.step()` is:
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import MovementPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = MovementPruner(model, config_list, trainer, traced_optimizer, criterion, 10, 3000, 27000)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_epochs: int,
warm_up_step: int, cool_down_beginning_step: int):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
cool_down_beginning_step: int):
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.training_epochs = training_epochs
self.warm_up_step = warm_up_step
self.cool_down_beginning_step = cool_down_beginning_step
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
init_kwargs = self._init_evaluator(model, new_api, old_api, {}, args, kwargs)
self.training_epochs: int = init_kwargs['training_epochs']
self.warm_up_step: int = init_kwargs['warm_up_step']
self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step']
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
super().__init__(model, config_list)
@ -184,14 +188,16 @@ class MovementPruner(BasicPruner):
if self.warm_up_step < current_step <= self.cool_down_beginning_step:
wrapper_dict = self.get_modules_wrapper()
for config in self.config_list:
current_sparsity = config['total_sparsity'] * (1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3)
scale = 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
current_sparsity = config['total_sparsity'] * scale
for op_name in config['op_names']:
wrapper_dict[op_name].config['total_sparsity'] = current_sparsity
wrapper = wrapper_dict[op_name]
wrapper.config['total_sparsity'] = current_sparsity
def reset_tools(self):
if self.metrics_calculator is None:
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = StraightMetricsCalculator()
if self.sparsity_allocator is None:
if not hasattr(self, 'sparsity_allocator'):
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
# use Adam to update the weight_score
@ -208,16 +214,30 @@ class MovementPruner(BasicPruner):
if self.step_counter > self.warm_up_step:
self.cubic_schedule(self.step_counter)
data = {}
target_name = 'weight'
for wrapper in self.get_modules_wrapper().values():
data[wrapper.name] = wrapper.weight_score.data
data[wrapper.name] = {target_name: wrapper.weight_score.data}
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
self.load_masks(masks)
if self.data_collector is None:
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch])
if self.using_evaluator:
# TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator,
after_opt_step_tasks=[_optimizer_patch],
max_epochs=self.training_epochs)
else:
self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
if not hasattr(self, 'data_collector'):
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper,
self.criterion, self.training_epochs,
opt_after_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
def _wrap_modules(self, layer: LayerInfo, config: Dict):
"""
@ -243,7 +263,6 @@ class MovementPruner(BasicPruner):
for wrapper in self.get_modules_wrapper().values():
wrapper.config['total_sparsity'] = 0
result = super().compress()
# del weight_score
for wrapper in self.get_modules_wrapper().values():
wrapper.weight_score = None
if self.using_evaluator:
self.evaluator.unbind_model()
return result

Просмотреть файл

@ -8,6 +8,12 @@ from .base import (
SparsityAllocator,
TaskGenerator
)
from .data_collector import (
TargetDataCollector,
EvaluatorBasedTargetDataCollector,
EvaluatorBasedHookDataCollector
)
# TODO: remove in nni v3.0.
from .data_collector import (
WeightDataCollector,
WeightTrainerBasedDataCollector,
@ -16,7 +22,7 @@ from .data_collector import (
from .metrics_calculator import (
StraightMetricsCalculator,
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
HookDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator

Просмотреть файл

@ -6,7 +6,7 @@ from datetime import datetime
import logging
from pathlib import Path
import types
from typing import List, Dict, Tuple, Optional, Callable, Union
from typing import List, Dict, Literal, Tuple, Optional, Callable, Union
import json_tricks
import torch
@ -15,7 +15,7 @@ from torch.nn import Module
from torch.optim import Optimizer
from ...base import Pruner, LayerInfo, Task, TaskResult
from ...utils import OptimizerConstructHelper, Scaling
from ...utils import Evaluator, Hook, OptimizerConstructHelper, Scaling
_logger = logging.getLogger(__name__)
@ -45,7 +45,7 @@ class DataCollector:
def __init__(self, compressor: Pruner):
self.compressor = compressor
def reset(self):
def reset(self, *args, **kwargs):
"""
Reset the `DataCollector`.
"""
@ -63,9 +63,12 @@ class DataCollector:
raise NotImplementedError()
# TODO: remove in nni v3.0.
COLLECTOR_TYPE = Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]
class HookCollectorInfo:
def __init__(self, targets: Union[Dict[str, Tensor], List[LayerInfo]], hook_type: str,
collector: Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]):
collector: COLLECTOR_TYPE):
"""
This class used to aggregate the information of what kind of hook is placed on which layers.
@ -76,23 +79,24 @@ class HookCollectorInfo:
hook_type
'forward' or 'backward'.
collector
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor, the output is a hook function.
The buffer is used to store the data wanted to hook.
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor,
the output is a hook function. The buffer is used to store the data wanted to hook.
"""
self.targets = targets
self.hook_type = hook_type
self.collector = collector
# TODO: remove in nni v3.0.
class TrainerBasedDataCollector(DataCollector):
"""
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
"""
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None):
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None],
optimizer_helper: OptimizerConstructHelper, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [], collector_infos: List[HookCollectorInfo] = [],
criterion_patch: Optional[Callable[[Callable], Callable]] = None):
"""
Parameters
----------
@ -252,6 +256,47 @@ class TrainerBasedDataCollector(DataCollector):
self._remove_hook(hook_id)
class EvaluatorBasedDataCollector(DataCollector):
"""
This data collector is the base class for the data collectors that want to use ``Evaluator`` to train or inference.
Three main usages are supported in this data collector:
1. Doing something before ``optimzer.step()`` and after ``optimzer.step()``. ``before_opt_step_tasks`` is a list of task functions
that will execute before ``optimzer.step()``. ``after_opt_step_tasks`` is a list of task functions that will execute after
``optimzer.step()``. All the task functions in the list should not have input arguments, function return value is allowed,
but ``Evaluator`` will not catch it.
2. Patch or modify the training loss. ``loss_patch`` is a function with input is the original loss and the output is the modified loss.
3. Add hooks on ``torch.nn.Module`` or ``Parameter`` or ``Buffer``. Three kinds of hook are supported, ``TensorHook``, ``ForwardHook``
and ``BackwardHook``. For initializing a ``Hook``, a hook function factory is needed, the factory function's input is an empty list,
and the output is a hook function defined by Pytorch.
Please refer `register_hook <https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html>`_,
`register_forward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_,
`register_backward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_backward_hook>`_.
"""
def __init__(self, compressor: Pruner, evaluator: Evaluator, before_opt_step_tasks: List[Callable] | None = None,
after_opt_step_tasks: List[Callable] | None = None, loss_patch: Callable[[Tensor], Tensor] | None = None,
hooks: Dict[str, Dict[str, Hook]] | None = None, max_steps: int | None = None, max_epochs: int | None = None):
super().__init__(compressor)
self.evaluator = evaluator
self.max_steps = max_steps
self.max_epochs = max_epochs
self.reset(before_opt_step_tasks, after_opt_step_tasks, loss_patch, hooks)
def reset(self, before_opt_step_tasks: List[Callable] | None = None, after_opt_step_tasks: List[Callable] | None = None,
loss_patch: Callable[[Tensor], Tensor] | None = None, hooks: Dict[str, Dict[str, Hook]] | None = None):
if before_opt_step_tasks or after_opt_step_tasks:
before_opt_step_tasks = before_opt_step_tasks if before_opt_step_tasks else []
after_opt_step_tasks = after_opt_step_tasks if after_opt_step_tasks else []
self.evaluator.patch_optimizer_step(before_opt_step_tasks, after_opt_step_tasks)
if loss_patch:
self.evaluator.patch_loss(loss_patch)
if hooks:
self._hooks = hooks
hook_list = [hook for _ in hooks.values() for hook in _.values()]
self.evaluator.register_hooks(hook_list)
class MetricsCalculator:
"""
An abstract class for calculate a kind of metrics of the given data.
@ -260,7 +305,8 @@ class MetricsCalculator:
----------
scalers
Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
@ -268,7 +314,8 @@ class MetricsCalculator:
"""
def __init__(self, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
def _get_scaler(self, module_name: str, target_name: str) -> Scaling:
scaler = _get_scaler(self.scalers, module_name, target_name)
@ -301,7 +348,8 @@ class SparsityAllocator:
scalers
Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric,
or expands the mask of the same size as the metric to the same size as the pruning target.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
@ -313,7 +361,8 @@ class SparsityAllocator:
def __init__(self, pruner: Pruner, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None, continuous_mask: bool = True):
self.pruner = pruner
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.continuous_mask = continuous_mask
def _get_scaler(self, module_name: str, target_name: str) -> Scaling | None:
@ -335,25 +384,39 @@ class SparsityAllocator:
mask = (scaler.shrink(mask) != 0).type_as(mask)
return mask
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def _mask_metric(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part in the metric to the minimum value.
target_name = 'weight'
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
shrinked_target_mask = self._shrink_mask(module_name, target_name, old_target_mask)
# make sure the masked position has the minimum metric
targets_metric[target_name] = targets_metric[target_name].to(shrinked_target_mask.device)
min_value = targets_metric[target_name].min() - 1
targets_metric[target_name] = torch.where(shrinked_target_mask != 0, targets_metric[target_name], min_value)
return metrics
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part to zero in the new_masks.
target_name = 'weight'
for module_name, target_mask in new_masks.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask = getattr(wrapper, f'{target_name}_mask', None)
old_target_mask: Tensor | None = getattr(wrapper, f'{target_name}_mask', None)
if old_target_mask is not None:
new_masks[module_name][target_name] = torch.min(target_mask[target_name], old_target_mask)
new_masks[module_name][target_name] = torch.min(target_mask[target_name],
old_target_mask.to(target_mask[target_name].device))
return new_masks
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
"""
Generate masks for metrics-dependent targets.
Parameters
----------
metrics
The format is {module_name: weight_metric}.
The metric of `weight` usually has the same size with shrinked mask.
The format is {module_name: {target_name: target_metric}}.
The metric of usually has the same size with shrinked mask.
Return
------
@ -384,7 +447,7 @@ class SparsityAllocator:
reduce_dims = [reduce_dim for reduce_dim in range(1, len(weight_mask.shape))]
# count unmasked number of values on dim 0 (output channel) of weight
unmasked_num_on_dim0 = weight_mask.sum(reduce_dims) if reduce_dims else weight_mask
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(old_bias_mask)
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(weight_mask)
return masks
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
@ -401,6 +464,8 @@ class SparsityAllocator:
Dict[str, Dict[str, Tensor]]
The masks format is {module_name: {target_name: mask}}.
"""
if self.continuous_mask:
metrics = self._mask_metric(metrics)
masks = self.common_target_masks_generation(metrics)
masks = self.special_target_masks_generation(masks)
if self.continuous_mask:
@ -425,11 +490,22 @@ class TaskGenerator:
The log directory use to saving the task generator log.
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
best_result_mode
The way to decide which one is the best result. Three modes are supported.
If the task results don't contain scores (task_result.score is None), it will fall back to ``latest``.
1. latest: The newest received result is the best result.
2. maximize: The one with largest task result score is the best result.
3. minimize: The one with smallest task result score is the best result.
"""
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
best_result_mode: Literal['latest', 'maximize', 'minimize'] = 'maximize'):
self._log_dir = log_dir
self._keep_intermediate_result = keep_intermediate_result
assert best_result_mode in ['latest', 'maximize', 'minimize'], f'Unsupported best_result_mode value: {best_result_mode}'
self._best_result_mode = best_result_mode
if origin_model is not None and origin_config_list is not None and origin_masks is not None:
self.reset(origin_model, origin_config_list, origin_masks)
@ -472,13 +548,24 @@ class TaskGenerator:
json_tricks.dump(config_list, f, indent=4)
def update_best_result(self, task_result: TaskResult):
score = task_result.score
task_id = task_result.task_id
task = self._tasks[task_id]
task.score = score
if self._best_score is None or (score is not None and score > self._best_score):
self._best_score = score
self._best_task_id = task_id
save_as_best_result = False
task = self._tasks[task_result.task_id]
task.score = task_result.score
if self._best_result_mode == 'latest':
self._best_task_id, save_as_best_result = task_result.task_id, True
if self._best_result_mode == 'maximize':
if self._best_score is None or (task.score is not None and task.score > self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if self._best_result_mode == 'minimize':
if self._best_score is None or (task.score is not None and task.score < self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if save_as_best_result:
with Path(task.config_list_path).open('r') as fr:
best_config_list = json_tricks.load(fr)
self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)

Просмотреть файл

@ -6,13 +6,16 @@ from typing import Dict, List
from torch import Tensor
from .base import DataCollector, TrainerBasedDataCollector
from .base import DataCollector, EvaluatorBasedDataCollector
from .base import TrainerBasedDataCollector
_logger = logging.getLogger(__name__)
__all__ = ['WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector']
__all__ = ['TargetDataCollector', 'EvaluatorBasedTargetDataCollector', 'EvaluatorBasedHookDataCollector',
'WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector'] # TODO: remove in nni v3.0.
# TODO: remove in nni v3.0.
class WeightDataCollector(DataCollector):
"""
Collect all wrapper weights.
@ -21,40 +24,102 @@ class WeightDataCollector(DataCollector):
def reset(self):
pass
def collect(self) -> Dict[str, Tensor]:
def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight.data
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
# TODO: remove in nni v3.0.
class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Collect all wrapper weights after training or inference.
"""
def collect(self) -> Dict[str, Tensor]:
def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight.data
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
# TODO: remove in nni v3.0.
class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Add hooks and collect data during training or inference.
Single means each wrapper only has one hook to collect data.
"""
def collect(self) -> Dict[str, List[Tensor]]:
def collect(self) -> Dict[str, Dict[str, List[Tensor]]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
[data.update(buffer_dict) for _, buffer_dict in self._hook_buffer.items()]
target_name = 'weight'
for _, buffer_dict in self._hook_buffer.items():
for module_name, target_data in buffer_dict.items():
data[module_name] = {target_name: target_data}
return data
class TargetDataCollector(DataCollector):
"""
Collect all wrapper targets.
"""
def reset(self):
# No need to reset anything in this data collector.
pass
def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedTargetDataCollector(EvaluatorBasedDataCollector):
"""
Collect all wrapper pruning target after training or inference.
"""
def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedHookDataCollector(EvaluatorBasedDataCollector):
"""
Add hooks and collect data during training or inference.
NOTE: Only support one target has one hook right now.
"""
def collect(self) -> Dict[str, Dict[str, List]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
for module_name, hooks in self._hooks.items():
data[module_name] = {}
for target_name, hook in hooks.items():
data[module_name][target_name] = hook.buffer
return data

Просмотреть файл

@ -11,7 +11,7 @@ from torch import Tensor
from .base import MetricsCalculator
from ...utils import Scaling
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
__all__ = ['NormMetricsCalculator', 'HookDataNormMetricsCalculator', 'DistMetricsCalculator',
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
@ -19,11 +19,12 @@ class StraightMetricsCalculator(MetricsCalculator):
"""
This metrics calculator directly returns a copy of data as metrics.
"""
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
metrics = {}
for name, tensor in data.items():
# use inplace detach `detach_` here to avoid creating a new tensor
metrics[name] = tensor.clone().detach_()
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
metrics[module_name][target_name] = target_data.clone().detach()
return metrics
@ -44,27 +45,32 @@ class NormMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return t.norm(p=self.p, dim=-1) # type: ignore
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
class MultiDataNormMetricsCalculator(NormMetricsCalculator):
class HookDataNormMetricsCalculator(NormMetricsCalculator):
"""
The data value format is a two-element list [batch_number, cumulative_data].
The hook data value format is a two-element list [batch_number, cumulative_data].
Directly use the cumulative_data as new_data to calculate norm metric.
TaylorFO pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
new_data = {name: buffer[1] for name, buffer in data.items()}
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
new_data = {}
for module_name, targets_data in data.items():
new_data[module_name] = {}
for target_name, (_, target_data) in targets_data.items():
new_data[module_name][target_name] = target_data
return super().calculate_metrics(new_data)
@ -85,7 +91,7 @@ class DistMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
reshape_data = t.reshape(-1, t.shape[-1])
metric = torch.zeros(reshape_data.shape[0], device=reshape_data.device)
@ -94,10 +100,11 @@ class DistMetricsCalculator(MetricsCalculator):
return metric.reshape(t.shape[:-1])
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
@ -108,16 +115,18 @@ class APoZRankMetricsCalculator(MetricsCalculator):
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return 1 - t.mean(dim=-1)
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
@ -127,14 +136,15 @@ class MeanRankMetricsCalculator(MetricsCalculator):
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return t.mean(dim=-1)
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics

Просмотреть файл

@ -17,7 +17,8 @@ _logger = logging.getLogger(__name__)
class AMCEnv:
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float, max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float,
max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
pruning_op_names = []
[pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)]
self.pruning_ops = OrderedDict()
@ -26,7 +27,10 @@ class AMCEnv:
if name in pruning_op_names:
op_type = type(layer).__name__
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore
if hasattr(layer, 'kernel_size'):
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) # type: ignore
else:
kernel_size = 1
self.pruning_ops[name] = (i, op_type, stride, kernel_size)
self.pruning_types.append(op_type)
self.pruning_types = list(set(self.pruning_types))
@ -60,15 +64,18 @@ class AMCEnv:
total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names])
previous_pruning_target = self.under_pruning_target - total_current_target
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] for name in self.pruning_op_names[index + 1:]])
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] \
for name in self.pruning_op_names[index + 1:]])
min_current_pruning_target = self.excepted_pruning_target - previous_pruning_target - max_rest_pruning_target
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - (self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - \
(self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_2 = self.excepted_pruning_target - previous_pruning_target
max_current_pruning_target = min(max_current_pruning_target_1, max_current_pruning_target_2)
min_action = min_current_pruning_target / current_statistics[op_name][self.target]
max_action = max_current_pruning_target / current_statistics[op_name][self.target]
if min_action > self.max_sparsity_per_layer[op_name]:
_logger.warning('[%s] min action > max sparsity per layer: %f > %f', op_name, min_action, self.max_sparsity_per_layer[op_name])
warn_msg = f'[{op_name}] min action > max sparsity per layer: {min_action} > {self.max_sparsity_per_layer[op_name]}'
_logger.warning(warn_msg)
action = max(0., min(max_action, max(min_action, action)))
self.current_op_name = op_name

Просмотреть файл

@ -4,7 +4,7 @@
from __future__ import annotations
import itertools
from typing import Any, Dict, List, Union
from typing import Any, Dict
import numpy as np
import torch
@ -23,22 +23,22 @@ class NormalSparsityAllocator(SparsityAllocator):
This allocator directly masks the locations of each pruning target with lower metric values.
"""
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
# TODO: Support more target type in wrapper & config list refactor
target_name = 'weight'
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
for target_name, target_metric in targets_metric.items():
sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
@ -46,7 +46,7 @@ class BankSparsityAllocator(SparsityAllocator):
"""
In bank pruner, all values in weight are divided into different sub blocks each shape
aligned with balance_gran. Each sub block has the same sparsity which equal to the overall sparsity.
This allocator pruned the weight in the granularity of block.
This allocator pruned the weight in the granularity of block.
"""
def __init__(self, pruner: Pruner, balance_gran: list):
@ -56,101 +56,108 @@ class BankSparsityAllocator(SparsityAllocator):
assert isinstance(gran, int) and gran > 0, 'All values in list balance_gran \
should be type int and bigger than zero'
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
# TODO: Support more target type in wrapper & config list refactor
target_name = 'weight'
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
sparsity_rate = wrapper.config['total_sparsity']
for target_name, target_metric in targets_metric.items():
sparsity_rate = wrapper.config['total_sparsity']
n_dim = len(target_metric.shape)
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
for i, j in zip(target_metric.shape, balance_gran):
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
n_dim = len(target_metric.shape)
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
for i, j in zip(target_metric.shape, balance_gran):
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
for iter_params in itertools.product(*loop_iters):
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
for iter_param, gran in zip(iter_params, balance_gran)]
index_str = ",".join(index_str_list)
sub_metric_str = "target_metric[{}]".format(index_str)
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
metric_bank: Tensor = eval(sub_metric_str)
prune_num = int(sparsity_rate * metric_bank.numel())
# mask_bank will be used in exec(sub_mask_str)
if prune_num != 0:
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank)
else:
mask_bank = torch.ones_like(metric_bank)
exec(sub_mask_str)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
for iter_params in itertools.product(*loop_iters):
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
for iter_param, gran in zip(iter_params, balance_gran)]
index_str = ",".join(index_str_list)
sub_metric_str = "target_metric[{}]".format(index_str)
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
metric_bank: Tensor = eval(sub_metric_str)
prune_num = int(sparsity_rate * metric_bank.numel())
# mask_bank will be used in exec(sub_mask_str)
if prune_num != 0:
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank) # type: ignore
else:
mask_bank = torch.ones_like(metric_bank) # type: ignore
mask_bank = mask_bank # `type: ignore` is useless for unused-variable error, add this line to workaround
exec(sub_mask_str)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class GlobalSparsityAllocator(SparsityAllocator):
"""
This allocator sorts all metrics as a whole, mask the locations of pruning target with lower metric value.
By default, this allocator will prevent each module from being over-pruned with upper sparsity 0.99.
"""
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
if not metrics:
return masks
# TODO: support more target type in wrapper & config list refactor
target_name = 'weight'
# validate all wrapper setting the same sparsity
# validate all wrapper setting have the same sparsity
# TODO: move validation logic to pruner
global_sparsity_rate = self.pruner.get_modules_wrapper()[list(metrics.keys())[0]].config['total_sparsity']
for module_name, target_metric in metrics.items():
for module_name in metrics.keys():
wrapper = self.pruner.get_modules_wrapper()[module_name]
assert global_sparsity_rate == wrapper.config['total_sparsity']
# find the largest metric value among all metrics
max_metric_value = list(metrics.values())[0].max()
for module_name, target_metric in metrics.items():
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
max_metric_value = list(list(metrics.values())[0].values())[0].max()
for targets_metric in metrics.values():
for target_metric in targets_metric.values():
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
assert 0 <= max_sparsity <= 1
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
metrics[module_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
for target_name, target_metric in targets_metric.items():
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
assert 0 <= max_sparsity <= 1
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
metrics[module_name][target_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
# build the global_matric & calculate global threshold
metric_list = []
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
metric_list.append(target_metric.reshape(-1).unsqueeze(0).expand(expand_times, -1).reshape(-1))
for target_name, target_metric in targets_metric.items():
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
metric_list.append(target_metric.reshape(-1).repeat_interleave(expand_times))
global_metric = torch.cat(metric_list)
max_pruning_num = int((global_metric != max_metric_value).sum().item())
total_pruning_num = min(int(global_sparsity_rate * global_metric.numel()), max_pruning_num)
global_threshold = torch.topk(global_metric.reshape(-1), total_pruning_num, largest=False)[0].max()
# generate masks for each target
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
shrinked_mask = torch.gt(target_metric, global_threshold).type_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
for target_name, target_metric in targets_metric.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
shrinked_mask = torch.gt(target_metric, global_threshold).type_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class DependencyAwareAllocator(NormalSparsityAllocator):
# TODO: This allocator will trace the model, means the model will be inference during initialization,
# sometime we may not aware of this inference and it may lead to some error.
class DependencyAwareAllocator(SparsityAllocator):
"""
An specific allocator for Conv2d & Linear module with dependency-aware.
It will generate a public mask for the modules that have dependencies,
@ -170,52 +177,79 @@ class DependencyAwareAllocator(NormalSparsityAllocator):
# group dependency format: {module_name: group_num}
self.pruner._unwrap_model()
graph = TorchModuleGraph(model=self.pruner.bound_model, dummy_input=dummy_input)
channel_dependency = ChannelDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
group_dependency = GroupDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
channel_dependency = ChannelDependency(model=self.pruner.bound_model, dummy_input=dummy_input,
traced_model=graph.trace).dependency_sets
group_dependency = GroupDependency(model=self.pruner.bound_model, dummy_input=dummy_input,
traced_model=graph.trace).dependency_sets
self.pruner._wrap_model()
return channel_dependency, group_dependency
def _metric_fuse(self, metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor:
def _metric_fuse(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Tensor]:
# Sum all metric value in the same position.
metrics = list(metrics.values()) if isinstance(metrics, dict) else metrics
assert all(metrics[0].size() == metric.size() for metric in metrics), 'Metrics size do not match.'
fused_metric = torch.zeros_like(metrics[0])
for metric in metrics:
fused_metric += metric
return fused_metric
fused_metrics = {}
for targets_metric in metrics.values():
for target_name, target_metric in targets_metric.items():
if target_name in fused_metrics:
fused_metrics[target_name] += target_metric
else:
fused_metrics[target_name] = target_metric
return fused_metrics
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# placeholder, here we need more discussion about dependence sparsity, Plan A or Plan B.
masks = {}
# generate public part for modules that have dependencies
for module_names in self.channel_dependency:
sub_metrics = {module_name: metrics[module_name] for module_name in module_names if module_name in metrics}
if not sub_metrics:
continue
fused_metric = self._metric_fuse(sub_metrics)
fused_metrics = self._metric_fuse(sub_metrics)
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] for module_name in sub_metrics.keys()}
min_sparsity_rate = min(sparsity_rates.values())
for target_name, fused_metric in fused_metrics.items():
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] \
for module_name in sub_metrics.keys()}
min_sparsity_rate = min(sparsity_rates.values())
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
max_group_nums = int(np.lcm.reduce(group_nums))
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate)
group_step = fused_metric.shape[0] // max_group_nums
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
max_group_nums = int(np.lcm.reduce(group_nums))
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate)
group_step = fused_metric.shape[0] // max_group_nums
# get the public part of the mask of the module with dependencies
sub_masks = []
for gid in range(max_group_nums):
_start = gid * group_step
_end = (gid + 1) * group_step
if pruned_numel_per_group > 0:
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max()
sub_mask = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric)
# get the public part of the mask of the module with dependencies
dependency_mask = torch.ones_like(fused_metric)
for gid in range(max_group_nums):
_start = gid * group_step
_end = (gid + 1) * group_step
if pruned_numel_per_group > 0:
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max()
dependency_mask[_start: _end] = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric)
# change the metric value corresponding to the public mask part to the minimum value
for module_name, targets_metric in sub_metrics.items():
if target_name in targets_metric:
# Following is Plan A, generate the dependency mask first, and then fill in the sparsity,
# the final mask is group unbalanced. - 1 ensure the denpendency metric is the minimum, and will be masked first.
# min_value = targets_metric[target_name].min() - 1
# metrics[module_name][target_name] = torch.where(dependency_mask!=0, targets_metric[target_name], min_value)
# Following is Plan B, just generate the dependency mask, the final mask is group balanced.
masks.setdefault(module_name, {})
masks[module_name][target_name] = self._expand_mask(module_name, target_name, dependency_mask)
# generate masks for layers without dependencies
for module_name, targets_metric in metrics.items():
masks.setdefault(module_name, {})
wrapper = self.pruner.get_modules_wrapper()[module_name]
for target_name, target_metric in targets_metric.items():
if target_name in masks[module_name]:
continue
sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
sub_mask = torch.ones_like(fused_metric[_start:_end])
sub_masks.append(sub_mask)
dependency_mask = torch.cat(sub_masks, dim=0)
# change the metric value corresponding to the public mask part to the minimum value
for module_name, target_metric in sub_metrics.items():
min_value = target_metric.min()
metrics[module_name] = torch.where(dependency_mask!=0, target_metric, min_value)
return super().common_target_masks_generation(metrics)
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks

Просмотреть файл

@ -51,7 +51,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self.total_iteration = total_iteration
self.skip_first_iteration = skip_first_iteration
super().__init__(origin_model, origin_config_list=origin_config_list, origin_masks=origin_masks,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='latest')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_iteration = 1 if self.skip_first_iteration else 0
@ -78,10 +78,14 @@ class FunctionBasedTaskGenerator(TaskGenerator):
# get current2origin_sparsity and compact2origin_sparsity
origin_model = torch.load(self._origin_model_path)
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity)
_logger.debug('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4))
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
self.target_sparsity)
debug_msg = f'\nTask {task_result.task_id} total real sparsity compared with original model is:\n' + \
f'{json_tricks.dumps(current2origin_sparsity, indent=4)}'
_logger.debug(debug_msg)
if task_result.task_id != 'origin':
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
task = self._tasks[task_result.task_id]
task.state['current2origin_sparsity'] = current2origin_sparsity
# if reach the total_iteration, no more task will be generated
if self.current_iteration > self.total_iteration:
@ -116,7 +120,8 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
@ -128,7 +133,8 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
@ -149,16 +155,18 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
# The following is the formula in paper.
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, start_temperature: float = 100, stop_temperature: float = 20,
cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.',
keep_intermediate_result: bool = False):
"""
Parameters
----------
@ -188,7 +196,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self.perturbation_magnitude = perturbation_magnitude
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature
@ -196,7 +204,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list)
@ -259,11 +270,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
sparsity = sorted(random_sparsity)
# calculate the scale
total_weights = np.sum(num_weights)
total_weights_pruned = np.sum([int(num_weight * sparsity[idx]) for idx, num_weight in enumerate(num_weights)])
if total_weights_pruned == 0:
return None

Просмотреть файл

@ -11,7 +11,7 @@ from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from nni.common.serializer import Traceable, is_traceable
from nni.common.serializer import is_traceable
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
@ -86,7 +86,8 @@ class OptimizerConstructHelper(ConstructHelper):
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
assert isinstance(optimizer_trace, Optimizer), \
'It is not an instance of torch.nn.Optimizer.'
return OptimizerConstructHelper(model, optimizer_trace.trace_symbol, *optimizer_trace.trace_args, **optimizer_trace.trace_kwargs) # type: ignore
return OptimizerConstructHelper(model, optimizer_trace.trace_symbol, *optimizer_trace.trace_args, # type: ignore
**optimizer_trace.trace_kwargs) # type: ignore
class LRSchedulerConstructHelper(ConstructHelper):
@ -115,4 +116,5 @@ class LRSchedulerConstructHelper(ConstructHelper):
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert isinstance(lr_scheduler_trace, _LRScheduler), \
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, **lr_scheduler_trace.trace_kwargs) # type: ignore
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
**lr_scheduler_trace.trace_kwargs) # type: ignore

Просмотреть файл

@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
_EVALUATOR_DOCSTRING = r"""NNI will use the evaluator to intervene in the model training process,
so as to perform training-aware model compression.
All training-aware model compression will use the evaluator as the entry for intervention training in the future.
Usually you just need to wrap some classes with ``nni.trace`` or package the training process as a function to initialize the evaluator.
Please refer ... for a full tutorial on how to initialize a ``evaluator``.
The following are two simple examples, if you use pytorch_lightning, please refer to :class:`nni.compression.pytorch.LightningEvaluator`,
if you use native pytorch, please refer to :class:`nni.compression.pytorch.TorchEvaluator`::
# LightningEvaluator example
import pytorch_lightning
lightning_trainer = nni.trace(pytorch_lightning.Trainer)(max_epochs=1, max_steps=50, logger=TensorBoardLogger(...))
lightning_data_module = nni.trace(pytorch_lightning.LightningDataModule)(...)
from nni.compression.pytorch import LightningEvaluator
evaluator = LightningEvaluator(lightning_trainer, lightning_data_module)
# TorchEvaluator example
import torch
import torch.nn.functional as F
def training_model(model, optimizer, criterion, lr_scheduler, max_steps, max_epochs, *args, **kwargs):
# max_steps, max_epochs might be None, which means unlimited training time,
# so here we need set a default termination condition (by default, total_epochs=10, total_steps=100000).
total_epochs = max_epochs if max_epochs else 10
total_steps = max_steps if max_steps else 100000
current_step = 0
# init dataloader
train_dataloader = ...
for epoch in range(total_epochs):
...
for input_data, target in train_dataloader:
optimizer.zero_grad()
result = model(input_data)
loss = criterion(result, target)
loss.backward()
optimizer.step()
current_step += 1
if current_step >= total_steps:
return
lr_scheduler.step()
traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01)
criterion = F.nll_loss
from nni.compression.pytorch import TorchEvaluator
evaluator = TorchEvaluator(training_func=training_model, optimziers=traced_optimizer, criterion=criterion)
"""

Просмотреть файл

@ -73,16 +73,16 @@ class TensorHook(Hook):
return hook
"""
def __init__(self, target: Tensor, target_name: str, hook_factory: Callable[[List], Callable[[Tensor], Any]]):
def __init__(self, target: Tensor, target_name: str, hook_factory: Callable[[List], Callable[[Tensor], Tensor | None]]):
assert isinstance(target, Tensor)
super().__init__(target, target_name, hook_factory)
def _register(self, hook_func: Callable[[Tensor], Any]) -> RemovableHandle:
def _register(self, hook_func: Callable[[Tensor], Tensor | None]) -> RemovableHandle:
return self.target.register_hook(hook_func) # type: ignore
class ModuleHook(Hook):
def __init__(self, target: Module, target_name: str, hook_factory: Callable[[List], Callable[[Module, Tensor, Tensor], Any]]):
def __init__(self, target: Module, target_name: str, hook_factory: Callable[[List], Callable[[Module, Any, Any], Any]]):
assert isinstance(target, Module)
super().__init__(target, target_name, hook_factory)
@ -97,7 +97,7 @@ class ForwardHook(ModuleHook):
return hook
"""
def _register(self, hook_func: Callable[[Module, Tensor, Tensor], Any]):
def _register(self, hook_func: Callable[[Module, Tuple[Any], Any], Any]):
return self.target.register_forward_hook(hook_func) # type: ignore
@ -111,7 +111,7 @@ class BackwardHook(ModuleHook):
return hook
"""
def _register(self, hook_func: Callable[[Module, Tensor, Tensor], Any]):
def _register(self, hook_func: Callable[[Module, Tuple[Tensor] | Tensor, Tuple[Tensor] | Tensor], Any]):
return self.target.register_backward_hook(hook_func) # type: ignore
@ -148,7 +148,8 @@ class Evaluator:
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
"""
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification, model training, and model evaluation.
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
model training, and model evaluation.
Parameter
---------
@ -246,10 +247,12 @@ class Evaluator:
def evaluate(self) -> float | None | Tuple[float, Any] | Tuple[None, Any]:
"""
NNI assume the evaluation function user passed in should return a float number or a dict as metric.
If the evaluation function returned a dict, take the value with dict key ``default`` as the first element of ``evaluate`` returned value,
If the evaluation function returned a dict, take the value with dict key ``default``
as the first element of ``evaluate`` returned value,
and put the dict as the second element of the returned value.
For any other type of the metric returned by evaluation function, ``evaluate`` will directly returned
(it should be a float, but NNI does not prevent other types from being returned, this will handle by the object calling ``evaluate``).
(it should be a float, but NNI does not prevent other types from being returned,
this will handle by the object calling ``evaluate``).
"""
# Note that the first item of the returned value will be used as the default metric used by NNI.
raise NotImplementedError
@ -287,9 +290,11 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None):
err_msg = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
self.trainer = trainer
self.data_module = data_module
self._dummy_input = dummy_input
@ -314,18 +319,20 @@ class LightningEvaluator(Evaluator):
optimizers_lr_schedulers: Any = pure_model.configure_optimizers()
# 1. None - Fit will run without any optimizer.
if optimizers_lr_schedulers is None:
err_msg = 'NNI does not support `LightningModule.configure_optimizers` returned None, '
err_msg += 'if you have a reason why you must, please file an issue at https://github.com/microsoft/nni/issues'
err_msg = 'NNI does not support `LightningModule.configure_optimizers` returned None, ' + \
'if you have a reason why you must, please file an issue at https://github.com/microsoft/nni/issues'
raise ValueError(err_msg)
# 2. Single optimizer.
# 3. Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
# 3. Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose
# value is a single LR scheduler or lr_scheduler_config.
elif isinstance(optimizers_lr_schedulers, (Optimizer, dict)):
optimizers_lr_schedulers = [optimizers_lr_schedulers]
err_msg = f'Got an wrong returned value type of `LightningModule.configure_optimizers`: {type(optimizers_lr_schedulers).__name__}'
assert isinstance(optimizers_lr_schedulers, (list, tuple)), err_msg
# 4. Two lists - the first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
# 4. Two lists - the first list has multiple optimizers,
# and the second has multiple LR schedulers (or multiple lr_scheduler_config).
if isinstance(optimizers_lr_schedulers[0], (list, tuple)):
optimizers, lr_schedulers = optimizers_lr_schedulers
self._optimizer_helpers = [OptimizerConstructHelper.from_trace(pure_model, optimizer) for optimizer in optimizers]
@ -364,7 +371,8 @@ class LightningEvaluator(Evaluator):
self._initialization_complete = True
def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None):
assert self._initialization_complete is True, 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, pl.LightningModule)
if self.model is not None:
_logger.warning('Already bound a model, will unbind it before bind a new model.')
@ -397,7 +405,8 @@ class LightningEvaluator(Evaluator):
if self._opt_returned_dicts:
def new_configure_optimizers(_): # type: ignore
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]])
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
opt_lrs_dicts = deepcopy(self._opt_returned_dicts)
for opt_lrs_dict in opt_lrs_dicts:
opt_lrs_dict['optimizer'] = optimizers[opt_lrs_dict['optimizer']]
@ -407,7 +416,8 @@ class LightningEvaluator(Evaluator):
elif self._lr_scheduler_helpers:
def new_configure_optimizers(_): # type: ignore
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]])
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
return optimizers, lr_schedulers
else:
def new_configure_optimizers(_):
@ -442,7 +452,8 @@ class LightningEvaluator(Evaluator):
assert isinstance(self.model, pl.LightningModule)
class OptimizerCallback(Callback):
def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer, opt_idx: int) -> None:
def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule,
optimizer: Optimizer, opt_idx: int) -> None:
for task in before_step_tasks:
task()
@ -486,10 +497,12 @@ class LightningEvaluator(Evaluator):
def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
"""
NNI will use metric with key ``default`` for evaluating model, please make sure you have this key in your ``Trainer.test()`` returned metric dicts.
If ``Trainer.test()`` returned list contains multiple dicts with key ``default``, NNI will take their average as the final metric.
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}, {'default': 0.7, 'loss': 2.3}]``,
NNI will take the final metric ``(0.8 + 0.6 + 0.7) / 3 = 0.7``.
NNI will use metric with key ``default`` for evaluating model,
please make sure you have this key in your ``Trainer.test()`` returned metric dicts.
If ``Trainer.test()`` returned list contains multiple dicts with key ``default``,
NNI will take their average as the final metric.
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
"""
assert isinstance(self.model, pl.LightningModule)
# reset trainer
@ -514,9 +527,11 @@ class LightningEvaluator(Evaluator):
raise e
_OPTIMIZERS = Union[Optimizer, List[Optimizer]]
_CRITERION = Callable[[Any, Any], Any]
_SCHEDULERS = Union[None, _LRScheduler, List[_LRScheduler]]
_EVALUATING_FUNC = Callable[[Module], Union[float, Dict]]
_TRAINING_FUNC = Callable[[Module, Union[Optimizer, List[Optimizer]], _CRITERION, Union[None, _LRScheduler, List[_LRScheduler]], Optional[int], Optional[int]], None]
_TRAINING_FUNC = Callable[[Module, _OPTIMIZERS, _CRITERION, _SCHEDULERS, Optional[int], Optional[int]], None]
class TorchEvaluator(Evaluator):
@ -528,8 +543,10 @@ class TorchEvaluator(Evaluator):
----------
training_func
The training function is used to train the model, note that this a entire optimization training loop.
It should have three required parameters [model, optimizers, criterion] and three optional parameters [schedulers, max_steps, max_epochs].
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``, it belongs to the ``optimizers`` pass to ``TorchEvaluator``.
It should have three required parameters [model, optimizers, criterion]
and three optional parameters [schedulers, max_steps, max_epochs].
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``,
it belongs to the ``optimizers`` pass to ``TorchEvaluator``.
``criterion`` and ``schedulers`` are also belonging to the ``criterion`` and ``schedulers`` pass to ``TorchEvaluator``.
``max_steps`` and ``max_epochs`` are used to control the training duration.
@ -574,7 +591,8 @@ class TorchEvaluator(Evaluator):
Optional. The traced _LRScheduler instance which the lr scheduler class is wrapped by nni.trace.
E.g. ``traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)``.
dummy_input
Optional. The dummy_input is used to trace the graph, the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``.
Optional. The dummy_input is used to trace the graph,
the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``.
evaluating_func
Optional. A function that input is model and return the evaluation metric.
The return value can be a single float or a tuple (float, Any).
@ -634,14 +652,16 @@ class TorchEvaluator(Evaluator):
self._lr_scheduler_helpers = [LRSchedulerConstructHelper.from_trace(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers]
optimizer_ids_map = {id(optimizer): i for i, optimizer in enumerate(self._tmp_optimizers)}
# record i-th lr_scheduler scheduling j-th optimizer lr
self._lrs_opt_map = {i: optimizer_ids_map[id(lr_scheduler.optimizer)] for i, lr_scheduler in enumerate(self._tmp_lr_schedulers)} # type: ignore
self._lrs_opt_map = {i: optimizer_ids_map[id(lr_scheduler.optimizer)] # type: ignore
for i, lr_scheduler in enumerate(self._tmp_lr_schedulers)} # type: ignore
delattr(self, '_tmp_optimizers')
delattr(self, '_tmp_lr_schedulers')
self._initialization_complete = True
def bind_model(self, model: Module, param_names_map: Dict[str, str] | None = None):
assert self._initialization_complete is True, 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, Module)
if self.model is not None:
_logger.warning('Already bound a model, will unbind it before bind a new model.')
@ -651,7 +671,8 @@ class TorchEvaluator(Evaluator):
self._param_names_map = param_names_map
# initialize optimizers & lr_schedulers for the bound model here
self._optimizers = [helper.call(model, param_names_map) for helper in self._optimizer_helpers]
self._lr_schedulers = [lrs_helper.call(self._optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
self._lr_schedulers = [lrs_helper.call(self._optimizers[self._lrs_opt_map[i]]) \
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
self._first_optimizer_step = self._optimizers[0].step
def unbind_model(self):
@ -717,7 +738,8 @@ class TorchEvaluator(Evaluator):
if isinstance(metric, dict):
nni_used_metric = metric.get('default', None)
if nni_used_metric is None:
warn_msg = f'Evaluation function returns a dict metric without key `default`, will return None as the model evaluation metric value.'
warn_msg = f'Evaluation function returns a dict metric without key `default`,' + \
'will return None as the model evaluation metric value.'
_logger.warning(warn_msg)
return nni_used_metric, metric
else:

Просмотреть файл

@ -229,7 +229,8 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
def get_model_weights_numel(model: Module, config_list: List[Dict],
masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
"""
Count the layer weight elements number in config_list.
If masks is not empty, the masked weight will not be counted.

Просмотреть файл

@ -53,18 +53,24 @@ class Scaling:
# for the `-1` in kernel_size, then expand size (4, 3, 1) to size (4, 6, 2).
kernel_padding_mode
'front' or 'back', default is 'front'.
If set 'front', for a given tensor when shrinking, padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding, padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`.
If set 'back', for a given tensor when shrinking, padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding, padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`.
If set 'front', for a given tensor when shrinking,
padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding,
padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`.
If set 'back', for a given tensor when shrinking,
padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding,
padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`.
"""
def __init__(self, kernel_size: List[int], kernel_padding_mode: Literal['front', 'back'] = 'front') -> None:
self.kernel_size = kernel_size
assert kernel_padding_mode in ['front', 'back'], f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."
err_msg = f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."
assert kernel_padding_mode in ['front', 'back'], err_msg
self.kernel_padding_mode = kernel_padding_mode
def _padding(self, _list: List[int], length: int, padding_value: int = -1, padding_mode: Literal['front', 'back'] = 'back') -> List[int]:
def _padding(self, _list: List[int], length: int, padding_value: int = -1,
padding_mode: Literal['front', 'back'] = 'back') -> List[int]:
"""
Padding the `_list` to a specific length with `padding_value`.
@ -144,10 +150,12 @@ class Scaling:
assert b % a == 0, f'Can not expand tensor with {target.shape} to {expand_size} with kernel size {kernel_size}.'
_expand_size.append(b // a)
_expand_size.append(a)
new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim), [new_target] + [2 * _ + 1 for _ in range(len(expand_size))]) # type: ignore
new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim),
[new_target] + [2 * _ + 1 for _ in range(len(expand_size))]) # type: ignore
# step 3: expanding the new target to _expand_size and reshape to expand_size.
# Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`, currently we don't have that need.
# Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`,
# currently we don't have that need.
result = new_target.expand(_expand_size).reshape(expand_size).clone()
return result

Просмотреть файл

@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.algorithms.compression.v2.pytorch import TorchEvaluator, LightningEvaluator
from .speedup import ModelSpeedup
from .compressor import Compressor, Pruner, Quantizer
from .utils.apply_compression import apply_compression_results

Просмотреть файл

@ -3,6 +3,7 @@
"nni/algorithms/compression/pytorch",
"nni/algorithms/compression/tensorflow",
"nni/algorithms/compression/v2/pytorch/base/pruner.py",
"nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py",
"nni/algorithms/feature_engineering",
"nni/algorithms/hpo",
"nni/algorithms/nas",

Просмотреть файл

Просмотреть файл

@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
from typing import Any, Dict, List
import torch
from .device import device
from .simple_mnist import SimpleLightningModel, SimpleTorchModel
from .utils import unfold_config_list
log_dir = Path(__file__).parent.parent / 'logs'
def create_model(model_type: str):
torch_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5},
{'op_names': ['conv1', 'conv2', 'conv3'], 'sparsity': 0.5},
{'op_names': ['fc2'], 'exclude': True}]
lightning_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5},
{'op_names': ['model.conv1', 'model.conv2', 'model.conv3'], 'sparsity': 0.5},
{'op_names': ['model.fc2'], 'exclude': True}]
if model_type == 'lightning':
model = SimpleLightningModel()
config_list = lightning_config_list
dummy_input = torch.rand(8, 1, 28, 28)
elif model_type == 'pytorch':
model = SimpleTorchModel().to(device)
config_list = torch_config_list
dummy_input = torch.rand(8, 1, 28, 28, device=device)
else:
raise ValueError(f'wrong model_type: {model_type}')
return model, config_list, dummy_input
def validate_masks(masks: Dict[str, Dict[str, torch.Tensor]], model: torch.nn.Module, config_list: List[Dict[str, Any]],
is_global: bool = False):
config_dict = unfold_config_list(model, config_list)
# validate if all configured layers have generated mask.
mismatched_op_names = set(config_dict.keys()).symmetric_difference(masks.keys())
assert f'mismatched op_names: {mismatched_op_names}'
target_name = 'weight'
total_masked_numel = 0
total_target_numel = 0
for module_name, target_masks in masks.items():
mask = target_masks[target_name]
assert mask.numel() == (mask == 0).sum().item() + (mask == 1).sum().item(), f'{module_name} {target_name} mask has values other than 0 and 1.'
if not is_global:
excepted_sparsity = config_dict[module_name].get('sparsity', config_dict[module_name].get('total_sparsity'))
real_sparsity = (mask == 0).sum().item() / mask.numel()
err_msg = f'{module_name} {target_name} excepted sparsity: {excepted_sparsity}, but real sparsity: {real_sparsity}'
assert excepted_sparsity * 0.9 < real_sparsity < excepted_sparsity * 1.1, err_msg
else:
total_masked_numel += (mask == 0).sum().item()
total_target_numel += mask.numel()
if is_global:
excepted_sparsity = next(iter(config_dict.values())).get('sparsity', config_dict[module_name].get('total_sparsity'))
real_sparsity = total_masked_numel / total_target_numel
err_msg = f'excepted global sparsity: {excepted_sparsity}, but real global sparsity: {real_sparsity}.'
assert excepted_sparsity * 0.9 < real_sparsity < excepted_sparsity * 1.1, err_msg
def validate_dependency_aware(model_type: str, masks: Dict[str, Dict[str, torch.Tensor]]):
# only for simple_mnist model
if model_type == 'lightning':
assert torch.equal(masks['model.conv2']['weight'].mean([1, 2, 3]), masks['model.conv3']['weight'].mean([1, 2, 3]))
if model_type == 'pytorch':
assert torch.equal(masks['conv2']['weight'].mean([1, 2, 3]), masks['conv3']['weight'].mean([1, 2, 3]))

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .simple_lightning_model import SimpleLightningModel, MNISTDataModule
from .simple_torch_model import SimpleTorchModel, training_model, evaluating_model, finetuning_model
from .simple_evaluator import create_lighting_evaluator, create_pytorch_evaluator

Просмотреть файл

@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import nni
from nni.algorithms.compression.v2.pytorch import LightningEvaluator, TorchEvaluator
from .simple_torch_model import training_model, evaluating_model
from .simple_lightning_model import MNISTDataModule
from ..common import device
def create_lighting_evaluator() -> LightningEvaluator:
pl_trainer = nni.trace(pl.Trainer)(
accelerator='auto',
devices=1,
max_epochs=1,
max_steps=50,
logger=TensorBoardLogger(Path(__file__).parent.parent / 'lightning_logs', name="resnet"),
)
pl.Trainer()
pl_trainer.num_sanity_val_steps = 0
pl_data = nni.trace(MNISTDataModule)(data_dir='data/mnist')
evaluator = LightningEvaluator(pl_trainer, pl_data, dummy_input=torch.rand(8, 1, 28, 28))
return evaluator
def create_pytorch_evaluator(model: torch.nn.Module) -> TorchEvaluator:
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
evaluator = TorchEvaluator(training_model, optimizer, F.nll_loss, lr_scheduler,
dummy_input=torch.rand(8, 1, 28, 28, device=device), evaluating_func=evaluating_model)
return evaluator

Просмотреть файл

@ -0,0 +1,105 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split, DataLoader
from torchmetrics.functional import accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
import nni
from .simple_torch_model import SimpleTorchModel
class SimpleLightningModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = SimpleTorchModel()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f"default", loss, prog_bar=False)
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def configure_optimizers(self):
optimizer = nni.trace(torch.optim.SGD)(
self.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=5e-4,
)
scheduler_dict = {
"scheduler": nni.trace(ExponentialLR)(
optimizer,
0.1,
),
"interval": "epoch",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = 'data/mnist'
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str | None = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)

Просмотреть файл

@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from pathlib import Path
from typing import Callable
import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from ..device import device
class SimpleTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 16, 3)
self.bn1 = torch.nn.BatchNorm2d(16)
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4)
self.bn2 = torch.nn.BatchNorm2d(8)
self.conv3 = torch.nn.Conv2d(16, 8, 3)
self.bn3 = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x: torch.Tensor):
x = self.bn1(self.conv1(x))
x = self.bn2(self.conv2(x)) + self.bn3(self.conv3(x))
x = self.fc2(self.fc1(x.reshape(x.shape[0], -1)))
return F.log_softmax(x, -1)
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler = None,
max_steps: int | None = None, max_epochs: int | None = None, device: torch.device = device):
model.train()
# prepare data
MNIST(root='data/mnist', train=True, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=32)
max_epochs = max_epochs if max_epochs else 1
max_steps = max_steps if max_steps else 50
current_steps = 0
# training
for _ in range(max_epochs):
for x, y in train_dataloader:
optimizer.zero_grad()
x, y = x.to(device), y.to(device)
logits = model(x)
loss: torch.Tensor = criterion(logits, y)
loss.backward()
optimizer.step()
current_steps += 1
if max_steps and current_steps == max_steps:
return
if scheduler is not None:
scheduler.step()
def finetuning_model(model: Module):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
training_model(model, optimizer, F.nll_loss)
def evaluating_model(model: Module, device: torch.device = device):
model.eval()
# prepare data
MNIST(root='data/mnist', train=False, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
test_dataloader = DataLoader(mnist_test, batch_size=32)
# testing
correct = 0
with torch.no_grad():
for x, y in test_dataloader:
x, y = x.to(device), y.to(device)
logits = model(x)
preds = torch.argmax(logits, dim=1)
correct += preds.eq(y.view_as(preds)).sum().item()
return correct / len(mnist_test)

Просмотреть файл

@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
######################################################################################
# NOTE: copy from branch wrapper-refactor, will rm this file in this or next release.#
######################################################################################
from copy import deepcopy
import logging
from typing import Any, Dict, List
from torch.nn import Module
_logger = logging.getLogger(__name__)
def _unfold_op_partial_names(model: Module, config_list: List[Dict]) -> List[Dict]:
config_list = deepcopy(config_list)
full_op_names = [op_name for op_name, _ in model.named_modules()]
for config in config_list:
op_names = config.pop('op_names', [])
op_partial_names = config.pop('op_partial_names', [])
for op_partial_name in op_partial_names:
op_names.extend([op_name for op_name in full_op_names if op_partial_name in op_name])
config['op_names'] = list(set(op_names))
return config_list
def unfold_config_list(model: Module, config_list: List[Dict]) -> Dict[str, Dict[str, Any]]:
'''
Unfold config_list to op_names level, return a config_dict {op_name: config}.
'''
config_list = _unfold_op_partial_names(model=model, config_list=config_list)
config_dict = {}
for config in config_list:
for key in ['op_types', 'op_names', 'exclude_op_names']:
config.setdefault(key, [])
op_names = []
for module_name, module in model.named_modules():
module_type = type(module).__name__
if (module_type in config['op_types'] or module_name in config['op_names']) and module_name not in config['exclude_op_names']:
op_names.append(module_name)
config_template = deepcopy(config)
for key in ['op_types', 'op_names', 'exclude_op_names']:
config_template.pop(key, [])
for op_name in op_names:
if op_name in config_dict:
warn_msg = f'{op_name} duplicate definition of config, replace old config:\n' + \
f'{config_dict[op_name]}\n' + \
f'with new config:\n{config_template}\n'
_logger.warning(warn_msg)
config_dict[op_name] = deepcopy(config_template)
return config_dict

Просмотреть файл

@ -3,193 +3,23 @@
from __future__ import annotations
from pathlib import Path
from typing import Callable
import pytest
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler
from torch.utils.data import random_split, DataLoader
from torchmetrics.functional import accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
import nni
from nni.algorithms.compression.v2.pytorch.utils.evaluator import (
TorchEvaluator,
LightningEvaluator,
TensorHook,
ForwardHook,
BackwardHook,
)
class SimpleTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 16, 3)
self.bn1 = torch.nn.BatchNorm2d(16)
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4)
self.bn2 = torch.nn.BatchNorm2d(8)
self.conv3 = torch.nn.Conv2d(16, 8, 3)
self.bn3 = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x: torch.Tensor):
x = self.bn1(self.conv1(x))
x = self.bn2(self.conv2(x)) + self.bn3(self.conv3(x))
x = self.fc2(self.fc1(x.reshape(x.shape[0], -1)))
return F.log_softmax(x, -1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler,
max_steps: int | None = None, max_epochs: int | None = None):
model.train()
# prepare data
data_dir = Path(__file__).parent / 'data'
MNIST(data_dir, train=True, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(data_dir, train=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=32)
max_epochs = max_epochs if max_epochs else 1
max_steps = max_steps if max_steps else 10
current_steps = 0
# training
for _ in range(max_epochs):
for x, y in train_dataloader:
optimizer.zero_grad()
x, y = x.to(device), y.to(device)
logits = model(x)
loss: torch.Tensor = criterion(logits, y)
loss.backward()
optimizer.step()
current_steps += 1
if max_steps and current_steps == max_steps:
return
scheduler.step()
def evaluating_model(model: Module):
model.eval()
# prepare data
data_dir = Path(__file__).parent / 'data'
MNIST(data_dir, train=False, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_test = MNIST(data_dir, train=False, transform=transform)
test_dataloader = DataLoader(mnist_test, batch_size=32)
# testing
correct = 0
with torch.no_grad():
for x, y in test_dataloader:
x, y = x.to(device), y.to(device)
logits = model(x)
preds = torch.argmax(logits, dim=1)
correct += preds.eq(y.view_as(preds)).sum().item()
return correct / len(mnist_test)
class SimpleLightningModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = SimpleTorchModel()
self.count = 0
def forward(self, x):
print(self.count)
self.count += 1
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def configure_optimizers(self):
optimizer = nni.trace(torch.optim.SGD)(
self.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=5e-4,
)
scheduler_dict = {
"scheduler": nni.trace(ExponentialLR)(
optimizer,
0.1,
),
"interval": "epoch",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str | None = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
from ..assets.device import device
from ..assets.simple_mnist import (
SimpleLightningModel,
SimpleTorchModel,
create_lighting_evaluator,
create_pytorch_evaluator
)
optimizer_before_step_flag = False
@ -237,41 +67,20 @@ def assert_flags():
assert loss_flag, 'Evaluator patch loss failed.'
def create_lighting_evaluator():
pl_model = SimpleLightningModel()
pl_trainer = nni.trace(pl.Trainer)(
max_epochs=1,
max_steps=10,
logger=TensorBoardLogger(Path(__file__).parent / 'lightning_logs', name="resnet"),
)
pl_trainer.num_sanity_val_steps = 0
pl_data = nni.trace(MNISTDataModule)(data_dir=Path(__file__).parent / 'data')
evaluator = LightningEvaluator(pl_trainer, pl_data)
evaluator._init_optimizer_helpers(pl_model)
return evaluator
def create_pytorch_evaluator():
model = SimpleTorchModel()
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
evaluator = TorchEvaluator(training_model, optimizer, F.nll_loss, lr_scheduler, evaluating_func=evaluating_model)
evaluator._init_optimizer_helpers(model)
return evaluator
@pytest.mark.parametrize("evaluator_type", ['lightning', 'pytorch'])
def test_evaluator(evaluator_type: str):
if evaluator_type == 'lightning':
evaluator = create_lighting_evaluator()
model = SimpleLightningModel()
evaluator = create_lighting_evaluator()
evaluator._init_optimizer_helpers(model)
evaluator.bind_model(model)
tensor_hook = TensorHook(model.model.conv1.weight, 'model.conv1.weight', tensor_hook_factory)
forward_hook = ForwardHook(model.model.conv1, 'model.conv1', forward_hook_factory)
backward_hook = BackwardHook(model.model.conv1, 'model.conv1', backward_hook_factory)
elif evaluator_type == 'pytorch':
evaluator = create_pytorch_evaluator()
model = SimpleTorchModel().to(device)
evaluator = create_pytorch_evaluator(model)
evaluator._init_optimizer_helpers(model)
evaluator.bind_model(model)
tensor_hook = TensorHook(model.conv1.weight, 'conv1.weight', tensor_hook_factory)
forward_hook = ForwardHook(model.conv1, 'conv1', forward_hook_factory)
@ -296,4 +105,4 @@ def test_evaluator(evaluator_type: str):
evaluator.finetune()
assert_flags()
assert all([len(hook.buffer) == 10 for hook in [tensor_hook, forward_hook, backward_hook]])
assert all([len(hook.buffer) == 50 for hook in [tensor_hook, forward_hook, backward_hook]])

Просмотреть файл

@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import pytest
import torch
import torch.nn.functional as F
import nni
from nni.compression.pytorch.pruning import (
LinearPruner,
AGPPruner,
LotteryTicketPruner,
SimulatedAnnealingPruner,
AutoCompressPruner
)
from ..assets.common import create_model, log_dir, validate_masks, validate_dependency_aware
from ..assets.device import device
from ..assets.simple_mnist import (
create_lighting_evaluator,
create_pytorch_evaluator,
training_model,
finetuning_model,
evaluating_model
)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('pruning_type', ['linear', 'agp', 'lottory'])
@pytest.mark.parametrize('speedup', [True, False])
def test_functional_pruner(model_type: str, using_evaluator: bool, pruning_type: str, speedup: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
if pruning_type == 'linear':
pruner = LinearPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'agp':
pruner = AGPPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'lottory':
pruner = LotteryTicketPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
else:
model.to(device)
dummy_input = dummy_input.to(device)
if pruning_type == 'linear':
pruner = LinearPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'agp':
pruner = AGPPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'lottory':
pruner = LotteryTicketPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
pruner.compress()
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
best_model(dummy_input)
validate_masks(best_masks, best_model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
def test_sa_pruner(model_type: str, using_evaluator: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = SimulatedAnnealingPruner(model=model, config_list=config_list, evaluator=evaluator, start_temperature=100,
stop_temperature=80, cool_down_rate=0.9, perturbation_magnitude=0.35, pruning_algorithm='l1',
pruning_params={}, log_dir=log_dir, keep_intermediate_result=False, speedup=False)
else:
model.to(device)
dummy_input = dummy_input.to(device)
pruner = SimulatedAnnealingPruner(model=model, config_list=config_list, evaluator=evaluating_model, start_temperature=100,
stop_temperature=80, cool_down_rate=0.9, perturbation_magnitude=0.35, pruning_algorithm='l1',
pruning_params={}, log_dir=log_dir, keep_intermediate_result=False, speedup=False)
pruner.compress()
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
best_model(dummy_input)
validate_masks(best_masks, best_model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
def test_auto_compress_pruner(model_type: str, using_evaluator: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
admm_params = {'evaluator': evaluator, 'iterations': 2, 'training_epochs': 1, 'granularity': 'coarse-grained'}
sa_params = {'evaluator': evaluator, 'start_temperature': 100, 'stop_temperature': 80, 'pruning_algorithm': 'l1'}
pruner = AutoCompressPruner(model=model, config_list=config_list, total_iteration=2, admm_params=admm_params, sa_params=sa_params,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=False)
else:
model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
admm_params = {'trainer': training_model, 'traced_optimizer': optimizer, 'criterion': F.nll_loss, 'iterations': 2,
'training_epochs': 1, 'granularity': 'coarse-grained'}
sa_params = {'evaluator': evaluating_model, 'start_temperature': 100, 'stop_temperature': 80, 'pruning_algorithm': 'l1'}
pruner = AutoCompressPruner(model=model, config_list=config_list, total_iteration=2, admm_params=admm_params, sa_params=sa_params,
log_dir=log_dir, keep_intermediate_result=False, finetuner=finetuning_model, speedup=False,
dummy_input=dummy_input, evaluator=evaluating_model)
pruner.compress()
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
best_model(dummy_input)
validate_masks(best_masks, best_model, config_list)
# we still need AMCPruner test, but it cost a lot, will add after we have GPU pool.

Просмотреть файл

@ -0,0 +1,172 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import pytest
import torch
import torch.nn.functional as F
import nni
from nni.compression.pytorch.pruning import (
LevelPruner,
L1NormPruner,
L2NormPruner,
SlimPruner,
FPGMPruner,
ActivationAPoZRankPruner,
ActivationMeanRankPruner,
TaylorFOWeightPruner,
ADMMPruner,
MovementPruner
)
from ..assets.device import device
from ..assets.simple_mnist import (
create_lighting_evaluator,
create_pytorch_evaluator,
training_model
)
from ..assets.common import create_model, validate_masks, validate_dependency_aware
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
def test_level_pruner(model_type: str):
model, config_list, dummy_input = create_model(model_type)
pruner = LevelPruner(model=model, config_list=config_list)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('pruning_type', ['l1', 'l2', 'fpgm'])
@pytest.mark.parametrize('mode', ['normal', 'dependency_aware'])
def test_norm_pruner(model_type: str, pruning_type: str, mode: str):
model, config_list, dummy_input = create_model(model_type)
if pruning_type == 'l1':
pruner = L1NormPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
elif pruning_type == 'l2':
pruner = L2NormPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
elif pruning_type == 'fpgm':
pruner = FPGMPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
else:
raise ValueError(f'wrong norm: {pruning_type}')
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
if mode == 'dependency_aware':
validate_dependency_aware(model_type, masks)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('mode', ['global', 'normal'])
def test_slim_pruner(model_type: str, using_evaluator: bool, mode: str):
model, _, dummy_input = create_model(model_type)
config_list = [{'op_types': ['BatchNorm2d'], 'total_sparsity': 0.5}]
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = SlimPruner(model=model, config_list=config_list, evaluator=evaluator, training_epochs=1, scale=0.0001, mode=mode)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = SlimPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_epochs=1, scale=0.0001, mode=mode)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list, is_global=(mode == 'global'))
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('pruning_type', ['apoz', 'mean', 'taylor'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('mode', ['normal', 'dependency_aware'])
def test_hook_based_pruner(model_type: str, pruning_type: str, using_evaluator: bool, mode: str):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
if pruning_type == 'apoz':
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'mean':
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'taylor':
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
mode=mode, dummy_input=dummy_input)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
if pruning_type == 'apoz':
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_batches=20, activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'mean':
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_batches=20, activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'taylor':
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_batches=20, mode=mode, dummy_input=dummy_input)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
if mode == 'dependency_aware':
validate_dependency_aware(model_type, masks)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('granularity', ['fine-grained', 'coarse-grained'])
def test_admm_pruner(model_type: str, using_evaluator: bool, granularity: str):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = ADMMPruner(model=model, config_list=config_list, evaluator=evaluator, iterations=2, training_epochs=1, granularity=granularity)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = ADMMPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer, criterion=F.nll_loss,
iterations=2, training_epochs=1, granularity=granularity)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
def test_movement_pruner(model_type: str, using_evaluator: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = MovementPruner(model=model, config_list=config_list, evaluator=evaluator, training_epochs=1, warm_up_step=10, cool_down_beginning_step=40)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = MovementPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer, criterion=F.nll_loss,
training_epochs=1, warm_up_step=10, cool_down_beginning_step=40)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)

Просмотреть файл

@ -8,14 +8,20 @@ import torch.nn.functional as F
import nni
from nni.algorithms.compression.v2.pytorch.base import Pruner
# TODO: remove in nni v3.0.
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
WeightDataCollector,
WeightTrainerBasedDataCollector,
SingleHookTrainerBasedDataCollector
)
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
TargetDataCollector,
EvaluatorBasedTargetDataCollector,
EvaluatorBasedHookDataCollector
)
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
HookDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
@ -84,7 +90,7 @@ class PruningToolsTestCase(unittest.TestCase):
# Test WeightDataCollector
data_collector = WeightDataCollector(pruner)
data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]['weight']) for module_name in ['conv1', 'conv2'])
# Test WeightTrainerBasedDataCollector
def opt_after():
@ -94,8 +100,8 @@ class PruningToolsTestCase(unittest.TestCase):
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 1, opt_after_tasks=[opt_after])
data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values())
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]['weight']) for module_name in ['conv1', 'conv2'])
assert all(t['weight'].numel() == (t['weight'] == 1).type_as(t['weight']).sum().item() for t in data.values())
# Test SingleHookTrainerBasedDataCollector
def _collector(buffer, weight_tensor):
@ -109,73 +115,73 @@ class PruningToolsTestCase(unittest.TestCase):
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 2, collector_infos=[collector_info])
data = data_collector.collect()
assert all(len(t) == 2 for t in data.values())
assert all(len(t['weight']) == 2 for t in data.values())
def test_metrics_calculator(self):
# Test NormMetricsCalculator
metrics_calculator = NormMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = {
'1': torch.ones(3, 3, 3),
'2': torch.ones(4, 4) * 2
'1': {'target_name': torch.ones(3, 3, 3)},
'2': {'target_name': torch.ones(4, 4) * 2}
}
result = {
'1': torch.ones(3) * 3,
'2': torch.ones(4) * 4
'1': {'target_name': torch.ones(3) * 3},
'2': {'target_name': torch.ones(4) * 4}
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test DistMetricsCalculator
metrics_calculator = DistMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = {
'1': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32),
'2': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)
'1': {'target_name': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32)},
'2': {'target_name': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)}
}
result = {
'1': torch.tensor([5, 5], dtype=torch.float32),
'2': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32))
'1': {'target_name': torch.tensor([5, 5], dtype=torch.float32)},
'2': {'target_name': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32))}
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test MultiDataNormMetricsCalculator
metrics_calculator = MultiDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
# Test HookDataNormMetricsCalculator
metrics_calculator = HookDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = {
'1': [2, torch.ones(3, 3, 3) * 2],
'2': [2, torch.ones(4, 4) * 2]
'1': {'target_name': [2, torch.ones(3, 3, 3) * 2]},
'2': {'target_name': [2, torch.ones(4, 4) * 2]}
}
result = {
'1': torch.ones(3) * 18,
'2': torch.ones(4) * 8
'1': {'target_name': torch.ones(3) * 18},
'2': {'target_name': torch.ones(4) * 8}
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test APoZRankMetricsCalculator
metrics_calculator = APoZRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
data = {
'1': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)],
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
'1': {'target_name': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)]},
'2': {'target_name': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]}
}
result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([1, 1, 0.75], dtype=torch.float32)
'1': {'target_name': torch.tensor([0.5, 0.5], dtype=torch.float32)},
'2': {'target_name': torch.tensor([1, 1, 0.75], dtype=torch.float32)}
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test MeanRankMetricsCalculator
metrics_calculator = MeanRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
data = {
'1': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
'1': {'target_name': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)]},
'2': {'target_name': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]}
}
result = {
'1': torch.tensor([0.25, 0.25], dtype=torch.float32),
'2': torch.tensor([0, 0, 0.25], dtype=torch.float32)
'1': {'target_name': torch.tensor([0.25, 0.25], dtype=torch.float32)},
'2': {'target_name': torch.tensor([0, 0, 0.25], dtype=torch.float32)}
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
def test_sparsity_allocator(self):
# Test NormalSparsityAllocator
@ -183,8 +189,8 @@ class PruningToolsTestCase(unittest.TestCase):
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = Pruner(model, config_list)
metrics = {
'conv1': torch.rand(5, 1, 5, 5),
'conv2': torch.rand(10, 5, 5, 5)
'conv1': {'weight': torch.rand(5, 1, 5, 5)},
'conv2': {'weight': torch.rand(10, 5, 5, 5)}
}
sparsity_allocator = NormalSparsityAllocator(pruner)
masks = sparsity_allocator.generate_sparsity(metrics)