[Compression] pruning speedup: support input/output masks (#5385)

This commit is contained in:
J-shang 2023-03-03 22:15:44 +08:00 коммит произвёл GitHub
Родитель e7828fa32b
Коммит 6bd93e5ad1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 88 добавлений и 121 удалений

Просмотреть файл

@ -4,6 +4,7 @@
from __future__ import annotations
from copy import deepcopy
import inspect
import logging
from pathlib import Path
import tempfile
@ -34,7 +35,7 @@ from .utils import tree_map_zip
@compatibility(is_backward_compatible=True)
class ModelSpeedup(torch.fx.Interpreter):
"""
This class is to speedup the model with provided weight mask.
This class is to speedup the model with provided masks.
Parameters
----------
@ -211,6 +212,8 @@ class ModelSpeedup(torch.fx.Interpreter):
self.node_infos[node].output_origin = output
self.node_infos[node].output_inplace = \
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else deepcopy(t), output)
self.node_infos[node].output_masks = \
tree_map_zip(lambda t: torch.ones_like(t).clone().detach() if isinstance(t, torch.Tensor) else None, output)
if self.garbage_collect_values:
# do memory collect to reduce memory usage
@ -309,84 +312,48 @@ class ModelSpeedup(torch.fx.Interpreter):
self.node_infos[node].mask_updater = mask_updater
break
# The following code is related to preset input/output mask...
for node_info in self.node_infos.values():
if node_info.module is None:
continue
masks = self.masks_file.get(node_info.node.target, {})
# node_to_masks = defaultdict(list)
# for node in (node for node in self.node_infos if self.node_infos[node].module is not None):
# # some 'call_module's has no 'module' such as 'log_softmax'
# param_masks = self.masks_file.get(node.target, {})
# if '_output_' in param_masks:
# node_to_masks[node].append(param_masks['_output_'])
# if '_input_' in param_masks:
# func = self.fetch_attr(node.target).forward
# while hasattr(func, '__wrapped__'):
# func = func.__wrapped__
# arg_list = inspect.getfullargspec(func).args
# kw_to_posi = dict(zip(arg_list[1:], range(len(arg_list))[1:]))
# node_kw = {
# **dict(zip(range(len(arg_list))[1:], node.args)),
# **dict(zip(arg_list[1:], node.args)),
# **{kw_to_posi[k]: v for k, v in node.kwargs.items()},
# **node.kwargs,
# }
# for key, mask in param_masks['_input_'].items():
# node_to_masks[node_kw[key]].append(mask)
output_masks = {name: masks[name] for name in filter(lambda name: name.startswith('_output_'), masks.keys())}
if output_masks:
if isinstance(node_info.output_masks, torch.Tensor):
node_info.output_masks *= list(output_masks.values())[0]
elif isinstance(node_info.output_masks, (list, tuple)):
for key, mask in output_masks.items():
key = key.split('_output_')[1]
assert key.isnumeric()
if mask is not None:
node_info.output_masks[int(key)] *= mask
elif isinstance(node_info.output_masks, dict):
for key, mask in output_masks.items():
if mask is not None:
key = key.split('_output_')[1]
node_info.output_masks[key] *= mask
else:
raise RuntimeError(f'Unsupported output type {type(node_info.output_masks)}.')
# def check_equal(a, b):
# if type(a) != type(b):
# return False
# if isinstance(a, (list, tuple)):
# if len(a) != len(b):
# return False
# for sub_a, sub_b in zip(a, b):
# if not check_equal(sub_a, sub_b):
# return False
# return True
# elif isinstance(a, dict):
# if len(set(a.keys()).symmetric_difference(b.keys())) != 0:
# return False
# for key in a.keys():
# if not check_equal(a[key], b[key]):
# return False
# return True
# else:
# assert isinstance(a, torch.Tensor), f'contents in masks can only be (list, tuple, dict, Tensor), not ({type(a)})'
# return torch.equal(a, b) # totally equal, no bias
# def check_valid(a, b):
# if isinstance(a, (list, tuple)):
# if not isinstance(b, (list, tuple)):
# return False
# if len(a) != len(b):
# return False
# for sub_a, sub_b in zip(a, b):
# if not check_equal(sub_a, sub_b):
# return False
# return True
# elif isinstance(a, dict):
# if not isinstance(b, dict):
# return False
# if len(set(a.keys()).symmetric_difference(b.keys())) != 0:
# return False
# for key in a.keys():
# if not check_equal(a[key], b[key]):
# return False
# return True
# elif isinstance(a, torch.Tensor):
# if not isinstance(b, torch.Tensor):
# return False
# return a.shape == b.shape
# else:
# return b is None
# for node, masks in node_to_masks.items():
# if len(masks) >= 1:
# mask = masks[0]
# for mask_next in masks[1:]:
# assert check_equal(mask, mask_next), f'preset-masks of "{node}" are not euqal!'
# assert check_valid(self.node_infos[node].output_origin, mask),\
# f'structure of preset-mask and value of slot "{node}" are not euqal!'
# self.node_infos[node].output_masks = mask
input_masks = {name: masks[name] for name in filter(lambda name: name.startswith('_input_'), masks.keys())}
if input_masks:
func = self.fetch_attr(node_info.node.target).forward
while hasattr(func, '__wrapped__'):
func = func.__wrapped__
arg_list = inspect.getfullargspec(func).args
kw_to_posi = dict(zip(arg_list[1:], range(len(arg_list) - 1)))
node_kw = {
**dict(zip(range(len(arg_list) - 1), node_info.node.args)),
**dict(zip(arg_list[1:], node_info.node.args)),
**{kw_to_posi[k]: v for k, v in node.kwargs.items()},
**node_info.node.kwargs,
}
for key, mask in input_masks.items():
key = key.split('_input_')[1]
key = int(key) if key.isnumeric() else key
if isinstance(mask, torch.Tensor):
assert isinstance(self.node_infos[node_kw[key]].output_masks, torch.Tensor)
self.node_infos[node_kw[key]].output_masks *= mask.detach().clone()
def speedup_model(self) -> GraphModule:
try:

Просмотреть файл

@ -70,4 +70,5 @@ def tree_map_zip(fn: Any, *pytrees):
flat_args, spec = tree_flatten(pytree)
flat_args_list.append(flat_args)
spec_list.append(spec)
assert all(len(args) == len(flat_args_list[0]) for args in flat_args_list), 'Inconsistent tree nodes length.'
return tree_unflatten([fn(*args) for args in zip(*flat_args_list)], spec_list[0])

Просмотреть файл

@ -127,7 +127,7 @@ class GroupMaskConflict(MaskFix):
for layername in depens:
group_max = depens[layername]
group_min = min_groups[layername]
if layername not in self.masks:
if layername not in self.masks or 'weight' not in self.masks[layername]:
# this layer not pruned
continue
w_mask = self.masks[layername]['weight']
@ -222,7 +222,7 @@ class ChannelMaskConflict(MaskFix):
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
(_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
device = _tmp_tensor['weight'].device
device = list(_tmp_tensor.values())[0].device
for dset in depen_sets:
if len(dset) <= 1:
@ -233,7 +233,7 @@ class ChannelMaskConflict(MaskFix):
channel_masks = []
fine_grained = False
for name in dset:
if name in self.masks:
if name in self.masks and 'weight' in self.masks[name]:
_, m = get_module_by_name(self.model, name)
assert m is not None
mask = self.masks[name]['weight']
@ -290,7 +290,7 @@ class ChannelMaskConflict(MaskFix):
merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]
for name in dset:
if name not in self.masks:
if name not in self.masks or 'weight' not in self.masks[name]:
assert all(merged_channel_mask)
continue
orig_mask = self.masks[name]['weight']
@ -377,6 +377,9 @@ def detect_mask_prune_dim(masks, model):
dim0_preserved, dim1_preserved = 0., 0.
dim0_num, dim1_num = 0., 0.
for module_name in masks:
if 'weight' not in masks[module_name]:
continue
_, m = get_module_by_name(model, module_name)
if m is None or type(m).__name__ != 'Conv2d':
continue

Просмотреть файл

@ -84,53 +84,49 @@ class TorchModel(torch.nn.Module):
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# class InitMaskTestCase(unittest.TestCase):
# def the_test_with_annotations(self, relu):
# torch.manual_seed(100)
# model = TorchModel(relu)
# dummy_input = torch.rand(3, 1, 28, 28)
class InitMaskTestCase(unittest.TestCase):
def the_test_with_annotations(self, relu):
torch.manual_seed(100)
model = TorchModel(relu)
dummy_input = torch.rand(3, 1, 28, 28)
# config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
# pruner = L1NormPruner(model=model, config_list=config_list)
# _, masks = pruner.compress()
# pruner.show_pruned_weights()
# pruner._unwrap_model() # unwrap all modules to normal state
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
pruner = L1NormPruner(model=model, config_list=config_list)
_, masks = pruner.compress()
pruner.show_pruned_weights()
pruner._unwrap_model() # unwrap all modules to normal state
# masks['relu1'] = {
# '_input_': {
# 'input': torch.ones((8, 20, 24, 24)),
# 1: torch.ones((8, 20, 24, 24))
# },
# '_output_': torch.ones((8, 20, 24, 24)),
# }
# masks['conv1']['_output_'] = torch.ones((8, 20, 24, 24))
masks['relu1'] = {
'_input_input': torch.ones((8, 20, 24, 24)),
'_output_0': torch.ones((8, 20, 24, 24)),
}
masks['conv1']['_output_0'] = torch.ones((8, 20, 24, 24))
# traced_model = concrete_trace(model, {'x': dummy_input}, leaf_module=(WithAnno1, WithAnno2, WithAnno3))
# ModelSpeedup(traced_model, customized_replace_func = {'WithAnno1': no_replace, 'WithAnno2': no_replace, 'WithAnno3': no_replace}
# ).run(args=[dummy_input], masks_file=masks)
# traced_model(dummy_input)
traced_model = concrete_trace(model, {'x': dummy_input}, leaf_module=(WithAnno1, WithAnno2, WithAnno3))
ModelSpeedup(traced_model, (dummy_input,), masks).speedup_model()
traced_model(dummy_input)
# print('model before speedup', repr(model))
# # 125.49 M, 0.85M, 93.29, 1.1012
# flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
# print(f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
print('model before speedup', repr(model))
# 125.49 M, 0.85M, 93.29, 1.1012
flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
print(f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
# print('model after speedup', repr(traced_model))
# flops, params, _ = count_flops_params(traced_model, dummy_input, verbose=False)
# print(f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
print('model after speedup', repr(traced_model))
flops, params, _ = count_flops_params(traced_model, dummy_input, verbose=False)
print(f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
# def test_with_annotation0(self):
# return self.the_test_with_annotations(torch.nn.ReLU6())
def test_with_annotation0(self):
return self.the_test_with_annotations(torch.nn.ReLU6())
# def test_with_annotation1(self):
# return self.the_test_with_annotations(WithAnno1())
def test_with_annotation1(self):
return self.the_test_with_annotations(WithAnno1())
# def test_with_annotation2(self):
# return self.the_test_with_annotations(WithAnno2())
def test_with_annotation2(self):
return self.the_test_with_annotations(WithAnno2())
# def test_with_annotation3(self):
# return self.the_test_with_annotations(WithAnno3())
def test_with_annotation3(self):
return self.the_test_with_annotations(WithAnno3())
# if __name__ == '__main__':
# # unittest.main()
# InitMaskTestCase().test_with_annotation1()
if __name__ == '__main__':
# unittest.main()
InitMaskTestCase().test_with_annotation1()