зеркало из https://github.com/microsoft/nni.git
[Compression] pruning stage 1: add pruning tools (#5387)
This commit is contained in:
Родитель
1777368ad8
Коммит
8b7dac3f8c
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .collect_data import _DATA, active_sparse_targets_filter
|
||||
from .calculate_metrics import _METRICS, norm_metrics, fpgm_metrics
|
||||
from .sparse_gen import _MASKS, generate_sparsity
|
||||
from .utils import is_active_target
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from .collect_data import _DATA
|
||||
from ...base.compressor import _PRUNING_TARGET_SPACES
|
||||
|
||||
|
||||
_METRICS = Dict[str, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
def norm_metrics(p: str | int, data: _DATA, target_spaces: _PRUNING_TARGET_SPACES) -> _METRICS:
|
||||
"""
|
||||
Calculate the norm of each block of the value in the given data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p
|
||||
The order of norm. Please refer `torch.norm <https://pytorch.org/docs/stable/generated/torch.norm.html>`__.
|
||||
data
|
||||
{module_name: {target_name: val}}.
|
||||
target_spaces
|
||||
{module_name: {target_name: pruning_target_space}}. Used to get the related scaler for each value in data.
|
||||
"""
|
||||
def reduce_func(t: torch.Tensor) -> torch.Tensor:
|
||||
return t.norm(p=p, dim=-1) # type: ignore
|
||||
|
||||
metrics = defaultdict(dict)
|
||||
for module_name, module_data in data.items():
|
||||
for target_name, target_data in module_data.items():
|
||||
target_space = target_spaces[module_name][target_name]
|
||||
if target_space._scaler is None:
|
||||
metrics[module_name][target_name] = target_data.abs()
|
||||
else:
|
||||
metrics[module_name][target_name] = target_space._scaler.shrink(target_data, reduce_func)
|
||||
return metrics
|
||||
|
||||
|
||||
def fpgm_metrics(p: str | int, data: _DATA, target_spaces: _PRUNING_TARGET_SPACES) -> _METRICS:
|
||||
def reduce_func(t: torch.Tensor) -> torch.Tensor:
|
||||
reshape_data = t.reshape(-1, t.shape[-1])
|
||||
metric = torch.zeros(reshape_data.shape[0]).type_as(reshape_data)
|
||||
for i in range(reshape_data.shape[0]):
|
||||
metric[i] = (reshape_data - reshape_data[i]).norm(p=p, dim=-1).sum() # type: ignore
|
||||
return metric.reshape(t.shape[:-1])
|
||||
|
||||
metrics = defaultdict(dict)
|
||||
for module_name, module_data in data.items():
|
||||
for target_name, target_data in module_data.items():
|
||||
target_space = target_spaces[module_name][target_name]
|
||||
assert target_space._scaler is not None, 'FPGM metric do not support finegrained sparse pattern.'
|
||||
metrics[module_name][target_name] = target_space._scaler.shrink(target_data, reduce_func)
|
||||
return metrics
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import is_active_target
|
||||
from ...base.compressor import _PRUNING_TARGET_SPACES
|
||||
|
||||
|
||||
_DATA = Dict[str, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
def active_sparse_targets_filter(target_spaces: _PRUNING_TARGET_SPACES) -> _DATA:
|
||||
# filter all targets need to active generate sparsity
|
||||
active_targets = defaultdict(dict)
|
||||
for module_name, ts in target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
if is_active_target(target_space):
|
||||
assert target_space.target is not None
|
||||
active_targets[module_name][target_name] = target_space.target.clone().detach()
|
||||
return active_targets
|
|
@ -0,0 +1,305 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
import heapq
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from .calculate_metrics import _METRICS
|
||||
from ...base.compressor import _PRUNING_TARGET_SPACES
|
||||
from ...base.target_space import PruningTargetSpace, TargetType
|
||||
|
||||
|
||||
_MASKS = Dict[str, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
def generate_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS:
|
||||
"""
|
||||
There are many ways to generate masks, in this function, most of the common generation rules are implemented,
|
||||
and these rules execute in a certain order.
|
||||
|
||||
The following rules are included in this function:
|
||||
|
||||
* Threshold. If sparse_threshold is set in target space, the mask will be generated by metric >= threshold is 1,
|
||||
and metric < threshold is 0.
|
||||
* Dependency. If dependency_group_id is set in target space, the metrics of the targets in the same group will be
|
||||
meaned as the group metric, then if target_space.internal_metric_block is set, all internal_metric_block of the
|
||||
targets will be put in one set to compute the lcm as the group internal block number.
|
||||
Split the group metric to group internal block number parts, compute mask for each part and merge as a group mask.
|
||||
All targets in this group share the group mask value.
|
||||
* Global. If global_group_id is set in target space, the metrics of the targets in the same group will be global ranked
|
||||
and generate the mask by taking smaller metric values as 0, others as 1 by sparse_ratio.
|
||||
* Ratio. The most common rule, directly generate the mask by taking the smaller metric values as 0, others as 1 by sparse_ratio.
|
||||
* Align. If align is set in target space, the mask will be generated by another existed mask.
|
||||
"""
|
||||
|
||||
def condition_dependency(target_space: PruningTargetSpace) -> bool:
|
||||
return target_space.dependency_group_id is not None
|
||||
|
||||
def condition_global(target_space: PruningTargetSpace) -> bool:
|
||||
return target_space.global_group_id is not None
|
||||
|
||||
def condition_ratio(target_space: PruningTargetSpace) -> bool:
|
||||
return target_space.sparse_ratio is not None
|
||||
|
||||
def condition_threshold(target_space: PruningTargetSpace) -> bool:
|
||||
return target_space.sparse_threshold is not None
|
||||
|
||||
def condition_align(target_space: PruningTargetSpace) -> bool:
|
||||
return target_space.align is not None
|
||||
|
||||
masks = defaultdict(dict)
|
||||
|
||||
threshold_target_spaces, remained_target_spaces = target_spaces_filter(target_spaces, condition_threshold)
|
||||
update_masks = _generate_threshold_sparsity(metrics, threshold_target_spaces)
|
||||
_nested_multiply_update_masks(masks, _expand_masks(update_masks, threshold_target_spaces))
|
||||
|
||||
dependency_target_spaces, remained_target_spaces = target_spaces_filter(target_spaces, condition_dependency)
|
||||
update_masks = _generate_dependency_sparsity(metrics, dependency_target_spaces)
|
||||
_nested_multiply_update_masks(masks, _expand_masks(update_masks, dependency_target_spaces))
|
||||
|
||||
global_target_spaces, remained_target_spaces = target_spaces_filter(remained_target_spaces, condition_global)
|
||||
update_masks = _generate_global_sparsity(metrics, global_target_spaces)
|
||||
_nested_multiply_update_masks(masks, _expand_masks(update_masks, global_target_spaces))
|
||||
|
||||
ratio_target_spaces, remained_target_spaces = target_spaces_filter(remained_target_spaces, condition_ratio)
|
||||
update_masks = _generate_ratio_sparsity(metrics, ratio_target_spaces)
|
||||
_nested_multiply_update_masks(masks, _expand_masks(update_masks, ratio_target_spaces))
|
||||
|
||||
align_target_spaces, remained_target_spaces = target_spaces_filter(remained_target_spaces, condition_align)
|
||||
update_masks = _generate_align_sparsity(masks, align_target_spaces)
|
||||
_nested_multiply_update_masks(masks, _expand_masks(update_masks, align_target_spaces))
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def target_spaces_filter(target_spaces: _PRUNING_TARGET_SPACES,
|
||||
condition: Callable[[PruningTargetSpace], bool]) -> Tuple[_PRUNING_TARGET_SPACES, _PRUNING_TARGET_SPACES]:
|
||||
filtered_target_spaces = defaultdict(dict)
|
||||
remained_target_spaces = defaultdict(dict)
|
||||
|
||||
for module_name, ts in target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
if (target_space.type is TargetType.PARAMETER and target_space.target is None) or not condition(target_space):
|
||||
remained_target_spaces[module_name][target_name] = target_space
|
||||
else:
|
||||
filtered_target_spaces[module_name][target_name] = target_space
|
||||
|
||||
return filtered_target_spaces, remained_target_spaces
|
||||
|
||||
|
||||
def _generate_ratio_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS:
|
||||
# NOTE: smaller metric value means more un-important
|
||||
masks = defaultdict(dict)
|
||||
for module_name, ts in target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
metric = metrics[module_name][target_name]
|
||||
min_sparse_ratio = target_space.min_sparse_ratio if target_space.min_sparse_ratio else 0.0
|
||||
max_sparse_ratio = target_space.max_sparse_ratio if target_space.max_sparse_ratio else 1.0
|
||||
sparse_ratio = min(max_sparse_ratio, max(min_sparse_ratio, target_space.sparse_ratio)) # type: ignore
|
||||
masks[module_name][target_name] = _ratio_mask(metric, sparse_ratio)
|
||||
return masks
|
||||
|
||||
|
||||
def _generate_threshold_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS:
|
||||
# NOTE: smaller metric value means more un-important
|
||||
masks = defaultdict(dict)
|
||||
for module_name, ts in target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
metric = metrics[module_name][target_name]
|
||||
# metric < threshold will be 0, metric >= threshold will be 1
|
||||
mask = _threshold_mask(metric, target_space.sparse_threshold) # type: ignore
|
||||
|
||||
# if sparse_ratio does not meet `min_sparse_ratio`, `max_sparse_ratio`, re-generate mask
|
||||
sparse_ratio = 1.0 - mask.sum() / mask.numel()
|
||||
min_sparse_ratio = target_space.min_sparse_ratio if target_space.min_sparse_ratio else 0.0
|
||||
max_sparse_ratio = target_space.max_sparse_ratio if target_space.max_sparse_ratio else 1.0
|
||||
if sparse_ratio < min_sparse_ratio:
|
||||
mask = _ratio_mask(metric, min_sparse_ratio)
|
||||
if sparse_ratio > max_sparse_ratio:
|
||||
mask = _ratio_mask(metric, max_sparse_ratio)
|
||||
|
||||
masks[module_name][target_name] = mask
|
||||
return masks
|
||||
|
||||
|
||||
def _generate_align_sparsity(masks: _MASKS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS:
|
||||
align_masks = defaultdict(dict)
|
||||
for module_name, ts in target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
src_mask = masks[module_name][target_space.align['target_name']] # type: ignore
|
||||
align_dims: List[int] = target_space.align['dims'] # type: ignore
|
||||
reduce_dims = [d for d in range(len(src_mask.shape)) if d not in align_dims and d - len(src_mask.shape) not in align_dims]
|
||||
align_masks[module_name][target_name] = src_mask.sum(reduce_dims).bool().float()
|
||||
return align_masks
|
||||
|
||||
|
||||
def _generate_global_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS:
|
||||
groups: Dict[str, List[Tuple[str, str, PruningTargetSpace]]] = defaultdict(list)
|
||||
for module_name, ts in target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
groups[target_space.global_group_id].append((module_name, target_name, target_space)) # type: ignore
|
||||
|
||||
masks = defaultdict(dict)
|
||||
for _, group in groups.items():
|
||||
group_sparse_ratio = None
|
||||
for _, _, target_space in group:
|
||||
if target_space.sparse_ratio is not None:
|
||||
if group_sparse_ratio is None:
|
||||
group_sparse_ratio = target_space.sparse_ratio
|
||||
else:
|
||||
assert group_sparse_ratio == target_space.sparse_ratio
|
||||
assert group_sparse_ratio is not None
|
||||
|
||||
# at least how many elements to mask
|
||||
sparse_number_low = 0
|
||||
# at most how many elements to mask
|
||||
sparse_number_high = 0
|
||||
# how many elements in this group
|
||||
total_element_number = 0
|
||||
for _, _, target_space in group:
|
||||
element_number = target_space.target.numel() # type: ignore
|
||||
total_element_number += element_number
|
||||
sparse_number_low += int(element_number * target_space.min_sparse_ratio) if target_space.min_sparse_ratio else 0
|
||||
sparse_number_high += int(element_number * target_space.max_sparse_ratio) if target_space.max_sparse_ratio else element_number
|
||||
# how many elements should be masked, controlled by sparse_ratio
|
||||
sparse_number = int(total_element_number * group_sparse_ratio)
|
||||
|
||||
if sparse_number <= sparse_number_low:
|
||||
# directly generate masks with target_space.min_sparse_ratio
|
||||
for module_name, target_name, target_space in group:
|
||||
sparse_ratio = target_space.min_sparse_ratio if target_space.min_sparse_ratio else 0.0
|
||||
masks[module_name][target_name] = _ratio_mask(metrics[module_name][target_name], sparse_ratio)
|
||||
continue
|
||||
|
||||
if sparse_number >= sparse_number_high:
|
||||
# directly generate masks with target_space.max_sparse_ratio
|
||||
for module_name, target_name, target_space in group:
|
||||
sparse_ratio = target_space.max_sparse_ratio if target_space.max_sparse_ratio else 0.0
|
||||
masks[module_name][target_name] = _ratio_mask(metrics[module_name][target_name], sparse_ratio)
|
||||
continue
|
||||
|
||||
sparse_threshold = _global_threshold_generate(metrics, group, sparse_number)
|
||||
for module_name, target_name, target_space in group:
|
||||
masks[module_name][target_name] = _threshold_mask(metrics[module_name][target_name], sparse_threshold)
|
||||
continue
|
||||
return masks
|
||||
|
||||
|
||||
def _generate_dependency_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS:
|
||||
groups: Dict[str, List[Tuple[str, str, PruningTargetSpace]]] = defaultdict(list)
|
||||
for module_name, ts in target_spaces.items():
|
||||
for target_name, target_space in ts.items():
|
||||
groups[target_space.dependency_group_id].append((module_name, target_name, target_space)) # type: ignore
|
||||
|
||||
masks = defaultdict(dict)
|
||||
for _, group in groups.items():
|
||||
block_numbers = [1]
|
||||
group_sparsity_ratio = None
|
||||
filtered_metrics = defaultdict(dict)
|
||||
|
||||
for module_name, target_name, target_space in group:
|
||||
assert target_space.internal_metric_block is None or isinstance(target_space.internal_metric_block, int)
|
||||
block_numbers.append(target_space.internal_metric_block if target_space.internal_metric_block else 1)
|
||||
if target_space.sparse_ratio is not None:
|
||||
if group_sparsity_ratio is None:
|
||||
group_sparsity_ratio = target_space.sparse_ratio
|
||||
else:
|
||||
assert group_sparsity_ratio == target_space.sparse_ratio
|
||||
filtered_metrics[module_name][target_name] = metrics[module_name][target_name]
|
||||
block_number = reduce(numpy.lcm, block_numbers)
|
||||
assert group_sparsity_ratio is not None
|
||||
group_metric = _metric_fuse(filtered_metrics)
|
||||
group_mask = _ratio_mask(group_metric, group_sparsity_ratio, view_size=[block_number, -1])
|
||||
|
||||
for module_name, target_name, _ in group:
|
||||
masks[module_name][target_name] = group_mask.clone()
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
# the following are helper functions
|
||||
|
||||
def _ratio_mask(metric: torch.Tensor, sparse_ratio: float, view_size: int | List[int] = -1):
|
||||
if sparse_ratio == 0.0:
|
||||
return torch.ones_like(metric)
|
||||
|
||||
if sparse_ratio == 1.0:
|
||||
return torch.zeros_like(metric)
|
||||
|
||||
assert 0.0 < sparse_ratio < 1.0
|
||||
if isinstance(view_size, int) or len(view_size[:-1]) == 0:
|
||||
block_number = 1
|
||||
else:
|
||||
block_number = numpy.prod(view_size[:-1])
|
||||
sparse_number_per_block = int(metric.numel() // block_number * sparse_ratio)
|
||||
viewed_metric = metric.view(view_size)
|
||||
_, indices = viewed_metric.topk(sparse_number_per_block, largest=False)
|
||||
return torch.ones_like(viewed_metric).scatter(-1, indices, 0.0).reshape_as(metric)
|
||||
|
||||
|
||||
def _threshold_mask(metric: torch.Tensor, sparse_threshold: float):
|
||||
return (metric >= sparse_threshold).float().to(metric.device)
|
||||
|
||||
|
||||
def _global_threshold_generate(metrics: _METRICS,
|
||||
group: List[Tuple[str, str, PruningTargetSpace]],
|
||||
sparse_number: int) -> float:
|
||||
buffer = []
|
||||
buffer_elem = 0
|
||||
for module_name, target_name, target_space in group:
|
||||
metric = metrics[module_name][target_name]
|
||||
grain_size = target_space.target.numel() // metric.numel() # type: ignore
|
||||
for m in metric.cpu().detach().view(-1):
|
||||
if buffer_elem <= sparse_number:
|
||||
heapq.heappush(buffer, (-m.item(), grain_size))
|
||||
buffer_elem += grain_size
|
||||
else:
|
||||
_, previous_grain_size = heapq.heappushpop(buffer, (-m.item(), grain_size))
|
||||
buffer_elem += grain_size - previous_grain_size
|
||||
return -heapq.heappop(buffer)[0]
|
||||
|
||||
|
||||
def _nested_multiply_update_masks(default_dict: _MASKS, update_dict: _MASKS):
|
||||
# if a target already has a mask, the old one will multiply the new one as the target mask,
|
||||
# that means the mask in default dict will more and more sparse.
|
||||
for module_name, target_tensors in update_dict.items():
|
||||
for target_name, target_tensor in target_tensors.items():
|
||||
if target_name in default_dict[module_name] and isinstance(default_dict[module_name][target_name], torch.Tensor):
|
||||
default_dict[module_name][target_name] = (default_dict[module_name][target_name] * target_tensor).bool().float()
|
||||
else:
|
||||
default_dict[module_name][target_name] = target_tensor
|
||||
|
||||
|
||||
def _metric_fuse(metrics: _METRICS) -> torch.Tensor:
|
||||
# mean all metric value
|
||||
fused_metric = None
|
||||
count = 0
|
||||
for _, module_metrics in metrics.items():
|
||||
for _, target_metric in module_metrics.items():
|
||||
if fused_metric is not None:
|
||||
fused_metric += target_metric
|
||||
else:
|
||||
fused_metric = target_metric.clone()
|
||||
count += 1
|
||||
assert fused_metric is not None
|
||||
return fused_metric / count
|
||||
|
||||
|
||||
def _expand_masks(masks: _MASKS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS:
|
||||
# expand the mask shape from metric shape to target shape
|
||||
new_masks = defaultdict(dict)
|
||||
for module_name, module_masks in masks.items():
|
||||
for target_name, target_mask in module_masks.items():
|
||||
target_space = target_spaces[module_name][target_name]
|
||||
if target_space._scaler:
|
||||
new_masks[module_name][target_name] = target_space._scaler.expand(target_mask, target_space.shape) # type: ignore
|
||||
else:
|
||||
new_masks[module_name][target_name] = target_mask
|
||||
return new_masks
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ...base.target_space import PruningTargetSpace
|
||||
|
||||
|
||||
def is_active_target(target_space: PruningTargetSpace):
|
||||
return target_space.sparse_ratio is not None or target_space.sparse_threshold is not None
|
Загрузка…
Ссылка в новой задаче