From 8b7dac3f8c101934b5fede626b48ec67b6982e57 Mon Sep 17 00:00:00 2001 From: J-shang <33053116+J-shang@users.noreply.github.com> Date: Tue, 7 Mar 2023 10:37:04 +0800 Subject: [PATCH] [Compression] pruning stage 1: add pruning tools (#5387) --- nni/contrib/compression/pruning/__init__.py | 0 .../compression/pruning/tools/__init__.py | 7 + .../pruning/tools/calculate_metrics.py | 59 ++++ .../compression/pruning/tools/collect_data.py | 26 ++ .../compression/pruning/tools/sparse_gen.py | 305 ++++++++++++++++++ .../compression/pruning/tools/utils.py | 10 + 6 files changed, 407 insertions(+) create mode 100644 nni/contrib/compression/pruning/__init__.py create mode 100644 nni/contrib/compression/pruning/tools/__init__.py create mode 100644 nni/contrib/compression/pruning/tools/calculate_metrics.py create mode 100644 nni/contrib/compression/pruning/tools/collect_data.py create mode 100644 nni/contrib/compression/pruning/tools/sparse_gen.py create mode 100644 nni/contrib/compression/pruning/tools/utils.py diff --git a/nni/contrib/compression/pruning/__init__.py b/nni/contrib/compression/pruning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nni/contrib/compression/pruning/tools/__init__.py b/nni/contrib/compression/pruning/tools/__init__.py new file mode 100644 index 000000000..09669f8aa --- /dev/null +++ b/nni/contrib/compression/pruning/tools/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .collect_data import _DATA, active_sparse_targets_filter +from .calculate_metrics import _METRICS, norm_metrics, fpgm_metrics +from .sparse_gen import _MASKS, generate_sparsity +from .utils import is_active_target diff --git a/nni/contrib/compression/pruning/tools/calculate_metrics.py b/nni/contrib/compression/pruning/tools/calculate_metrics.py new file mode 100644 index 000000000..6f63e0c8b --- /dev/null +++ b/nni/contrib/compression/pruning/tools/calculate_metrics.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from collections import defaultdict +from typing import Dict + +import torch + +from .collect_data import _DATA +from ...base.compressor import _PRUNING_TARGET_SPACES + + +_METRICS = Dict[str, Dict[str, torch.Tensor]] + + +def norm_metrics(p: str | int, data: _DATA, target_spaces: _PRUNING_TARGET_SPACES) -> _METRICS: + """ + Calculate the norm of each block of the value in the given data. + + Parameters + ---------- + p + The order of norm. Please refer `torch.norm `__. + data + {module_name: {target_name: val}}. + target_spaces + {module_name: {target_name: pruning_target_space}}. Used to get the related scaler for each value in data. + """ + def reduce_func(t: torch.Tensor) -> torch.Tensor: + return t.norm(p=p, dim=-1) # type: ignore + + metrics = defaultdict(dict) + for module_name, module_data in data.items(): + for target_name, target_data in module_data.items(): + target_space = target_spaces[module_name][target_name] + if target_space._scaler is None: + metrics[module_name][target_name] = target_data.abs() + else: + metrics[module_name][target_name] = target_space._scaler.shrink(target_data, reduce_func) + return metrics + + +def fpgm_metrics(p: str | int, data: _DATA, target_spaces: _PRUNING_TARGET_SPACES) -> _METRICS: + def reduce_func(t: torch.Tensor) -> torch.Tensor: + reshape_data = t.reshape(-1, t.shape[-1]) + metric = torch.zeros(reshape_data.shape[0]).type_as(reshape_data) + for i in range(reshape_data.shape[0]): + metric[i] = (reshape_data - reshape_data[i]).norm(p=p, dim=-1).sum() # type: ignore + return metric.reshape(t.shape[:-1]) + + metrics = defaultdict(dict) + for module_name, module_data in data.items(): + for target_name, target_data in module_data.items(): + target_space = target_spaces[module_name][target_name] + assert target_space._scaler is not None, 'FPGM metric do not support finegrained sparse pattern.' + metrics[module_name][target_name] = target_space._scaler.shrink(target_data, reduce_func) + return metrics diff --git a/nni/contrib/compression/pruning/tools/collect_data.py b/nni/contrib/compression/pruning/tools/collect_data.py new file mode 100644 index 000000000..f62d13325 --- /dev/null +++ b/nni/contrib/compression/pruning/tools/collect_data.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from collections import defaultdict +from typing import Dict + +import torch + +from .utils import is_active_target +from ...base.compressor import _PRUNING_TARGET_SPACES + + +_DATA = Dict[str, Dict[str, torch.Tensor]] + + +def active_sparse_targets_filter(target_spaces: _PRUNING_TARGET_SPACES) -> _DATA: + # filter all targets need to active generate sparsity + active_targets = defaultdict(dict) + for module_name, ts in target_spaces.items(): + for target_name, target_space in ts.items(): + if is_active_target(target_space): + assert target_space.target is not None + active_targets[module_name][target_name] = target_space.target.clone().detach() + return active_targets diff --git a/nni/contrib/compression/pruning/tools/sparse_gen.py b/nni/contrib/compression/pruning/tools/sparse_gen.py new file mode 100644 index 000000000..81acaa204 --- /dev/null +++ b/nni/contrib/compression/pruning/tools/sparse_gen.py @@ -0,0 +1,305 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from collections import defaultdict +from functools import reduce +import heapq +from typing import Callable, Dict, List, Tuple + +import numpy +import torch + +from .calculate_metrics import _METRICS +from ...base.compressor import _PRUNING_TARGET_SPACES +from ...base.target_space import PruningTargetSpace, TargetType + + +_MASKS = Dict[str, Dict[str, torch.Tensor]] + + +def generate_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS: + """ + There are many ways to generate masks, in this function, most of the common generation rules are implemented, + and these rules execute in a certain order. + + The following rules are included in this function: + + * Threshold. If sparse_threshold is set in target space, the mask will be generated by metric >= threshold is 1, + and metric < threshold is 0. + * Dependency. If dependency_group_id is set in target space, the metrics of the targets in the same group will be + meaned as the group metric, then if target_space.internal_metric_block is set, all internal_metric_block of the + targets will be put in one set to compute the lcm as the group internal block number. + Split the group metric to group internal block number parts, compute mask for each part and merge as a group mask. + All targets in this group share the group mask value. + * Global. If global_group_id is set in target space, the metrics of the targets in the same group will be global ranked + and generate the mask by taking smaller metric values as 0, others as 1 by sparse_ratio. + * Ratio. The most common rule, directly generate the mask by taking the smaller metric values as 0, others as 1 by sparse_ratio. + * Align. If align is set in target space, the mask will be generated by another existed mask. + """ + + def condition_dependency(target_space: PruningTargetSpace) -> bool: + return target_space.dependency_group_id is not None + + def condition_global(target_space: PruningTargetSpace) -> bool: + return target_space.global_group_id is not None + + def condition_ratio(target_space: PruningTargetSpace) -> bool: + return target_space.sparse_ratio is not None + + def condition_threshold(target_space: PruningTargetSpace) -> bool: + return target_space.sparse_threshold is not None + + def condition_align(target_space: PruningTargetSpace) -> bool: + return target_space.align is not None + + masks = defaultdict(dict) + + threshold_target_spaces, remained_target_spaces = target_spaces_filter(target_spaces, condition_threshold) + update_masks = _generate_threshold_sparsity(metrics, threshold_target_spaces) + _nested_multiply_update_masks(masks, _expand_masks(update_masks, threshold_target_spaces)) + + dependency_target_spaces, remained_target_spaces = target_spaces_filter(target_spaces, condition_dependency) + update_masks = _generate_dependency_sparsity(metrics, dependency_target_spaces) + _nested_multiply_update_masks(masks, _expand_masks(update_masks, dependency_target_spaces)) + + global_target_spaces, remained_target_spaces = target_spaces_filter(remained_target_spaces, condition_global) + update_masks = _generate_global_sparsity(metrics, global_target_spaces) + _nested_multiply_update_masks(masks, _expand_masks(update_masks, global_target_spaces)) + + ratio_target_spaces, remained_target_spaces = target_spaces_filter(remained_target_spaces, condition_ratio) + update_masks = _generate_ratio_sparsity(metrics, ratio_target_spaces) + _nested_multiply_update_masks(masks, _expand_masks(update_masks, ratio_target_spaces)) + + align_target_spaces, remained_target_spaces = target_spaces_filter(remained_target_spaces, condition_align) + update_masks = _generate_align_sparsity(masks, align_target_spaces) + _nested_multiply_update_masks(masks, _expand_masks(update_masks, align_target_spaces)) + + return masks + + +def target_spaces_filter(target_spaces: _PRUNING_TARGET_SPACES, + condition: Callable[[PruningTargetSpace], bool]) -> Tuple[_PRUNING_TARGET_SPACES, _PRUNING_TARGET_SPACES]: + filtered_target_spaces = defaultdict(dict) + remained_target_spaces = defaultdict(dict) + + for module_name, ts in target_spaces.items(): + for target_name, target_space in ts.items(): + if (target_space.type is TargetType.PARAMETER and target_space.target is None) or not condition(target_space): + remained_target_spaces[module_name][target_name] = target_space + else: + filtered_target_spaces[module_name][target_name] = target_space + + return filtered_target_spaces, remained_target_spaces + + +def _generate_ratio_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS: + # NOTE: smaller metric value means more un-important + masks = defaultdict(dict) + for module_name, ts in target_spaces.items(): + for target_name, target_space in ts.items(): + metric = metrics[module_name][target_name] + min_sparse_ratio = target_space.min_sparse_ratio if target_space.min_sparse_ratio else 0.0 + max_sparse_ratio = target_space.max_sparse_ratio if target_space.max_sparse_ratio else 1.0 + sparse_ratio = min(max_sparse_ratio, max(min_sparse_ratio, target_space.sparse_ratio)) # type: ignore + masks[module_name][target_name] = _ratio_mask(metric, sparse_ratio) + return masks + + +def _generate_threshold_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS: + # NOTE: smaller metric value means more un-important + masks = defaultdict(dict) + for module_name, ts in target_spaces.items(): + for target_name, target_space in ts.items(): + metric = metrics[module_name][target_name] + # metric < threshold will be 0, metric >= threshold will be 1 + mask = _threshold_mask(metric, target_space.sparse_threshold) # type: ignore + + # if sparse_ratio does not meet `min_sparse_ratio`, `max_sparse_ratio`, re-generate mask + sparse_ratio = 1.0 - mask.sum() / mask.numel() + min_sparse_ratio = target_space.min_sparse_ratio if target_space.min_sparse_ratio else 0.0 + max_sparse_ratio = target_space.max_sparse_ratio if target_space.max_sparse_ratio else 1.0 + if sparse_ratio < min_sparse_ratio: + mask = _ratio_mask(metric, min_sparse_ratio) + if sparse_ratio > max_sparse_ratio: + mask = _ratio_mask(metric, max_sparse_ratio) + + masks[module_name][target_name] = mask + return masks + + +def _generate_align_sparsity(masks: _MASKS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS: + align_masks = defaultdict(dict) + for module_name, ts in target_spaces.items(): + for target_name, target_space in ts.items(): + src_mask = masks[module_name][target_space.align['target_name']] # type: ignore + align_dims: List[int] = target_space.align['dims'] # type: ignore + reduce_dims = [d for d in range(len(src_mask.shape)) if d not in align_dims and d - len(src_mask.shape) not in align_dims] + align_masks[module_name][target_name] = src_mask.sum(reduce_dims).bool().float() + return align_masks + + +def _generate_global_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS: + groups: Dict[str, List[Tuple[str, str, PruningTargetSpace]]] = defaultdict(list) + for module_name, ts in target_spaces.items(): + for target_name, target_space in ts.items(): + groups[target_space.global_group_id].append((module_name, target_name, target_space)) # type: ignore + + masks = defaultdict(dict) + for _, group in groups.items(): + group_sparse_ratio = None + for _, _, target_space in group: + if target_space.sparse_ratio is not None: + if group_sparse_ratio is None: + group_sparse_ratio = target_space.sparse_ratio + else: + assert group_sparse_ratio == target_space.sparse_ratio + assert group_sparse_ratio is not None + + # at least how many elements to mask + sparse_number_low = 0 + # at most how many elements to mask + sparse_number_high = 0 + # how many elements in this group + total_element_number = 0 + for _, _, target_space in group: + element_number = target_space.target.numel() # type: ignore + total_element_number += element_number + sparse_number_low += int(element_number * target_space.min_sparse_ratio) if target_space.min_sparse_ratio else 0 + sparse_number_high += int(element_number * target_space.max_sparse_ratio) if target_space.max_sparse_ratio else element_number + # how many elements should be masked, controlled by sparse_ratio + sparse_number = int(total_element_number * group_sparse_ratio) + + if sparse_number <= sparse_number_low: + # directly generate masks with target_space.min_sparse_ratio + for module_name, target_name, target_space in group: + sparse_ratio = target_space.min_sparse_ratio if target_space.min_sparse_ratio else 0.0 + masks[module_name][target_name] = _ratio_mask(metrics[module_name][target_name], sparse_ratio) + continue + + if sparse_number >= sparse_number_high: + # directly generate masks with target_space.max_sparse_ratio + for module_name, target_name, target_space in group: + sparse_ratio = target_space.max_sparse_ratio if target_space.max_sparse_ratio else 0.0 + masks[module_name][target_name] = _ratio_mask(metrics[module_name][target_name], sparse_ratio) + continue + + sparse_threshold = _global_threshold_generate(metrics, group, sparse_number) + for module_name, target_name, target_space in group: + masks[module_name][target_name] = _threshold_mask(metrics[module_name][target_name], sparse_threshold) + continue + return masks + + +def _generate_dependency_sparsity(metrics: _METRICS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS: + groups: Dict[str, List[Tuple[str, str, PruningTargetSpace]]] = defaultdict(list) + for module_name, ts in target_spaces.items(): + for target_name, target_space in ts.items(): + groups[target_space.dependency_group_id].append((module_name, target_name, target_space)) # type: ignore + + masks = defaultdict(dict) + for _, group in groups.items(): + block_numbers = [1] + group_sparsity_ratio = None + filtered_metrics = defaultdict(dict) + + for module_name, target_name, target_space in group: + assert target_space.internal_metric_block is None or isinstance(target_space.internal_metric_block, int) + block_numbers.append(target_space.internal_metric_block if target_space.internal_metric_block else 1) + if target_space.sparse_ratio is not None: + if group_sparsity_ratio is None: + group_sparsity_ratio = target_space.sparse_ratio + else: + assert group_sparsity_ratio == target_space.sparse_ratio + filtered_metrics[module_name][target_name] = metrics[module_name][target_name] + block_number = reduce(numpy.lcm, block_numbers) + assert group_sparsity_ratio is not None + group_metric = _metric_fuse(filtered_metrics) + group_mask = _ratio_mask(group_metric, group_sparsity_ratio, view_size=[block_number, -1]) + + for module_name, target_name, _ in group: + masks[module_name][target_name] = group_mask.clone() + + return masks + + +# the following are helper functions + +def _ratio_mask(metric: torch.Tensor, sparse_ratio: float, view_size: int | List[int] = -1): + if sparse_ratio == 0.0: + return torch.ones_like(metric) + + if sparse_ratio == 1.0: + return torch.zeros_like(metric) + + assert 0.0 < sparse_ratio < 1.0 + if isinstance(view_size, int) or len(view_size[:-1]) == 0: + block_number = 1 + else: + block_number = numpy.prod(view_size[:-1]) + sparse_number_per_block = int(metric.numel() // block_number * sparse_ratio) + viewed_metric = metric.view(view_size) + _, indices = viewed_metric.topk(sparse_number_per_block, largest=False) + return torch.ones_like(viewed_metric).scatter(-1, indices, 0.0).reshape_as(metric) + + +def _threshold_mask(metric: torch.Tensor, sparse_threshold: float): + return (metric >= sparse_threshold).float().to(metric.device) + + +def _global_threshold_generate(metrics: _METRICS, + group: List[Tuple[str, str, PruningTargetSpace]], + sparse_number: int) -> float: + buffer = [] + buffer_elem = 0 + for module_name, target_name, target_space in group: + metric = metrics[module_name][target_name] + grain_size = target_space.target.numel() // metric.numel() # type: ignore + for m in metric.cpu().detach().view(-1): + if buffer_elem <= sparse_number: + heapq.heappush(buffer, (-m.item(), grain_size)) + buffer_elem += grain_size + else: + _, previous_grain_size = heapq.heappushpop(buffer, (-m.item(), grain_size)) + buffer_elem += grain_size - previous_grain_size + return -heapq.heappop(buffer)[0] + + +def _nested_multiply_update_masks(default_dict: _MASKS, update_dict: _MASKS): + # if a target already has a mask, the old one will multiply the new one as the target mask, + # that means the mask in default dict will more and more sparse. + for module_name, target_tensors in update_dict.items(): + for target_name, target_tensor in target_tensors.items(): + if target_name in default_dict[module_name] and isinstance(default_dict[module_name][target_name], torch.Tensor): + default_dict[module_name][target_name] = (default_dict[module_name][target_name] * target_tensor).bool().float() + else: + default_dict[module_name][target_name] = target_tensor + + +def _metric_fuse(metrics: _METRICS) -> torch.Tensor: + # mean all metric value + fused_metric = None + count = 0 + for _, module_metrics in metrics.items(): + for _, target_metric in module_metrics.items(): + if fused_metric is not None: + fused_metric += target_metric + else: + fused_metric = target_metric.clone() + count += 1 + assert fused_metric is not None + return fused_metric / count + + +def _expand_masks(masks: _MASKS, target_spaces: _PRUNING_TARGET_SPACES) -> _MASKS: + # expand the mask shape from metric shape to target shape + new_masks = defaultdict(dict) + for module_name, module_masks in masks.items(): + for target_name, target_mask in module_masks.items(): + target_space = target_spaces[module_name][target_name] + if target_space._scaler: + new_masks[module_name][target_name] = target_space._scaler.expand(target_mask, target_space.shape) # type: ignore + else: + new_masks[module_name][target_name] = target_mask + return new_masks diff --git a/nni/contrib/compression/pruning/tools/utils.py b/nni/contrib/compression/pruning/tools/utils.py new file mode 100644 index 000000000..cc8e4ce5d --- /dev/null +++ b/nni/contrib/compression/pruning/tools/utils.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from ...base.target_space import PruningTargetSpace + + +def is_active_target(target_space: PruningTargetSpace): + return target_space.sparse_ratio is not None or target_space.sparse_threshold is not None