зеркало из https://github.com/microsoft/nni.git
[Compression] evaluator - step2 (#4992)
This commit is contained in:
Родитель
a689e619c4
Коммит
ed455174db
|
@ -5,3 +5,9 @@ Quickstart
|
|||
|
||||
PyTorch </tutorials/hpo_quickstart_pytorch/main>
|
||||
TensorFlow </tutorials/hpo_quickstart_tensorflow/main>
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
/tutorials/hpo_quickstart_pytorch/index
|
||||
/tutorials/hpo_quickstart_tensorflow/index
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
Evaluator
|
||||
=========
|
||||
|
||||
.. _compression-torch-evaluator:
|
||||
|
||||
TorchEvaluator
|
||||
--------------
|
||||
|
||||
.. autoclass:: nni.compression.pytorch.TorchEvaluator
|
||||
|
||||
.. _compression-lightning-evaluator:
|
||||
|
||||
LightningEvaluator
|
||||
------------------
|
||||
|
||||
.. autoclass:: nni.compression.pytorch.LightningEvaluator
|
|
@ -8,5 +8,6 @@ Compression API Reference
|
|||
Quantizer <quantizer>
|
||||
Pruning Speedup <pruning_speedup>
|
||||
Quantization Speedup <quantization_speedup>
|
||||
Evaluator <evaluator>
|
||||
Compression Utilities <utils>
|
||||
Framework Related <framework>
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .utils import LightningEvaluator, TorchEvaluator
|
|
@ -119,7 +119,8 @@ class Compressor:
|
|||
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
|
||||
The model will be instrumented and user should never edit it after calling this method.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.bound_model is not None, err_msg
|
||||
|
||||
if self._modules_to_compress is None:
|
||||
self._modules_to_compress = []
|
||||
|
@ -146,7 +147,8 @@ class Compressor:
|
|||
Optional[Dict]
|
||||
The retrieved configuration for this layer, if None, this layer should not be compressed.
|
||||
"""
|
||||
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.config_list is not None, err_msg
|
||||
|
||||
ret = None
|
||||
for config in self.config_list:
|
||||
|
@ -240,8 +242,10 @@ class Compressor:
|
|||
Dict[int, List[str]]
|
||||
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.bound_model is not None, err_msg
|
||||
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.config_list is not None, err_msg
|
||||
|
||||
self._unwrap_model()
|
||||
module_groups = {}
|
||||
|
@ -323,6 +327,8 @@ class Compressor:
|
|||
torch.nn.Module
|
||||
model with specified modules compressed.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.bound_model is not None, err_msg
|
||||
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.config_list is not None, err_msg
|
||||
return self.bound_model
|
||||
|
|
|
@ -43,8 +43,8 @@ class PrunerModuleWrapper(Module):
|
|||
pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
|
||||
pruning_target = getattr(self.module, pruning_target_name, None)
|
||||
if hasattr(self.module, pruning_target_name) and pruning_target is not None:
|
||||
setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape)))
|
||||
self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape))
|
||||
setattr(self, pruning_target_name, Parameter(torch.empty_like(pruning_target)))
|
||||
self.register_buffer(pruning_target_mask_name, torch.ones_like(pruning_target))
|
||||
else:
|
||||
self.register_buffer(pruning_target_mask_name, None)
|
||||
|
||||
|
@ -67,11 +67,11 @@ class PrunerModuleWrapper(Module):
|
|||
The best place to call this function is in `Pruner._unwrap_model()`.
|
||||
"""
|
||||
delattr(self.module, 'weight')
|
||||
self.module.weight = Parameter(torch.empty(self.weight.size()))
|
||||
self.module.weight = Parameter(torch.empty_like(self.weight))
|
||||
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
delattr(self.module, 'bias')
|
||||
self.module.bias = Parameter(torch.empty(self.bias.size()))
|
||||
self.module.bias = Parameter(torch.empty_like(self.bias))
|
||||
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
|
||||
|
||||
def forward(self, *inputs):
|
||||
|
@ -130,7 +130,8 @@ class Pruner(Compressor):
|
|||
Wrap all modules that needed to be compressed.
|
||||
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.bound_model is not None, err_msg
|
||||
|
||||
if not self.is_wrapped:
|
||||
for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
|
||||
|
@ -143,7 +144,8 @@ class Pruner(Compressor):
|
|||
Unwrap all modules that needed to be compressed.
|
||||
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
|
||||
"""
|
||||
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
|
||||
assert self.bound_model is not None, err_msg
|
||||
|
||||
if self.is_wrapped:
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
|
@ -165,8 +167,10 @@ class Pruner(Compressor):
|
|||
self._unwrap_model()
|
||||
parameter_name_map = {}
|
||||
for name, param in self.bound_model.named_parameters():
|
||||
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap.
|
||||
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap.
|
||||
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`,
|
||||
# the name will not change after wrap.
|
||||
# If the parameter name in under wrapped module is others,
|
||||
# the name `xxx.param` will change to `xxx.module.param` after wrap.
|
||||
parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name
|
||||
self._wrap_model()
|
||||
return parameter_name_map
|
||||
|
@ -183,14 +187,12 @@ class Pruner(Compressor):
|
|||
The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
|
||||
"""
|
||||
wrappers = self.get_modules_wrapper()
|
||||
for name, layer_mask in masks.items():
|
||||
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name)
|
||||
if layer_mask.get('weight') is not None:
|
||||
assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.'
|
||||
setattr(wrappers[name], 'weight_mask', layer_mask.get('weight'))
|
||||
if layer_mask.get('bias') is not None:
|
||||
assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.'
|
||||
setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))
|
||||
for module_name, target_masks in masks.items():
|
||||
assert module_name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(module_name)
|
||||
for target_name, target_mask in target_masks.items():
|
||||
assert hasattr(wrappers[module_name], f'{target_name}_mask'), f'There is no attribute {target_name}_mask in wrapper.'
|
||||
target: Tensor = getattr(self.get_modules_wrapper()[module_name], target_name)
|
||||
setattr(wrappers[module_name], f'{target_name}_mask', target_mask.to(target.device))
|
||||
|
||||
def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
|
||||
"""
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Callable, Optional, cast
|
||||
from typing import Dict, List, Callable, Optional, cast, overload
|
||||
|
||||
import json_tricks
|
||||
import torch
|
||||
|
@ -11,12 +13,13 @@ from torch import Tensor
|
|||
from torch.nn import Module
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
|
||||
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity, config_list_canonical
|
||||
from nni.compression.pytorch.utils import count_flops_params
|
||||
|
||||
from .iterative_pruner import IterativePruner, PRUNER_DICT
|
||||
from .tools import TaskGenerator
|
||||
from .tools.rl_env import DDPG, AMCEnv
|
||||
from ..utils import LightningEvaluator, TorchEvaluator, compute_sparsity, config_list_canonical
|
||||
from ..utils.docstring import _EVALUATOR_DOCSTRING
|
||||
|
||||
|
||||
class AMCTaskGenerator(TaskGenerator):
|
||||
|
@ -41,8 +44,8 @@ class AMCTaskGenerator(TaskGenerator):
|
|||
ddpg_params
|
||||
The ddpg agent parameters.
|
||||
target : str
|
||||
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
|
||||
This parameter is used to explain what the sparsity setting in config_list refers to.
|
||||
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse,
|
||||
but in AMC, you can choose flops sparse. This parameter is used to explain what the sparsity setting in config_list refers to.
|
||||
"""
|
||||
|
||||
def __init__(self, total_episode: int, dummy_input: Tensor, origin_model: Module, origin_config_list: List[Dict],
|
||||
|
@ -56,7 +59,7 @@ class AMCTaskGenerator(TaskGenerator):
|
|||
self.config_list_copy = deepcopy(origin_config_list)
|
||||
|
||||
super().__init__(origin_model=origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
|
||||
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
|
||||
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
|
||||
|
||||
def init_pending_tasks(self) -> List[Task]:
|
||||
origin_model = torch.load(self._origin_model_path)
|
||||
|
@ -82,6 +85,8 @@ class AMCTaskGenerator(TaskGenerator):
|
|||
return self.generate_tasks(task_result)
|
||||
|
||||
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
|
||||
self.temp_config_list = self.temp_config_list if hasattr(self, 'temp_config_list') else []
|
||||
|
||||
# append experience & update agent policy
|
||||
if self.action is not None:
|
||||
action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
|
||||
|
@ -106,7 +111,8 @@ class AMCTaskGenerator(TaskGenerator):
|
|||
origin_model = torch.load(self._origin_model_path)
|
||||
compact_model = task_result.compact_model
|
||||
compact_model_masks = task_result.compact_model_masks
|
||||
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.temp_config_list)
|
||||
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
|
||||
self.temp_config_list) # type: ignore
|
||||
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
|
||||
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy)
|
||||
self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity
|
||||
|
@ -162,7 +168,7 @@ class AMCTaskGenerator(TaskGenerator):
|
|||
|
||||
|
||||
class AMCPruner(IterativePruner):
|
||||
r"""
|
||||
__doc__ = r"""
|
||||
AMC pruner leverages reinforcement learning to provide the model compression policy.
|
||||
According to the author, this learning-based compression policy outperforms conventional rule-based compression policy by having a higher compression ratio,
|
||||
better preserving the accuracy and freeing human labor.
|
||||
|
@ -186,10 +192,11 @@ class AMCPruner(IterativePruner):
|
|||
- op_names : Operation name to be pruned.
|
||||
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
|
||||
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
|
||||
dummy_input : torch.Tensor
|
||||
`dummy_input` is required for speedup and tracing the model in RL environment.
|
||||
evaluator : Callable[[Module], float]
|
||||
Evaluate the pruned model and give a score.
|
||||
evaluator
|
||||
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
|
||||
{evaluator_docstring}
|
||||
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
pruning_algorithm : str
|
||||
Supported pruning algorithm ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'].
|
||||
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
|
||||
|
@ -197,8 +204,6 @@ class AMCPruner(IterativePruner):
|
|||
The log directory use to saving the result, you can find the best result under this folder.
|
||||
keep_intermediate_result : bool
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
finetuner : Optional[Callable[[Module], None]]
|
||||
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
|
||||
ddpg_params : Dict
|
||||
Configuration dict to configure the DDPG agent, any key unset will be set to default implicitly.
|
||||
- hidden1: hidden num of first fully connect layer. Default: 300
|
||||
|
@ -223,23 +228,42 @@ class AMCPruner(IterativePruner):
|
|||
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
|
||||
This parameter is used to explain what the sparsity setting in config_list refers to.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from nni.compression.pytorch.pruning import AMCPruner
|
||||
>>> config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
|
||||
>>> dummy_input = torch.rand(...).to(device)
|
||||
>>> evaluator = ...
|
||||
>>> finetuner = ...
|
||||
>>> pruner = AMCPruner(400, model, config_list, dummy_input, evaluator, finetuner=finetuner)
|
||||
>>> pruner.compress()
|
||||
|
||||
Notes
|
||||
-----
|
||||
The full script can be found :githublink:`here <examples/model_compress/pruning/amc_pruning_torch.py>`.
|
||||
"""
|
||||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
@overload
|
||||
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
|
||||
pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], dummy_input: Tensor,
|
||||
evaluator: Callable[[Module], float], pruning_algorithm: str = 'l1', log_dir: str = '.',
|
||||
keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None,
|
||||
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
|
||||
...
|
||||
|
||||
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], *args, **kwargs):
|
||||
new_api = ['evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'ddpg_params', 'pruning_params', 'target']
|
||||
new_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False,
|
||||
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
|
||||
old_api = ['dummy_input', 'evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'finetuner', 'ddpg_params',
|
||||
'pruning_params', 'target']
|
||||
old_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None,
|
||||
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
|
||||
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
|
||||
|
||||
pruning_algorithm = init_kwargs['pruning_algorithm']
|
||||
log_dir = init_kwargs['log_dir']
|
||||
keep_intermediate_result = init_kwargs['keep_intermediate_result']
|
||||
ddpg_params = init_kwargs['ddpg_params']
|
||||
pruning_params = init_kwargs['pruning_params']
|
||||
target = init_kwargs['target']
|
||||
dummy_input = self.dummy_input if not self.using_evaluator else self.evaluator.get_dummy_input()
|
||||
|
||||
assert pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'], \
|
||||
"Only support pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']"
|
||||
task_generator = AMCTaskGenerator(total_episode=total_episode,
|
||||
|
@ -251,5 +275,9 @@ class AMCPruner(IterativePruner):
|
|||
ddpg_params=ddpg_params,
|
||||
target=target)
|
||||
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
|
||||
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=True, dummy_input=dummy_input,
|
||||
evaluator=evaluator, reset_weight=False)
|
||||
|
||||
if self.using_evaluator:
|
||||
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=True, reset_weight=False)
|
||||
else:
|
||||
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=True, dummy_input=self.dummy_input,
|
||||
evaluator=self._evaluator, reset_weight=False) # type: ignore
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Callable, Optional
|
||||
from typing import Dict, List, Callable, Optional, overload
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
|
||||
|
||||
from .basic_pruner import ADMMPruner
|
||||
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
|
||||
from .tools import LotteryTicketTaskGenerator
|
||||
from ..utils import LightningEvaluator, TorchEvaluator, OptimizerConstructHelper
|
||||
from ..utils.docstring import _EVALUATOR_DOCSTRING
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -21,10 +23,7 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
|
|||
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
|
||||
origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.',
|
||||
keep_intermediate_result: bool = False):
|
||||
self.iterative_pruner = SimulatedAnnealingPruner(model=None,
|
||||
config_list=None,
|
||||
log_dir=Path(log_dir, 'SA'),
|
||||
**sa_params)
|
||||
self._sa_params = sa_params
|
||||
super().__init__(total_iteration=total_iteration,
|
||||
origin_model=origin_model,
|
||||
origin_config_list=origin_config_list,
|
||||
|
@ -36,12 +35,20 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
|
|||
# TODO: replace with validation here
|
||||
for config in config_list:
|
||||
if 'sparsity' in config or 'sparsity_per_layer' in config:
|
||||
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
|
||||
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
|
||||
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
|
||||
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
|
||||
_logger.warning(warn_msg)
|
||||
return super().reset(model, config_list, masks)
|
||||
|
||||
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
|
||||
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
|
||||
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
|
||||
if not hasattr(self, 'iterative_pruner'):
|
||||
self.iterative_pruner = SimulatedAnnealingPruner(model=model,
|
||||
config_list=config_list,
|
||||
log_dir=Path(self._log_dir_root, 'SA'),
|
||||
**self._sa_params)
|
||||
else:
|
||||
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
|
||||
|
||||
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
|
||||
self._iterative_pruner_reset(model, new_config_list, masks)
|
||||
|
@ -53,8 +60,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
|
|||
|
||||
|
||||
class AutoCompressPruner(IterativePruner):
|
||||
r"""
|
||||
__doc__ = r"""
|
||||
For total iteration number :math:`N`, AutoCompressPruner prune the model that survive the previous iteration for a fixed sparsity ratio (e.g., :math:`1-{(1-0.8)}^{(1/N)}`) to achieve the overall sparsity (e.g., :math:`0.8`):
|
||||
""" + r"""
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
@ -65,35 +73,27 @@ class AutoCompressPruner(IterativePruner):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
model : Module
|
||||
model
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
config_list : List[Dict]
|
||||
config_list
|
||||
The origin config list provided by the user.
|
||||
total_iteration : int
|
||||
total_iteration
|
||||
The total iteration number.
|
||||
evaluator : Callable[[Module], float]
|
||||
Evaluate the pruned model and give a score.
|
||||
admm_params : Dict
|
||||
admm_params
|
||||
The parameters passed to the ADMMPruner.
|
||||
|
||||
- trainer : Callable[[Module, Optimizer, Callable].
|
||||
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
|
||||
The model will be trained or inferenced `training_epochs` epochs.
|
||||
- traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
|
||||
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
|
||||
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
|
||||
- criterion : Callable[[Tensor, Tensor], Tensor].
|
||||
The criterion function used in trainer. Take model output and target value as input, and return the loss.
|
||||
- evaluator : LightningEvaluator or TorchEvaluator.
|
||||
The same with the evaluator of AutoCompressPruner input parameter.
|
||||
- iterations : int.
|
||||
The total iteration number in admm pruning algorithm.
|
||||
- training_epochs : int.
|
||||
The epoch number for training model in each iteration.
|
||||
|
||||
sa_params : Dict
|
||||
sa_params
|
||||
The parameters passed to the SimulatedAnnealingPruner.
|
||||
|
||||
- evaluator : Callable[[Module], float]. Required.
|
||||
Evaluate the pruned model and give a score.
|
||||
- evaluator : LightningEvaluator or TorchEvaluator.
|
||||
The same with the evaluator of AutoCompressPruner input parameter.
|
||||
- start_temperature : float. Default: `100`.
|
||||
Start temperature of the simulated annealing process.
|
||||
- stop_temperature : float. Default: `20`.
|
||||
|
@ -104,54 +104,50 @@ class AutoCompressPruner(IterativePruner):
|
|||
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
|
||||
- pruning_algorithm : str. Default: `'level'`.
|
||||
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
|
||||
- pruning_params : Dict. Default: `{}`.
|
||||
- pruning_params : Dict. Default: dict().
|
||||
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
|
||||
|
||||
log_dir : str
|
||||
log_dir
|
||||
The log directory used to save the result, you can find the best result under this folder.
|
||||
keep_intermediate_result : bool
|
||||
keep_intermediate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
finetuner : Optional[Callable[[Module], None]]
|
||||
The finetuner handles all finetune logic, takes a pytorch module as input.
|
||||
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
|
||||
speedup : bool
|
||||
evaluator
|
||||
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
|
||||
{evaluator_docstring}
|
||||
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
speedup
|
||||
If set True, speedup the model at the end of each iteration to make the pruned model compact.
|
||||
dummy_input : Optional[torch.Tensor]
|
||||
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import nni
|
||||
>>> from nni.compression.pytorch.pruning import AutoCompressPruner
|
||||
>>> model = ...
|
||||
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
|
||||
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
|
||||
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
|
||||
>>> trainer = ...
|
||||
>>> criterion = ...
|
||||
>>> evaluator = ...
|
||||
>>> finetuner = ...
|
||||
>>> admm_params = {
|
||||
>>> 'trainer': trainer,
|
||||
>>> 'traced_optimizer': traced_optimizer,
|
||||
>>> 'criterion': criterion,
|
||||
>>> 'iterations': 10,
|
||||
>>> 'training_epochs': 1
|
||||
>>> }
|
||||
>>> sa_params = {
|
||||
>>> 'evaluator': evaluator
|
||||
>>> }
|
||||
>>> pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params, finetuner=finetuner)
|
||||
>>> pruner.compress()
|
||||
>>> _, model, masks, _, _ = pruner.get_best_result()
|
||||
|
||||
Notes
|
||||
-----
|
||||
The full script can be found :githublink:`here <examples/model_compress/pruning/auto_compress_pruner.py>`.
|
||||
"""
|
||||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
|
||||
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
|
||||
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
|
||||
dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
|
||||
...
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
|
||||
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
*args, **kwargs):
|
||||
new_api = ['evaluator', 'speedup']
|
||||
new_init_kwargs = {'evaluator': None, 'speedup': False}
|
||||
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator']
|
||||
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False}
|
||||
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
|
||||
|
||||
speedup = init_kwargs['speedup']
|
||||
|
||||
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
|
||||
origin_model=model,
|
||||
origin_config_list=config_list,
|
||||
|
@ -175,6 +171,10 @@ class AutoCompressPruner(IterativePruner):
|
|||
else:
|
||||
admm_params['granularity'] = 'fine-grained'
|
||||
|
||||
pruner = ADMMPruner(None, None, **admm_params)
|
||||
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=evaluator, reset_weight=False)
|
||||
pruner = ADMMPruner(None, None, **admm_params) # type: ignore
|
||||
|
||||
if self.using_evaluator:
|
||||
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
|
||||
else:
|
||||
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
|
||||
evaluator=self._evaluator, reset_weight=False) # type: ignore
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,8 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Tuple, Callable, Optional, Union
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional, Union, overload
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -12,9 +15,63 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
|
|||
from nni.compression.pytorch.speedup import ModelSpeedup
|
||||
|
||||
from .tools import TaskGenerator
|
||||
from ..utils import Evaluator, LightningEvaluator, TorchEvaluator
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
_LEGACY_FINETUNER = Callable[[Module], None]
|
||||
_LEGACY_EVALUATOR = Callable[[Module], float]
|
||||
|
||||
|
||||
class PruningScheduler(BasePruningScheduler):
|
||||
# TODO: remove in nni v3.0.
|
||||
class EvaluatorBasedPruningScheduler(BasePruningScheduler):
|
||||
evaluator: LightningEvaluator | TorchEvaluator
|
||||
using_evaluator: bool
|
||||
finetuner: _LEGACY_FINETUNER
|
||||
_evaluator: _LEGACY_EVALUATOR
|
||||
dummy_input: Any
|
||||
|
||||
def _init_evaluator(self, model: Module, new_api: List[str], new_init_kwargs: Dict, old_api: List[str],
|
||||
old_init_kwargs: Dict, args: Tuple, kwargs: Dict) -> Dict:
|
||||
# for fake __init__ overload, parsing args and kwargs,
|
||||
# initializing evaluator or [finetuner, evaluator, dummy_input], return the remaining arguments.
|
||||
if (len(args) > 0 and isinstance(args[0], Evaluator)) or \
|
||||
(len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
|
||||
init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs)
|
||||
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
|
||||
if not self.evaluator._initialization_complete:
|
||||
self.evaluator._init_optimizer_helpers(model) # type: ignore
|
||||
self.using_evaluator = True
|
||||
else:
|
||||
init_kwargs = self._parse_args(old_api, args, kwargs, old_init_kwargs)
|
||||
self.finetuner: _LEGACY_FINETUNER = init_kwargs.pop('finetuner')
|
||||
self._evaluator: _LEGACY_EVALUATOR = init_kwargs.pop('evaluator')
|
||||
self.dummy_input = init_kwargs.pop('dummy_input')
|
||||
self.using_evaluator = False
|
||||
warn_msg = f'The old API ...{",".join(old_api)} will be deprecated after NNI v3.0,' +\
|
||||
f'please using the new one ...{",".join(new_api)}'
|
||||
_logger.warning(warn_msg)
|
||||
return init_kwargs
|
||||
|
||||
def _parse_args(self, arg_names: List, args: Tuple, kwargs: Dict, def_kwargs: Dict) -> Dict:
|
||||
merged_kwargs = {arg_names[idx]: arg for idx, arg in enumerate(args)}
|
||||
for key, value in kwargs.items():
|
||||
if key in merged_kwargs:
|
||||
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
|
||||
merged_kwargs[key] = value
|
||||
for key, value in def_kwargs.items():
|
||||
if key not in merged_kwargs:
|
||||
merged_kwargs[key] = value
|
||||
diff = set(arg_names).difference(merged_kwargs.keys())
|
||||
if diff:
|
||||
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
|
||||
diff = set(merged_kwargs.keys()).difference(arg_names)
|
||||
if diff:
|
||||
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
|
||||
return merged_kwargs
|
||||
|
||||
|
||||
class PruningScheduler(EvaluatorBasedPruningScheduler):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -25,7 +82,8 @@ class PruningScheduler(BasePruningScheduler):
|
|||
Used to generate task for each iteration.
|
||||
finetuner
|
||||
The finetuner handled all finetune logic, use a pytorch module as input.
|
||||
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
|
||||
It will be called at the end of each iteration if reset_weight is False,
|
||||
will be called at the beginning of each iteration otherwise.
|
||||
speedup
|
||||
If set True, speedup the model at the end of each iteration to make the pruned model compact.
|
||||
dummy_input
|
||||
|
@ -36,16 +94,30 @@ class PruningScheduler(BasePruningScheduler):
|
|||
reset_weight
|
||||
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
|
||||
"""
|
||||
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
|
||||
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None,
|
||||
|
||||
@overload
|
||||
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: LightningEvaluator | TorchEvaluator,
|
||||
speedup: bool = False, reset_weight: bool = False):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: _LEGACY_FINETUNER | None = None,
|
||||
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: _LEGACY_EVALUATOR | None = None,
|
||||
reset_weight: bool = False):
|
||||
...
|
||||
|
||||
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, *args, **kwargs) -> None:
|
||||
# TODO: remove in nni v3.0. Fake overload.
|
||||
new_api = ['evaluator', 'speedup', 'reset_weight']
|
||||
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': False}
|
||||
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight']
|
||||
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'reset_weight': False}
|
||||
init_kwargs = self._init_evaluator(None, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs) # type: ignore
|
||||
|
||||
self.pruner = pruner
|
||||
self.task_generator = task_generator
|
||||
self.finetuner = finetuner
|
||||
self.speedup = speedup
|
||||
self.dummy_input = dummy_input
|
||||
self.evaluator = evaluator
|
||||
self.reset_weight = reset_weight
|
||||
self.speedup = init_kwargs['speedup']
|
||||
self.reset_weight = init_kwargs['reset_weight']
|
||||
|
||||
def reset(self, model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}):
|
||||
self.task_generator.reset(model, config_list, masks)
|
||||
|
@ -61,6 +133,7 @@ class PruningScheduler(BasePruningScheduler):
|
|||
generate masks -> speedup -> finetune -> evaluate
|
||||
"""
|
||||
model, masks, config_list = task.load_data()
|
||||
|
||||
self.pruner.reset(model, config_list)
|
||||
self.pruner.load_masks(masks)
|
||||
|
||||
|
@ -74,28 +147,58 @@ class PruningScheduler(BasePruningScheduler):
|
|||
|
||||
# speedup
|
||||
if self.speedup and task.speedup:
|
||||
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
|
||||
compact_model_masks = {}
|
||||
if self.using_evaluator:
|
||||
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
|
||||
compact_model_masks = {}
|
||||
else:
|
||||
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
|
||||
compact_model_masks = {}
|
||||
|
||||
# finetune
|
||||
if self.finetuner is not None and task.finetune:
|
||||
if self.speedup:
|
||||
self.finetuner(compact_model)
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
self.finetuner(compact_model)
|
||||
self.pruner._unwrap_model()
|
||||
if self.using_evaluator:
|
||||
if task.finetune:
|
||||
self.evaluator.bind_model(compact_model) # type: ignore
|
||||
if self.speedup:
|
||||
self.evaluator.finetune()
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
self.evaluator.finetune()
|
||||
self.pruner._unwrap_model()
|
||||
self.evaluator.unbind_model()
|
||||
else:
|
||||
if self.finetuner is not None and task.finetune:
|
||||
if self.speedup:
|
||||
self.finetuner(compact_model)
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
self.finetuner(compact_model)
|
||||
self.pruner._unwrap_model()
|
||||
|
||||
# evaluate
|
||||
if self.evaluator is not None and task.evaluate:
|
||||
if self.speedup:
|
||||
score = self.evaluator(compact_model)
|
||||
if self.using_evaluator:
|
||||
if task.evaluate:
|
||||
self.evaluator.bind_model(compact_model) # type: ignore
|
||||
# TODO: support saving customized score
|
||||
if self.speedup:
|
||||
score = self.evaluator.evaluate()
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self.evaluator.evaluate()
|
||||
self.pruner._unwrap_model()
|
||||
score = score[0] if isinstance(score, tuple) else score
|
||||
self.evaluator.unbind_model()
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self.evaluator(compact_model)
|
||||
self.pruner._unwrap_model()
|
||||
score = None
|
||||
else:
|
||||
score = None
|
||||
if self._evaluator is not None and task.evaluate:
|
||||
if self.speedup:
|
||||
score = self._evaluator(compact_model) # type: ignore
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self._evaluator(compact_model) # type: ignore
|
||||
self.pruner._unwrap_model()
|
||||
else:
|
||||
score = None
|
||||
|
||||
# clear model references
|
||||
self.pruner.clear_model_references()
|
||||
|
@ -107,13 +210,20 @@ class PruningScheduler(BasePruningScheduler):
|
|||
finetune -> generate masks -> reset weight -> speedup -> evaluate
|
||||
"""
|
||||
model, masks, config_list = task.load_data()
|
||||
|
||||
checkpoint = deepcopy(model.state_dict())
|
||||
self.pruner.reset(model, config_list)
|
||||
self.pruner.load_masks(masks)
|
||||
|
||||
# finetune
|
||||
if self.finetuner is not None and task.finetune:
|
||||
self.finetuner(model)
|
||||
if self.using_evaluator:
|
||||
if task.finetune:
|
||||
self.evaluator.bind_model(model) # type: ignore
|
||||
self.evaluator.finetune()
|
||||
self.evaluator.unbind_model()
|
||||
else:
|
||||
if self.finetuner is not None and task.finetune:
|
||||
self.finetuner(model)
|
||||
|
||||
# pruning model
|
||||
compact_model, pruner_generated_masks = self.pruner.compress()
|
||||
|
@ -128,19 +238,38 @@ class PruningScheduler(BasePruningScheduler):
|
|||
|
||||
# speedup
|
||||
if self.speedup and task.speedup:
|
||||
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
|
||||
compact_model_masks = {}
|
||||
if self.using_evaluator:
|
||||
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
|
||||
compact_model_masks = {}
|
||||
else:
|
||||
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
|
||||
compact_model_masks = {}
|
||||
|
||||
# evaluate
|
||||
if self.evaluator is not None and task.evaluate:
|
||||
if self.speedup:
|
||||
score = self.evaluator(compact_model)
|
||||
if self.using_evaluator:
|
||||
if task.evaluate:
|
||||
self.evaluator.bind_model(compact_model) # type: ignore
|
||||
# TODO: support saving customized score
|
||||
if self.speedup:
|
||||
score = self.evaluator.evaluate()
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self.evaluator.evaluate()
|
||||
self.pruner._unwrap_model()
|
||||
score = score[0] if isinstance(score, tuple) else score
|
||||
self.evaluator.unbind_model()
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self.evaluator(compact_model)
|
||||
self.pruner._unwrap_model()
|
||||
score = None
|
||||
else:
|
||||
score = None
|
||||
if self._evaluator is not None and task.evaluate:
|
||||
if self.speedup:
|
||||
score = self._evaluator(compact_model) # type: ignore
|
||||
else:
|
||||
self.pruner._wrap_model()
|
||||
score = self._evaluator(compact_model) # type: ignore
|
||||
self.pruner._unwrap_model()
|
||||
else:
|
||||
score = None
|
||||
|
||||
# clear model references
|
||||
self.pruner.clear_model_references()
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Callable, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, overload
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
|
||||
|
||||
from .basic_pruner import (
|
||||
LevelPruner,
|
||||
L1NormPruner,
|
||||
|
@ -21,13 +21,19 @@ from .basic_pruner import (
|
|||
TaylorFOWeightPruner,
|
||||
ADMMPruner
|
||||
)
|
||||
from .basic_scheduler import PruningScheduler
|
||||
from .basic_scheduler import PruningScheduler, _LEGACY_FINETUNER, _LEGACY_EVALUATOR
|
||||
from .tools import (
|
||||
LinearTaskGenerator,
|
||||
AGPTaskGenerator,
|
||||
LotteryTicketTaskGenerator,
|
||||
SimulatedAnnealingTaskGenerator
|
||||
)
|
||||
from ..utils import (
|
||||
OptimizerConstructHelper,
|
||||
LightningEvaluator,
|
||||
TorchEvaluator
|
||||
)
|
||||
from ..utils.docstring import _EVALUATOR_DOCSTRING
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -71,55 +77,67 @@ class IterativePruner(PruningScheduler):
|
|||
|
||||
|
||||
class LinearPruner(IterativePruner):
|
||||
r"""
|
||||
__doc__ = r"""
|
||||
Linear pruner is an iterative pruner, it will increase sparsity evenly from scratch during each iteration.
|
||||
|
||||
For example, the final sparsity is set as 0.5, and the iteration number is 5, then the sparsity used in each iteration are ``[0, 0.1, 0.2, 0.3, 0.4, 0.5]``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : Module
|
||||
model
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
config_list : List[Dict]
|
||||
config_list
|
||||
The origin config list provided by the user.
|
||||
pruning_algorithm : str
|
||||
pruning_algorithm
|
||||
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
|
||||
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
|
||||
total_iteration : int
|
||||
total_iteration
|
||||
The total iteration number.
|
||||
log_dir : str
|
||||
log_dir
|
||||
The log directory use to saving the result, you can find the best result under this folder.
|
||||
keep_intermediate_result : bool
|
||||
keep_intermediate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
finetuner : Optional[Callable[[Module], None]]
|
||||
The finetuner handled all finetune logic, use a pytorch module as input.
|
||||
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
|
||||
speedup : bool
|
||||
evaluator
|
||||
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
|
||||
{evaluator_docstring}
|
||||
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
speedup
|
||||
If set True, speedup the model at the end of each iteration to make the pruned model compact.
|
||||
dummy_input : Optional[torch.Tensor]
|
||||
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
|
||||
evaluator : Optional[Callable[[Module], float]]
|
||||
Evaluate the pruned model and give a score.
|
||||
If evaluator is None, the best result refers to the latest result.
|
||||
pruning_params : Dict
|
||||
pruning_params
|
||||
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from nni.compression.pytorch.pruning import LinearPruner
|
||||
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
|
||||
>>> finetuner = ...
|
||||
>>> pruner = LinearPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
|
||||
>>> pruner.compress()
|
||||
>>> _, model, masks, _, _ = pruner.get_best_result()
|
||||
|
||||
Notes
|
||||
-----
|
||||
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
|
||||
"""
|
||||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
|
||||
pruning_params: Dict = {}):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Any | None = None,
|
||||
evaluator: _LEGACY_EVALUATOR | None = None, pruning_params: Dict = {}):
|
||||
...
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
|
||||
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
|
||||
*args, **kwargs):
|
||||
new_api = ['evaluator', 'speedup', 'pruning_params']
|
||||
new_init_kwargs = {'evaluator': None, 'speedup': False, 'pruning_params': {}}
|
||||
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'pruning_params']
|
||||
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'pruning_params': {}}
|
||||
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
|
||||
|
||||
speedup = init_kwargs['speedup']
|
||||
pruning_params = init_kwargs['pruning_params']
|
||||
|
||||
task_generator = LinearTaskGenerator(total_iteration=total_iteration,
|
||||
origin_model=model,
|
||||
origin_config_list=config_list,
|
||||
|
@ -128,63 +146,80 @@ class LinearPruner(IterativePruner):
|
|||
if 'traced_optimizer' in pruning_params:
|
||||
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
|
||||
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
|
||||
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=evaluator, reset_weight=False)
|
||||
|
||||
if self.using_evaluator:
|
||||
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
|
||||
else:
|
||||
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
|
||||
evaluator=self._evaluator, reset_weight=False) # type: ignore
|
||||
|
||||
|
||||
class AGPPruner(IterativePruner):
|
||||
r"""
|
||||
__doc__ = r"""
|
||||
This is an iterative pruner, which the sparsity is increased from an initial sparsity value :math:`s_{i}` (usually 0) to a final sparsity value :math:`s_{f}` over a span of :math:`n` pruning iterations,
|
||||
starting at training step :math:`t_{0}` and with pruning frequency :math:`\Delta t`:
|
||||
|
||||
:math:`s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n \Delta t}\right)^{3} \text { for } t \in\left\{t_{0}, t_{0}+\Delta t, \ldots, t_{0} + n \Delta t\right\}`
|
||||
""" + r"""
|
||||
|
||||
For more details please refer to `To prune, or not to prune: exploring the efficacy of pruning for model compression <https://arxiv.org/abs/1710.01878>`__\.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : Module
|
||||
model
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
config_list : List[Dict]
|
||||
config_list
|
||||
The origin config list provided by the user.
|
||||
pruning_algorithm : str
|
||||
pruning_algorithm
|
||||
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
|
||||
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
|
||||
total_iteration : int
|
||||
total_iteration
|
||||
The total iteration number.
|
||||
log_dir : str
|
||||
log_dir
|
||||
The log directory use to saving the result, you can find the best result under this folder.
|
||||
keep_intermediate_result : bool
|
||||
keep_intermediate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
finetuner : Optional[Callable[[Module], None]]
|
||||
The finetuner handled all finetune logic, use a pytorch module as input.
|
||||
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
|
||||
speedup : bool
|
||||
evaluator
|
||||
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
|
||||
{evaluator_docstring}
|
||||
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
speedup
|
||||
If set True, speedup the model at the end of each iteration to make the pruned model compact.
|
||||
dummy_input : Optional[torch.Tensor]
|
||||
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
|
||||
evaluator : Optional[Callable[[Module], float]]
|
||||
Evaluate the pruned model and give a score.
|
||||
If evaluator is None, the best result refers to the latest result.
|
||||
pruning_params : Dict
|
||||
pruning_params
|
||||
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from nni.compression.pytorch.pruning import AGPPruner
|
||||
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
|
||||
>>> finetuner = ...
|
||||
>>> pruner = AGPPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
|
||||
>>> pruner.compress()
|
||||
>>> _, model, masks, _, _ = pruner.get_best_result()
|
||||
|
||||
Notes
|
||||
-----
|
||||
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
|
||||
"""
|
||||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
|
||||
pruning_params: Dict = {}):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Any | None = None,
|
||||
evaluator: _LEGACY_EVALUATOR | None = None, pruning_params: Dict = {}):
|
||||
...
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
|
||||
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
|
||||
*args, **kwargs):
|
||||
new_api = ['evaluator', 'speedup', 'pruning_params']
|
||||
new_init_kwargs = {'evaluator': None, 'speedup': False, 'pruning_params': {}}
|
||||
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'pruning_params']
|
||||
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'pruning_params': {}}
|
||||
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
|
||||
|
||||
speedup = init_kwargs['speedup']
|
||||
pruning_params = init_kwargs['pruning_params']
|
||||
|
||||
task_generator = AGPTaskGenerator(total_iteration=total_iteration,
|
||||
origin_model=model,
|
||||
origin_config_list=config_list,
|
||||
|
@ -193,12 +228,16 @@ class AGPPruner(IterativePruner):
|
|||
if 'traced_optimizer' in pruning_params:
|
||||
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
|
||||
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
|
||||
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=evaluator, reset_weight=False)
|
||||
|
||||
if self.using_evaluator:
|
||||
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
|
||||
else:
|
||||
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
|
||||
evaluator=self._evaluator, reset_weight=False) # type: ignore
|
||||
|
||||
|
||||
class LotteryTicketPruner(IterativePruner):
|
||||
r"""
|
||||
__doc__ = r"""
|
||||
`The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks <https://arxiv.org/abs/1803.03635>`__\ ,
|
||||
authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis,
|
||||
and articulate the *lottery ticket hypothesis*\ : dense, randomly-initialized, feed-forward networks contain subnetworks (*winning tickets*\ ) that
|
||||
|
@ -216,55 +255,69 @@ class LotteryTicketPruner(IterativePruner):
|
|||
|
||||
If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning,
|
||||
each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.
|
||||
""" + r"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : Module
|
||||
model
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
config_list : List[Dict]
|
||||
config_list
|
||||
The origin config list provided by the user.
|
||||
pruning_algorithm : str
|
||||
pruning_algorithm
|
||||
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
|
||||
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
|
||||
total_iteration : int
|
||||
total_iteration
|
||||
The total iteration number.
|
||||
log_dir : str
|
||||
log_dir
|
||||
The log directory use to saving the result, you can find the best result under this folder.
|
||||
keep_intermediate_result : bool
|
||||
keep_intermediate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
finetuner : Optional[Callable[[Module], None]]
|
||||
The finetuner handled all finetune logic, use a pytorch module as input.
|
||||
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
|
||||
speedup : bool
|
||||
evaluator
|
||||
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
|
||||
{evaluator_docstring}
|
||||
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
speedup
|
||||
If set True, speedup the model at the end of each iteration to make the pruned model compact.
|
||||
dummy_input : Optional[torch.Tensor]
|
||||
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
|
||||
evaluator : Optional[Callable[[Module], float]]
|
||||
Evaluate the pruned model and give a score.
|
||||
If evaluator is None, the best result refers to the latest result.
|
||||
reset_weight : bool
|
||||
reset_weight
|
||||
If set True, the model weight will reset to the original model weight at the end of each iteration step.
|
||||
pruning_params : Dict
|
||||
pruning_params
|
||||
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from nni.compression.pytorch.pruning import LotteryTicketPruner
|
||||
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
|
||||
>>> finetuner = ...
|
||||
>>> pruner = LotteryTicketPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner, reset_weight=True)
|
||||
>>> pruner.compress()
|
||||
>>> _, model, masks, _, _ = pruner.get_best_result()
|
||||
|
||||
Notes
|
||||
-----
|
||||
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
|
||||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
"""
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
|
||||
reset_weight: bool = True, pruning_params: Dict = {}):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
|
||||
evaluator: _LEGACY_EVALUATOR | None = None, reset_weight: bool = True,
|
||||
pruning_params: Dict = {}):
|
||||
...
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
|
||||
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
|
||||
evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True,
|
||||
pruning_params: Dict = {}):
|
||||
*args, **kwargs):
|
||||
new_api = ['evaluator', 'speedup', 'reset_weight', 'pruning_params']
|
||||
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': True, 'pruning_params': {}}
|
||||
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight', 'pruning_params']
|
||||
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False,
|
||||
'reset_weight': True, 'pruning_params': {}}
|
||||
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
|
||||
|
||||
speedup = init_kwargs['speedup']
|
||||
reset_weight = init_kwargs['reset_weight']
|
||||
pruning_params = init_kwargs['pruning_params']
|
||||
|
||||
task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration,
|
||||
origin_model=model,
|
||||
origin_config_list=config_list,
|
||||
|
@ -273,12 +326,16 @@ class LotteryTicketPruner(IterativePruner):
|
|||
if 'traced_optimizer' in pruning_params:
|
||||
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
|
||||
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
|
||||
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=evaluator, reset_weight=reset_weight)
|
||||
|
||||
if self.using_evaluator:
|
||||
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=reset_weight)
|
||||
else:
|
||||
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
|
||||
evaluator=self._evaluator, reset_weight=reset_weight) # type: ignore
|
||||
|
||||
|
||||
class SimulatedAnnealingPruner(IterativePruner):
|
||||
"""
|
||||
__doc__ = r"""
|
||||
We implement a guided heuristic search method, Simulated Annealing (SA) algorithm. As mentioned in the paper, this method is enhanced on guided search based on prior experience.
|
||||
The enhanced SA technique is based on the observation that a DNN layer with more number of weights often has a higher degree of model compression with less impact on overall accuracy.
|
||||
|
||||
|
@ -294,54 +351,81 @@ class SimulatedAnnealingPruner(IterativePruner):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
model : Optional[Module]
|
||||
model
|
||||
The origin unwrapped pytorch model to be pruned.
|
||||
config_list : Optional[List[Dict]]
|
||||
config_list
|
||||
The origin config list provided by the user.
|
||||
evaluator : Callable[[Module], float]
|
||||
Evaluate the pruned model and give a score.
|
||||
start_temperature : float
|
||||
evaluator
|
||||
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
|
||||
{evaluator_docstring}
|
||||
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
start_temperature
|
||||
Start temperature of the simulated annealing process.
|
||||
stop_temperature : float
|
||||
stop_temperature
|
||||
Stop temperature of the simulated annealing process.
|
||||
cool_down_rate : float
|
||||
cool_down_rate
|
||||
Cool down rate of the temperature.
|
||||
perturbation_magnitude : float
|
||||
perturbation_magnitude
|
||||
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
|
||||
pruning_algorithm : str
|
||||
pruning_algorithm
|
||||
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
|
||||
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
|
||||
pruning_params : Dict
|
||||
pruning_params
|
||||
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
|
||||
log_dir : Union[str, Path]
|
||||
log_dir
|
||||
The log directory use to saving the result, you can find the best result under this folder.
|
||||
keep_intermediate_result : bool
|
||||
keep_intermediate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
finetuner : Optional[Callable[[Module], None]]
|
||||
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
|
||||
speedup : bool
|
||||
speedup
|
||||
If set True, speedup the model at the end of each iteration to make the pruned model compact.
|
||||
dummy_input : Optional[torch.Tensor]
|
||||
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from nni.compression.pytorch.pruning import SimulatedAnnealingPruner
|
||||
>>> model = ...
|
||||
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
|
||||
>>> evaluator = ...
|
||||
>>> finetuner = ...
|
||||
>>> pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm='l1', evaluator=evaluator, cool_down_rate=0.9, finetuner=finetuner)
|
||||
>>> pruner.compress()
|
||||
>>> _, model, masks, _, _ = pruner.get_best_result()
|
||||
|
||||
Notes
|
||||
-----
|
||||
For detailed example please refer to :githublink:`examples/model_compress/pruning/simulated_anealing_pruning_torch.py <examples/model_compress/pruning/simulated_anealing_pruning_torch.py>`
|
||||
"""
|
||||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
|
||||
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
|
||||
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
|
||||
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False, speedup: bool = False):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], evaluator: _LEGACY_EVALUATOR,
|
||||
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
|
||||
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
|
||||
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False,
|
||||
dummy_input: Optional[Tensor] = None):
|
||||
...
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
|
||||
new_api = ['evaluator', 'start_temperature', 'stop_temperature', 'cool_down_rate', 'perturbation_magnitude',
|
||||
'pruning_algorithm', 'pruning_params', 'log_dir', 'keep_intermediate_result', 'speedup']
|
||||
new_init_kwargs = {'start_temperature': 100, 'stop_temperature': 20, 'cool_down_rate': 0.9,
|
||||
'perturbation_magnitude': 0.35, 'pruning_algorithm': 'level', 'pruning_params': {},
|
||||
'log_dir': '.', 'keep_intermediate_result': False, 'speedup': False}
|
||||
old_api = ['evaluator', 'start_temperature', 'stop_temperature', 'cool_down_rate', 'perturbation_magnitude',
|
||||
'pruning_algorithm', 'pruning_params', 'log_dir', 'keep_intermediate_result', 'finetuner',
|
||||
'speedup', 'dummy_input']
|
||||
old_init_kwargs = {'start_temperature': 100, 'stop_temperature': 20, 'cool_down_rate': 0.9,
|
||||
'perturbation_magnitude': 0.35, 'pruning_algorithm': 'level', 'pruning_params': {},
|
||||
'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None, 'speedup': False,
|
||||
'dummy_input': None}
|
||||
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
|
||||
|
||||
start_temperature = init_kwargs['start_temperature']
|
||||
stop_temperature = init_kwargs['stop_temperature']
|
||||
cool_down_rate = init_kwargs['cool_down_rate']
|
||||
perturbation_magnitude = init_kwargs['perturbation_magnitude']
|
||||
pruning_algorithm = init_kwargs['pruning_algorithm']
|
||||
pruning_params = init_kwargs['pruning_params']
|
||||
log_dir = init_kwargs['log_dir']
|
||||
keep_intermediate_result = init_kwargs['keep_intermediate_result']
|
||||
speedup = init_kwargs['speedup']
|
||||
|
||||
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], evaluator: Callable[[Module], float], start_temperature: float = 100,
|
||||
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
|
||||
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
|
||||
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None):
|
||||
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
|
||||
origin_config_list=config_list,
|
||||
start_temperature=start_temperature,
|
||||
|
@ -351,7 +435,12 @@ class SimulatedAnnealingPruner(IterativePruner):
|
|||
log_dir=log_dir,
|
||||
keep_intermediate_result=keep_intermediate_result)
|
||||
if 'traced_optimizer' in pruning_params:
|
||||
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
|
||||
pruning_params['traced_optimizer'] = \
|
||||
OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
|
||||
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
|
||||
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=evaluator, reset_weight=False)
|
||||
|
||||
if self.using_evaluator:
|
||||
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
|
||||
else:
|
||||
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup,
|
||||
dummy_input=self.dummy_input, evaluator=self._evaluator, reset_weight=False) # type: ignore
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Callable
|
||||
from typing import Dict, List, Tuple, Callable, overload
|
||||
|
||||
import torch
|
||||
from torch import autograd, Tensor
|
||||
|
@ -12,17 +14,23 @@ from torch.nn.parameter import Parameter
|
|||
from torch.optim import Optimizer, Adam
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
|
||||
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
|
||||
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper
|
||||
from nni.common.serializer import Traceable
|
||||
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import EvaluatorBasedPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
|
||||
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema
|
||||
|
||||
from .tools.base import TrainerBasedDataCollector
|
||||
from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
|
||||
|
||||
from .tools import (
|
||||
StraightMetricsCalculator,
|
||||
NormalSparsityAllocator
|
||||
NormalSparsityAllocator,
|
||||
StraightMetricsCalculator
|
||||
)
|
||||
|
||||
from ..utils import (
|
||||
LightningEvaluator,
|
||||
TorchEvaluator
|
||||
)
|
||||
|
||||
from ..utils.docstring import _EVALUATOR_DOCSTRING
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -47,8 +55,7 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
|
|||
|
||||
def forward(self, *inputs):
|
||||
# apply mask to weight, bias
|
||||
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()`
|
||||
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
|
||||
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # type: ignore
|
||||
if hasattr(self.module, 'bias') and self.module.bias is not None:
|
||||
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
|
||||
return self.module(*inputs)
|
||||
|
@ -77,13 +84,30 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
|
|||
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
|
||||
|
||||
data = {}
|
||||
target_name = 'weight'
|
||||
for _, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
data[wrapper.name] = wrapper.weight_score.data # type: ignore
|
||||
data[wrapper.name] = {target_name: wrapper.weight_score.data} # type: ignore
|
||||
return data
|
||||
|
||||
|
||||
class MovementPruner(BasicPruner):
|
||||
r"""
|
||||
class EvaluatorBasedScoreDataCollector(EvaluatorBasedDataCollector):
|
||||
"""
|
||||
Collect all weight_score in wrappers as data used to calculate metrics.
|
||||
"""
|
||||
def collect(self) -> Dict[str, Tensor]:
|
||||
assert self.compressor.bound_model is not None
|
||||
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
|
||||
|
||||
data = {}
|
||||
target_name = 'weight'
|
||||
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
target_score: Tensor = getattr(wrapper, f'{target_name}_score')
|
||||
data[module_name] = {target_name: target_score.data.clone()}
|
||||
return data
|
||||
|
||||
|
||||
class MovementPruner(EvaluatorBasedPruner):
|
||||
__doc__ = r"""
|
||||
Movement pruner is an implementation of movement pruning.
|
||||
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
|
||||
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
|
||||
|
@ -110,30 +134,12 @@ class MovementPruner(BasicPruner):
|
|||
- op_names : Operation names to be pruned.
|
||||
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
|
||||
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
|
||||
trainer : Callable[[Module, Optimizer, Callable]
|
||||
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
|
||||
The model will be trained or inferenced `training_epochs` epochs.
|
||||
|
||||
Example::
|
||||
|
||||
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
|
||||
training = model.training
|
||||
model.train(mode=True)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
|
||||
optimizer.step()
|
||||
model.train(mode=training)
|
||||
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
|
||||
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
|
||||
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
|
||||
criterion : Callable[[Tensor, Tensor], Tensor]
|
||||
The criterion function used in trainer. Take model output and target value as input, and return the loss.
|
||||
evaluator
|
||||
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
|
||||
{evaluator_docstring}
|
||||
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
|
||||
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
|
||||
training_epochs : int
|
||||
The total epoch number for training the model.
|
||||
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
|
||||
|
@ -145,33 +151,31 @@ class MovementPruner(BasicPruner):
|
|||
The sparsity after each `optimizer.step()` is:
|
||||
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import nni
|
||||
>>> from nni.compression.pytorch.pruning import MovementPruner
|
||||
>>> model = ...
|
||||
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
|
||||
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
|
||||
>>> trainer = ...
|
||||
>>> criterion = ...
|
||||
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
|
||||
>>> pruner = MovementPruner(model, config_list, trainer, traced_optimizer, criterion, 10, 3000, 27000)
|
||||
>>> masked_model, masks = pruner.compress()
|
||||
|
||||
Notes
|
||||
-----
|
||||
For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>`
|
||||
"""
|
||||
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_epochs: int,
|
||||
warm_up_step: int, cool_down_beginning_step: int):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
|
||||
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
|
||||
cool_down_beginning_step: int):
|
||||
self.trainer = trainer
|
||||
if isinstance(traced_optimizer, OptimizerConstructHelper):
|
||||
self.optimizer_helper = traced_optimizer
|
||||
else:
|
||||
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
|
||||
self.criterion = criterion
|
||||
self.training_epochs = training_epochs
|
||||
self.warm_up_step = warm_up_step
|
||||
self.cool_down_beginning_step = cool_down_beginning_step
|
||||
...
|
||||
|
||||
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
|
||||
# TODO: remove in nni v3.0. Fake overload.
|
||||
new_api = ['evaluator', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
|
||||
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
|
||||
init_kwargs = self._init_evaluator(model, new_api, old_api, {}, args, kwargs)
|
||||
|
||||
self.training_epochs: int = init_kwargs['training_epochs']
|
||||
self.warm_up_step: int = init_kwargs['warm_up_step']
|
||||
self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step']
|
||||
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
|
||||
super().__init__(model, config_list)
|
||||
|
||||
|
@ -184,14 +188,16 @@ class MovementPruner(BasicPruner):
|
|||
if self.warm_up_step < current_step <= self.cool_down_beginning_step:
|
||||
wrapper_dict = self.get_modules_wrapper()
|
||||
for config in self.config_list:
|
||||
current_sparsity = config['total_sparsity'] * (1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3)
|
||||
scale = 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
|
||||
current_sparsity = config['total_sparsity'] * scale
|
||||
for op_name in config['op_names']:
|
||||
wrapper_dict[op_name].config['total_sparsity'] = current_sparsity
|
||||
wrapper = wrapper_dict[op_name]
|
||||
wrapper.config['total_sparsity'] = current_sparsity
|
||||
|
||||
def reset_tools(self):
|
||||
if self.metrics_calculator is None:
|
||||
if not hasattr(self, 'metrics_calculator'):
|
||||
self.metrics_calculator = StraightMetricsCalculator()
|
||||
if self.sparsity_allocator is None:
|
||||
if not hasattr(self, 'sparsity_allocator'):
|
||||
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
|
||||
|
||||
# use Adam to update the weight_score
|
||||
|
@ -208,16 +214,30 @@ class MovementPruner(BasicPruner):
|
|||
if self.step_counter > self.warm_up_step:
|
||||
self.cubic_schedule(self.step_counter)
|
||||
data = {}
|
||||
target_name = 'weight'
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
data[wrapper.name] = wrapper.weight_score.data
|
||||
data[wrapper.name] = {target_name: wrapper.weight_score.data}
|
||||
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
|
||||
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
|
||||
self.load_masks(masks)
|
||||
|
||||
if self.data_collector is None:
|
||||
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch])
|
||||
if self.using_evaluator:
|
||||
# TODO: move to other place in nni v3.0
|
||||
self.evaluator.unbind_model()
|
||||
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
|
||||
if not hasattr(self, 'data_collector'):
|
||||
self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator,
|
||||
after_opt_step_tasks=[_optimizer_patch],
|
||||
max_epochs=self.training_epochs)
|
||||
else:
|
||||
self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch])
|
||||
else:
|
||||
self.data_collector.reset()
|
||||
if not hasattr(self, 'data_collector'):
|
||||
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper,
|
||||
self.criterion, self.training_epochs,
|
||||
opt_after_tasks=[_optimizer_patch])
|
||||
else:
|
||||
self.data_collector.reset()
|
||||
|
||||
def _wrap_modules(self, layer: LayerInfo, config: Dict):
|
||||
"""
|
||||
|
@ -243,7 +263,6 @@ class MovementPruner(BasicPruner):
|
|||
for wrapper in self.get_modules_wrapper().values():
|
||||
wrapper.config['total_sparsity'] = 0
|
||||
result = super().compress()
|
||||
# del weight_score
|
||||
for wrapper in self.get_modules_wrapper().values():
|
||||
wrapper.weight_score = None
|
||||
if self.using_evaluator:
|
||||
self.evaluator.unbind_model()
|
||||
return result
|
||||
|
|
|
@ -8,6 +8,12 @@ from .base import (
|
|||
SparsityAllocator,
|
||||
TaskGenerator
|
||||
)
|
||||
from .data_collector import (
|
||||
TargetDataCollector,
|
||||
EvaluatorBasedTargetDataCollector,
|
||||
EvaluatorBasedHookDataCollector
|
||||
)
|
||||
# TODO: remove in nni v3.0.
|
||||
from .data_collector import (
|
||||
WeightDataCollector,
|
||||
WeightTrainerBasedDataCollector,
|
||||
|
@ -16,7 +22,7 @@ from .data_collector import (
|
|||
from .metrics_calculator import (
|
||||
StraightMetricsCalculator,
|
||||
NormMetricsCalculator,
|
||||
MultiDataNormMetricsCalculator,
|
||||
HookDataNormMetricsCalculator,
|
||||
DistMetricsCalculator,
|
||||
APoZRankMetricsCalculator,
|
||||
MeanRankMetricsCalculator
|
||||
|
|
|
@ -6,7 +6,7 @@ from datetime import datetime
|
|||
import logging
|
||||
from pathlib import Path
|
||||
import types
|
||||
from typing import List, Dict, Tuple, Optional, Callable, Union
|
||||
from typing import List, Dict, Literal, Tuple, Optional, Callable, Union
|
||||
|
||||
import json_tricks
|
||||
import torch
|
||||
|
@ -15,7 +15,7 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from ...base import Pruner, LayerInfo, Task, TaskResult
|
||||
from ...utils import OptimizerConstructHelper, Scaling
|
||||
from ...utils import Evaluator, Hook, OptimizerConstructHelper, Scaling
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,7 +45,7 @@ class DataCollector:
|
|||
def __init__(self, compressor: Pruner):
|
||||
self.compressor = compressor
|
||||
|
||||
def reset(self):
|
||||
def reset(self, *args, **kwargs):
|
||||
"""
|
||||
Reset the `DataCollector`.
|
||||
"""
|
||||
|
@ -63,9 +63,12 @@ class DataCollector:
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
# TODO: remove in nni v3.0.
|
||||
COLLECTOR_TYPE = Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]
|
||||
|
||||
class HookCollectorInfo:
|
||||
def __init__(self, targets: Union[Dict[str, Tensor], List[LayerInfo]], hook_type: str,
|
||||
collector: Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]):
|
||||
collector: COLLECTOR_TYPE):
|
||||
"""
|
||||
This class used to aggregate the information of what kind of hook is placed on which layers.
|
||||
|
||||
|
@ -76,23 +79,24 @@ class HookCollectorInfo:
|
|||
hook_type
|
||||
'forward' or 'backward'.
|
||||
collector
|
||||
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor, the output is a hook function.
|
||||
The buffer is used to store the data wanted to hook.
|
||||
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor,
|
||||
the output is a hook function. The buffer is used to store the data wanted to hook.
|
||||
"""
|
||||
self.targets = targets
|
||||
self.hook_type = hook_type
|
||||
self.collector = collector
|
||||
|
||||
|
||||
# TODO: remove in nni v3.0.
|
||||
class TrainerBasedDataCollector(DataCollector):
|
||||
"""
|
||||
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
|
||||
"""
|
||||
|
||||
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
|
||||
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
|
||||
opt_before_tasks: List = [], opt_after_tasks: List = [],
|
||||
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None):
|
||||
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None],
|
||||
optimizer_helper: OptimizerConstructHelper, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
|
||||
opt_before_tasks: List = [], opt_after_tasks: List = [], collector_infos: List[HookCollectorInfo] = [],
|
||||
criterion_patch: Optional[Callable[[Callable], Callable]] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -252,6 +256,47 @@ class TrainerBasedDataCollector(DataCollector):
|
|||
self._remove_hook(hook_id)
|
||||
|
||||
|
||||
class EvaluatorBasedDataCollector(DataCollector):
|
||||
"""
|
||||
This data collector is the base class for the data collectors that want to use ``Evaluator`` to train or inference.
|
||||
Three main usages are supported in this data collector:
|
||||
|
||||
1. Doing something before ``optimzer.step()`` and after ``optimzer.step()``. ``before_opt_step_tasks`` is a list of task functions
|
||||
that will execute before ``optimzer.step()``. ``after_opt_step_tasks`` is a list of task functions that will execute after
|
||||
``optimzer.step()``. All the task functions in the list should not have input arguments, function return value is allowed,
|
||||
but ``Evaluator`` will not catch it.
|
||||
2. Patch or modify the training loss. ``loss_patch`` is a function with input is the original loss and the output is the modified loss.
|
||||
3. Add hooks on ``torch.nn.Module`` or ``Parameter`` or ``Buffer``. Three kinds of hook are supported, ``TensorHook``, ``ForwardHook``
|
||||
and ``BackwardHook``. For initializing a ``Hook``, a hook function factory is needed, the factory function's input is an empty list,
|
||||
and the output is a hook function defined by Pytorch.
|
||||
Please refer `register_hook <https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html>`_,
|
||||
`register_forward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_,
|
||||
`register_backward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_backward_hook>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, compressor: Pruner, evaluator: Evaluator, before_opt_step_tasks: List[Callable] | None = None,
|
||||
after_opt_step_tasks: List[Callable] | None = None, loss_patch: Callable[[Tensor], Tensor] | None = None,
|
||||
hooks: Dict[str, Dict[str, Hook]] | None = None, max_steps: int | None = None, max_epochs: int | None = None):
|
||||
super().__init__(compressor)
|
||||
self.evaluator = evaluator
|
||||
self.max_steps = max_steps
|
||||
self.max_epochs = max_epochs
|
||||
self.reset(before_opt_step_tasks, after_opt_step_tasks, loss_patch, hooks)
|
||||
|
||||
def reset(self, before_opt_step_tasks: List[Callable] | None = None, after_opt_step_tasks: List[Callable] | None = None,
|
||||
loss_patch: Callable[[Tensor], Tensor] | None = None, hooks: Dict[str, Dict[str, Hook]] | None = None):
|
||||
if before_opt_step_tasks or after_opt_step_tasks:
|
||||
before_opt_step_tasks = before_opt_step_tasks if before_opt_step_tasks else []
|
||||
after_opt_step_tasks = after_opt_step_tasks if after_opt_step_tasks else []
|
||||
self.evaluator.patch_optimizer_step(before_opt_step_tasks, after_opt_step_tasks)
|
||||
if loss_patch:
|
||||
self.evaluator.patch_loss(loss_patch)
|
||||
if hooks:
|
||||
self._hooks = hooks
|
||||
hook_list = [hook for _ in hooks.values() for hook in _.values()]
|
||||
self.evaluator.register_hooks(hook_list)
|
||||
|
||||
|
||||
class MetricsCalculator:
|
||||
"""
|
||||
An abstract class for calculate a kind of metrics of the given data.
|
||||
|
@ -260,7 +305,8 @@ class MetricsCalculator:
|
|||
----------
|
||||
scalers
|
||||
Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator.
|
||||
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
|
||||
If you want to use different scalers for different pruning targets in different modules,
|
||||
please use a dict `{module_name: {target_name: scaler}}`.
|
||||
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
|
||||
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
|
||||
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
|
||||
|
@ -268,7 +314,8 @@ class MetricsCalculator:
|
|||
"""
|
||||
|
||||
def __init__(self, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
|
||||
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
|
||||
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
|
||||
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
|
||||
|
||||
def _get_scaler(self, module_name: str, target_name: str) -> Scaling:
|
||||
scaler = _get_scaler(self.scalers, module_name, target_name)
|
||||
|
@ -301,7 +348,8 @@ class SparsityAllocator:
|
|||
scalers
|
||||
Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric,
|
||||
or expands the mask of the same size as the metric to the same size as the pruning target.
|
||||
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
|
||||
If you want to use different scalers for different pruning targets in different modules,
|
||||
please use a dict `{module_name: {target_name: scaler}}`.
|
||||
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
|
||||
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
|
||||
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
|
||||
|
@ -313,7 +361,8 @@ class SparsityAllocator:
|
|||
|
||||
def __init__(self, pruner: Pruner, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None, continuous_mask: bool = True):
|
||||
self.pruner = pruner
|
||||
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
|
||||
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
|
||||
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
|
||||
self.continuous_mask = continuous_mask
|
||||
|
||||
def _get_scaler(self, module_name: str, target_name: str) -> Scaling | None:
|
||||
|
@ -335,25 +384,39 @@ class SparsityAllocator:
|
|||
mask = (scaler.shrink(mask) != 0).type_as(mask)
|
||||
return mask
|
||||
|
||||
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def _mask_metric(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
# Set the already masked part in the metric to the minimum value.
|
||||
target_name = 'weight'
|
||||
for module_name, targets_metric in metrics.items():
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
|
||||
shrinked_target_mask = self._shrink_mask(module_name, target_name, old_target_mask)
|
||||
# make sure the masked position has the minimum metric
|
||||
targets_metric[target_name] = targets_metric[target_name].to(shrinked_target_mask.device)
|
||||
min_value = targets_metric[target_name].min() - 1
|
||||
targets_metric[target_name] = torch.where(shrinked_target_mask != 0, targets_metric[target_name], min_value)
|
||||
return metrics
|
||||
|
||||
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
# Set the already masked part to zero in the new_masks.
|
||||
target_name = 'weight'
|
||||
for module_name, target_mask in new_masks.items():
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
old_target_mask = getattr(wrapper, f'{target_name}_mask', None)
|
||||
old_target_mask: Tensor | None = getattr(wrapper, f'{target_name}_mask', None)
|
||||
if old_target_mask is not None:
|
||||
new_masks[module_name][target_name] = torch.min(target_mask[target_name], old_target_mask)
|
||||
new_masks[module_name][target_name] = torch.min(target_mask[target_name],
|
||||
old_target_mask.to(target_mask[target_name].device))
|
||||
return new_masks
|
||||
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
"""
|
||||
Generate masks for metrics-dependent targets.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metrics
|
||||
The format is {module_name: weight_metric}.
|
||||
The metric of `weight` usually has the same size with shrinked mask.
|
||||
The format is {module_name: {target_name: target_metric}}.
|
||||
The metric of usually has the same size with shrinked mask.
|
||||
|
||||
Return
|
||||
------
|
||||
|
@ -384,7 +447,7 @@ class SparsityAllocator:
|
|||
reduce_dims = [reduce_dim for reduce_dim in range(1, len(weight_mask.shape))]
|
||||
# count unmasked number of values on dim 0 (output channel) of weight
|
||||
unmasked_num_on_dim0 = weight_mask.sum(reduce_dims) if reduce_dims else weight_mask
|
||||
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(old_bias_mask)
|
||||
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(weight_mask)
|
||||
return masks
|
||||
|
||||
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
|
||||
|
@ -401,6 +464,8 @@ class SparsityAllocator:
|
|||
Dict[str, Dict[str, Tensor]]
|
||||
The masks format is {module_name: {target_name: mask}}.
|
||||
"""
|
||||
if self.continuous_mask:
|
||||
metrics = self._mask_metric(metrics)
|
||||
masks = self.common_target_masks_generation(metrics)
|
||||
masks = self.special_target_masks_generation(masks)
|
||||
if self.continuous_mask:
|
||||
|
@ -425,11 +490,22 @@ class TaskGenerator:
|
|||
The log directory use to saving the task generator log.
|
||||
keep_intermediate_result
|
||||
If keeping the intermediate result, including intermediate model and masks during each iteration.
|
||||
best_result_mode
|
||||
The way to decide which one is the best result. Three modes are supported.
|
||||
If the task results don't contain scores (task_result.score is None), it will fall back to ``latest``.
|
||||
|
||||
1. latest: The newest received result is the best result.
|
||||
2. maximize: The one with largest task result score is the best result.
|
||||
3. minimize: The one with smallest task result score is the best result.
|
||||
"""
|
||||
|
||||
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
|
||||
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
|
||||
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
|
||||
best_result_mode: Literal['latest', 'maximize', 'minimize'] = 'maximize'):
|
||||
self._log_dir = log_dir
|
||||
self._keep_intermediate_result = keep_intermediate_result
|
||||
assert best_result_mode in ['latest', 'maximize', 'minimize'], f'Unsupported best_result_mode value: {best_result_mode}'
|
||||
self._best_result_mode = best_result_mode
|
||||
|
||||
if origin_model is not None and origin_config_list is not None and origin_masks is not None:
|
||||
self.reset(origin_model, origin_config_list, origin_masks)
|
||||
|
@ -472,13 +548,24 @@ class TaskGenerator:
|
|||
json_tricks.dump(config_list, f, indent=4)
|
||||
|
||||
def update_best_result(self, task_result: TaskResult):
|
||||
score = task_result.score
|
||||
task_id = task_result.task_id
|
||||
task = self._tasks[task_id]
|
||||
task.score = score
|
||||
if self._best_score is None or (score is not None and score > self._best_score):
|
||||
self._best_score = score
|
||||
self._best_task_id = task_id
|
||||
save_as_best_result = False
|
||||
task = self._tasks[task_result.task_id]
|
||||
task.score = task_result.score
|
||||
|
||||
if self._best_result_mode == 'latest':
|
||||
self._best_task_id, save_as_best_result = task_result.task_id, True
|
||||
|
||||
if self._best_result_mode == 'maximize':
|
||||
if self._best_score is None or (task.score is not None and task.score > self._best_score):
|
||||
self._best_score = task.score
|
||||
self._best_task_id, save_as_best_result = task_result.task_id, True
|
||||
|
||||
if self._best_result_mode == 'minimize':
|
||||
if self._best_score is None or (task.score is not None and task.score < self._best_score):
|
||||
self._best_score = task.score
|
||||
self._best_task_id, save_as_best_result = task_result.task_id, True
|
||||
|
||||
if save_as_best_result:
|
||||
with Path(task.config_list_path).open('r') as fr:
|
||||
best_config_list = json_tricks.load(fr)
|
||||
self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)
|
||||
|
|
|
@ -6,13 +6,16 @@ from typing import Dict, List
|
|||
|
||||
from torch import Tensor
|
||||
|
||||
from .base import DataCollector, TrainerBasedDataCollector
|
||||
from .base import DataCollector, EvaluatorBasedDataCollector
|
||||
from .base import TrainerBasedDataCollector
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector']
|
||||
__all__ = ['TargetDataCollector', 'EvaluatorBasedTargetDataCollector', 'EvaluatorBasedHookDataCollector',
|
||||
'WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector'] # TODO: remove in nni v3.0.
|
||||
|
||||
|
||||
# TODO: remove in nni v3.0.
|
||||
class WeightDataCollector(DataCollector):
|
||||
"""
|
||||
Collect all wrapper weights.
|
||||
|
@ -21,40 +24,102 @@ class WeightDataCollector(DataCollector):
|
|||
def reset(self):
|
||||
pass
|
||||
|
||||
def collect(self) -> Dict[str, Tensor]:
|
||||
def collect(self) -> Dict[str, Dict[str, Tensor]]:
|
||||
data = {}
|
||||
for _, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
data[wrapper.name] = wrapper.weight.data
|
||||
target_name = 'weight'
|
||||
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
target: Tensor = getattr(wrapper, target_name)
|
||||
data[module_name] = {target_name: target.data.clone()}
|
||||
return data
|
||||
|
||||
|
||||
# TODO: remove in nni v3.0.
|
||||
class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
|
||||
"""
|
||||
Collect all wrapper weights after training or inference.
|
||||
"""
|
||||
|
||||
def collect(self) -> Dict[str, Tensor]:
|
||||
def collect(self) -> Dict[str, Dict[str, Tensor]]:
|
||||
assert self.compressor.bound_model is not None
|
||||
for _ in range(self.training_epochs):
|
||||
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
|
||||
|
||||
data = {}
|
||||
for _, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
data[wrapper.name] = wrapper.weight.data
|
||||
target_name = 'weight'
|
||||
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
target: Tensor = getattr(wrapper, target_name)
|
||||
data[module_name] = {target_name: target.data.clone()}
|
||||
return data
|
||||
|
||||
|
||||
# TODO: remove in nni v3.0.
|
||||
class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
|
||||
"""
|
||||
Add hooks and collect data during training or inference.
|
||||
Single means each wrapper only has one hook to collect data.
|
||||
"""
|
||||
|
||||
def collect(self) -> Dict[str, List[Tensor]]:
|
||||
def collect(self) -> Dict[str, Dict[str, List[Tensor]]]:
|
||||
assert self.compressor.bound_model is not None
|
||||
for _ in range(self.training_epochs):
|
||||
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
|
||||
|
||||
data = {}
|
||||
[data.update(buffer_dict) for _, buffer_dict in self._hook_buffer.items()]
|
||||
target_name = 'weight'
|
||||
for _, buffer_dict in self._hook_buffer.items():
|
||||
for module_name, target_data in buffer_dict.items():
|
||||
data[module_name] = {target_name: target_data}
|
||||
return data
|
||||
|
||||
|
||||
class TargetDataCollector(DataCollector):
|
||||
"""
|
||||
Collect all wrapper targets.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
# No need to reset anything in this data collector.
|
||||
pass
|
||||
|
||||
def collect(self) -> Dict[str, Dict[str, Tensor]]:
|
||||
data = {}
|
||||
target_name = 'weight'
|
||||
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
target: Tensor = getattr(wrapper, target_name)
|
||||
data[module_name] = {target_name: target.data.clone()}
|
||||
return data
|
||||
|
||||
|
||||
class EvaluatorBasedTargetDataCollector(EvaluatorBasedDataCollector):
|
||||
"""
|
||||
Collect all wrapper pruning target after training or inference.
|
||||
"""
|
||||
|
||||
def collect(self) -> Dict[str, Dict[str, Tensor]]:
|
||||
assert self.compressor.bound_model is not None
|
||||
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
|
||||
|
||||
data = {}
|
||||
target_name = 'weight'
|
||||
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
|
||||
target: Tensor = getattr(wrapper, target_name)
|
||||
data[module_name] = {target_name: target.data.clone()}
|
||||
return data
|
||||
|
||||
|
||||
class EvaluatorBasedHookDataCollector(EvaluatorBasedDataCollector):
|
||||
"""
|
||||
Add hooks and collect data during training or inference.
|
||||
NOTE: Only support one target has one hook right now.
|
||||
"""
|
||||
|
||||
def collect(self) -> Dict[str, Dict[str, List]]:
|
||||
assert self.compressor.bound_model is not None
|
||||
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
|
||||
|
||||
data = {}
|
||||
for module_name, hooks in self._hooks.items():
|
||||
data[module_name] = {}
|
||||
for target_name, hook in hooks.items():
|
||||
data[module_name][target_name] = hook.buffer
|
||||
return data
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
|||
from .base import MetricsCalculator
|
||||
from ...utils import Scaling
|
||||
|
||||
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
|
||||
__all__ = ['NormMetricsCalculator', 'HookDataNormMetricsCalculator', 'DistMetricsCalculator',
|
||||
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
|
||||
|
||||
|
||||
|
@ -19,11 +19,12 @@ class StraightMetricsCalculator(MetricsCalculator):
|
|||
"""
|
||||
This metrics calculator directly returns a copy of data as metrics.
|
||||
"""
|
||||
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
metrics = {}
|
||||
for name, tensor in data.items():
|
||||
# use inplace detach `detach_` here to avoid creating a new tensor
|
||||
metrics[name] = tensor.clone().detach_()
|
||||
for module_name, targets_data in data.items():
|
||||
metrics[module_name] = {}
|
||||
for target_name, target_data in targets_data.items():
|
||||
metrics[module_name][target_name] = target_data.clone().detach()
|
||||
return metrics
|
||||
|
||||
|
||||
|
@ -44,27 +45,32 @@ class NormMetricsCalculator(MetricsCalculator):
|
|||
super().__init__(scalers=scalers)
|
||||
self.p = p if p is not None else 'fro'
|
||||
|
||||
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def reduce_func(t: Tensor) -> Tensor:
|
||||
return t.norm(p=self.p, dim=-1) # type: ignore
|
||||
|
||||
metrics = {}
|
||||
target_name = 'weight'
|
||||
for module_name, target_data in data.items():
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name] = scaler.shrink(target_data, reduce_func)
|
||||
for module_name, targets_data in data.items():
|
||||
metrics[module_name] = {}
|
||||
for target_name, target_data in targets_data.items():
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
|
||||
return metrics
|
||||
|
||||
|
||||
class MultiDataNormMetricsCalculator(NormMetricsCalculator):
|
||||
class HookDataNormMetricsCalculator(NormMetricsCalculator):
|
||||
"""
|
||||
The data value format is a two-element list [batch_number, cumulative_data].
|
||||
The hook data value format is a two-element list [batch_number, cumulative_data].
|
||||
Directly use the cumulative_data as new_data to calculate norm metric.
|
||||
TaylorFO pruner uses this to calculate metric.
|
||||
"""
|
||||
|
||||
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
|
||||
new_data = {name: buffer[1] for name, buffer in data.items()}
|
||||
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
new_data = {}
|
||||
for module_name, targets_data in data.items():
|
||||
new_data[module_name] = {}
|
||||
for target_name, (_, target_data) in targets_data.items():
|
||||
new_data[module_name][target_name] = target_data
|
||||
return super().calculate_metrics(new_data)
|
||||
|
||||
|
||||
|
@ -85,7 +91,7 @@ class DistMetricsCalculator(MetricsCalculator):
|
|||
super().__init__(scalers=scalers)
|
||||
self.p = p if p is not None else 'fro'
|
||||
|
||||
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def reduce_func(t: Tensor) -> Tensor:
|
||||
reshape_data = t.reshape(-1, t.shape[-1])
|
||||
metric = torch.zeros(reshape_data.shape[0], device=reshape_data.device)
|
||||
|
@ -94,10 +100,11 @@ class DistMetricsCalculator(MetricsCalculator):
|
|||
return metric.reshape(t.shape[:-1])
|
||||
|
||||
metrics = {}
|
||||
target_name = 'weight'
|
||||
for module_name, target_data in data.items():
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name] = scaler.shrink(target_data, reduce_func)
|
||||
for module_name, targets_data in data.items():
|
||||
metrics[module_name] = {}
|
||||
for target_name, target_data in targets_data.items():
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
|
||||
return metrics
|
||||
|
||||
|
||||
|
@ -108,16 +115,18 @@ class APoZRankMetricsCalculator(MetricsCalculator):
|
|||
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
|
||||
APoZRank pruner uses this to calculate metric.
|
||||
"""
|
||||
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
|
||||
|
||||
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def reduce_func(t: Tensor) -> Tensor:
|
||||
return 1 - t.mean(dim=-1)
|
||||
|
||||
metrics = {}
|
||||
target_name = 'weight'
|
||||
for module_name, target_data in data.items():
|
||||
target_data = target_data[1] / target_data[0]
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name] = scaler.shrink(target_data, reduce_func)
|
||||
for module_name, targets_data in data.items():
|
||||
metrics[module_name] = {}
|
||||
for target_name, target_data in targets_data.items():
|
||||
target_data = target_data[1] / target_data[0]
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
|
||||
return metrics
|
||||
|
||||
|
||||
|
@ -127,14 +136,15 @@ class MeanRankMetricsCalculator(MetricsCalculator):
|
|||
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
|
||||
MeanRank pruner uses this to calculate metric.
|
||||
"""
|
||||
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
|
||||
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def reduce_func(t: Tensor) -> Tensor:
|
||||
return t.mean(dim=-1)
|
||||
|
||||
metrics = {}
|
||||
target_name = 'weight'
|
||||
for module_name, target_data in data.items():
|
||||
target_data = target_data[1] / target_data[0]
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name] = scaler.shrink(target_data, reduce_func)
|
||||
for module_name, targets_data in data.items():
|
||||
metrics[module_name] = {}
|
||||
for target_name, target_data in targets_data.items():
|
||||
target_data = target_data[1] / target_data[0]
|
||||
scaler = self._get_scaler(module_name, target_name)
|
||||
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
|
||||
return metrics
|
||||
|
|
|
@ -17,7 +17,8 @@ _logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AMCEnv:
|
||||
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float, max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
|
||||
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float,
|
||||
max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
|
||||
pruning_op_names = []
|
||||
[pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)]
|
||||
self.pruning_ops = OrderedDict()
|
||||
|
@ -26,7 +27,10 @@ class AMCEnv:
|
|||
if name in pruning_op_names:
|
||||
op_type = type(layer).__name__
|
||||
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore
|
||||
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore
|
||||
if hasattr(layer, 'kernel_size'):
|
||||
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) # type: ignore
|
||||
else:
|
||||
kernel_size = 1
|
||||
self.pruning_ops[name] = (i, op_type, stride, kernel_size)
|
||||
self.pruning_types.append(op_type)
|
||||
self.pruning_types = list(set(self.pruning_types))
|
||||
|
@ -60,15 +64,18 @@ class AMCEnv:
|
|||
|
||||
total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names])
|
||||
previous_pruning_target = self.under_pruning_target - total_current_target
|
||||
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] for name in self.pruning_op_names[index + 1:]])
|
||||
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] \
|
||||
for name in self.pruning_op_names[index + 1:]])
|
||||
min_current_pruning_target = self.excepted_pruning_target - previous_pruning_target - max_rest_pruning_target
|
||||
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - (self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
|
||||
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - \
|
||||
(self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
|
||||
max_current_pruning_target_2 = self.excepted_pruning_target - previous_pruning_target
|
||||
max_current_pruning_target = min(max_current_pruning_target_1, max_current_pruning_target_2)
|
||||
min_action = min_current_pruning_target / current_statistics[op_name][self.target]
|
||||
max_action = max_current_pruning_target / current_statistics[op_name][self.target]
|
||||
if min_action > self.max_sparsity_per_layer[op_name]:
|
||||
_logger.warning('[%s] min action > max sparsity per layer: %f > %f', op_name, min_action, self.max_sparsity_per_layer[op_name])
|
||||
warn_msg = f'[{op_name}] min action > max sparsity per layer: {min_action} > {self.max_sparsity_per_layer[op_name]}'
|
||||
_logger.warning(warn_msg)
|
||||
action = max(0., min(max_action, max(min_action, action)))
|
||||
|
||||
self.current_op_name = op_name
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -23,22 +23,22 @@ class NormalSparsityAllocator(SparsityAllocator):
|
|||
This allocator directly masks the locations of each pruning target with lower metric values.
|
||||
"""
|
||||
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
masks = {}
|
||||
# TODO: Support more target type in wrapper & config list refactor
|
||||
target_name = 'weight'
|
||||
for module_name, target_metric in metrics.items():
|
||||
for module_name, targets_metric in metrics.items():
|
||||
masks[module_name] = {}
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
sparsity_rate = wrapper.config['total_sparsity']
|
||||
prune_num = int(sparsity_rate * target_metric.numel())
|
||||
if prune_num != 0:
|
||||
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
|
||||
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
|
||||
else:
|
||||
# target_metric should have the same size as shrinked_mask
|
||||
shrinked_mask = torch.ones_like(target_metric)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
sparsity_rate = wrapper.config['total_sparsity']
|
||||
prune_num = int(sparsity_rate * target_metric.numel())
|
||||
if prune_num != 0:
|
||||
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
|
||||
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
|
||||
else:
|
||||
# target_metric should have the same size as shrinked_mask
|
||||
shrinked_mask = torch.ones_like(target_metric)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
return masks
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ class BankSparsityAllocator(SparsityAllocator):
|
|||
"""
|
||||
In bank pruner, all values in weight are divided into different sub blocks each shape
|
||||
aligned with balance_gran. Each sub block has the same sparsity which equal to the overall sparsity.
|
||||
This allocator pruned the weight in the granularity of block.
|
||||
This allocator pruned the weight in the granularity of block.
|
||||
"""
|
||||
|
||||
def __init__(self, pruner: Pruner, balance_gran: list):
|
||||
|
@ -56,101 +56,108 @@ class BankSparsityAllocator(SparsityAllocator):
|
|||
assert isinstance(gran, int) and gran > 0, 'All values in list balance_gran \
|
||||
should be type int and bigger than zero'
|
||||
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
masks = {}
|
||||
# TODO: Support more target type in wrapper & config list refactor
|
||||
target_name = 'weight'
|
||||
for module_name, target_metric in metrics.items():
|
||||
for module_name, targets_metric in metrics.items():
|
||||
masks[module_name] = {}
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
sparsity_rate = wrapper.config['total_sparsity']
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
sparsity_rate = wrapper.config['total_sparsity']
|
||||
|
||||
n_dim = len(target_metric.shape)
|
||||
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
|
||||
# make up for balance_gran
|
||||
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
|
||||
for i, j in zip(target_metric.shape, balance_gran):
|
||||
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
|
||||
n_dim = len(target_metric.shape)
|
||||
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
|
||||
# make up for balance_gran
|
||||
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
|
||||
for i, j in zip(target_metric.shape, balance_gran):
|
||||
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
|
||||
|
||||
# FIXME: The following code need refactor, do it after scaling refactor is done.
|
||||
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
|
||||
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
|
||||
for iter_params in itertools.product(*loop_iters):
|
||||
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
|
||||
for iter_param, gran in zip(iter_params, balance_gran)]
|
||||
index_str = ",".join(index_str_list)
|
||||
sub_metric_str = "target_metric[{}]".format(index_str)
|
||||
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
|
||||
metric_bank: Tensor = eval(sub_metric_str)
|
||||
prune_num = int(sparsity_rate * metric_bank.numel())
|
||||
# mask_bank will be used in exec(sub_mask_str)
|
||||
if prune_num != 0:
|
||||
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
|
||||
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank)
|
||||
else:
|
||||
mask_bank = torch.ones_like(metric_bank)
|
||||
exec(sub_mask_str)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
# FIXME: The following code need refactor, do it after scaling refactor is done.
|
||||
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
|
||||
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
|
||||
for iter_params in itertools.product(*loop_iters):
|
||||
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
|
||||
for iter_param, gran in zip(iter_params, balance_gran)]
|
||||
index_str = ",".join(index_str_list)
|
||||
sub_metric_str = "target_metric[{}]".format(index_str)
|
||||
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
|
||||
metric_bank: Tensor = eval(sub_metric_str)
|
||||
prune_num = int(sparsity_rate * metric_bank.numel())
|
||||
# mask_bank will be used in exec(sub_mask_str)
|
||||
if prune_num != 0:
|
||||
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
|
||||
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank) # type: ignore
|
||||
else:
|
||||
mask_bank = torch.ones_like(metric_bank) # type: ignore
|
||||
mask_bank = mask_bank # `type: ignore` is useless for unused-variable error, add this line to workaround
|
||||
exec(sub_mask_str)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
return masks
|
||||
|
||||
|
||||
class GlobalSparsityAllocator(SparsityAllocator):
|
||||
"""
|
||||
This allocator sorts all metrics as a whole, mask the locations of pruning target with lower metric value.
|
||||
By default, this allocator will prevent each module from being over-pruned with upper sparsity 0.99.
|
||||
"""
|
||||
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
masks = {}
|
||||
if not metrics:
|
||||
return masks
|
||||
# TODO: support more target type in wrapper & config list refactor
|
||||
target_name = 'weight'
|
||||
|
||||
# validate all wrapper setting the same sparsity
|
||||
# validate all wrapper setting have the same sparsity
|
||||
# TODO: move validation logic to pruner
|
||||
global_sparsity_rate = self.pruner.get_modules_wrapper()[list(metrics.keys())[0]].config['total_sparsity']
|
||||
for module_name, target_metric in metrics.items():
|
||||
for module_name in metrics.keys():
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
assert global_sparsity_rate == wrapper.config['total_sparsity']
|
||||
|
||||
# find the largest metric value among all metrics
|
||||
max_metric_value = list(metrics.values())[0].max()
|
||||
for module_name, target_metric in metrics.items():
|
||||
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
|
||||
max_metric_value = list(list(metrics.values())[0].values())[0].max()
|
||||
for targets_metric in metrics.values():
|
||||
for target_metric in targets_metric.values():
|
||||
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
|
||||
|
||||
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
|
||||
for module_name, target_metric in metrics.items():
|
||||
for module_name, targets_metric in metrics.items():
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
|
||||
assert 0 <= max_sparsity <= 1
|
||||
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
|
||||
expand_times = old_target_mask.numel() // target_metric.numel()
|
||||
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
|
||||
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
|
||||
metrics[module_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
|
||||
assert 0 <= max_sparsity <= 1
|
||||
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
|
||||
expand_times = old_target_mask.numel() // target_metric.numel()
|
||||
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
|
||||
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
|
||||
metrics[module_name][target_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
|
||||
|
||||
# build the global_matric & calculate global threshold
|
||||
metric_list = []
|
||||
for module_name, target_metric in metrics.items():
|
||||
for module_name, targets_metric in metrics.items():
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
|
||||
expand_times = old_target_mask.numel() // target_metric.numel()
|
||||
metric_list.append(target_metric.reshape(-1).unsqueeze(0).expand(expand_times, -1).reshape(-1))
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
|
||||
expand_times = old_target_mask.numel() // target_metric.numel()
|
||||
metric_list.append(target_metric.reshape(-1).repeat_interleave(expand_times))
|
||||
global_metric = torch.cat(metric_list)
|
||||
max_pruning_num = int((global_metric != max_metric_value).sum().item())
|
||||
total_pruning_num = min(int(global_sparsity_rate * global_metric.numel()), max_pruning_num)
|
||||
global_threshold = torch.topk(global_metric.reshape(-1), total_pruning_num, largest=False)[0].max()
|
||||
|
||||
# generate masks for each target
|
||||
for module_name, target_metric in metrics.items():
|
||||
for module_name, targets_metric in metrics.items():
|
||||
masks[module_name] = {}
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
shrinked_mask = torch.gt(target_metric, global_threshold).type_as(target_metric)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
shrinked_mask = torch.gt(target_metric, global_threshold).type_as(target_metric)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
return masks
|
||||
|
||||
|
||||
class DependencyAwareAllocator(NormalSparsityAllocator):
|
||||
# TODO: This allocator will trace the model, means the model will be inference during initialization,
|
||||
# sometime we may not aware of this inference and it may lead to some error.
|
||||
class DependencyAwareAllocator(SparsityAllocator):
|
||||
"""
|
||||
An specific allocator for Conv2d & Linear module with dependency-aware.
|
||||
It will generate a public mask for the modules that have dependencies,
|
||||
|
@ -170,52 +177,79 @@ class DependencyAwareAllocator(NormalSparsityAllocator):
|
|||
# group dependency format: {module_name: group_num}
|
||||
self.pruner._unwrap_model()
|
||||
graph = TorchModuleGraph(model=self.pruner.bound_model, dummy_input=dummy_input)
|
||||
channel_dependency = ChannelDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
|
||||
group_dependency = GroupDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
|
||||
channel_dependency = ChannelDependency(model=self.pruner.bound_model, dummy_input=dummy_input,
|
||||
traced_model=graph.trace).dependency_sets
|
||||
group_dependency = GroupDependency(model=self.pruner.bound_model, dummy_input=dummy_input,
|
||||
traced_model=graph.trace).dependency_sets
|
||||
self.pruner._wrap_model()
|
||||
return channel_dependency, group_dependency
|
||||
|
||||
def _metric_fuse(self, metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor:
|
||||
def _metric_fuse(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Tensor]:
|
||||
# Sum all metric value in the same position.
|
||||
metrics = list(metrics.values()) if isinstance(metrics, dict) else metrics
|
||||
assert all(metrics[0].size() == metric.size() for metric in metrics), 'Metrics size do not match.'
|
||||
fused_metric = torch.zeros_like(metrics[0])
|
||||
for metric in metrics:
|
||||
fused_metric += metric
|
||||
return fused_metric
|
||||
fused_metrics = {}
|
||||
for targets_metric in metrics.values():
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
if target_name in fused_metrics:
|
||||
fused_metrics[target_name] += target_metric
|
||||
else:
|
||||
fused_metrics[target_name] = target_metric
|
||||
return fused_metrics
|
||||
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
# placeholder, here we need more discussion about dependence sparsity, Plan A or Plan B.
|
||||
masks = {}
|
||||
# generate public part for modules that have dependencies
|
||||
for module_names in self.channel_dependency:
|
||||
sub_metrics = {module_name: metrics[module_name] for module_name in module_names if module_name in metrics}
|
||||
if not sub_metrics:
|
||||
continue
|
||||
fused_metric = self._metric_fuse(sub_metrics)
|
||||
fused_metrics = self._metric_fuse(sub_metrics)
|
||||
|
||||
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] for module_name in sub_metrics.keys()}
|
||||
min_sparsity_rate = min(sparsity_rates.values())
|
||||
for target_name, fused_metric in fused_metrics.items():
|
||||
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] \
|
||||
for module_name in sub_metrics.keys()}
|
||||
min_sparsity_rate = min(sparsity_rates.values())
|
||||
|
||||
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
|
||||
max_group_nums = int(np.lcm.reduce(group_nums))
|
||||
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate)
|
||||
group_step = fused_metric.shape[0] // max_group_nums
|
||||
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
|
||||
max_group_nums = int(np.lcm.reduce(group_nums))
|
||||
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate)
|
||||
group_step = fused_metric.shape[0] // max_group_nums
|
||||
|
||||
# get the public part of the mask of the module with dependencies
|
||||
sub_masks = []
|
||||
for gid in range(max_group_nums):
|
||||
_start = gid * group_step
|
||||
_end = (gid + 1) * group_step
|
||||
if pruned_numel_per_group > 0:
|
||||
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max()
|
||||
sub_mask = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric)
|
||||
# get the public part of the mask of the module with dependencies
|
||||
dependency_mask = torch.ones_like(fused_metric)
|
||||
for gid in range(max_group_nums):
|
||||
_start = gid * group_step
|
||||
_end = (gid + 1) * group_step
|
||||
if pruned_numel_per_group > 0:
|
||||
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max()
|
||||
dependency_mask[_start: _end] = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric)
|
||||
|
||||
# change the metric value corresponding to the public mask part to the minimum value
|
||||
for module_name, targets_metric in sub_metrics.items():
|
||||
if target_name in targets_metric:
|
||||
# Following is Plan A, generate the dependency mask first, and then fill in the sparsity,
|
||||
# the final mask is group unbalanced. - 1 ensure the denpendency metric is the minimum, and will be masked first.
|
||||
# min_value = targets_metric[target_name].min() - 1
|
||||
# metrics[module_name][target_name] = torch.where(dependency_mask!=0, targets_metric[target_name], min_value)
|
||||
|
||||
# Following is Plan B, just generate the dependency mask, the final mask is group balanced.
|
||||
masks.setdefault(module_name, {})
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, dependency_mask)
|
||||
|
||||
# generate masks for layers without dependencies
|
||||
for module_name, targets_metric in metrics.items():
|
||||
masks.setdefault(module_name, {})
|
||||
wrapper = self.pruner.get_modules_wrapper()[module_name]
|
||||
for target_name, target_metric in targets_metric.items():
|
||||
if target_name in masks[module_name]:
|
||||
continue
|
||||
sparsity_rate = wrapper.config['total_sparsity']
|
||||
prune_num = int(sparsity_rate * target_metric.numel())
|
||||
if prune_num != 0:
|
||||
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
|
||||
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
|
||||
else:
|
||||
sub_mask = torch.ones_like(fused_metric[_start:_end])
|
||||
sub_masks.append(sub_mask)
|
||||
dependency_mask = torch.cat(sub_masks, dim=0)
|
||||
|
||||
# change the metric value corresponding to the public mask part to the minimum value
|
||||
for module_name, target_metric in sub_metrics.items():
|
||||
min_value = target_metric.min()
|
||||
metrics[module_name] = torch.where(dependency_mask!=0, target_metric, min_value)
|
||||
|
||||
return super().common_target_masks_generation(metrics)
|
||||
# target_metric should have the same size as shrinked_mask
|
||||
shrinked_mask = torch.ones_like(target_metric)
|
||||
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
|
||||
return masks
|
||||
|
|
|
@ -51,7 +51,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
|
|||
self.total_iteration = total_iteration
|
||||
self.skip_first_iteration = skip_first_iteration
|
||||
super().__init__(origin_model, origin_config_list=origin_config_list, origin_masks=origin_masks,
|
||||
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
|
||||
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='latest')
|
||||
|
||||
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
|
||||
self.current_iteration = 1 if self.skip_first_iteration else 0
|
||||
|
@ -78,10 +78,14 @@ class FunctionBasedTaskGenerator(TaskGenerator):
|
|||
|
||||
# get current2origin_sparsity and compact2origin_sparsity
|
||||
origin_model = torch.load(self._origin_model_path)
|
||||
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity)
|
||||
_logger.debug('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4))
|
||||
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
|
||||
self.target_sparsity)
|
||||
debug_msg = f'\nTask {task_result.task_id} total real sparsity compared with original model is:\n' + \
|
||||
f'{json_tricks.dumps(current2origin_sparsity, indent=4)}'
|
||||
_logger.debug(debug_msg)
|
||||
if task_result.task_id != 'origin':
|
||||
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
|
||||
task = self._tasks[task_result.task_id]
|
||||
task.state['current2origin_sparsity'] = current2origin_sparsity
|
||||
|
||||
# if reach the total_iteration, no more task will be generated
|
||||
if self.current_iteration > self.total_iteration:
|
||||
|
@ -116,7 +120,8 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
|
|||
for target, mo in zip(target_sparsity, compact2origin_sparsity):
|
||||
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
|
||||
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
|
||||
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
assert 0 <= sparsity <= 1, err_msg
|
||||
config_list.append(deepcopy(target))
|
||||
config_list[-1]['total_sparsity'] = sparsity
|
||||
return config_list
|
||||
|
@ -128,7 +133,8 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
|
|||
for target, mo in zip(target_sparsity, compact2origin_sparsity):
|
||||
ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
|
||||
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
|
||||
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
assert 0 <= sparsity <= 1, err_msg
|
||||
config_list.append(deepcopy(target))
|
||||
config_list[-1]['total_sparsity'] = sparsity
|
||||
return config_list
|
||||
|
@ -149,16 +155,18 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
|
|||
# The following is the formula in paper.
|
||||
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
|
||||
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
|
||||
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
|
||||
assert 0 <= sparsity <= 1, err_msg
|
||||
config_list.append(deepcopy(target))
|
||||
config_list[-1]['total_sparsity'] = sparsity
|
||||
return config_list
|
||||
|
||||
|
||||
class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
||||
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {},
|
||||
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
|
||||
perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
|
||||
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]],
|
||||
origin_masks: Dict[str, Dict[str, Tensor]] = {}, start_temperature: float = 100, stop_temperature: float = 20,
|
||||
cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.',
|
||||
keep_intermediate_result: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -188,7 +196,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
self.perturbation_magnitude = perturbation_magnitude
|
||||
|
||||
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
|
||||
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
|
||||
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
|
||||
|
||||
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
|
||||
self.current_temperature = self.start_temperature
|
||||
|
@ -196,7 +204,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
# TODO: replace with validation here
|
||||
for config in config_list:
|
||||
if 'sparsity' in config or 'sparsity_per_layer' in config:
|
||||
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
|
||||
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
|
||||
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
|
||||
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
|
||||
_logger.warning(warn_msg)
|
||||
|
||||
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
|
||||
self.target_sparsity_list = config_list_canonical(model, config_list)
|
||||
|
@ -259,11 +270,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
|
|||
|
||||
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
|
||||
sparsity = sorted(random_sparsity)
|
||||
|
||||
|
||||
# calculate the scale
|
||||
total_weights = np.sum(num_weights)
|
||||
total_weights_pruned = np.sum([int(num_weight * sparsity[idx]) for idx, num_weight in enumerate(num_weights)])
|
||||
|
||||
|
||||
if total_weights_pruned == 0:
|
||||
return None
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.nn import Module
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from nni.common.serializer import Traceable, is_traceable
|
||||
from nni.common.serializer import is_traceable
|
||||
|
||||
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
|
||||
|
||||
|
@ -86,7 +86,8 @@ class OptimizerConstructHelper(ConstructHelper):
|
|||
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
|
||||
assert isinstance(optimizer_trace, Optimizer), \
|
||||
'It is not an instance of torch.nn.Optimizer.'
|
||||
return OptimizerConstructHelper(model, optimizer_trace.trace_symbol, *optimizer_trace.trace_args, **optimizer_trace.trace_kwargs) # type: ignore
|
||||
return OptimizerConstructHelper(model, optimizer_trace.trace_symbol, *optimizer_trace.trace_args, # type: ignore
|
||||
**optimizer_trace.trace_kwargs) # type: ignore
|
||||
|
||||
|
||||
class LRSchedulerConstructHelper(ConstructHelper):
|
||||
|
@ -115,4 +116,5 @@ class LRSchedulerConstructHelper(ConstructHelper):
|
|||
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
|
||||
assert isinstance(lr_scheduler_trace, _LRScheduler), \
|
||||
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
|
||||
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, **lr_scheduler_trace.trace_kwargs) # type: ignore
|
||||
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
|
||||
**lr_scheduler_trace.trace_kwargs) # type: ignore
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
_EVALUATOR_DOCSTRING = r"""NNI will use the evaluator to intervene in the model training process,
|
||||
so as to perform training-aware model compression.
|
||||
All training-aware model compression will use the evaluator as the entry for intervention training in the future.
|
||||
Usually you just need to wrap some classes with ``nni.trace`` or package the training process as a function to initialize the evaluator.
|
||||
Please refer ... for a full tutorial on how to initialize a ``evaluator``.
|
||||
|
||||
The following are two simple examples, if you use pytorch_lightning, please refer to :class:`nni.compression.pytorch.LightningEvaluator`,
|
||||
if you use native pytorch, please refer to :class:`nni.compression.pytorch.TorchEvaluator`::
|
||||
|
||||
# LightningEvaluator example
|
||||
import pytorch_lightning
|
||||
lightning_trainer = nni.trace(pytorch_lightning.Trainer)(max_epochs=1, max_steps=50, logger=TensorBoardLogger(...))
|
||||
lightning_data_module = nni.trace(pytorch_lightning.LightningDataModule)(...)
|
||||
|
||||
from nni.compression.pytorch import LightningEvaluator
|
||||
evaluator = LightningEvaluator(lightning_trainer, lightning_data_module)
|
||||
|
||||
# TorchEvaluator example
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def training_model(model, optimizer, criterion, lr_scheduler, max_steps, max_epochs, *args, **kwargs):
|
||||
# max_steps, max_epochs might be None, which means unlimited training time,
|
||||
# so here we need set a default termination condition (by default, total_epochs=10, total_steps=100000).
|
||||
total_epochs = max_epochs if max_epochs else 10
|
||||
total_steps = max_steps if max_steps else 100000
|
||||
current_step = 0
|
||||
|
||||
# init dataloader
|
||||
train_dataloader = ...
|
||||
|
||||
for epoch in range(total_epochs):
|
||||
...
|
||||
for input_data, target in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
result = model(input_data)
|
||||
loss = criterion(result, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
current_step += 1
|
||||
if current_step >= total_steps:
|
||||
return
|
||||
lr_scheduler.step()
|
||||
|
||||
traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01)
|
||||
criterion = F.nll_loss
|
||||
|
||||
from nni.compression.pytorch import TorchEvaluator
|
||||
evaluator = TorchEvaluator(training_func=training_model, optimziers=traced_optimizer, criterion=criterion)
|
||||
"""
|
|
@ -73,16 +73,16 @@ class TensorHook(Hook):
|
|||
return hook
|
||||
"""
|
||||
|
||||
def __init__(self, target: Tensor, target_name: str, hook_factory: Callable[[List], Callable[[Tensor], Any]]):
|
||||
def __init__(self, target: Tensor, target_name: str, hook_factory: Callable[[List], Callable[[Tensor], Tensor | None]]):
|
||||
assert isinstance(target, Tensor)
|
||||
super().__init__(target, target_name, hook_factory)
|
||||
|
||||
def _register(self, hook_func: Callable[[Tensor], Any]) -> RemovableHandle:
|
||||
def _register(self, hook_func: Callable[[Tensor], Tensor | None]) -> RemovableHandle:
|
||||
return self.target.register_hook(hook_func) # type: ignore
|
||||
|
||||
|
||||
class ModuleHook(Hook):
|
||||
def __init__(self, target: Module, target_name: str, hook_factory: Callable[[List], Callable[[Module, Tensor, Tensor], Any]]):
|
||||
def __init__(self, target: Module, target_name: str, hook_factory: Callable[[List], Callable[[Module, Any, Any], Any]]):
|
||||
assert isinstance(target, Module)
|
||||
super().__init__(target, target_name, hook_factory)
|
||||
|
||||
|
@ -97,7 +97,7 @@ class ForwardHook(ModuleHook):
|
|||
return hook
|
||||
"""
|
||||
|
||||
def _register(self, hook_func: Callable[[Module, Tensor, Tensor], Any]):
|
||||
def _register(self, hook_func: Callable[[Module, Tuple[Any], Any], Any]):
|
||||
return self.target.register_forward_hook(hook_func) # type: ignore
|
||||
|
||||
|
||||
|
@ -111,7 +111,7 @@ class BackwardHook(ModuleHook):
|
|||
return hook
|
||||
"""
|
||||
|
||||
def _register(self, hook_func: Callable[[Module, Tensor, Tensor], Any]):
|
||||
def _register(self, hook_func: Callable[[Module, Tuple[Tensor] | Tensor, Tuple[Tensor] | Tensor], Any]):
|
||||
return self.target.register_backward_hook(hook_func) # type: ignore
|
||||
|
||||
|
||||
|
@ -148,7 +148,8 @@ class Evaluator:
|
|||
|
||||
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
|
||||
"""
|
||||
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification, model training, and model evaluation.
|
||||
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
|
||||
model training, and model evaluation.
|
||||
|
||||
Parameter
|
||||
---------
|
||||
|
@ -246,10 +247,12 @@ class Evaluator:
|
|||
def evaluate(self) -> float | None | Tuple[float, Any] | Tuple[None, Any]:
|
||||
"""
|
||||
NNI assume the evaluation function user passed in should return a float number or a dict as metric.
|
||||
If the evaluation function returned a dict, take the value with dict key ``default`` as the first element of ``evaluate`` returned value,
|
||||
If the evaluation function returned a dict, take the value with dict key ``default``
|
||||
as the first element of ``evaluate`` returned value,
|
||||
and put the dict as the second element of the returned value.
|
||||
For any other type of the metric returned by evaluation function, ``evaluate`` will directly returned
|
||||
(it should be a float, but NNI does not prevent other types from being returned, this will handle by the object calling ``evaluate``).
|
||||
(it should be a float, but NNI does not prevent other types from being returned,
|
||||
this will handle by the object calling ``evaluate``).
|
||||
"""
|
||||
# Note that the first item of the returned value will be used as the default metric used by NNI.
|
||||
raise NotImplementedError
|
||||
|
@ -287,9 +290,11 @@ class LightningEvaluator(Evaluator):
|
|||
|
||||
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
|
||||
dummy_input: Any | None = None):
|
||||
err_msg = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
|
||||
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
|
||||
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
|
||||
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
|
||||
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
|
||||
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
|
||||
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
|
||||
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
|
||||
self.trainer = trainer
|
||||
self.data_module = data_module
|
||||
self._dummy_input = dummy_input
|
||||
|
@ -314,18 +319,20 @@ class LightningEvaluator(Evaluator):
|
|||
optimizers_lr_schedulers: Any = pure_model.configure_optimizers()
|
||||
# 1. None - Fit will run without any optimizer.
|
||||
if optimizers_lr_schedulers is None:
|
||||
err_msg = 'NNI does not support `LightningModule.configure_optimizers` returned None, '
|
||||
err_msg += 'if you have a reason why you must, please file an issue at https://github.com/microsoft/nni/issues'
|
||||
err_msg = 'NNI does not support `LightningModule.configure_optimizers` returned None, ' + \
|
||||
'if you have a reason why you must, please file an issue at https://github.com/microsoft/nni/issues'
|
||||
raise ValueError(err_msg)
|
||||
# 2. Single optimizer.
|
||||
# 3. Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
|
||||
# 3. Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose
|
||||
# value is a single LR scheduler or lr_scheduler_config.
|
||||
elif isinstance(optimizers_lr_schedulers, (Optimizer, dict)):
|
||||
optimizers_lr_schedulers = [optimizers_lr_schedulers]
|
||||
|
||||
err_msg = f'Got an wrong returned value type of `LightningModule.configure_optimizers`: {type(optimizers_lr_schedulers).__name__}'
|
||||
assert isinstance(optimizers_lr_schedulers, (list, tuple)), err_msg
|
||||
|
||||
# 4. Two lists - the first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
|
||||
# 4. Two lists - the first list has multiple optimizers,
|
||||
# and the second has multiple LR schedulers (or multiple lr_scheduler_config).
|
||||
if isinstance(optimizers_lr_schedulers[0], (list, tuple)):
|
||||
optimizers, lr_schedulers = optimizers_lr_schedulers
|
||||
self._optimizer_helpers = [OptimizerConstructHelper.from_trace(pure_model, optimizer) for optimizer in optimizers]
|
||||
|
@ -364,7 +371,8 @@ class LightningEvaluator(Evaluator):
|
|||
self._initialization_complete = True
|
||||
|
||||
def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None):
|
||||
assert self._initialization_complete is True, 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
|
||||
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
|
||||
assert self._initialization_complete is True, err_msg
|
||||
assert isinstance(model, pl.LightningModule)
|
||||
if self.model is not None:
|
||||
_logger.warning('Already bound a model, will unbind it before bind a new model.')
|
||||
|
@ -397,7 +405,8 @@ class LightningEvaluator(Evaluator):
|
|||
if self._opt_returned_dicts:
|
||||
def new_configure_optimizers(_): # type: ignore
|
||||
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
|
||||
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
|
||||
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]])
|
||||
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
|
||||
opt_lrs_dicts = deepcopy(self._opt_returned_dicts)
|
||||
for opt_lrs_dict in opt_lrs_dicts:
|
||||
opt_lrs_dict['optimizer'] = optimizers[opt_lrs_dict['optimizer']]
|
||||
|
@ -407,7 +416,8 @@ class LightningEvaluator(Evaluator):
|
|||
elif self._lr_scheduler_helpers:
|
||||
def new_configure_optimizers(_): # type: ignore
|
||||
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
|
||||
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
|
||||
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]])
|
||||
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
|
||||
return optimizers, lr_schedulers
|
||||
else:
|
||||
def new_configure_optimizers(_):
|
||||
|
@ -442,7 +452,8 @@ class LightningEvaluator(Evaluator):
|
|||
assert isinstance(self.model, pl.LightningModule)
|
||||
|
||||
class OptimizerCallback(Callback):
|
||||
def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer, opt_idx: int) -> None:
|
||||
def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule,
|
||||
optimizer: Optimizer, opt_idx: int) -> None:
|
||||
for task in before_step_tasks:
|
||||
task()
|
||||
|
||||
|
@ -486,10 +497,12 @@ class LightningEvaluator(Evaluator):
|
|||
|
||||
def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
|
||||
"""
|
||||
NNI will use metric with key ``default`` for evaluating model, please make sure you have this key in your ``Trainer.test()`` returned metric dicts.
|
||||
If ``Trainer.test()`` returned list contains multiple dicts with key ``default``, NNI will take their average as the final metric.
|
||||
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}, {'default': 0.7, 'loss': 2.3}]``,
|
||||
NNI will take the final metric ``(0.8 + 0.6 + 0.7) / 3 = 0.7``.
|
||||
NNI will use metric with key ``default`` for evaluating model,
|
||||
please make sure you have this key in your ``Trainer.test()`` returned metric dicts.
|
||||
If ``Trainer.test()`` returned list contains multiple dicts with key ``default``,
|
||||
NNI will take their average as the final metric.
|
||||
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
|
||||
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
|
||||
"""
|
||||
assert isinstance(self.model, pl.LightningModule)
|
||||
# reset trainer
|
||||
|
@ -514,9 +527,11 @@ class LightningEvaluator(Evaluator):
|
|||
raise e
|
||||
|
||||
|
||||
_OPTIMIZERS = Union[Optimizer, List[Optimizer]]
|
||||
_CRITERION = Callable[[Any, Any], Any]
|
||||
_SCHEDULERS = Union[None, _LRScheduler, List[_LRScheduler]]
|
||||
_EVALUATING_FUNC = Callable[[Module], Union[float, Dict]]
|
||||
_TRAINING_FUNC = Callable[[Module, Union[Optimizer, List[Optimizer]], _CRITERION, Union[None, _LRScheduler, List[_LRScheduler]], Optional[int], Optional[int]], None]
|
||||
_TRAINING_FUNC = Callable[[Module, _OPTIMIZERS, _CRITERION, _SCHEDULERS, Optional[int], Optional[int]], None]
|
||||
|
||||
|
||||
class TorchEvaluator(Evaluator):
|
||||
|
@ -528,8 +543,10 @@ class TorchEvaluator(Evaluator):
|
|||
----------
|
||||
training_func
|
||||
The training function is used to train the model, note that this a entire optimization training loop.
|
||||
It should have three required parameters [model, optimizers, criterion] and three optional parameters [schedulers, max_steps, max_epochs].
|
||||
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``, it belongs to the ``optimizers`` pass to ``TorchEvaluator``.
|
||||
It should have three required parameters [model, optimizers, criterion]
|
||||
and three optional parameters [schedulers, max_steps, max_epochs].
|
||||
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``,
|
||||
it belongs to the ``optimizers`` pass to ``TorchEvaluator``.
|
||||
``criterion`` and ``schedulers`` are also belonging to the ``criterion`` and ``schedulers`` pass to ``TorchEvaluator``.
|
||||
``max_steps`` and ``max_epochs`` are used to control the training duration.
|
||||
|
||||
|
@ -574,7 +591,8 @@ class TorchEvaluator(Evaluator):
|
|||
Optional. The traced _LRScheduler instance which the lr scheduler class is wrapped by nni.trace.
|
||||
E.g. ``traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)``.
|
||||
dummy_input
|
||||
Optional. The dummy_input is used to trace the graph, the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``.
|
||||
Optional. The dummy_input is used to trace the graph,
|
||||
the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``.
|
||||
evaluating_func
|
||||
Optional. A function that input is model and return the evaluation metric.
|
||||
The return value can be a single float or a tuple (float, Any).
|
||||
|
@ -634,14 +652,16 @@ class TorchEvaluator(Evaluator):
|
|||
self._lr_scheduler_helpers = [LRSchedulerConstructHelper.from_trace(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers]
|
||||
optimizer_ids_map = {id(optimizer): i for i, optimizer in enumerate(self._tmp_optimizers)}
|
||||
# record i-th lr_scheduler scheduling j-th optimizer lr
|
||||
self._lrs_opt_map = {i: optimizer_ids_map[id(lr_scheduler.optimizer)] for i, lr_scheduler in enumerate(self._tmp_lr_schedulers)} # type: ignore
|
||||
self._lrs_opt_map = {i: optimizer_ids_map[id(lr_scheduler.optimizer)] # type: ignore
|
||||
for i, lr_scheduler in enumerate(self._tmp_lr_schedulers)} # type: ignore
|
||||
|
||||
delattr(self, '_tmp_optimizers')
|
||||
delattr(self, '_tmp_lr_schedulers')
|
||||
self._initialization_complete = True
|
||||
|
||||
def bind_model(self, model: Module, param_names_map: Dict[str, str] | None = None):
|
||||
assert self._initialization_complete is True, 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
|
||||
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
|
||||
assert self._initialization_complete is True, err_msg
|
||||
assert isinstance(model, Module)
|
||||
if self.model is not None:
|
||||
_logger.warning('Already bound a model, will unbind it before bind a new model.')
|
||||
|
@ -651,7 +671,8 @@ class TorchEvaluator(Evaluator):
|
|||
self._param_names_map = param_names_map
|
||||
# initialize optimizers & lr_schedulers for the bound model here
|
||||
self._optimizers = [helper.call(model, param_names_map) for helper in self._optimizer_helpers]
|
||||
self._lr_schedulers = [lrs_helper.call(self._optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
|
||||
self._lr_schedulers = [lrs_helper.call(self._optimizers[self._lrs_opt_map[i]]) \
|
||||
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
|
||||
self._first_optimizer_step = self._optimizers[0].step
|
||||
|
||||
def unbind_model(self):
|
||||
|
@ -717,7 +738,8 @@ class TorchEvaluator(Evaluator):
|
|||
if isinstance(metric, dict):
|
||||
nni_used_metric = metric.get('default', None)
|
||||
if nni_used_metric is None:
|
||||
warn_msg = f'Evaluation function returns a dict metric without key `default`, will return None as the model evaluation metric value.'
|
||||
warn_msg = f'Evaluation function returns a dict metric without key `default`,' + \
|
||||
'will return None as the model evaluation metric value.'
|
||||
_logger.warning(warn_msg)
|
||||
return nni_used_metric, metric
|
||||
else:
|
||||
|
|
|
@ -229,7 +229,8 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
|
|||
return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity
|
||||
|
||||
|
||||
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
|
||||
def get_model_weights_numel(model: Module, config_list: List[Dict],
|
||||
masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
|
||||
"""
|
||||
Count the layer weight elements number in config_list.
|
||||
If masks is not empty, the masked weight will not be counted.
|
||||
|
|
|
@ -53,18 +53,24 @@ class Scaling:
|
|||
# for the `-1` in kernel_size, then expand size (4, 3, 1) to size (4, 6, 2).
|
||||
kernel_padding_mode
|
||||
'front' or 'back', default is 'front'.
|
||||
If set 'front', for a given tensor when shrinking, padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`;
|
||||
for a given expand size when expanding, padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`.
|
||||
If set 'back', for a given tensor when shrinking, padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`;
|
||||
for a given expand size when expanding, padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`.
|
||||
If set 'front', for a given tensor when shrinking,
|
||||
padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`;
|
||||
for a given expand size when expanding,
|
||||
padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`.
|
||||
If set 'back', for a given tensor when shrinking,
|
||||
padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`;
|
||||
for a given expand size when expanding,
|
||||
padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size: List[int], kernel_padding_mode: Literal['front', 'back'] = 'front') -> None:
|
||||
self.kernel_size = kernel_size
|
||||
assert kernel_padding_mode in ['front', 'back'], f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."
|
||||
err_msg = f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."
|
||||
assert kernel_padding_mode in ['front', 'back'], err_msg
|
||||
self.kernel_padding_mode = kernel_padding_mode
|
||||
|
||||
def _padding(self, _list: List[int], length: int, padding_value: int = -1, padding_mode: Literal['front', 'back'] = 'back') -> List[int]:
|
||||
def _padding(self, _list: List[int], length: int, padding_value: int = -1,
|
||||
padding_mode: Literal['front', 'back'] = 'back') -> List[int]:
|
||||
"""
|
||||
Padding the `_list` to a specific length with `padding_value`.
|
||||
|
||||
|
@ -144,10 +150,12 @@ class Scaling:
|
|||
assert b % a == 0, f'Can not expand tensor with {target.shape} to {expand_size} with kernel size {kernel_size}.'
|
||||
_expand_size.append(b // a)
|
||||
_expand_size.append(a)
|
||||
new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim), [new_target] + [2 * _ + 1 for _ in range(len(expand_size))]) # type: ignore
|
||||
new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim),
|
||||
[new_target] + [2 * _ + 1 for _ in range(len(expand_size))]) # type: ignore
|
||||
|
||||
# step 3: expanding the new target to _expand_size and reshape to expand_size.
|
||||
# Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`, currently we don't have that need.
|
||||
# Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`,
|
||||
# currently we don't have that need.
|
||||
result = new_target.expand(_expand_size).reshape(expand_size).clone()
|
||||
|
||||
return result
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from nni.algorithms.compression.v2.pytorch import TorchEvaluator, LightningEvaluator
|
||||
from .speedup import ModelSpeedup
|
||||
from .compressor import Compressor, Pruner, Quantizer
|
||||
from .utils.apply_compression import apply_compression_results
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
"nni/algorithms/compression/pytorch",
|
||||
"nni/algorithms/compression/tensorflow",
|
||||
"nni/algorithms/compression/v2/pytorch/base/pruner.py",
|
||||
"nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py",
|
||||
"nni/algorithms/feature_engineering",
|
||||
"nni/algorithms/hpo",
|
||||
"nni/algorithms/nas",
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from .device import device
|
||||
from .simple_mnist import SimpleLightningModel, SimpleTorchModel
|
||||
from .utils import unfold_config_list
|
||||
|
||||
|
||||
log_dir = Path(__file__).parent.parent / 'logs'
|
||||
|
||||
|
||||
def create_model(model_type: str):
|
||||
torch_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5},
|
||||
{'op_names': ['conv1', 'conv2', 'conv3'], 'sparsity': 0.5},
|
||||
{'op_names': ['fc2'], 'exclude': True}]
|
||||
|
||||
lightning_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5},
|
||||
{'op_names': ['model.conv1', 'model.conv2', 'model.conv3'], 'sparsity': 0.5},
|
||||
{'op_names': ['model.fc2'], 'exclude': True}]
|
||||
|
||||
if model_type == 'lightning':
|
||||
model = SimpleLightningModel()
|
||||
config_list = lightning_config_list
|
||||
dummy_input = torch.rand(8, 1, 28, 28)
|
||||
elif model_type == 'pytorch':
|
||||
model = SimpleTorchModel().to(device)
|
||||
config_list = torch_config_list
|
||||
dummy_input = torch.rand(8, 1, 28, 28, device=device)
|
||||
else:
|
||||
raise ValueError(f'wrong model_type: {model_type}')
|
||||
return model, config_list, dummy_input
|
||||
|
||||
|
||||
def validate_masks(masks: Dict[str, Dict[str, torch.Tensor]], model: torch.nn.Module, config_list: List[Dict[str, Any]],
|
||||
is_global: bool = False):
|
||||
config_dict = unfold_config_list(model, config_list)
|
||||
# validate if all configured layers have generated mask.
|
||||
mismatched_op_names = set(config_dict.keys()).symmetric_difference(masks.keys())
|
||||
assert f'mismatched op_names: {mismatched_op_names}'
|
||||
|
||||
target_name = 'weight'
|
||||
total_masked_numel = 0
|
||||
total_target_numel = 0
|
||||
for module_name, target_masks in masks.items():
|
||||
mask = target_masks[target_name]
|
||||
assert mask.numel() == (mask == 0).sum().item() + (mask == 1).sum().item(), f'{module_name} {target_name} mask has values other than 0 and 1.'
|
||||
if not is_global:
|
||||
excepted_sparsity = config_dict[module_name].get('sparsity', config_dict[module_name].get('total_sparsity'))
|
||||
real_sparsity = (mask == 0).sum().item() / mask.numel()
|
||||
err_msg = f'{module_name} {target_name} excepted sparsity: {excepted_sparsity}, but real sparsity: {real_sparsity}'
|
||||
assert excepted_sparsity * 0.9 < real_sparsity < excepted_sparsity * 1.1, err_msg
|
||||
else:
|
||||
total_masked_numel += (mask == 0).sum().item()
|
||||
total_target_numel += mask.numel()
|
||||
if is_global:
|
||||
excepted_sparsity = next(iter(config_dict.values())).get('sparsity', config_dict[module_name].get('total_sparsity'))
|
||||
real_sparsity = total_masked_numel / total_target_numel
|
||||
err_msg = f'excepted global sparsity: {excepted_sparsity}, but real global sparsity: {real_sparsity}.'
|
||||
assert excepted_sparsity * 0.9 < real_sparsity < excepted_sparsity * 1.1, err_msg
|
||||
|
||||
|
||||
def validate_dependency_aware(model_type: str, masks: Dict[str, Dict[str, torch.Tensor]]):
|
||||
# only for simple_mnist model
|
||||
if model_type == 'lightning':
|
||||
assert torch.equal(masks['model.conv2']['weight'].mean([1, 2, 3]), masks['model.conv3']['weight'].mean([1, 2, 3]))
|
||||
if model_type == 'pytorch':
|
||||
assert torch.equal(masks['conv2']['weight'].mean([1, 2, 3]), masks['conv3']['weight'].mean([1, 2, 3]))
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .simple_lightning_model import SimpleLightningModel, MNISTDataModule
|
||||
from .simple_torch_model import SimpleTorchModel, training_model, evaluating_model, finetuning_model
|
||||
from .simple_evaluator import create_lighting_evaluator, create_pytorch_evaluator
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
import nni
|
||||
from nni.algorithms.compression.v2.pytorch import LightningEvaluator, TorchEvaluator
|
||||
|
||||
from .simple_torch_model import training_model, evaluating_model
|
||||
from .simple_lightning_model import MNISTDataModule
|
||||
from ..common import device
|
||||
|
||||
|
||||
def create_lighting_evaluator() -> LightningEvaluator:
|
||||
pl_trainer = nni.trace(pl.Trainer)(
|
||||
accelerator='auto',
|
||||
devices=1,
|
||||
max_epochs=1,
|
||||
max_steps=50,
|
||||
logger=TensorBoardLogger(Path(__file__).parent.parent / 'lightning_logs', name="resnet"),
|
||||
)
|
||||
pl.Trainer()
|
||||
pl_trainer.num_sanity_val_steps = 0
|
||||
pl_data = nni.trace(MNISTDataModule)(data_dir='data/mnist')
|
||||
evaluator = LightningEvaluator(pl_trainer, pl_data, dummy_input=torch.rand(8, 1, 28, 28))
|
||||
return evaluator
|
||||
|
||||
|
||||
def create_pytorch_evaluator(model: torch.nn.Module) -> TorchEvaluator:
|
||||
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
|
||||
evaluator = TorchEvaluator(training_model, optimizer, F.nll_loss, lr_scheduler,
|
||||
dummy_input=torch.rand(8, 1, 28, 28, device=device), evaluating_func=evaluating_model)
|
||||
return evaluator
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
from torchmetrics.functional import accuracy
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
import nni
|
||||
|
||||
from .simple_torch_model import SimpleTorchModel
|
||||
|
||||
|
||||
class SimpleLightningModel(pl.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = SimpleTorchModel()
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def evaluate(self, batch, stage=None):
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = accuracy(preds, y)
|
||||
|
||||
if stage:
|
||||
self.log(f"default", loss, prog_bar=False)
|
||||
self.log(f"{stage}_loss", loss, prog_bar=True)
|
||||
self.log(f"{stage}_acc", acc, prog_bar=True)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.evaluate(batch, "val")
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
self.evaluate(batch, "test")
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = nni.trace(torch.optim.SGD)(
|
||||
self.parameters(),
|
||||
lr=0.01,
|
||||
momentum=0.9,
|
||||
weight_decay=5e-4,
|
||||
)
|
||||
scheduler_dict = {
|
||||
"scheduler": nni.trace(ExponentialLR)(
|
||||
optimizer,
|
||||
0.1,
|
||||
),
|
||||
"interval": "epoch",
|
||||
}
|
||||
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
|
||||
|
||||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def __init__(self, data_dir: str = "./"):
|
||||
super().__init__()
|
||||
self.data_dir = 'data/mnist'
|
||||
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
|
||||
def prepare_data(self):
|
||||
# download
|
||||
MNIST(self.data_dir, train=True, download=True)
|
||||
MNIST(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: str | None = None):
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
if stage == "fit" or stage is None:
|
||||
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
if stage == "test" or stage is None:
|
||||
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
|
||||
|
||||
if stage == "predict" or stage is None:
|
||||
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.mnist_train, batch_size=32)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.mnist_val, batch_size=32)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=32)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.mnist_predict, batch_size=32)
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
from ..device import device
|
||||
|
||||
|
||||
class SimpleTorchModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 16, 3)
|
||||
self.bn1 = torch.nn.BatchNorm2d(16)
|
||||
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4)
|
||||
self.bn2 = torch.nn.BatchNorm2d(8)
|
||||
self.conv3 = torch.nn.Conv2d(16, 8, 3)
|
||||
self.bn3 = torch.nn.BatchNorm2d(8)
|
||||
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100)
|
||||
self.fc2 = torch.nn.Linear(100, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.bn1(self.conv1(x))
|
||||
x = self.bn2(self.conv2(x)) + self.bn3(self.conv3(x))
|
||||
x = self.fc2(self.fc1(x.reshape(x.shape[0], -1)))
|
||||
return F.log_softmax(x, -1)
|
||||
|
||||
|
||||
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler = None,
|
||||
max_steps: int | None = None, max_epochs: int | None = None, device: torch.device = device):
|
||||
model.train()
|
||||
|
||||
# prepare data
|
||||
MNIST(root='data/mnist', train=True, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
|
||||
train_dataloader = DataLoader(mnist_train, batch_size=32)
|
||||
|
||||
max_epochs = max_epochs if max_epochs else 1
|
||||
max_steps = max_steps if max_steps else 50
|
||||
current_steps = 0
|
||||
|
||||
# training
|
||||
for _ in range(max_epochs):
|
||||
for x, y in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
loss: torch.Tensor = criterion(logits, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
current_steps += 1
|
||||
if max_steps and current_steps == max_steps:
|
||||
return
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
|
||||
def finetuning_model(model: Module):
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
training_model(model, optimizer, F.nll_loss)
|
||||
|
||||
|
||||
def evaluating_model(model: Module, device: torch.device = device):
|
||||
model.eval()
|
||||
|
||||
# prepare data
|
||||
MNIST(root='data/mnist', train=False, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
|
||||
test_dataloader = DataLoader(mnist_test, batch_size=32)
|
||||
|
||||
# testing
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for x, y in test_dataloader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
correct += preds.eq(y.view_as(preds)).sum().item()
|
||||
return correct / len(mnist_test)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
######################################################################################
|
||||
# NOTE: copy from branch wrapper-refactor, will rm this file in this or next release.#
|
||||
######################################################################################
|
||||
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _unfold_op_partial_names(model: Module, config_list: List[Dict]) -> List[Dict]:
|
||||
config_list = deepcopy(config_list)
|
||||
full_op_names = [op_name for op_name, _ in model.named_modules()]
|
||||
for config in config_list:
|
||||
op_names = config.pop('op_names', [])
|
||||
op_partial_names = config.pop('op_partial_names', [])
|
||||
for op_partial_name in op_partial_names:
|
||||
op_names.extend([op_name for op_name in full_op_names if op_partial_name in op_name])
|
||||
config['op_names'] = list(set(op_names))
|
||||
return config_list
|
||||
|
||||
|
||||
def unfold_config_list(model: Module, config_list: List[Dict]) -> Dict[str, Dict[str, Any]]:
|
||||
'''
|
||||
Unfold config_list to op_names level, return a config_dict {op_name: config}.
|
||||
'''
|
||||
config_list = _unfold_op_partial_names(model=model, config_list=config_list)
|
||||
config_dict = {}
|
||||
for config in config_list:
|
||||
for key in ['op_types', 'op_names', 'exclude_op_names']:
|
||||
config.setdefault(key, [])
|
||||
op_names = []
|
||||
for module_name, module in model.named_modules():
|
||||
module_type = type(module).__name__
|
||||
if (module_type in config['op_types'] or module_name in config['op_names']) and module_name not in config['exclude_op_names']:
|
||||
op_names.append(module_name)
|
||||
config_template = deepcopy(config)
|
||||
for key in ['op_types', 'op_names', 'exclude_op_names']:
|
||||
config_template.pop(key, [])
|
||||
for op_name in op_names:
|
||||
if op_name in config_dict:
|
||||
warn_msg = f'{op_name} duplicate definition of config, replace old config:\n' + \
|
||||
f'{config_dict[op_name]}\n' + \
|
||||
f'with new config:\n{config_template}\n'
|
||||
_logger.warning(warn_msg)
|
||||
config_dict[op_name] = deepcopy(config_template)
|
||||
return config_dict
|
|
@ -3,193 +3,23 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
import pytest
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
from torchmetrics.functional import accuracy
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
import nni
|
||||
from nni.algorithms.compression.v2.pytorch.utils.evaluator import (
|
||||
TorchEvaluator,
|
||||
LightningEvaluator,
|
||||
TensorHook,
|
||||
ForwardHook,
|
||||
BackwardHook,
|
||||
)
|
||||
|
||||
|
||||
class SimpleTorchModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 16, 3)
|
||||
self.bn1 = torch.nn.BatchNorm2d(16)
|
||||
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4)
|
||||
self.bn2 = torch.nn.BatchNorm2d(8)
|
||||
self.conv3 = torch.nn.Conv2d(16, 8, 3)
|
||||
self.bn3 = torch.nn.BatchNorm2d(8)
|
||||
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100)
|
||||
self.fc2 = torch.nn.Linear(100, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.bn1(self.conv1(x))
|
||||
x = self.bn2(self.conv2(x)) + self.bn3(self.conv3(x))
|
||||
x = self.fc2(self.fc1(x.reshape(x.shape[0], -1)))
|
||||
return F.log_softmax(x, -1)
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler,
|
||||
max_steps: int | None = None, max_epochs: int | None = None):
|
||||
model.train()
|
||||
|
||||
# prepare data
|
||||
data_dir = Path(__file__).parent / 'data'
|
||||
MNIST(data_dir, train=True, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_train = MNIST(data_dir, train=True, transform=transform)
|
||||
train_dataloader = DataLoader(mnist_train, batch_size=32)
|
||||
|
||||
max_epochs = max_epochs if max_epochs else 1
|
||||
max_steps = max_steps if max_steps else 10
|
||||
current_steps = 0
|
||||
|
||||
# training
|
||||
for _ in range(max_epochs):
|
||||
for x, y in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
loss: torch.Tensor = criterion(logits, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
current_steps += 1
|
||||
if max_steps and current_steps == max_steps:
|
||||
return
|
||||
scheduler.step()
|
||||
|
||||
|
||||
def evaluating_model(model: Module):
|
||||
model.eval()
|
||||
|
||||
# prepare data
|
||||
data_dir = Path(__file__).parent / 'data'
|
||||
MNIST(data_dir, train=False, download=True)
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
mnist_test = MNIST(data_dir, train=False, transform=transform)
|
||||
test_dataloader = DataLoader(mnist_test, batch_size=32)
|
||||
|
||||
# testing
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for x, y in test_dataloader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
logits = model(x)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
correct += preds.eq(y.view_as(preds)).sum().item()
|
||||
return correct / len(mnist_test)
|
||||
|
||||
|
||||
class SimpleLightningModel(pl.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = SimpleTorchModel()
|
||||
self.count = 0
|
||||
|
||||
def forward(self, x):
|
||||
print(self.count)
|
||||
self.count += 1
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def evaluate(self, batch, stage=None):
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
acc = accuracy(preds, y)
|
||||
|
||||
if stage:
|
||||
self.log(f"{stage}_loss", loss, prog_bar=True)
|
||||
self.log(f"{stage}_acc", acc, prog_bar=True)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.evaluate(batch, "val")
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
self.evaluate(batch, "test")
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = nni.trace(torch.optim.SGD)(
|
||||
self.parameters(),
|
||||
lr=0.01,
|
||||
momentum=0.9,
|
||||
weight_decay=5e-4,
|
||||
)
|
||||
scheduler_dict = {
|
||||
"scheduler": nni.trace(ExponentialLR)(
|
||||
optimizer,
|
||||
0.1,
|
||||
),
|
||||
"interval": "epoch",
|
||||
}
|
||||
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
|
||||
|
||||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def __init__(self, data_dir: str = "./"):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
||||
|
||||
def prepare_data(self):
|
||||
# download
|
||||
MNIST(self.data_dir, train=True, download=True)
|
||||
MNIST(self.data_dir, train=False, download=True)
|
||||
|
||||
def setup(self, stage: str | None = None):
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
if stage == "fit" or stage is None:
|
||||
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
if stage == "test" or stage is None:
|
||||
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
|
||||
|
||||
if stage == "predict" or stage is None:
|
||||
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.mnist_train, batch_size=32)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.mnist_val, batch_size=32)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=32)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.mnist_predict, batch_size=32)
|
||||
from ..assets.device import device
|
||||
from ..assets.simple_mnist import (
|
||||
SimpleLightningModel,
|
||||
SimpleTorchModel,
|
||||
create_lighting_evaluator,
|
||||
create_pytorch_evaluator
|
||||
)
|
||||
|
||||
|
||||
optimizer_before_step_flag = False
|
||||
|
@ -237,41 +67,20 @@ def assert_flags():
|
|||
assert loss_flag, 'Evaluator patch loss failed.'
|
||||
|
||||
|
||||
def create_lighting_evaluator():
|
||||
pl_model = SimpleLightningModel()
|
||||
pl_trainer = nni.trace(pl.Trainer)(
|
||||
max_epochs=1,
|
||||
max_steps=10,
|
||||
logger=TensorBoardLogger(Path(__file__).parent / 'lightning_logs', name="resnet"),
|
||||
)
|
||||
pl_trainer.num_sanity_val_steps = 0
|
||||
pl_data = nni.trace(MNISTDataModule)(data_dir=Path(__file__).parent / 'data')
|
||||
evaluator = LightningEvaluator(pl_trainer, pl_data)
|
||||
evaluator._init_optimizer_helpers(pl_model)
|
||||
return evaluator
|
||||
|
||||
|
||||
def create_pytorch_evaluator():
|
||||
model = SimpleTorchModel()
|
||||
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
|
||||
evaluator = TorchEvaluator(training_model, optimizer, F.nll_loss, lr_scheduler, evaluating_func=evaluating_model)
|
||||
evaluator._init_optimizer_helpers(model)
|
||||
return evaluator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("evaluator_type", ['lightning', 'pytorch'])
|
||||
def test_evaluator(evaluator_type: str):
|
||||
if evaluator_type == 'lightning':
|
||||
evaluator = create_lighting_evaluator()
|
||||
model = SimpleLightningModel()
|
||||
evaluator = create_lighting_evaluator()
|
||||
evaluator._init_optimizer_helpers(model)
|
||||
evaluator.bind_model(model)
|
||||
tensor_hook = TensorHook(model.model.conv1.weight, 'model.conv1.weight', tensor_hook_factory)
|
||||
forward_hook = ForwardHook(model.model.conv1, 'model.conv1', forward_hook_factory)
|
||||
backward_hook = BackwardHook(model.model.conv1, 'model.conv1', backward_hook_factory)
|
||||
elif evaluator_type == 'pytorch':
|
||||
evaluator = create_pytorch_evaluator()
|
||||
model = SimpleTorchModel().to(device)
|
||||
evaluator = create_pytorch_evaluator(model)
|
||||
evaluator._init_optimizer_helpers(model)
|
||||
evaluator.bind_model(model)
|
||||
tensor_hook = TensorHook(model.conv1.weight, 'conv1.weight', tensor_hook_factory)
|
||||
forward_hook = ForwardHook(model.conv1, 'conv1', forward_hook_factory)
|
||||
|
@ -296,4 +105,4 @@ def test_evaluator(evaluator_type: str):
|
|||
|
||||
evaluator.finetune()
|
||||
assert_flags()
|
||||
assert all([len(hook.buffer) == 10 for hook in [tensor_hook, forward_hook, backward_hook]])
|
||||
assert all([len(hook.buffer) == 50 for hook in [tensor_hook, forward_hook, backward_hook]])
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import nni
|
||||
from nni.compression.pytorch.pruning import (
|
||||
LinearPruner,
|
||||
AGPPruner,
|
||||
LotteryTicketPruner,
|
||||
SimulatedAnnealingPruner,
|
||||
AutoCompressPruner
|
||||
)
|
||||
|
||||
from ..assets.common import create_model, log_dir, validate_masks, validate_dependency_aware
|
||||
from ..assets.device import device
|
||||
from ..assets.simple_mnist import (
|
||||
create_lighting_evaluator,
|
||||
create_pytorch_evaluator,
|
||||
training_model,
|
||||
finetuning_model,
|
||||
evaluating_model
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('using_evaluator', [True, False])
|
||||
@pytest.mark.parametrize('pruning_type', ['linear', 'agp', 'lottory'])
|
||||
@pytest.mark.parametrize('speedup', [True, False])
|
||||
def test_functional_pruner(model_type: str, using_evaluator: bool, pruning_type: str, speedup: bool):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
if using_evaluator:
|
||||
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
|
||||
if pruning_type == 'linear':
|
||||
pruner = LinearPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
|
||||
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
|
||||
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
|
||||
elif pruning_type == 'agp':
|
||||
pruner = AGPPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
|
||||
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
|
||||
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
|
||||
elif pruning_type == 'lottory':
|
||||
pruner = LotteryTicketPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
|
||||
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
|
||||
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
|
||||
else:
|
||||
model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
if pruning_type == 'linear':
|
||||
pruner = LinearPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
|
||||
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
|
||||
elif pruning_type == 'agp':
|
||||
pruner = AGPPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
|
||||
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
|
||||
elif pruning_type == 'lottory':
|
||||
pruner = LotteryTicketPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
|
||||
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
|
||||
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
|
||||
|
||||
pruner.compress()
|
||||
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
|
||||
best_model(dummy_input)
|
||||
validate_masks(best_masks, best_model, config_list)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('using_evaluator', [True, False])
|
||||
def test_sa_pruner(model_type: str, using_evaluator: bool):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
if using_evaluator:
|
||||
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
|
||||
pruner = SimulatedAnnealingPruner(model=model, config_list=config_list, evaluator=evaluator, start_temperature=100,
|
||||
stop_temperature=80, cool_down_rate=0.9, perturbation_magnitude=0.35, pruning_algorithm='l1',
|
||||
pruning_params={}, log_dir=log_dir, keep_intermediate_result=False, speedup=False)
|
||||
else:
|
||||
model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
pruner = SimulatedAnnealingPruner(model=model, config_list=config_list, evaluator=evaluating_model, start_temperature=100,
|
||||
stop_temperature=80, cool_down_rate=0.9, perturbation_magnitude=0.35, pruning_algorithm='l1',
|
||||
pruning_params={}, log_dir=log_dir, keep_intermediate_result=False, speedup=False)
|
||||
|
||||
pruner.compress()
|
||||
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
|
||||
best_model(dummy_input)
|
||||
validate_masks(best_masks, best_model, config_list)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('using_evaluator', [True, False])
|
||||
def test_auto_compress_pruner(model_type: str, using_evaluator: bool):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
if using_evaluator:
|
||||
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
|
||||
admm_params = {'evaluator': evaluator, 'iterations': 2, 'training_epochs': 1, 'granularity': 'coarse-grained'}
|
||||
sa_params = {'evaluator': evaluator, 'start_temperature': 100, 'stop_temperature': 80, 'pruning_algorithm': 'l1'}
|
||||
pruner = AutoCompressPruner(model=model, config_list=config_list, total_iteration=2, admm_params=admm_params, sa_params=sa_params,
|
||||
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=False)
|
||||
else:
|
||||
model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
admm_params = {'trainer': training_model, 'traced_optimizer': optimizer, 'criterion': F.nll_loss, 'iterations': 2,
|
||||
'training_epochs': 1, 'granularity': 'coarse-grained'}
|
||||
sa_params = {'evaluator': evaluating_model, 'start_temperature': 100, 'stop_temperature': 80, 'pruning_algorithm': 'l1'}
|
||||
pruner = AutoCompressPruner(model=model, config_list=config_list, total_iteration=2, admm_params=admm_params, sa_params=sa_params,
|
||||
log_dir=log_dir, keep_intermediate_result=False, finetuner=finetuning_model, speedup=False,
|
||||
dummy_input=dummy_input, evaluator=evaluating_model)
|
||||
|
||||
pruner.compress()
|
||||
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
|
||||
best_model(dummy_input)
|
||||
validate_masks(best_masks, best_model, config_list)
|
||||
|
||||
|
||||
# we still need AMCPruner test, but it cost a lot, will add after we have GPU pool.
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import nni
|
||||
from nni.compression.pytorch.pruning import (
|
||||
LevelPruner,
|
||||
L1NormPruner,
|
||||
L2NormPruner,
|
||||
SlimPruner,
|
||||
FPGMPruner,
|
||||
ActivationAPoZRankPruner,
|
||||
ActivationMeanRankPruner,
|
||||
TaylorFOWeightPruner,
|
||||
ADMMPruner,
|
||||
MovementPruner
|
||||
)
|
||||
|
||||
from ..assets.device import device
|
||||
from ..assets.simple_mnist import (
|
||||
create_lighting_evaluator,
|
||||
create_pytorch_evaluator,
|
||||
training_model
|
||||
)
|
||||
from ..assets.common import create_model, validate_masks, validate_dependency_aware
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
def test_level_pruner(model_type: str):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
pruner = LevelPruner(model=model, config_list=config_list)
|
||||
|
||||
_, masks = pruner.compress()
|
||||
model(dummy_input)
|
||||
pruner._unwrap_model()
|
||||
validate_masks(masks, model, config_list)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('pruning_type', ['l1', 'l2', 'fpgm'])
|
||||
@pytest.mark.parametrize('mode', ['normal', 'dependency_aware'])
|
||||
def test_norm_pruner(model_type: str, pruning_type: str, mode: str):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
if pruning_type == 'l1':
|
||||
pruner = L1NormPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
|
||||
elif pruning_type == 'l2':
|
||||
pruner = L2NormPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
|
||||
elif pruning_type == 'fpgm':
|
||||
pruner = FPGMPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
|
||||
else:
|
||||
raise ValueError(f'wrong norm: {pruning_type}')
|
||||
|
||||
_, masks = pruner.compress()
|
||||
model(dummy_input)
|
||||
pruner._unwrap_model()
|
||||
validate_masks(masks, model, config_list)
|
||||
if mode == 'dependency_aware':
|
||||
validate_dependency_aware(model_type, masks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('using_evaluator', [True, False])
|
||||
@pytest.mark.parametrize('mode', ['global', 'normal'])
|
||||
def test_slim_pruner(model_type: str, using_evaluator: bool, mode: str):
|
||||
model, _, dummy_input = create_model(model_type)
|
||||
config_list = [{'op_types': ['BatchNorm2d'], 'total_sparsity': 0.5}]
|
||||
|
||||
if using_evaluator:
|
||||
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
|
||||
pruner = SlimPruner(model=model, config_list=config_list, evaluator=evaluator, training_epochs=1, scale=0.0001, mode=mode)
|
||||
else:
|
||||
model = model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
pruner = SlimPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
|
||||
criterion=F.nll_loss, training_epochs=1, scale=0.0001, mode=mode)
|
||||
|
||||
_, masks = pruner.compress()
|
||||
model(dummy_input)
|
||||
pruner._unwrap_model()
|
||||
validate_masks(masks, model, config_list, is_global=(mode == 'global'))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('pruning_type', ['apoz', 'mean', 'taylor'])
|
||||
@pytest.mark.parametrize('using_evaluator', [True, False])
|
||||
@pytest.mark.parametrize('mode', ['normal', 'dependency_aware'])
|
||||
def test_hook_based_pruner(model_type: str, pruning_type: str, using_evaluator: bool, mode: str):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
if using_evaluator:
|
||||
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
|
||||
if pruning_type == 'apoz':
|
||||
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
|
||||
activation='relu', mode=mode, dummy_input=dummy_input)
|
||||
elif pruning_type == 'mean':
|
||||
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
|
||||
activation='relu', mode=mode, dummy_input=dummy_input)
|
||||
elif pruning_type == 'taylor':
|
||||
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
|
||||
mode=mode, dummy_input=dummy_input)
|
||||
else:
|
||||
model = model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
if pruning_type == 'apoz':
|
||||
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
|
||||
criterion=F.nll_loss, training_batches=20, activation='relu', mode=mode, dummy_input=dummy_input)
|
||||
elif pruning_type == 'mean':
|
||||
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
|
||||
criterion=F.nll_loss, training_batches=20, activation='relu', mode=mode, dummy_input=dummy_input)
|
||||
elif pruning_type == 'taylor':
|
||||
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
|
||||
criterion=F.nll_loss, training_batches=20, mode=mode, dummy_input=dummy_input)
|
||||
|
||||
_, masks = pruner.compress()
|
||||
model(dummy_input)
|
||||
pruner._unwrap_model()
|
||||
validate_masks(masks, model, config_list)
|
||||
if mode == 'dependency_aware':
|
||||
validate_dependency_aware(model_type, masks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('using_evaluator', [True, False])
|
||||
@pytest.mark.parametrize('granularity', ['fine-grained', 'coarse-grained'])
|
||||
def test_admm_pruner(model_type: str, using_evaluator: bool, granularity: str):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
if using_evaluator:
|
||||
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
|
||||
pruner = ADMMPruner(model=model, config_list=config_list, evaluator=evaluator, iterations=2, training_epochs=1, granularity=granularity)
|
||||
else:
|
||||
model = model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
pruner = ADMMPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer, criterion=F.nll_loss,
|
||||
iterations=2, training_epochs=1, granularity=granularity)
|
||||
|
||||
_, masks = pruner.compress()
|
||||
model(dummy_input)
|
||||
pruner._unwrap_model()
|
||||
validate_masks(masks, model, config_list)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
|
||||
@pytest.mark.parametrize('using_evaluator', [True, False])
|
||||
def test_movement_pruner(model_type: str, using_evaluator: bool):
|
||||
model, config_list, dummy_input = create_model(model_type)
|
||||
|
||||
if using_evaluator:
|
||||
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
|
||||
pruner = MovementPruner(model=model, config_list=config_list, evaluator=evaluator, training_epochs=1, warm_up_step=10, cool_down_beginning_step=40)
|
||||
else:
|
||||
model = model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
||||
pruner = MovementPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer, criterion=F.nll_loss,
|
||||
training_epochs=1, warm_up_step=10, cool_down_beginning_step=40)
|
||||
|
||||
_, masks = pruner.compress()
|
||||
model(dummy_input)
|
||||
pruner._unwrap_model()
|
||||
validate_masks(masks, model, config_list)
|
|
@ -8,14 +8,20 @@ import torch.nn.functional as F
|
|||
|
||||
import nni
|
||||
from nni.algorithms.compression.v2.pytorch.base import Pruner
|
||||
# TODO: remove in nni v3.0.
|
||||
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
|
||||
WeightDataCollector,
|
||||
WeightTrainerBasedDataCollector,
|
||||
SingleHookTrainerBasedDataCollector
|
||||
)
|
||||
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
|
||||
TargetDataCollector,
|
||||
EvaluatorBasedTargetDataCollector,
|
||||
EvaluatorBasedHookDataCollector
|
||||
)
|
||||
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
|
||||
NormMetricsCalculator,
|
||||
MultiDataNormMetricsCalculator,
|
||||
HookDataNormMetricsCalculator,
|
||||
DistMetricsCalculator,
|
||||
APoZRankMetricsCalculator,
|
||||
MeanRankMetricsCalculator
|
||||
|
@ -84,7 +90,7 @@ class PruningToolsTestCase(unittest.TestCase):
|
|||
# Test WeightDataCollector
|
||||
data_collector = WeightDataCollector(pruner)
|
||||
data = data_collector.collect()
|
||||
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
|
||||
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]['weight']) for module_name in ['conv1', 'conv2'])
|
||||
|
||||
# Test WeightTrainerBasedDataCollector
|
||||
def opt_after():
|
||||
|
@ -94,8 +100,8 @@ class PruningToolsTestCase(unittest.TestCase):
|
|||
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
|
||||
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 1, opt_after_tasks=[opt_after])
|
||||
data = data_collector.collect()
|
||||
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
|
||||
assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values())
|
||||
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]['weight']) for module_name in ['conv1', 'conv2'])
|
||||
assert all(t['weight'].numel() == (t['weight'] == 1).type_as(t['weight']).sum().item() for t in data.values())
|
||||
|
||||
# Test SingleHookTrainerBasedDataCollector
|
||||
def _collector(buffer, weight_tensor):
|
||||
|
@ -109,73 +115,73 @@ class PruningToolsTestCase(unittest.TestCase):
|
|||
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
|
||||
data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 2, collector_infos=[collector_info])
|
||||
data = data_collector.collect()
|
||||
assert all(len(t) == 2 for t in data.values())
|
||||
assert all(len(t['weight']) == 2 for t in data.values())
|
||||
|
||||
def test_metrics_calculator(self):
|
||||
# Test NormMetricsCalculator
|
||||
metrics_calculator = NormMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
|
||||
data = {
|
||||
'1': torch.ones(3, 3, 3),
|
||||
'2': torch.ones(4, 4) * 2
|
||||
'1': {'target_name': torch.ones(3, 3, 3)},
|
||||
'2': {'target_name': torch.ones(4, 4) * 2}
|
||||
}
|
||||
result = {
|
||||
'1': torch.ones(3) * 3,
|
||||
'2': torch.ones(4) * 4
|
||||
'1': {'target_name': torch.ones(3) * 3},
|
||||
'2': {'target_name': torch.ones(4) * 4}
|
||||
}
|
||||
metrics = metrics_calculator.calculate_metrics(data)
|
||||
assert all(torch.equal(result[k], v) for k, v in metrics.items())
|
||||
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
|
||||
|
||||
# Test DistMetricsCalculator
|
||||
metrics_calculator = DistMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
|
||||
data = {
|
||||
'1': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32),
|
||||
'2': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)
|
||||
'1': {'target_name': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32)},
|
||||
'2': {'target_name': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)}
|
||||
}
|
||||
result = {
|
||||
'1': torch.tensor([5, 5], dtype=torch.float32),
|
||||
'2': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32))
|
||||
'1': {'target_name': torch.tensor([5, 5], dtype=torch.float32)},
|
||||
'2': {'target_name': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32))}
|
||||
}
|
||||
metrics = metrics_calculator.calculate_metrics(data)
|
||||
assert all(torch.equal(result[k], v) for k, v in metrics.items())
|
||||
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
|
||||
|
||||
# Test MultiDataNormMetricsCalculator
|
||||
metrics_calculator = MultiDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
|
||||
# Test HookDataNormMetricsCalculator
|
||||
metrics_calculator = HookDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
|
||||
data = {
|
||||
'1': [2, torch.ones(3, 3, 3) * 2],
|
||||
'2': [2, torch.ones(4, 4) * 2]
|
||||
'1': {'target_name': [2, torch.ones(3, 3, 3) * 2]},
|
||||
'2': {'target_name': [2, torch.ones(4, 4) * 2]}
|
||||
}
|
||||
result = {
|
||||
'1': torch.ones(3) * 18,
|
||||
'2': torch.ones(4) * 8
|
||||
'1': {'target_name': torch.ones(3) * 18},
|
||||
'2': {'target_name': torch.ones(4) * 8}
|
||||
}
|
||||
metrics = metrics_calculator.calculate_metrics(data)
|
||||
assert all(torch.equal(result[k], v) for k, v in metrics.items())
|
||||
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
|
||||
|
||||
# Test APoZRankMetricsCalculator
|
||||
metrics_calculator = APoZRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
|
||||
data = {
|
||||
'1': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)],
|
||||
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
|
||||
'1': {'target_name': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)]},
|
||||
'2': {'target_name': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]}
|
||||
}
|
||||
result = {
|
||||
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
|
||||
'2': torch.tensor([1, 1, 0.75], dtype=torch.float32)
|
||||
'1': {'target_name': torch.tensor([0.5, 0.5], dtype=torch.float32)},
|
||||
'2': {'target_name': torch.tensor([1, 1, 0.75], dtype=torch.float32)}
|
||||
}
|
||||
metrics = metrics_calculator.calculate_metrics(data)
|
||||
assert all(torch.equal(result[k], v) for k, v in metrics.items())
|
||||
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
|
||||
|
||||
# Test MeanRankMetricsCalculator
|
||||
metrics_calculator = MeanRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
|
||||
data = {
|
||||
'1': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
|
||||
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
|
||||
'1': {'target_name': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)]},
|
||||
'2': {'target_name': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]}
|
||||
}
|
||||
result = {
|
||||
'1': torch.tensor([0.25, 0.25], dtype=torch.float32),
|
||||
'2': torch.tensor([0, 0, 0.25], dtype=torch.float32)
|
||||
'1': {'target_name': torch.tensor([0.25, 0.25], dtype=torch.float32)},
|
||||
'2': {'target_name': torch.tensor([0, 0, 0.25], dtype=torch.float32)}
|
||||
}
|
||||
metrics = metrics_calculator.calculate_metrics(data)
|
||||
assert all(torch.equal(result[k], v) for k, v in metrics.items())
|
||||
assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
|
||||
|
||||
def test_sparsity_allocator(self):
|
||||
# Test NormalSparsityAllocator
|
||||
|
@ -183,8 +189,8 @@ class PruningToolsTestCase(unittest.TestCase):
|
|||
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
|
||||
pruner = Pruner(model, config_list)
|
||||
metrics = {
|
||||
'conv1': torch.rand(5, 1, 5, 5),
|
||||
'conv2': torch.rand(10, 5, 5, 5)
|
||||
'conv1': {'weight': torch.rand(5, 1, 5, 5)},
|
||||
'conv2': {'weight': torch.rand(10, 5, 5, 5)}
|
||||
}
|
||||
sparsity_allocator = NormalSparsityAllocator(pruner)
|
||||
masks = sparsity_allocator.generate_sparsity(metrics)
|
||||
|
|
Загрузка…
Ссылка в новой задаче