[Compression] pruning stage 3: add scheduled/movement pruner (#5389)

This commit is contained in:
J-shang 2023-03-07 16:04:45 +08:00 коммит произвёл GitHub
Родитель 80017c015b
Коммит 34f3a14d4a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 425 добавлений и 2 удалений

Просмотреть файл

@ -42,7 +42,7 @@ def _fill_one_on_dims(mask: torch.Tensor, dims: int | List[int]) -> torch.Tensor
continue
dim_mask = (mask.sum([_ for _ in range(len(mask.shape)) if _ != i]) == 0.)
new_mask = new_mask.transpose(0, i)
new_mask[torch.arange(len(dim_mask))[dim_mask].long().tolist()] = 0.
new_mask[torch.arange(len(dim_mask), device=new_mask.device)[dim_mask].long().tolist()] = 0.
new_mask = new_mask.transpose(0, i)
return new_mask

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .basic_pruner import LevelPruner, L1NormPruner, L2NormPruner
from .movement_pruner import MovementPruner
from .scheduled_pruner import LinearPruner, AGPPruner
from .slim_pruner import SlimPruner
from .taylor_pruner import TaylorPruner

Просмотреть файл

@ -37,9 +37,11 @@ class _NormPruner(Pruner):
# `skip_first_step` controls if generating masks at the first step.
# `interval_steps` is the optimize step interval for generating masks.
# `total_times` is the total generation times of masks.
self.first_step_gen = False
self.interval_steps = -1
self.total_times: int | Literal['unlimited'] = 1
# here is a reserved interface for potential iterative pruning needs,
# first_step_gen controls if the masks generated on the first step.
self.first_step_gen = False
@classmethod
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):

Просмотреть файл

