[compression] pruning dependency v2 (#5569)

This commit is contained in:
super-dainiu 2023-06-26 15:02:29 +08:00 коммит произвёл GitHub
Родитель d5eeae4f2a
Коммит a9ff00b7e5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 516 добавлений и 1 удалений

Просмотреть файл

@ -0,0 +1,444 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import operator
from copy import deepcopy
from typing import (Any, Dict, List, Set)
from uuid import uuid4
import numpy as np
import torch
import torch.fx
from torch.fx.passes.shape_prop import TensorMetadata
from nni.contrib.compression.base.config import trans_legacy_config_list, select_modules_by_config
__all__ = [
'build_channel_dependency',
'build_group_dependency',
'build_reshape_dependency',
'build_weight_sharing_dependency',
'auto_set_denpendency_group_ids',
]
# see https://pytorch.org/docs/stable/torch.html#pointwise-ops
CALL_FUNCTION_REDUCE = [
torch.add, torch.sub, torch.subtract, torch.mul, torch.div, torch.multiply, torch.divide,
torch.addcmul, torch.addcdiv, torch.logical_xor, torch.logical_and, torch.logical_or,
operator.add, operator.iadd, operator.sub, operator.isub, operator.mul, operator.imul,
operator.truediv, operator.floordiv, operator.or_, operator.and_, operator.xor,
]
CALL_METHOD_REDUCE = [
'add', 'add_', 'sub', 'sub_', 'subtract', 'subtract_', 'mul', 'mul_',
'div', 'div_', 'multiply', 'multiply_', 'divide', 'divide_',
'addcmul', 'addcmul_', 'addcdiv', 'addcdiv_', 'logical_xor', 'logical_xor_',
'logical_and', 'logical_and_', 'logical_or', 'logical_or_',
]
# see https://pytorch.org/docs/stable/torch.html#indexing-slicing-joining-mutating-ops
CALL_FUNCTION_CONCAT = [torch.cat, torch.concat, torch.concatenate, torch.column_stack, torch.dstack,
torch.hstack, torch.row_stack, torch.stack]
CALL_FUNCTION_RESHAPE = [
*CALL_FUNCTION_CONCAT,
torch.chunk, torch.diagonal, torch.flatten, torch.flip, torch.fliplr, torch.flipud, torch.moveaxis, torch.movedim, torch.narrow,
torch.reshape, torch.split, torch.split_with_sizes, torch.squeeze, torch.unsqueeze, torch.transpose, torch.t, torch.permute,
torch.repeat_interleave,
]
CALL_METHOD_RESHAPE = [
'chunk', 'diagonal', 'flatten', 'flip', 'fliplr', 'flipud', 'moveaxis', 'movedim', 'narrow',
'reshape', 'split', 'split_with_sizes', 'squeeze', 'unsqueeze', 'transpose', 't', 'permute',
'view', 'view_as', 'expand', 'expand_as', 'expand_dims', 'repeat', 'repeat_interleave',
]
logger = logging.getLogger('Shape Dependency')
def lcm_list(L):
lcm = 1
for i in L:
lcm = np.lcm(lcm, i)
return lcm
def gcd_list(L):
gcd = L[0]
for i in L:
gcd = np.gcd(gcd, i)
return gcd
def channel_dependency_breakpoint(node: torch.fx.Node):
"""
Check if this operation will break the channel dependency.
Parameters
----------
node : torch.fx.Node
Returns
-------
bool
If this operation will break the channel dependency.
"""
if len(node.all_input_nodes) == 0:
return True
def find_parent(node: torch.fx.Node):
par_arg = node.all_input_nodes[0]
while isinstance(par_arg, tuple):
par_arg = par_arg[0]
return par_arg
def get_shape(node):
meta = node.meta.get('tensor_meta', None)
if isinstance(meta, TensorMetadata):
return meta[0]
elif isinstance(meta, tuple):
return meta[0][0]
return None
in_shape = get_shape(find_parent(node))
out_shape = get_shape(node)
if not in_shape or not out_shape or len(in_shape) <= 1 or len(out_shape) <= 1:
return True
in_channel = in_shape[1]
out_channel = out_shape[1]
# TODO: add more rules here
return in_channel != out_channel
def find_adjacent_layers(node: torch.fx.Node,
module: torch.fx.GraphModule,
target_types: List[torch.nn.Module],
direction: str = 'parent') -> List[torch.fx.Node]:
"""
Find the nearest layers of `target_types` in module for the target node given search direction.
Parameters
---------
node : torch.fx.Node
The target node.
module : torch.fx.GraphModule
The model to be analyzed.
target_types : list
The target types of the father layers.
direction : str
The search direction, 'parent' or 'child'.
Returns
-------
parent_layers: list
The nearest father conv/linear layers for the target worknode.
"""
layers = []
if direction == 'parent':
nodes = node.all_input_nodes
elif direction == 'child':
nodes = node.users
else:
raise ValueError("search direction should be 'parent' or 'child'")
for node in nodes:
if node.op == 'call_module':
target_module = module.get_submodule(node.target)
if isinstance(target_module, tuple(target_types)):
layers.append(node)
continue
elif channel_dependency_breakpoint(node):
continue
layers.extend(find_adjacent_layers(node, module, target_types, direction))
return layers
def auto_set_denpendency_group_ids(graph_module: torch.fx.GraphModule,
config_list: List[Dict[str, Any]],
prune_type: str = 'Filter',
prune_axis: int = 0) -> List[Dict[str, Any]]:
"""
Auto find the output dependency between all 'Conv2d', 'Linear', 'ConvTranspose2d', 'Embedding' modules,
then set the ``dependency_group_id`` in config list.
Note that a new dependency group id will be set as a shortcut in one config,
it will replace the old configured one in that config.
Parameters
----------
graph_module : torch.fx.GraphModule
The traced model to be analyzed.
config_list : list
The compression config list.
Returns
config_list : list
The new config list with dependency group id.
"""
channel_dependency = build_channel_dependency(graph_module, prune_type, prune_axis)
module2uid = {}
for d_set in channel_dependency:
uid = uuid4().hex
module2uid.update({node.target: uid for node in d_set})
group_dependency = build_group_dependency(graph_module)
group_dependency = {node.target: group_max for node, (group_max, group_min) in group_dependency.items()}
config_list = trans_legacy_config_list(config_list)
new_config_list = []
for config in config_list:
modules, public_config, _ = select_modules_by_config(graph_module, config)
for target in modules.keys():
sub_config = deepcopy(public_config)
if target in module2uid:
sub_config['dependency_group_id'] = module2uid[target]
if target in group_dependency:
sub_config['internal_metric_block'] = int(group_dependency[target])
new_config_list.append({
'op_names': [target],
**sub_config
})
return new_config_list
def convert_dependency_to_set(dependency: Dict[Any, Set[Any]]) -> List[Set[Any]]:
"""
Convert the dependency dict to sets of dependent nodes.
Parameters
----------
dependency : Dict
The individual dependency mapping for the target graph.
Returns
-------
List
A list of dependencies for the target graph.
"""
visited = set()
d_sets = []
for key, value in dependency.items():
if key not in visited:
tmp = set([key]) | value
visited.update(tmp)
if len(tmp) > 1:
d_sets.append(tmp)
return d_sets
def build_weight_sharing_dependency(graph_module: torch.fx.GraphModule) -> List[List[str]]:
"""
This model analyze the weight sharing dependencies between the conv
layers in a model. (e.g. different node refer to same module)
Parameters
----------
graph_module : torch.fx.GraphModule
The target graph and module.
Returns
-------
dependency : List
The weight sharing dependency for the target graph.
"""
dependency = {}
target_types = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ConvTranspose2d, torch.nn.Embedding]
for node in graph_module.graph.nodes:
if node.op == 'call_module':
target_module = graph_module.get_submodule(node.target)
if isinstance(target_module, tuple(target_types)):
dependency[node.target] = dependency.get(node.target, []) + [node]
return [dep for dep in dependency.values() if len(dep) > 1]
def build_channel_dependency(graph_module: torch.fx.GraphModule,
prune_type: str = 'Filter',
prune_axis: int = 0) -> List[Set[torch.fx.Node]]:
"""
This model analyze the channel dependencies between the conv
layers in a model.
Parameters
----------
graph_module : torch.fx.GraphModule
The target graph and module.
prune_type : str
The channel pruning type. `Filter` prune the filter of the conv
layer to prune the corresponding channels. `BatchNorm` prune the
batchnorm layer to prune the corresponding channels.
prune_axis : int
The pruned axis of the conv layer. 1 for output channel, 0 for input channel.
Returns
-------
dependency : List
The channel dependency for the target graph.
"""
if prune_type not in ['Filter', 'BatchNorm']:
raise ValueError('Unsupported prune type: {}'.format(prune_type))
if prune_axis not in [0, 1]:
raise ValueError('Unsupported prune axis: {}. Support 0 for input channel and 1 for output channel.'.format(prune_axis))
graph = graph_module.graph
target_types = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ConvTranspose2d, torch.nn.Embedding]
if prune_type == 'BatchNorm':
target_types.append(torch.nn.BatchNorm2d)
dependency = dict()
for node in graph.nodes:
d_set = set()
node: torch.fx.Node
# input channel dependency
if prune_axis == 1:
# Some pruners may prune the input channel of the convolutional
# layers. While pruning the input channel of the convolutional layers,
# the layers that share the same input tensor should prune the same
# channels, and we say these layers that share the same input tensor/channel
# has the input channel dependency. If we only prune the input channel of one
# layer in the dependency set, there will be a shape conflict for the other
# layers in the same dependency set, which may trigger a runtime error.
# Here we judge whether the application will truncate the dependency by analyzing
# whether the number of channels before and after the operation has changed.
# If not, the input channel dependency will be passed to the following nodes.
d_set = set(find_adjacent_layers(node, graph_module, target_types, 'child'))
# output channel dependency
else:
# Output channel dependency indicates the dependent layers when pruning the
# output channles of conv layers (for example, L1FilterPruner).
if node.op == 'call_module':
submodule = graph_module.get_submodule(node.target)
# additional denpendency for (group number == output channel number) depth-wise conv:
if (isinstance(submodule, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)) and submodule.groups == submodule.out_channels) \
or (isinstance(submodule, torch.nn.GroupNorm) and submodule.num_groups == submodule.num_channels):
d_set = set([node] + find_adjacent_layers(node, graph_module, target_types, 'parent'))
elif node.op == 'call_function':
if node.target in CALL_FUNCTION_REDUCE:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input
# tensor is from skip connection(residual), the channel propagation
# may be failed(the input is also used by another layer and cannot be
# pruned), in this case, we need to fix the conflict maunally.
d_set = set(find_adjacent_layers(node, graph_module, target_types, 'parent'))
elif node.target in CALL_FUNCTION_CONCAT:
# To determine if this cat operation will introduce channel
# dependency, we need the specific input parameters of the cat
# operation.
cat_dim = node.kwargs.get('dim', None) or node.args[1]
if cat_dim != 1:
d_set = set(find_adjacent_layers(node, graph_module, target_types, 'parent'))
elif node.op == 'call_method':
if node.target in CALL_METHOD_REDUCE:
d_set = set(find_adjacent_layers(node, graph_module, target_types, 'parent'))
# merge dependencies
for parent in tuple(d_set):
if parent in dependency:
d_set |= dependency[parent]
for parent in d_set:
dependency[parent] = d_set
return convert_dependency_to_set(dependency)
def build_group_dependency(graph_module: torch.fx.GraphModule):
"""
This model analyze the group dependencis between the conv
layers in a model.
Parameters
----------
graph_module : torch.fx.GraphModule
The target graph and module.
Returns
-------
dependency : Dict
The group dependency for the target graph.
"""
graph = graph_module.graph
dependency = dict()
groups = dict()
for node in graph.nodes:
node: torch.fx.Node
d_set = set()
# Build the channel dependency for the conv layers in the model.
# This function return the group number of each conv layers. Note
# that, here, the group count of conv layers may be larger than
# their originl groups. This is because that the input channel
# will also be grouped for the group conv layers. To make this
# clear, assume we have two group conv layers: conv1(group=2),
# conv2(group=4). conv2 takes the output features of conv1 as input.
# Then we have to the filters of conv1 can still be divided into
# 4 groups after filter pruning, because the input channels of
# conv2 should be divided into 4 groups.
if node.op == 'call_module':
num_groups = 0
submodule = graph_module.get_submodule(node.target)
if isinstance(submodule, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
# depthwise conv will not introduce extra group dependency
num_groups = submodule.groups if submodule.groups != submodule.out_channels else 1
elif isinstance(submodule, torch.nn.GroupNorm):
num_groups = submodule.num_groups
else:
continue
groups[node] = groups.get(node, []) + [num_groups]
if num_groups > 1:
# for the conv layer whose group is larger than 1, it will require the number
# of output channels of their parent conv layer to be divisible by group.
d_set = find_adjacent_layers(node, graph_module, [torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.Linear], 'parent')
for parent in d_set:
groups[parent] = groups.get(parent, []) + [num_groups]
for node in groups:
group_max = lcm_list(groups[node])
group_min = gcd_list(groups[node]) if min(groups[node]) == gcd_list(groups[node]) else 1
if group_max == 1 and group_min == 1:
continue
dependency[node] = (group_max, group_min)
return dependency
def build_reshape_dependency(graph_module: torch.fx.GraphModule):
"""
Some model may have the view/reshape functions, such functions may have fixed parameters
and cannot be replaced at all. Therefore, these functions may have some constraints on
their input shapes. In this class, we find the direct input conv/linear layers of these
reshape functions. If you get the shape conflict when run the forward inference on the
speeduped model, please try remove these layers from the pruner config list and try again.
"""
graph = graph_module.graph
dependency = dict()
for node in graph.nodes:
parent_layers = []
# find the node that contains the reshape-like function
if node.op == 'call_function' and node.target in [torch.narrow, torch.reshape]:
logger.info(f'Detect reshape-like functions: {node.target}, args = {node.args}, kwargs = {node.kwargs}')
parent_layers = find_adjacent_layers(node, graph_module, [torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.Linear], 'parent')
dependency[node] = parent_layers
elif node.op == 'call_method' and node.target in ['narrow', 'reshape', 'view']:
logger.info(f'Detect reshape-like methods: {node.target}, args = {node.args}, kwargs = {node.kwargs}')
parent_layers = find_adjacent_layers(node, graph_module, [torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.Linear], 'parent')
dependency[node] = parent_layers
return dependency

Просмотреть файл

@ -14,6 +14,7 @@ import torch.fx
from torch.fx import GraphModule
from torch.fx._compatibility import compatibility
from torch.fx.node import Node, Target
from torch.fx.passes.shape_prop import ShapeProp
from nni.common.concrete_trace_utils import concrete_trace
from nni.compression.pytorch.utils import set_nested_attr
@ -36,6 +37,8 @@ def _normalize_input(dummy_input: Any) -> Any:
dummy_input = (dummy_input, )
elif isinstance(dummy_input, list):
dummy_input = tuple(dummy_input)
elif isinstance(dummy_input, dict):
dummy_input = tuple(dummy_input.values())
return dummy_input
@ -97,7 +100,14 @@ class ModelSpeedup(torch.fx.Interpreter):
logger: logging.Logger | None = None):
self.dummy_input = _normalize_input(dummy_input)
self.bound_model = model
self.graph_module = graph_module if isinstance(graph_module, GraphModule) else concrete_trace(model, self.dummy_input)
if isinstance(graph_module, GraphModule):
self.graph_module = graph_module
elif isinstance(dummy_input, dict):
self.graph_module = concrete_trace(model, dummy_input)
else:
self.graph_module = concrete_trace(model, self.dummy_input)
ShapeProp(self.graph_module).propagate(*self.dummy_input) # attach shape to graph_module
super().__init__(self.graph_module, garbage_collect_values)

