зеркало из https://github.com/microsoft/nni.git
[Bugbash] bugfix & add support for pytorch 2.0 (#5484)
This commit is contained in:
Родитель
16595a9191
Коммит
02fd013962
|
@ -3,6 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import sys
|
||||
import inspect
|
||||
import logging
|
||||
|
@ -27,6 +28,19 @@ from torch.fx.graph import Graph
|
|||
from torch.fx.node import Target, Node
|
||||
from torch.fx.proxy import TracerBase
|
||||
|
||||
try:
|
||||
# Scope is a new class to record module path in pytorch 2.0
|
||||
from torch.fx.proxy import Scope
|
||||
except ImportError:
|
||||
# copy from pytorch 2.0
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class Scope:
|
||||
def __init__(self, module_path: str, module_type: Any):
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
self.module_type = module_type
|
||||
|
||||
|
||||
from . import concrete_proxy as ep
|
||||
from .operator_patcher import OperatorPatcherContext
|
||||
from .utils import (
|
||||
|
@ -190,6 +204,9 @@ class ConcreteTracer(TracerBase):
|
|||
remove the 'param_shapes_constant' because we can get real shape when executing.
|
||||
"""
|
||||
super().__init__()
|
||||
self.scope = Scope("", None)
|
||||
self.module_stack = collections.OrderedDict()
|
||||
self.node_name_to_scope = {}
|
||||
|
||||
@contextmanager
|
||||
def do_temp_disable(self, call=False, attr=False, agfunc_apply=False):
|
||||
|
|
|
@ -243,7 +243,7 @@ class NoMaskUpdater(DefaultMaskUpdater):
|
|||
return True
|
||||
elif node.op == 'call_method':
|
||||
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
|
||||
if node.target in ('dim', 'size', 'clone', 'detach'):
|
||||
if node.target in ('dim', 'size'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -434,6 +434,10 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
|
|||
module: torch.nn.Module = model_speedup.fetch_attr(node.target)
|
||||
if isinstance(module, self.no_change_act_module):
|
||||
return self.direct_activation, self.indirect_activation
|
||||
elif node.op == 'call_method':
|
||||
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
|
||||
if node.target in ('clone', 'detach'):
|
||||
return self.direct_activation, self.indirect_activation
|
||||
return None
|
||||
|
||||
def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node):
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal
|
||||
|
@ -438,7 +439,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
|
|||
) -> Tuple[Dict[str, ModuleWrapper], Dict[str, Dict[str, TargetSpace]]]:
|
||||
assert mode in ['pruning', 'quantization', 'distillation']
|
||||
|
||||
configured_target_spaces = {}
|
||||
configured_target_spaces = defaultdict(dict)
|
||||
existed_wrappers = existed_wrappers if existed_wrappers else {}
|
||||
module_wrappers = {k: v for k, v in existed_wrappers.items()}
|
||||
identity_module_set = set()
|
||||
|
@ -459,7 +460,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
|
|||
wrapper, target_spaces = create_module_wrapper(model, module, module_name, mode, public_config, \
|
||||
old_wrapper, list(fused_modules_pair))
|
||||
module_wrappers[module_name] = wrapper
|
||||
configured_target_spaces[module_name] = target_spaces
|
||||
configured_target_spaces[module_name].update(target_spaces)
|
||||
if len(fuse_module_names) > 0:
|
||||
raise ValueError(f'{fuse_module_names} can\'t be fused with {modules.keys()}')
|
||||
if module_set.intersection(identity_module_set):
|
||||
|
|
Загрузка…
Ссылка в новой задаче