diff --git a/nni/common/concrete_trace_utils/concrete_tracer.py b/nni/common/concrete_trace_utils/concrete_tracer.py index 2fbdeac80..b6eaddccb 100644 --- a/nni/common/concrete_trace_utils/concrete_tracer.py +++ b/nni/common/concrete_trace_utils/concrete_tracer.py @@ -3,6 +3,7 @@ from __future__ import annotations +import collections import sys import inspect import logging @@ -27,6 +28,19 @@ from torch.fx.graph import Graph from torch.fx.node import Target, Node from torch.fx.proxy import TracerBase +try: + # Scope is a new class to record module path in pytorch 2.0 + from torch.fx.proxy import Scope +except ImportError: + # copy from pytorch 2.0 + @compatibility(is_backward_compatible=False) + class Scope: + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + from . import concrete_proxy as ep from .operator_patcher import OperatorPatcherContext from .utils import ( @@ -190,6 +204,9 @@ class ConcreteTracer(TracerBase): remove the 'param_shapes_constant' because we can get real shape when executing. """ super().__init__() + self.scope = Scope("", None) + self.module_stack = collections.OrderedDict() + self.node_name_to_scope = {} @contextmanager def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): diff --git a/nni/compression/pytorch/speedup/v2/mask_updater.py b/nni/compression/pytorch/speedup/v2/mask_updater.py index af743d5e3..add24d67a 100644 --- a/nni/compression/pytorch/speedup/v2/mask_updater.py +++ b/nni/compression/pytorch/speedup/v2/mask_updater.py @@ -243,7 +243,7 @@ class NoMaskUpdater(DefaultMaskUpdater): return True elif node.op == 'call_method': if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor): - if node.target in ('dim', 'size', 'clone', 'detach'): + if node.target in ('dim', 'size'): return True return False @@ -434,6 +434,10 @@ class NoChangeMaskUpdater(DefaultMaskUpdater): module: torch.nn.Module = model_speedup.fetch_attr(node.target) if isinstance(module, self.no_change_act_module): return self.direct_activation, self.indirect_activation + elif node.op == 'call_method': + if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor): + if node.target in ('clone', 'detach'): + return self.direct_activation, self.indirect_activation return None def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node): diff --git a/nni/contrib/compression/base/wrapper.py b/nni/contrib/compression/base/wrapper.py index 472a23297..06d5faed4 100644 --- a/nni/contrib/compression/base/wrapper.py +++ b/nni/contrib/compression/base/wrapper.py @@ -3,6 +3,7 @@ from __future__ import annotations +from collections import defaultdict import logging import inspect from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal @@ -438,7 +439,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]], ) -> Tuple[Dict[str, ModuleWrapper], Dict[str, Dict[str, TargetSpace]]]: assert mode in ['pruning', 'quantization', 'distillation'] - configured_target_spaces = {} + configured_target_spaces = defaultdict(dict) existed_wrappers = existed_wrappers if existed_wrappers else {} module_wrappers = {k: v for k, v in existed_wrappers.items()} identity_module_set = set() @@ -459,7 +460,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]], wrapper, target_spaces = create_module_wrapper(model, module, module_name, mode, public_config, \ old_wrapper, list(fused_modules_pair)) module_wrappers[module_name] = wrapper - configured_target_spaces[module_name] = target_spaces + configured_target_spaces[module_name].update(target_spaces) if len(fuse_module_names) > 0: raise ValueError(f'{fuse_module_names} can\'t be fused with {modules.keys()}') if module_set.intersection(identity_module_set):