Просмотреть файл

@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytest
import torch
from nni.common.concrete_trace_utils import concrete_trace
from nni.compression.pytorch.speedup.v2.dependency import build_channel_dependency
# should have dependency between conv1 and conv2 and conv3 and conv4
class PatternA(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(3, 16, 3, 1, 1)
self.conv3 = torch.nn.Conv2d(3, 16, 3, 1, 1)
self.conv4 = torch.nn.Conv2d(3, 16, 3, 1, 1)
def forward(self, x: torch.Tensor):
return torch.cat((self.conv1(x), self.conv2(x)), dim = 0) + torch.cat((self.conv3(x), self.conv4(x)), dim = 0)
# should not have dependency
class PatternB(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
def forward(self, x: torch.Tensor):
return torch.cat((self.conv1(x), self.conv2(x)), dim = 1)
# should not have dependency (shape breakpoint)
class PatternC(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(3, 16, 3, 1, 1)
def forward(self, x: torch.Tensor):
return torch.sum(self.conv1(x), dim = 1) + torch.sum(self.conv2(x), dim = 1)
def check_sets_all_match(a, b):
assert len(a) == len(b)
for s in a:
s = {node.name for node in s}
assert s in b
b.pop(b.index(s))
assert len(b) == 0
@pytest.mark.parametrize('mod, deps', [
(PatternA, [{'conv1', 'conv2'}, {'conv3', 'conv4'}]),
(PatternB, []),
(PatternC, []),
])
def test_channel_dependency(mod, deps):
model = mod()
dummy_input = (torch.randn(1, 3, 224, 224), )
traced = concrete_trace(model, dummy_input)
dependency = build_channel_dependency(traced)
check_sets_all_match(dependency, deps)