[Bugfix] fix pruning speedup value inplace change issue (#5534)

This commit is contained in:
J-shang 2023-05-05 10:33:33 +08:00 коммит произвёл GitHub
Родитель 0ea9459026
Коммит 47fde3da21
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 22 добавлений и 8 удалений

Просмотреть файл

@ -14,7 +14,7 @@ from torch.nn import functional as F
from torch.fx.node import Node
from torch.utils._pytree import tree_flatten, tree_unflatten
from .utils import randomize_tensor_inplace, randomize_if_tensor, tree_map_zip, torch_float_dtype
from .utils import randomize_tensor_inplace, randomize_if_tensor, tree_map_zip, torch_float_dtype, poss_deepcopy
class MaskUpdater:
@ -154,8 +154,8 @@ class DefaultMaskUpdater(MaskUpdater):
# Some operator may have the in_place operations, so we need to clone the input
# before passing to the model_speedup.module
args_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else t, args)
kwargs_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else t, kwargs)
args_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else poss_deepcopy(t), args)
kwargs_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else poss_deepcopy(t), kwargs)
output = getattr(model_speedup, node.op)(node.target, args_cloned, kwargs_cloned)
@ -372,7 +372,7 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
input_node = node.kwargs['input']
input_mask = model_speedup.node_infos[input_node].output_masks
model_speedup.node_infos[node].output_masks = \
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else t, input_mask)
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else poss_deepcopy(t), input_mask)
def indirect_activation(self, model_speedup: 'ModelSpeedup', node: Node):
if len(node.args) != 0:
@ -393,7 +393,7 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
sub_mask = operator.getitem(arg_0_masks, arg_1_val)
model_speedup.node_infos[node].output_masks = \
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else t, sub_mask)
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else poss_deepcopy(t), sub_mask)
def indirect_getitem(self, model_speedup: 'ModelSpeedup', node: Node):
assert len(node.args) == 2

Просмотреть файл

@ -28,7 +28,7 @@ from .mask_updater import (MaskUpdater,
NoMaskUpdater,
NoChangeMaskUpdater)
from .replacer import Replacer, DefaultReplacer
from .utils import tree_map_zip
from .utils import tree_map_zip, poss_deepcopy
def _normalize_input(dummy_input: Any) -> Any:
@ -38,6 +38,7 @@ def _normalize_input(dummy_input: Any) -> Any:
dummy_input = tuple(dummy_input)
return dummy_input
@compatibility(is_backward_compatible=True)
class ModelSpeedup(torch.fx.Interpreter):
"""
@ -224,7 +225,7 @@ class ModelSpeedup(torch.fx.Interpreter):
self.node_infos[node].output_origin = output
self.node_infos[node].output_inplace = \
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else t, output)
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else poss_deepcopy(t, self.logger), output)
self.node_infos[node].output_masks = \
tree_map_zip(lambda t: torch.ones_like(t).clone().detach() if isinstance(t, torch.Tensor) else None, output)

Просмотреть файл

@ -2,6 +2,7 @@
# Licensed under the MIT license.
from copy import deepcopy
import logging
from typing import Any, Dict, List, Tuple
import torch
@ -79,3 +80,15 @@ def tree_map_zip(fn: Any, *pytrees):
spec_list.append(spec)
assert all(len(args) == len(flat_args_list[0]) for args in flat_args_list), 'Inconsistent tree nodes length.'
return tree_unflatten([fn(*args) for args in zip(*flat_args_list)], spec_list[0])
def poss_deepcopy(o, logger: logging.Logger = None) -> Any:
try:
new_o = deepcopy(o)
except Exception as e:
if logger is not None:
logger.warning(str(e))
else:
print(str(e))
new_o = o
return new_o

Просмотреть файл

@ -62,7 +62,7 @@ def trans_legacy_config_list(config_list: List[Dict[str, Any]]) -> List[Dict[str
group_id = None
max_sparse_ratio = config.pop('max_sparsity_per_layer', None)
if 'sparsity_per_layer' in config or 'sparsity' in config:
sparse_ratio = config.pop('sparsity_per_layer', config.pop('sparsity'))
sparse_ratio = config.pop('sparsity_per_layer', config.pop('sparsity', None))
if 'total_sparsity' in config:
sparse_ratio = config.pop('total_sparsity')
group_id = group_id_candidate