зеркало из https://github.com/microsoft/nni.git
[Compression] pruning speedup: support input/output masks (#5385)
This commit is contained in:
Родитель
e7828fa32b
Коммит
6bd93e5ad1
|
@ -4,6 +4,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
@ -34,7 +35,7 @@ from .utils import tree_map_zip
|
|||
@compatibility(is_backward_compatible=True)
|
||||
class ModelSpeedup(torch.fx.Interpreter):
|
||||
"""
|
||||
This class is to speedup the model with provided weight mask.
|
||||
This class is to speedup the model with provided masks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -211,6 +212,8 @@ class ModelSpeedup(torch.fx.Interpreter):
|
|||
self.node_infos[node].output_origin = output
|
||||
self.node_infos[node].output_inplace = \
|
||||
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else deepcopy(t), output)
|
||||
self.node_infos[node].output_masks = \
|
||||
tree_map_zip(lambda t: torch.ones_like(t).clone().detach() if isinstance(t, torch.Tensor) else None, output)
|
||||
|
||||
if self.garbage_collect_values:
|
||||
# do memory collect to reduce memory usage
|
||||
|
@ -309,84 +312,48 @@ class ModelSpeedup(torch.fx.Interpreter):
|
|||
self.node_infos[node].mask_updater = mask_updater
|
||||
break
|
||||
|
||||
# The following code is related to preset input/output mask...
|
||||
for node_info in self.node_infos.values():
|
||||
if node_info.module is None:
|
||||
continue
|
||||
masks = self.masks_file.get(node_info.node.target, {})
|
||||
|
||||
# node_to_masks = defaultdict(list)
|
||||
# for node in (node for node in self.node_infos if self.node_infos[node].module is not None):
|
||||
# # some 'call_module's has no 'module' such as 'log_softmax'
|
||||
# param_masks = self.masks_file.get(node.target, {})
|
||||
# if '_output_' in param_masks:
|
||||
# node_to_masks[node].append(param_masks['_output_'])
|
||||
# if '_input_' in param_masks:
|
||||
# func = self.fetch_attr(node.target).forward
|
||||
# while hasattr(func, '__wrapped__'):
|
||||
# func = func.__wrapped__
|
||||
# arg_list = inspect.getfullargspec(func).args
|
||||
# kw_to_posi = dict(zip(arg_list[1:], range(len(arg_list))[1:]))
|
||||
# node_kw = {
|
||||
# **dict(zip(range(len(arg_list))[1:], node.args)),
|
||||
# **dict(zip(arg_list[1:], node.args)),
|
||||
# **{kw_to_posi[k]: v for k, v in node.kwargs.items()},
|
||||
# **node.kwargs,
|
||||
# }
|
||||
# for key, mask in param_masks['_input_'].items():
|
||||
# node_to_masks[node_kw[key]].append(mask)
|
||||
output_masks = {name: masks[name] for name in filter(lambda name: name.startswith('_output_'), masks.keys())}
|
||||
if output_masks:
|
||||
if isinstance(node_info.output_masks, torch.Tensor):
|
||||
node_info.output_masks *= list(output_masks.values())[0]
|
||||
elif isinstance(node_info.output_masks, (list, tuple)):
|
||||
for key, mask in output_masks.items():
|
||||
key = key.split('_output_')[1]
|
||||
assert key.isnumeric()
|
||||
if mask is not None:
|
||||
node_info.output_masks[int(key)] *= mask
|
||||
elif isinstance(node_info.output_masks, dict):
|
||||
for key, mask in output_masks.items():
|
||||
if mask is not None:
|
||||
key = key.split('_output_')[1]
|
||||
node_info.output_masks[key] *= mask
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported output type {type(node_info.output_masks)}.')
|
||||
|
||||
# def check_equal(a, b):
|
||||
# if type(a) != type(b):
|
||||
# return False
|
||||
# if isinstance(a, (list, tuple)):
|
||||
# if len(a) != len(b):
|
||||
# return False
|
||||
# for sub_a, sub_b in zip(a, b):
|
||||
# if not check_equal(sub_a, sub_b):
|
||||
# return False
|
||||
# return True
|
||||
# elif isinstance(a, dict):
|
||||
# if len(set(a.keys()).symmetric_difference(b.keys())) != 0:
|
||||
# return False
|
||||
# for key in a.keys():
|
||||
# if not check_equal(a[key], b[key]):
|
||||
# return False
|
||||
# return True
|
||||
# else:
|
||||
# assert isinstance(a, torch.Tensor), f'contents in masks can only be (list, tuple, dict, Tensor), not ({type(a)})'
|
||||
# return torch.equal(a, b) # totally equal, no bias
|
||||
|
||||
# def check_valid(a, b):
|
||||
# if isinstance(a, (list, tuple)):
|
||||
# if not isinstance(b, (list, tuple)):
|
||||
# return False
|
||||
# if len(a) != len(b):
|
||||
# return False
|
||||
# for sub_a, sub_b in zip(a, b):
|
||||
# if not check_equal(sub_a, sub_b):
|
||||
# return False
|
||||
# return True
|
||||
# elif isinstance(a, dict):
|
||||
# if not isinstance(b, dict):
|
||||
# return False
|
||||
# if len(set(a.keys()).symmetric_difference(b.keys())) != 0:
|
||||
# return False
|
||||
# for key in a.keys():
|
||||
# if not check_equal(a[key], b[key]):
|
||||
# return False
|
||||
# return True
|
||||
# elif isinstance(a, torch.Tensor):
|
||||
# if not isinstance(b, torch.Tensor):
|
||||
# return False
|
||||
# return a.shape == b.shape
|
||||
# else:
|
||||
# return b is None
|
||||
|
||||
# for node, masks in node_to_masks.items():
|
||||
# if len(masks) >= 1:
|
||||
# mask = masks[0]
|
||||
# for mask_next in masks[1:]:
|
||||
# assert check_equal(mask, mask_next), f'preset-masks of "{node}" are not euqal!'
|
||||
# assert check_valid(self.node_infos[node].output_origin, mask),\
|
||||
# f'structure of preset-mask and value of slot "{node}" are not euqal!'
|
||||
# self.node_infos[node].output_masks = mask
|
||||
input_masks = {name: masks[name] for name in filter(lambda name: name.startswith('_input_'), masks.keys())}
|
||||
if input_masks:
|
||||
func = self.fetch_attr(node_info.node.target).forward
|
||||
while hasattr(func, '__wrapped__'):
|
||||
func = func.__wrapped__
|
||||
arg_list = inspect.getfullargspec(func).args
|
||||
kw_to_posi = dict(zip(arg_list[1:], range(len(arg_list) - 1)))
|
||||
node_kw = {
|
||||
**dict(zip(range(len(arg_list) - 1), node_info.node.args)),
|
||||
**dict(zip(arg_list[1:], node_info.node.args)),
|
||||
**{kw_to_posi[k]: v for k, v in node.kwargs.items()},
|
||||
**node_info.node.kwargs,
|
||||
}
|
||||
for key, mask in input_masks.items():
|
||||
key = key.split('_input_')[1]
|
||||
key = int(key) if key.isnumeric() else key
|
||||
if isinstance(mask, torch.Tensor):
|
||||
assert isinstance(self.node_infos[node_kw[key]].output_masks, torch.Tensor)
|
||||
self.node_infos[node_kw[key]].output_masks *= mask.detach().clone()
|
||||
|
||||
def speedup_model(self) -> GraphModule:
|
||||
try:
|
||||
|
|
|
@ -70,4 +70,5 @@ def tree_map_zip(fn: Any, *pytrees):
|
|||
flat_args, spec = tree_flatten(pytree)
|
||||
flat_args_list.append(flat_args)
|
||||
spec_list.append(spec)
|
||||
assert all(len(args) == len(flat_args_list[0]) for args in flat_args_list), 'Inconsistent tree nodes length.'
|
||||
return tree_unflatten([fn(*args) for args in zip(*flat_args_list)], spec_list[0])
|
||||
|
|
|
@ -127,7 +127,7 @@ class GroupMaskConflict(MaskFix):
|
|||
for layername in depens:
|
||||
group_max = depens[layername]
|
||||
group_min = min_groups[layername]
|
||||
if layername not in self.masks:
|
||||
if layername not in self.masks or 'weight' not in self.masks[layername]:
|
||||
# this layer not pruned
|
||||
continue
|
||||
w_mask = self.masks[layername]['weight']
|
||||
|
@ -222,7 +222,7 @@ class ChannelMaskConflict(MaskFix):
|
|||
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
|
||||
|
||||
(_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
|
||||
device = _tmp_tensor['weight'].device
|
||||
device = list(_tmp_tensor.values())[0].device
|
||||
|
||||
for dset in depen_sets:
|
||||
if len(dset) <= 1:
|
||||
|
@ -233,7 +233,7 @@ class ChannelMaskConflict(MaskFix):
|
|||
channel_masks = []
|
||||
fine_grained = False
|
||||
for name in dset:
|
||||
if name in self.masks:
|
||||
if name in self.masks and 'weight' in self.masks[name]:
|
||||
_, m = get_module_by_name(self.model, name)
|
||||
assert m is not None
|
||||
mask = self.masks[name]['weight']
|
||||
|
@ -290,7 +290,7 @@ class ChannelMaskConflict(MaskFix):
|
|||
merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]
|
||||
|
||||
for name in dset:
|
||||
if name not in self.masks:
|
||||
if name not in self.masks or 'weight' not in self.masks[name]:
|
||||
assert all(merged_channel_mask)
|
||||
continue
|
||||
orig_mask = self.masks[name]['weight']
|
||||
|
@ -377,6 +377,9 @@ def detect_mask_prune_dim(masks, model):
|
|||
dim0_preserved, dim1_preserved = 0., 0.
|
||||
dim0_num, dim1_num = 0., 0.
|
||||
for module_name in masks:
|
||||
if 'weight' not in masks[module_name]:
|
||||
continue
|
||||
|
||||
_, m = get_module_by_name(model, module_name)
|
||||
if m is None or type(m).__name__ != 'Conv2d':
|
||||
continue
|
||||
|
|
|
@ -84,53 +84,49 @@ class TorchModel(torch.nn.Module):
|
|||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
# class InitMaskTestCase(unittest.TestCase):
|
||||
# def the_test_with_annotations(self, relu):
|
||||
# torch.manual_seed(100)
|
||||
# model = TorchModel(relu)
|
||||
# dummy_input = torch.rand(3, 1, 28, 28)
|
||||
class InitMaskTestCase(unittest.TestCase):
|
||||
def the_test_with_annotations(self, relu):
|
||||
torch.manual_seed(100)
|
||||
model = TorchModel(relu)
|
||||
dummy_input = torch.rand(3, 1, 28, 28)
|
||||
|
||||
# config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
|
||||
# pruner = L1NormPruner(model=model, config_list=config_list)
|
||||
# _, masks = pruner.compress()
|
||||
# pruner.show_pruned_weights()
|
||||
# pruner._unwrap_model() # unwrap all modules to normal state
|
||||
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.5}]
|
||||
pruner = L1NormPruner(model=model, config_list=config_list)
|
||||
_, masks = pruner.compress()
|
||||
pruner.show_pruned_weights()
|
||||
pruner._unwrap_model() # unwrap all modules to normal state
|
||||
|
||||
# masks['relu1'] = {
|
||||
# '_input_': {
|
||||
# 'input': torch.ones((8, 20, 24, 24)),
|
||||
# 1: torch.ones((8, 20, 24, 24))
|
||||
# },
|
||||
# '_output_': torch.ones((8, 20, 24, 24)),
|
||||
# }
|
||||
# masks['conv1']['_output_'] = torch.ones((8, 20, 24, 24))
|
||||
masks['relu1'] = {
|
||||
'_input_input': torch.ones((8, 20, 24, 24)),
|
||||
'_output_0': torch.ones((8, 20, 24, 24)),
|
||||
}
|
||||
masks['conv1']['_output_0'] = torch.ones((8, 20, 24, 24))
|
||||
|
||||
# traced_model = concrete_trace(model, {'x': dummy_input}, leaf_module=(WithAnno1, WithAnno2, WithAnno3))
|
||||
# ModelSpeedup(traced_model, customized_replace_func = {'WithAnno1': no_replace, 'WithAnno2': no_replace, 'WithAnno3': no_replace}
|
||||
# ).run(args=[dummy_input], masks_file=masks)
|
||||
# traced_model(dummy_input)
|
||||
traced_model = concrete_trace(model, {'x': dummy_input}, leaf_module=(WithAnno1, WithAnno2, WithAnno3))
|
||||
ModelSpeedup(traced_model, (dummy_input,), masks).speedup_model()
|
||||
traced_model(dummy_input)
|
||||
|
||||
# print('model before speedup', repr(model))
|
||||
# # 125.49 M, 0.85M, 93.29, 1.1012
|
||||
# flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
|
||||
# print(f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
|
||||
print('model before speedup', repr(model))
|
||||
# 125.49 M, 0.85M, 93.29, 1.1012
|
||||
flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
|
||||
print(f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
|
||||
|
||||
# print('model after speedup', repr(traced_model))
|
||||
# flops, params, _ = count_flops_params(traced_model, dummy_input, verbose=False)
|
||||
# print(f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
|
||||
print('model after speedup', repr(traced_model))
|
||||
flops, params, _ = count_flops_params(traced_model, dummy_input, verbose=False)
|
||||
print(f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M')
|
||||
|
||||
# def test_with_annotation0(self):
|
||||
# return self.the_test_with_annotations(torch.nn.ReLU6())
|
||||
def test_with_annotation0(self):
|
||||
return self.the_test_with_annotations(torch.nn.ReLU6())
|
||||
|
||||
# def test_with_annotation1(self):
|
||||
# return self.the_test_with_annotations(WithAnno1())
|
||||
def test_with_annotation1(self):
|
||||
return self.the_test_with_annotations(WithAnno1())
|
||||
|
||||
# def test_with_annotation2(self):
|
||||
# return self.the_test_with_annotations(WithAnno2())
|
||||
def test_with_annotation2(self):
|
||||
return self.the_test_with_annotations(WithAnno2())
|
||||
|
||||
# def test_with_annotation3(self):
|
||||
# return self.the_test_with_annotations(WithAnno3())
|
||||
def test_with_annotation3(self):
|
||||
return self.the_test_with_annotations(WithAnno3())
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# # unittest.main()
|
||||
# InitMaskTestCase().test_with_annotation1()
|
||||
if __name__ == '__main__':
|
||||
# unittest.main()
|
||||
InitMaskTestCase().test_with_annotation1()
|
||||
|
|
Загрузка…
Ссылка в новой задаче