fix(bugbash): bugs in compression and speedup (#5108)

This commit is contained in:
Louis-J 2022-09-06 08:11:22 +08:00 коммит произвёл GitHub
Родитель 5874c27fcd
Коммит 55158b78f5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 72 добавлений и 21 удалений

Просмотреть файл

@ -53,7 +53,7 @@ Coarse-grained pruning or structured pruning is pruning a regular group of weigh
Only :ref:`level-pruner` and :ref:`admm-pruner` support fine-grained pruning, all other pruners do some kind of structured pruning on weights.
.. _dependency-awareode-for-output-channel-pruning:
.. _dependency-aware-mode-for-output-channel-pruning:
Dependency-aware Mode for Output Channel Pruning
------------------------------------------------
@ -105,4 +105,6 @@ In addition, for the convolutional layers that have more than one filter group,
``dependency-aware pruner`` will also try to prune the same number of the channels for each filter group.
Overall, this pruner will prune the model according to the L1 norm of each filter and try to meet the topological constrains (channel dependency, etc) to improve the final speed gain after the speedup process.
.. Note:: Operations that will be recognized as having channel dependencies: add/sub/mul/div, addcmul/addcdiv, logical_and/or/xor
In the dependency-aware mode, the pruner will provide a better speed gain from the model pruning.

Просмотреть файл

@ -25,7 +25,7 @@ jitid_2_dtype = {4: torch.long, 6:torch.float32}
__all__ = [
'getattr_python', 'jit_to_python_function', 'num2tensor_python', 'parse_constant', 'slice_python',
'translate_list', 'tupleunpack_python', 'dtype_trans', 'memory_format_trans'
'translate_list', 'tupleunpack_python', 'arg_trans_dtype', 'arg_trans_memory_format', 'arg_trans_layout'
]
def translate_list(list_node: torch._C.Value, speedup: ModelSpeedup=None) -> List:
@ -273,11 +273,11 @@ enum_to_dtype_names = {
enum_to_dtype_dict = {}
for enum_value, dtype_name in enum_to_dtype_names.items():
if hasattr(torch, dtype_name):
enum_to_dtype_dict[enum_value] = getattr(torch, dtype_name)
for enum_value, name in enum_to_dtype_names.items():
if hasattr(torch, name):
enum_to_dtype_dict[enum_value] = getattr(torch, name)
def dtype_trans(ivalue: Union[int, torch.dtype]):
def arg_trans_dtype(ivalue: Union[int, torch.dtype]):
"""
Special process for dtype.
Torch will transform dtype to an enum in cpp, so the value of dtype we get in jit is an int.
@ -294,7 +294,7 @@ def dtype_trans(ivalue: Union[int, torch.dtype]):
elif isinstance(ivalue, int):
if ivalue in enum_to_dtype_dict:
return enum_to_dtype_dict[ivalue]
raise TypeError('No torch.dtype corresponding to the value "%s"', ivalue)
raise TypeError('No torch.dtype corresponding to the value "%s"' % ivalue)
enum_to_memory_format_dict = {
0: torch.contiguous_format,
@ -303,7 +303,7 @@ enum_to_memory_format_dict = {
3: torch.channels_last_3d,
}
def memory_format_trans(ivalue: Union[int, torch.memory_format]):
def arg_trans_memory_format(ivalue: Union[int, torch.memory_format]):
"""
Special process for memory_format.
Torch will transform memory_format to an enum in cpp, so the value of memory_format we get in jit is an int.
@ -318,14 +318,49 @@ def memory_format_trans(ivalue: Union[int, torch.memory_format]):
if ivalue is None or isinstance(ivalue, torch.memory_format):
return ivalue
elif isinstance(ivalue, int):
global enum_to_memory_format_dict
if ivalue in enum_to_memory_format_dict:
return enum_to_memory_format_dict[ivalue]
raise TypeError('No torch.memory_format corresponding to the value "%s"', ivalue)
raise TypeError('No torch.memory_format corresponding to the value "%s"' % ivalue)
enum_to_layout_names = {
0: 'strided',
1: 'sparse_coo',
2: 'sparse_csr',
3: '_mkldnn',
4: 'sparse_csc',
5: 'sparse_bsr',
6: 'sparse_bsc',
}
enum_to_layout_dict = {}
for enum_value, name in enum_to_layout_names.items():
if hasattr(torch, name):
enum_to_layout_dict[enum_value] = getattr(torch, name)
def arg_trans_layout(ivalue: Union[int, torch.layout]):
"""
Special process for layout.
Torch will transform layout to an enum in cpp, so the value of layout we get in jit is an int.
This function is used to recover the int to torch.layout in python.
Parameters
----------
ivalue
The value of layout or method to be recovered.
"""
if ivalue is None or isinstance(ivalue, torch.layout):
return ivalue
elif isinstance(ivalue, int):
if ivalue in enum_to_layout_dict:
return enum_to_layout_dict[ivalue]
raise TypeError('No torch.layout corresponding to the value "%s"' % ivalue)
special_treat_dict = {
'dtype': dtype_trans,
'memory_format': memory_format_trans,
'dtype': arg_trans_dtype,
'memory_format': arg_trans_memory_format,
'layout': arg_trans_layout,
}
schema_fix_dict = {
@ -358,6 +393,7 @@ schema_fix_dict = {
# """aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, int? dtype=None, int? layout=None, Device? devi
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}
@lru_cache(maxsize=256)
def parse_aten_schema(schema: str):
"""
@ -390,9 +426,9 @@ def parse_input_value(speedup: ModelSpeedup, input_nodes: List[torch._C.Node], p
"""
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
"""
positional = list()
keyword = dict()
undetermined = list()
positional: List[str] = list()
keyword: Dict[str, str] = dict()
undetermined: List[Union[int, str]] = list()
for ainput in input_nodes:
if ainput.node().kind() == 'prim::ListConstruct':
@ -404,17 +440,17 @@ def parse_input_value(speedup: ModelSpeedup, input_nodes: List[torch._C.Node], p
if len(positional) < positional_num:
undetermined.append(len(positional))
else:
undetermined.append(keyword_list[positional_num - len(positional)])
undetermined.append(keyword_list[len(keyword)])
arg = None
if len(positional) < positional_num:
positional.append(arg)
else:
keyword[keyword_list[positional_num - len(positional)]] = arg
keyword[keyword_list[len(keyword)]] = arg
return positional, keyword, undetermined
def special_treat_to_constant_value(positional: List, keyword: Dict[str], undetermined: List[Union[int, str]],
special_treat: Dict[Union[int, str], Callable]):
special_treat: Dict[Union[int, str], Callable]) -> Dict[Union[int, str], Callable]:
"""
if any argument with special_treat is not in undetermined, do the treat
"""

Просмотреть файл

@ -15,8 +15,15 @@ __all__ = ['ChannelDependency', 'GroupDependency', 'ReshapeDependency',
CONV_TYPE = 'aten::_convolution'
ADD_TYPES = ['aten::add', 'aten::add_']
MUL_TYPES = ['aten::mul', 'atem::mul_']
ADD_MUL_LOGICAL_TYPES = [
'aten::add', 'aten::add_', 'aten::sub', 'aten::sub_', 'aten::subtract', 'aten::subtract_',
'aten::mul', 'aten::mul_', 'aten::div', 'aten::div_', 'aten::multiply', 'aten::multiply_', 'aten::divide', 'aten::divide_',
'aten::addcmul', 'aten::addcmul_',
'aten::addcdiv', 'aten::addcdiv_',
'aten::logical_xor', 'aten::logical_xor_',
'aten::logical_and', 'aten::logical_and_',
'aten::logical_or', 'aten::logical_or_',
]
CAT_TYPE = 'aten::cat'
logger = logging.getLogger('Shape_Dependency')
RESHAPE_OPS = [CAT_TYPE, 'aten::view',
@ -173,7 +180,7 @@ class ChannelDependency(Dependency):
parent_layers = []
# find the node that contains aten::add
# or aten::cat operations
if node.op_type in ADD_TYPES or node.op_type in MUL_TYPES:
if node.op_type in ADD_MUL_LOGICAL_TYPES:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input

Просмотреть файл

@ -70,6 +70,12 @@ class TorchModel1(torch.nn.Module):
self.asub = ASubModel()
def forward(self, x: torch.Tensor):
y1 = torch.ones_like(x)
y2 = torch.rand_like(x)
y3 = torch.randn_like(x)
y4 = torch.zeros_like(x)
x = x - y1 + y2 + y3 + y4
x = x.contiguous(memory_format=torch.channels_last)
x = torch._C._nn.upsample_bilinear2d(x, (28, 28), False)
x = torch._C._nn.upsample_nearest2d(x, (28, 28))