зеркало из https://github.com/microsoft/nni.git
[Compression] pruning stage 3: add scheduled/movement pruner (#5389)
This commit is contained in:
Родитель
80017c015b
Коммит
34f3a14d4a
|
@ -42,7 +42,7 @@ def _fill_one_on_dims(mask: torch.Tensor, dims: int | List[int]) -> torch.Tensor
|
|||
continue
|
||||
dim_mask = (mask.sum([_ for _ in range(len(mask.shape)) if _ != i]) == 0.)
|
||||
new_mask = new_mask.transpose(0, i)
|
||||
new_mask[torch.arange(len(dim_mask))[dim_mask].long().tolist()] = 0.
|
||||
new_mask[torch.arange(len(dim_mask), device=new_mask.device)[dim_mask].long().tolist()] = 0.
|
||||
new_mask = new_mask.transpose(0, i)
|
||||
return new_mask
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .basic_pruner import LevelPruner, L1NormPruner, L2NormPruner
|
||||
from .movement_pruner import MovementPruner
|
||||
from .scheduled_pruner import LinearPruner, AGPPruner
|
||||
from .slim_pruner import SlimPruner
|
||||
from .taylor_pruner import TaylorPruner
|
|
@ -37,9 +37,11 @@ class _NormPruner(Pruner):
|
|||
# `skip_first_step` controls if generating masks at the first step.
|
||||
# `interval_steps` is the optimize step interval for generating masks.
|
||||
# `total_times` is the total generation times of masks.
|
||||
self.first_step_gen = False
|
||||
self.interval_steps = -1
|
||||
self.total_times: int | Literal['unlimited'] = 1
|
||||
# here is a reserved interface for potential iterative pruning needs,
|
||||
# first_step_gen controls if the masks generated on the first step.
|
||||
self.first_step_gen = False
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
from typing import Dict, List, overload
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
|
||||
from .scheduled_pruner import ScheduledPruner
|
||||
from .tools import is_active_target, generate_sparsity
|
||||
from ..base.compressor import Compressor
|
||||
from ..base.target_space import TargetType
|
||||
from ..base.wrapper import ModuleWrapper
|
||||
from ..utils import Evaluator
|
||||
|
||||
MOVEMENT_SCORE_PNAME = '{}_mvp_score'
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MovementPruner(ScheduledPruner):
|
||||
"""
|
||||
Movement pruner is an implementation of movement pruning.
|
||||
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
|
||||
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
|
||||
This means the weight elements moving towards zero will accumulate negative scores,
|
||||
the weight elements moving away from zero will accumulate positive scores.
|
||||
The weight elements with low scores will be masked during inference.
|
||||
|
||||
The following figure from the paper shows the weight pruning by movement pruning.
|
||||
|
||||
.. image:: ../../../img/movement_pruning.png
|
||||
:target: ../../../img/movement_pruning.png
|
||||
:alt:
|
||||
|
||||
For more details, please refer to `Movement Pruning: Adaptive Sparsity by Fine-Tuning <https://arxiv.org/abs/2005.07683>`__.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model
|
||||
Model to be pruned.
|
||||
config_list
|
||||
A list of dict, each dict configure which module need to be pruned, and how to prune.
|
||||
Please refer :doc:`Compression Config Specification </compression/compression_config_list>` for more information.
|
||||
evaluator
|
||||
TODO: {evaluator_docstring}
|
||||
warmup_step
|
||||
The total `optimizer.step()` number before start pruning for warm up.
|
||||
Make sure ``warmup_step`` is smaller than ``cooldown_begin_step``.
|
||||
cooldown_begin_step
|
||||
The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed.
|
||||
The sparse ratio or sparse threshold after each `optimizer.step()` is::
|
||||
|
||||
final_sparse * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3)
|
||||
regular_scale
|
||||
A scale factor used to control the movement score regular loss.
|
||||
This factor only works on pruning target controlled by ``sparse_threshold``,
|
||||
the pruning target controlled by ``sparse_ratio`` will not be regularized.
|
||||
|
||||
Examples
|
||||
--------
|
||||
TODO
|
||||
"""
|
||||
@overload
|
||||
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int,
|
||||
cooldown_begin_step: int, regular_scale: float = 1.):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int,
|
||||
cooldown_begin_step: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None):
|
||||
...
|
||||
|
||||
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int,
|
||||
cooldown_begin_step: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None):
|
||||
super().__init__(model, config_list, evaluator, existed_wrappers)
|
||||
self.evaluator: Evaluator
|
||||
assert 0 <= warmup_step < cooldown_begin_step
|
||||
self.warmup_step = warmup_step
|
||||
self.cooldown_begin_step = cooldown_begin_step
|
||||
self.regular_scale = regular_scale
|
||||
self._init_sparse_goals()
|
||||
self._set_apply_method()
|
||||
|
||||
self.interval_steps = 1
|
||||
self.total_times = (self.cooldown_begin_step - self.warmup_step) // self.interval_steps
|
||||
self._remaining_times: int
|
||||
self.scores: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], warmup_step: int,
|
||||
cooldown_begin_step: int, regular_scale: float = 1., evaluator: Evaluator | None = None):
|
||||
return super().from_compressor(compressor, new_config_list, warmup_step=warmup_step, cooldown_begin_step=cooldown_begin_step,
|
||||
regular_scale=regular_scale, evaluator=evaluator)
|
||||
|
||||
def _set_apply_method(self):
|
||||
for _, ts in self._target_spaces.items():
|
||||
for _, target_space in ts.items():
|
||||
if target_space.apply_method == 'mul':
|
||||
target_space.apply_method = 'movement_mul'
|
||||
if target_space.apply_method == 'add':
|
||||
target_space.apply_method = 'movement_add'
|
||||
|
||||
def _register_movement_scores(self):
|
||||
for module_name, ts in self._target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
if is_active_target(target_space):
|
||||
# TODO: add input / output
|
||||
if target_space.type is TargetType.PARAMETER:
|
||||
# TODO: here using a shrinked score to save memory, but need to test the speed.
|
||||
score_val = torch.zeros_like(target_space.target) # type: ignore
|
||||
if target_space._scaler is not None:
|
||||
score_val = target_space._scaler.shrink(score_val)
|
||||
target_space._wrapper.register_parameter(MOVEMENT_SCORE_PNAME.format(target_name),
|
||||
torch.nn.Parameter(score_val))
|
||||
score = target_space._get_wrapper_attr(MOVEMENT_SCORE_PNAME.format(target_name))
|
||||
self.scores[module_name][target_name] = score
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _register_scores_optimization(self, evaluator: Evaluator):
|
||||
scores = []
|
||||
for _, target_scores in self.scores.items():
|
||||
for _, score in target_scores.items():
|
||||
scores.append(score)
|
||||
|
||||
if not scores:
|
||||
return
|
||||
|
||||
params = [{"params": scores}]
|
||||
optimizer = Adam(params, 1e-2)
|
||||
|
||||
def optimizer_task():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
evaluator.patch_optimizer_step(before_step_tasks=[optimizer_task], after_step_tasks=[])
|
||||
|
||||
def _patch_loss(self, evaluator: Evaluator):
|
||||
def loss_patch(original_loss, batch):
|
||||
reg_loss = 0.
|
||||
count = 0
|
||||
for module_name, target_scores in self.scores.items():
|
||||
for target_name, score in target_scores.items():
|
||||
target_space = self._target_spaces[module_name][target_name]
|
||||
if target_space.sparse_threshold is not None:
|
||||
reg_loss += torch.norm(score.sigmoid(), p=1) / score.numel() # type: ignore
|
||||
count += 1
|
||||
ratio = max(0., min(1., 1 - (self._remaining_times / self.total_times) ** 3))
|
||||
if count > 0:
|
||||
reg_loss = self.regular_scale * ratio * reg_loss / count
|
||||
return original_loss + reg_loss
|
||||
|
||||
evaluator.patch_loss(loss_patch)
|
||||
|
||||
def _register_trigger(self, evaluator: Evaluator):
|
||||
self._current_step = 0
|
||||
self._iterial_step = 0
|
||||
self._remaining_times = self.total_times
|
||||
|
||||
def optimizer_task():
|
||||
self._current_step += 1
|
||||
if self.warmup_step < self._current_step <= self.cooldown_begin_step:
|
||||
self._iterial_step += 1
|
||||
if self._iterial_step == self.interval_steps:
|
||||
self._remaining_times -= 1
|
||||
self.update_sparse_goals(self.total_times - self._remaining_times)
|
||||
debug_msg = f'{self.__class__.__name__} generate masks, remaining times {self._remaining_times}'
|
||||
_logger.debug(debug_msg)
|
||||
if self._remaining_times > 0:
|
||||
self._iterial_step = 0
|
||||
if self.warmup_step < self._current_step:
|
||||
self.update_masks(self.generate_masks())
|
||||
|
||||
evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task])
|
||||
|
||||
def update_sparse_goals(self, current_times: int):
|
||||
ratio = max(0., min(1., 1 - (1 - current_times / self.total_times) ** 3))
|
||||
self._update_sparse_goals_by_ratio(ratio)
|
||||
|
||||
def _collect_data(self) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
data = defaultdict(dict)
|
||||
for module_name, ts in self._target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
score: torch.Tensor = getattr(target_space._wrapper, MOVEMENT_SCORE_PNAME.format(target_name), None) # type: ignore
|
||||
if score is not None:
|
||||
data[module_name][target_name] = score.clone().detach()
|
||||
return data
|
||||
|
||||
def _calculate_metrics(self, data: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
metrics = defaultdict(dict)
|
||||
for module_name, td in data.items():
|
||||
for target_name, target_data in td.items():
|
||||
if self._target_spaces[module_name][target_name].sparse_threshold is not None:
|
||||
metrics[module_name][target_name] = target_data.sigmoid()
|
||||
else:
|
||||
metrics[module_name][target_name] = target_data
|
||||
return metrics
|
||||
|
||||
def _generate_sparsity(self, metrics: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
return generate_sparsity(metrics=metrics, target_spaces=self._target_spaces)
|
||||
|
||||
def _single_compress(self, max_steps: int | None, max_epochs: int | None):
|
||||
self._fusion_compress(max_steps, max_epochs)
|
||||
|
||||
def _fuse_preprocess(self, evaluator: Evaluator):
|
||||
self._update_sparse_goals_by_ratio(0.)
|
||||
self._register_movement_scores()
|
||||
self._patch_loss(evaluator)
|
||||
self._register_scores_optimization(evaluator)
|
||||
self._register_trigger(evaluator)
|
||||
|
||||
def _fuse_postprocess(self, evaluator: Evaluator):
|
||||
pass
|
||||
|
||||
def compress(self, max_steps: int | None, max_epochs: int | None):
|
||||
if max_steps is not None:
|
||||
assert max_steps >= self.cooldown_begin_step
|
||||
else:
|
||||
warn_msg = \
|
||||
f'Using epochs number as training duration, please make sure the total training steps larger than `cooldown_begin_step`.'
|
||||
_logger.warning(warn_msg)
|
||||
return super().compress(max_steps, max_epochs)
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ..base.compressor import Pruner
|
||||
from ..base.wrapper import ModuleWrapper
|
||||
from ..utils import Evaluator
|
||||
|
||||
from .basic_pruner import LevelPruner, L1NormPruner, L2NormPruner
|
||||
from .slim_pruner import SlimPruner
|
||||
from .taylor_pruner import TaylorPruner
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScheduledPruner(Pruner):
|
||||
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator | None = None,
|
||||
existed_wrappers: Dict[str, ModuleWrapper] | None = None):
|
||||
super().__init__(model, config_list, evaluator, existed_wrappers)
|
||||
self.evaluator: Evaluator
|
||||
|
||||
self.sparse_goals: Dict[str, Dict[str, Dict[str, float]]] = defaultdict(dict)
|
||||
self._goals_initialized = False
|
||||
self._scheduled_keys = ['sparse_ratio', 'sparse_threshold', 'max_sparse_ratio', 'min_sparse_ratio']
|
||||
|
||||
def _init_sparse_goals(self):
|
||||
if self._goals_initialized:
|
||||
_logger.warning('Sparse goals have already initialized.')
|
||||
return
|
||||
for module_name, ts in self._target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
self.sparse_goals[module_name][target_name] = {}
|
||||
for scheduled_key in self._scheduled_keys:
|
||||
if getattr(target_space, scheduled_key) is not None:
|
||||
self.sparse_goals[module_name][target_name][scheduled_key] = getattr(target_space, scheduled_key)
|
||||
self._goals_initialized = True
|
||||
|
||||
def update_sparse_goals(self, current_times: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _update_sparse_goals_by_ratio(self, ratio: float):
|
||||
for module_name, tg in self.sparse_goals.items():
|
||||
for target_name, target_goals in tg.items():
|
||||
for scheduled_key, goal in target_goals.items():
|
||||
setattr(self._target_spaces[module_name][target_name], scheduled_key, goal * ratio)
|
||||
|
||||
|
||||
class _ComboPruner(ScheduledPruner):
|
||||
def __init__(self, pruner: Pruner, interval_steps: int, total_times: int, evaluator: Evaluator | None = None):
|
||||
assert isinstance(pruner, Pruner)
|
||||
assert hasattr(pruner, 'interval_steps') and hasattr(pruner, 'total_times')
|
||||
if not isinstance(pruner, (LevelPruner, L1NormPruner, L2NormPruner, SlimPruner, TaylorPruner)):
|
||||
warning_msg = f'Compatibility not tested with pruner type {pruner.__class__.__name__}.'
|
||||
_logger.warning(warning_msg)
|
||||
if pruner._is_wrapped:
|
||||
pruner.unwrap_model()
|
||||
|
||||
model = pruner.bound_model
|
||||
existed_wrappers = pruner._module_wrappers
|
||||
if pruner.evaluator is not None and evaluator is not None:
|
||||
_logger.warning('Pruner already has evaluator, the new evaluator passed to this function will be ignored.')
|
||||
evaluator = pruner.evaluator if pruner.evaluator else evaluator
|
||||
assert isinstance(evaluator, Evaluator)
|
||||
|
||||
super().__init__(model=model, config_list=[], evaluator=evaluator, existed_wrappers=existed_wrappers)
|
||||
# skip the pruner passed in
|
||||
self.fused_compressors.extend(pruner.fused_compressors[1:])
|
||||
self._target_spaces = pruner._target_spaces
|
||||
self.interval_steps = interval_steps
|
||||
self.total_times = total_times
|
||||
self.bound_pruner = pruner
|
||||
|
||||
self._init_sparse_goals()
|
||||
self._initial_ratio = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, *args, **kwargs):
|
||||
raise NotImplementedError(f'{cls.__name__} can not initialized from any compressor.')
|
||||
|
||||
def _initialize_state(self):
|
||||
self._update_sparse_goals_by_ratio(self._initial_ratio)
|
||||
self.bound_pruner.interval_steps = self.interval_steps # type: ignore
|
||||
self.bound_pruner.total_times = self.total_times # type: ignore
|
||||
|
||||
def _register_trigger(self, evaluator: Evaluator):
|
||||
self._current_step = 0
|
||||
self._remaining_times = self.total_times
|
||||
|
||||
def optimizer_task():
|
||||
self._current_step += 1
|
||||
if self._current_step == self.interval_steps:
|
||||
self._remaining_times -= 1
|
||||
self.update_sparse_goals(self.total_times - self._remaining_times)
|
||||
debug_msg = f'{self.__class__.__name__} generate masks, remaining times {self._remaining_times}'
|
||||
_logger.debug(debug_msg)
|
||||
if self._remaining_times > 0:
|
||||
self._current_step = 0
|
||||
|
||||
evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task])
|
||||
|
||||
def _single_compress(self, max_steps: int | None, max_epochs: int | None):
|
||||
self._fusion_compress(max_steps, max_epochs)
|
||||
|
||||
def _fuse_preprocess(self, evaluator: Evaluator) -> None:
|
||||
self._initialize_state()
|
||||
self._register_trigger(evaluator)
|
||||
self.bound_pruner._fuse_preprocess(evaluator)
|
||||
|
||||
def _fuse_postprocess(self, evaluator: Evaluator) -> None:
|
||||
self.bound_pruner._fuse_postprocess(evaluator)
|
||||
|
||||
def compress(self, max_steps: int | None, max_epochs: int | None):
|
||||
if max_steps is not None:
|
||||
assert max_steps >= self.total_times * self.interval_steps
|
||||
else:
|
||||
warn_msg = f'Using epochs number as training duration, ' + \
|
||||
'please make sure the total training steps larger than total_times * interval_steps.'
|
||||
_logger.warning(warn_msg)
|
||||
return super().compress(max_steps, max_epochs)
|
||||
|
||||
|
||||
class LinearPruner(_ComboPruner):
|
||||
"""
|
||||
The sparse ratio or sparse threshold in the bound pruner will increase in a linear way from 0. to final::
|
||||
|
||||
current_sparse = (1 - initial_ratio) * current_times / total_times * final_sparse
|
||||
|
||||
If min/max sparse ratio is also set in target setting, they will also synchronous increase in a linear way.
|
||||
|
||||
Note that this pruner can not be initialized by ``LinearPruner.from_compressor(...)``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pruner
|
||||
The bound pruner.
|
||||
interval_steps
|
||||
A integer number, for each ``interval_steps`` training, the sparse goal will be updated.
|
||||
total_times
|
||||
A integer number, how many times to update the sparse goal in total.
|
||||
evaluator
|
||||
TODO
|
||||
|
||||
Examples
|
||||
--------
|
||||
TODO
|
||||
"""
|
||||
|
||||
def update_sparse_goals(self, current_times: int):
|
||||
ratio = (1 - self._initial_ratio) * current_times / self.total_times
|
||||
self._update_sparse_goals_by_ratio(ratio)
|
||||
|
||||
|
||||
class AGPPruner(_ComboPruner):
|
||||
"""
|
||||
The sparse ratio or sparse threshold in the bound pruner will increase in a AGP way from 0. to final::
|
||||
|
||||
current_sparse = (1 - (1 - self._initial_ratio) * (1 - current_times / self.total_times) ** 3) * final_sparse
|
||||
|
||||
If min/max sparse ratio is also set in target setting, they will also synchronous increase in a AGP way.
|
||||
|
||||
Note that this pruner can not be initialized by ``AGPPruner.from_compressor(...)``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pruner
|
||||
The bound pruner.
|
||||
interval_steps
|
||||
A integer number, for each ``interval_steps`` training, the sparse goal will be updated.
|
||||
total_times
|
||||
A integer number, how many times to update the sparse goal in total.
|
||||
evaluator
|
||||
TODO
|
||||
|
||||
Examples
|
||||
--------
|
||||
TODO
|
||||
"""
|
||||
def update_sparse_goals(self, current_times: int):
|
||||
ratio = 1 - (1 - self._initial_ratio) * (1 - current_times / self.total_times) ** 3
|
||||
self._update_sparse_goals_by_ratio(ratio)
|
Загрузка…
Ссылка в новой задаче