зеркало из https://github.com/microsoft/nni.git
[Bugfix] fix pruning speedup value inplace change issue (#5534)
This commit is contained in:
Родитель
0ea9459026
Коммит
47fde3da21
|
@ -14,7 +14,7 @@ from torch.nn import functional as F
|
|||
from torch.fx.node import Node
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from .utils import randomize_tensor_inplace, randomize_if_tensor, tree_map_zip, torch_float_dtype
|
||||
from .utils import randomize_tensor_inplace, randomize_if_tensor, tree_map_zip, torch_float_dtype, poss_deepcopy
|
||||
|
||||
|
||||
class MaskUpdater:
|
||||
|
@ -154,8 +154,8 @@ class DefaultMaskUpdater(MaskUpdater):
|
|||
|
||||
# Some operator may have the in_place operations, so we need to clone the input
|
||||
# before passing to the model_speedup.module
|
||||
args_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else t, args)
|
||||
kwargs_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else t, kwargs)
|
||||
args_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else poss_deepcopy(t), args)
|
||||
kwargs_cloned = tree_map_zip(lambda t: t.clone() if isinstance(t, torch.Tensor) else poss_deepcopy(t), kwargs)
|
||||
|
||||
output = getattr(model_speedup, node.op)(node.target, args_cloned, kwargs_cloned)
|
||||
|
||||
|
@ -372,7 +372,7 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
|
|||
input_node = node.kwargs['input']
|
||||
input_mask = model_speedup.node_infos[input_node].output_masks
|
||||
model_speedup.node_infos[node].output_masks = \
|
||||
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else t, input_mask)
|
||||
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else poss_deepcopy(t), input_mask)
|
||||
|
||||
def indirect_activation(self, model_speedup: 'ModelSpeedup', node: Node):
|
||||
if len(node.args) != 0:
|
||||
|
@ -393,7 +393,7 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
|
|||
sub_mask = operator.getitem(arg_0_masks, arg_1_val)
|
||||
|
||||
model_speedup.node_infos[node].output_masks = \
|
||||
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else t, sub_mask)
|
||||
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else poss_deepcopy(t), sub_mask)
|
||||
|
||||
def indirect_getitem(self, model_speedup: 'ModelSpeedup', node: Node):
|
||||
assert len(node.args) == 2
|
||||
|
|
|
@ -28,7 +28,7 @@ from .mask_updater import (MaskUpdater,
|
|||
NoMaskUpdater,
|
||||
NoChangeMaskUpdater)
|
||||
from .replacer import Replacer, DefaultReplacer
|
||||
from .utils import tree_map_zip
|
||||
from .utils import tree_map_zip, poss_deepcopy
|
||||
|
||||
|
||||
def _normalize_input(dummy_input: Any) -> Any:
|
||||
|
@ -38,6 +38,7 @@ def _normalize_input(dummy_input: Any) -> Any:
|
|||
dummy_input = tuple(dummy_input)
|
||||
return dummy_input
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ModelSpeedup(torch.fx.Interpreter):
|
||||
"""
|
||||
|
@ -224,7 +225,7 @@ class ModelSpeedup(torch.fx.Interpreter):
|
|||
|
||||
self.node_infos[node].output_origin = output
|
||||
self.node_infos[node].output_inplace = \
|
||||
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else t, output)
|
||||
tree_map_zip(lambda t: t.clone().detach() if isinstance(t, torch.Tensor) else poss_deepcopy(t, self.logger), output)
|
||||
self.node_infos[node].output_masks = \
|
||||
tree_map_zip(lambda t: torch.ones_like(t).clone().detach() if isinstance(t, torch.Tensor) else None, output)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -79,3 +80,15 @@ def tree_map_zip(fn: Any, *pytrees):
|
|||
spec_list.append(spec)
|
||||
assert all(len(args) == len(flat_args_list[0]) for args in flat_args_list), 'Inconsistent tree nodes length.'
|
||||
return tree_unflatten([fn(*args) for args in zip(*flat_args_list)], spec_list[0])
|
||||
|
||||
|
||||
def poss_deepcopy(o, logger: logging.Logger = None) -> Any:
|
||||
try:
|
||||
new_o = deepcopy(o)
|
||||
except Exception as e:
|
||||
if logger is not None:
|
||||
logger.warning(str(e))
|
||||
else:
|
||||
print(str(e))
|
||||
new_o = o
|
||||
return new_o
|
||||
|
|
|
@ -62,7 +62,7 @@ def trans_legacy_config_list(config_list: List[Dict[str, Any]]) -> List[Dict[str
|
|||
group_id = None
|
||||
max_sparse_ratio = config.pop('max_sparsity_per_layer', None)
|
||||
if 'sparsity_per_layer' in config or 'sparsity' in config:
|
||||
sparse_ratio = config.pop('sparsity_per_layer', config.pop('sparsity'))
|
||||
sparse_ratio = config.pop('sparsity_per_layer', config.pop('sparsity', None))
|
||||
if 'total_sparsity' in config:
|
||||
sparse_ratio = config.pop('total_sparsity')
|
||||
group_id = group_id_candidate
|
||||
|
|
Загрузка…
Ссылка в новой задаче