@ -0,0 +1,226 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from collections import defaultdict
import logging
from typing import Dict, List, overload
import torch
from torch.optim import Adam
from .scheduled_pruner import ScheduledPruner
from .tools import is_active_target, generate_sparsity
from ..base.compressor import Compressor
from ..base.target_space import TargetType
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator
MOVEMENT_SCORE_PNAME = '{}_mvp_score'
_logger = logging.getLogger(__name__)
class MovementPruner(ScheduledPruner):
"""
Movement pruner is an implementation of movement pruning.
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
This means the weight elements moving towards zero will accumulate negative scores,
the weight elements moving away from zero will accumulate positive scores.
The weight elements with low scores will be masked during inference.
The following figure from the paper shows the weight pruning by movement pruning.
.. image:: ../../../img/movement_pruning.png
:target: ../../../img/movement_pruning.png
:alt:
For more details, please refer to `Movement Pruning: Adaptive Sparsity by Fine-Tuning <https://arxiv.org/abs/2005.07683>`__.
Parameters
----------
model
Model to be pruned.
config_list
A list of dict, each dict configure which module need to be pruned, and how to prune.
Please refer :doc:`Compression Config Specification </compression/compression_config_list>` for more information.
evaluator
TODO: {evaluator_docstring}
warmup_step
The total `optimizer.step()` number before start pruning for warm up.
Make sure ``warmup_step`` is smaller than ``cooldown_begin_step``.
cooldown_begin_step
The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed.
The sparse ratio or sparse threshold after each `optimizer.step()` is::
final_sparse * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3)
regular_scale
A scale factor used to control the movement score regular loss.
This factor only works on pruning target controlled by ``sparse_threshold``,
the pruning target controlled by ``sparse_ratio`` will not be regularized.
Examples
--------
TODO
"""
@overload
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int,
cooldown_begin_step: int, regular_scale: float = 1.):
...
@overload
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int,
cooldown_begin_step: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None):
...
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int,
cooldown_begin_step: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None):
super().__init__(model, config_list, evaluator, existed_wrappers)
self.evaluator: Evaluator
assert 0 <= warmup_step < cooldown_begin_step
self.warmup_step = warmup_step
self.cooldown_begin_step = cooldown_begin_step
self.regular_scale = regular_scale
self._init_sparse_goals()
self._set_apply_method()
self.interval_steps = 1
self.total_times = (self.cooldown_begin_step - self.warmup_step) // self.interval_steps
self._remaining_times: int
self.scores: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)
@classmethod
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], warmup_step: int,
cooldown_begin_step: int, regular_scale: float = 1., evaluator: Evaluator | None = None):
return super().from_compressor(compressor, new_config_list, warmup_step=warmup_step, cooldown_begin_step=cooldown_begin_step,
regular_scale=regular_scale, evaluator=evaluator)
def _set_apply_method(self):
for _, ts in self._target_spaces.items():
for _, target_space in ts.items():
if target_space.apply_method == 'mul':
target_space.apply_method = 'movement_mul'
if target_space.apply_method == 'add':
target_space.apply_method = 'movement_add'
def _register_movement_scores(self):
for module_name, ts in self._target_spaces.items():
for target_name, target_space in ts.items():
if is_active_target(target_space):
# TODO: add input / output
if target_space.type is TargetType.PARAMETER:
# TODO: here using a shrinked score to save memory, but need to test the speed.
score_val = torch.zeros_like(target_space.target) # type: ignore
if target_space._scaler is not None:
score_val = target_space._scaler.shrink(score_val)
target_space._wrapper.register_parameter(MOVEMENT_SCORE_PNAME.format(target_name),
torch.nn.Parameter(score_val))
score = target_space._get_wrapper_attr(MOVEMENT_SCORE_PNAME.format(target_name))
self.scores[module_name][target_name] = score
else:
raise NotImplementedError()
def _register_scores_optimization(self, evaluator: Evaluator):
scores = []
for _, target_scores in self.scores.items():
for _, score in target_scores.items():
scores.append(score)
if not scores:
return
params = [{"params": scores}]
optimizer = Adam(params, 1e-2)
def optimizer_task():
optimizer.step()
optimizer.zero_grad()
evaluator.patch_optimizer_step(before_step_tasks=[optimizer_task], after_step_tasks=[])
def _patch_loss(self, evaluator: Evaluator):
def loss_patch(original_loss, batch):
reg_loss = 0.
count = 0
for module_name, target_scores in self.scores.items():
for target_name, score in target_scores.items():
target_space = self._target_spaces[module_name][target_name]
if target_space.sparse_threshold is not None:
reg_loss += torch.norm(score.sigmoid(), p=1) / score.numel() # type: ignore
count += 1
ratio = max(0., min(1., 1 - (self._remaining_times / self.total_times) ** 3))
if count > 0:
reg_loss = self.regular_scale * ratio * reg_loss / count
return original_loss + reg_loss
evaluator.patch_loss(loss_patch)
def _register_trigger(self, evaluator: Evaluator):
self._current_step = 0
self._iterial_step = 0
self._remaining_times = self.total_times
def optimizer_task():
self._current_step += 1
if self.warmup_step < self._current_step <= self.cooldown_begin_step:
self._iterial_step += 1
if self._iterial_step == self.interval_steps:
self._remaining_times -= 1
self.update_sparse_goals(self.total_times - self._remaining_times)
debug_msg = f'{self.__class__.__name__} generate masks, remaining times {self._remaining_times}'
_logger.debug(debug_msg)
if self._remaining_times > 0:
self._iterial_step = 0
if self.warmup_step < self._current_step:
self.update_masks(self.generate_masks())
evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task])
def update_sparse_goals(self, current_times: int):
ratio = max(0., min(1., 1 - (1 - current_times / self.total_times) ** 3))
self._update_sparse_goals_by_ratio(ratio)
def _collect_data(self) -> Dict[str, Dict[str, torch.Tensor]]:
data = defaultdict(dict)
for module_name, ts in self._target_spaces.items():
for target_name, target_space in ts.items():
score: torch.Tensor = getattr(target_space._wrapper, MOVEMENT_SCORE_PNAME.format(target_name), None) # type: ignore
if score is not None:
data[module_name][target_name] = score.clone().detach()
return data
def _calculate_metrics(self, data: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, Dict[str, torch.Tensor]]:
metrics = defaultdict(dict)
for module_name, td in data.items():
for target_name, target_data in td.items():
if self._target_spaces[module_name][target_name].sparse_threshold is not None:
metrics[module_name][target_name] = target_data.sigmoid()
else:
metrics[module_name][target_name] = target_data
return metrics
def _generate_sparsity(self, metrics: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, Dict[str, torch.Tensor]]:
return generate_sparsity(metrics=metrics, target_spaces=self._target_spaces)
def _single_compress(self, max_steps: int | None, max_epochs: int | None):
self._fusion_compress(max_steps, max_epochs)
def _fuse_preprocess(self, evaluator: Evaluator):
self._update_sparse_goals_by_ratio(0.)
self._register_movement_scores()
self._patch_loss(evaluator)
self._register_scores_optimization(evaluator)
self._register_trigger(evaluator)
def _fuse_postprocess(self, evaluator: Evaluator):
pass
def compress(self, max_steps: int | None, max_epochs: int | None):
if max_steps is not None:
assert max_steps >= self.cooldown_begin_step
else:
warn_msg = \
f'Using epochs number as training duration, please make sure the total training steps larger than `cooldown_begin_step`.'
_logger.warning(warn_msg)
return super().compress(max_steps, max_epochs)

Просмотреть файл

@ -0,0 +1,187 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from collections import defaultdict
import logging
from typing import Dict, List
import torch
from ..base.compressor import Pruner
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator
from .basic_pruner import LevelPruner, L1NormPruner, L2NormPruner
from .slim_pruner import SlimPruner
from .taylor_pruner import TaylorPruner
_logger = logging.getLogger(__name__)
class ScheduledPruner(Pruner):
def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator | None = None,
existed_wrappers: Dict[str, ModuleWrapper] | None = None):
super().__init__(model, config_list, evaluator, existed_wrappers)
self.evaluator: Evaluator
self.sparse_goals: Dict[str, Dict[str, Dict[str, float]]] = defaultdict(dict)
self._goals_initialized = False
self._scheduled_keys = ['sparse_ratio', 'sparse_threshold', 'max_sparse_ratio', 'min_sparse_ratio']
def _init_sparse_goals(self):
if self._goals_initialized:
_logger.warning('Sparse goals have already initialized.')
return
for module_name, ts in self._target_spaces.items():
for target_name, target_space in ts.items():
self.sparse_goals[module_name][target_name] = {}
for scheduled_key in self._scheduled_keys:
if getattr(target_space, scheduled_key) is not None:
self.sparse_goals[module_name][target_name][scheduled_key] = getattr(target_space, scheduled_key)
self._goals_initialized = True
def update_sparse_goals(self, current_times: int):
raise NotImplementedError()
def _update_sparse_goals_by_ratio(self, ratio: float):
for module_name, tg in self.sparse_goals.items():
for target_name, target_goals in tg.items():
for scheduled_key, goal in target_goals.items():
setattr(self._target_spaces[module_name][target_name], scheduled_key, goal * ratio)
class _ComboPruner(ScheduledPruner):
def __init__(self, pruner: Pruner, interval_steps: int, total_times: int, evaluator: Evaluator | None = None):
assert isinstance(pruner, Pruner)
assert hasattr(pruner, 'interval_steps') and hasattr(pruner, 'total_times')
if not isinstance(pruner, (LevelPruner, L1NormPruner, L2NormPruner, SlimPruner, TaylorPruner)):
warning_msg = f'Compatibility not tested with pruner type {pruner.__class__.__name__}.'
_logger.warning(warning_msg)
if pruner._is_wrapped:
pruner.unwrap_model()
model = pruner.bound_model
existed_wrappers = pruner._module_wrappers
if pruner.evaluator is not None and evaluator is not None:
_logger.warning('Pruner already has evaluator, the new evaluator passed to this function will be ignored.')
evaluator = pruner.evaluator if pruner.evaluator else evaluator
assert isinstance(evaluator, Evaluator)
super().__init__(model=model, config_list=[], evaluator=evaluator, existed_wrappers=existed_wrappers)
# skip the pruner passed in
self.fused_compressors.extend(pruner.fused_compressors[1:])
self._target_spaces = pruner._target_spaces
self.interval_steps = interval_steps
self.total_times = total_times
self.bound_pruner = pruner
self._init_sparse_goals()
self._initial_ratio = 0.0
@classmethod
def from_compressor(cls, *args, **kwargs):
raise NotImplementedError(f'{cls.__name__} can not initialized from any compressor.')
def _initialize_state(self):
self._update_sparse_goals_by_ratio(self._initial_ratio)
self.bound_pruner.interval_steps = self.interval_steps # type: ignore
self.bound_pruner.total_times = self.total_times # type: ignore
def _register_trigger(self, evaluator: Evaluator):
self._current_step = 0
self._remaining_times = self.total_times
def optimizer_task():
self._current_step += 1
if self._current_step == self.interval_steps:
self._remaining_times -= 1
self.update_sparse_goals(self.total_times - self._remaining_times)
debug_msg = f'{self.__class__.__name__} generate masks, remaining times {self._remaining_times}'
_logger.debug(debug_msg)
if self._remaining_times > 0:
self._current_step = 0
evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task])
def _single_compress(self, max_steps: int | None, max_epochs: int | None):
self._fusion_compress(max_steps, max_epochs)
def _fuse_preprocess(self, evaluator: Evaluator) -> None:
self._initialize_state()
self._register_trigger(evaluator)
self.bound_pruner._fuse_preprocess(evaluator)
def _fuse_postprocess(self, evaluator: Evaluator) -> None:
self.bound_pruner._fuse_postprocess(evaluator)
def compress(self, max_steps: int | None, max_epochs: int | None):
if max_steps is not None:
assert max_steps >= self.total_times * self.interval_steps
else:
warn_msg = f'Using epochs number as training duration, ' + \
'please make sure the total training steps larger than total_times * interval_steps.'
_logger.warning(warn_msg)
return super().compress(max_steps, max_epochs)
class LinearPruner(_ComboPruner):
"""
The sparse ratio or sparse threshold in the bound pruner will increase in a linear way from 0. to final::
current_sparse = (1 - initial_ratio) * current_times / total_times * final_sparse
If min/max sparse ratio is also set in target setting, they will also synchronous increase in a linear way.
Note that this pruner can not be initialized by ``LinearPruner.from_compressor(...)``.
Parameters
----------
pruner
The bound pruner.
interval_steps
A integer number, for each ``interval_steps`` training, the sparse goal will be updated.
total_times
A integer number, how many times to update the sparse goal in total.
evaluator
TODO
Examples
--------
TODO
"""
def update_sparse_goals(self, current_times: int):
ratio = (1 - self._initial_ratio) * current_times / self.total_times
self._update_sparse_goals_by_ratio(ratio)
class AGPPruner(_ComboPruner):
"""
The sparse ratio or sparse threshold in the bound pruner will increase in a AGP way from 0. to final::
current_sparse = (1 - (1 - self._initial_ratio) * (1 - current_times / self.total_times) ** 3) * final_sparse
If min/max sparse ratio is also set in target setting, they will also synchronous increase in a AGP way.
Note that this pruner can not be initialized by ``AGPPruner.from_compressor(...)``.
Parameters
----------
pruner
The bound pruner.
interval_steps
A integer number, for each ``interval_steps`` training, the sparse goal will be updated.
total_times
A integer number, how many times to update the sparse goal in total.
evaluator
TODO
Examples
--------
TODO
"""
def update_sparse_goals(self, current_times: int):
ratio = 1 - (1 - self._initial_ratio) * (1 - current_times / self.total_times) ** 3
self._update_sparse_goals_by_ratio(ratio)