зеркало из https://github.com/microsoft/nni.git
[Bugbash] bugfix & add support for pytorch 2.0 (#5484)
This commit is contained in:
Родитель
16595a9191
Коммит
02fd013962
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import collections
|
||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
@ -27,6 +28,19 @@ from torch.fx.graph import Graph
|
||||||
from torch.fx.node import Target, Node
|
from torch.fx.node import Target, Node
|
||||||
from torch.fx.proxy import TracerBase
|
from torch.fx.proxy import TracerBase
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Scope is a new class to record module path in pytorch 2.0
|
||||||
|
from torch.fx.proxy import Scope
|
||||||
|
except ImportError:
|
||||||
|
# copy from pytorch 2.0
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
class Scope:
|
||||||
|
def __init__(self, module_path: str, module_type: Any):
|
||||||
|
super().__init__()
|
||||||
|
self.module_path = module_path
|
||||||
|
self.module_type = module_type
|
||||||
|
|
||||||
|
|
||||||
from . import concrete_proxy as ep
|
from . import concrete_proxy as ep
|
||||||
from .operator_patcher import OperatorPatcherContext
|
from .operator_patcher import OperatorPatcherContext
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
@ -190,6 +204,9 @@ class ConcreteTracer(TracerBase):
|
||||||
remove the 'param_shapes_constant' because we can get real shape when executing.
|
remove the 'param_shapes_constant' because we can get real shape when executing.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.scope = Scope("", None)
|
||||||
|
self.module_stack = collections.OrderedDict()
|
||||||
|
self.node_name_to_scope = {}
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def do_temp_disable(self, call=False, attr=False, agfunc_apply=False):
|
def do_temp_disable(self, call=False, attr=False, agfunc_apply=False):
|
||||||
|
|
|
@ -243,7 +243,7 @@ class NoMaskUpdater(DefaultMaskUpdater):
|
||||||
return True
|
return True
|
||||||
elif node.op == 'call_method':
|
elif node.op == 'call_method':
|
||||||
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
|
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
|
||||||
if node.target in ('dim', 'size', 'clone', 'detach'):
|
if node.target in ('dim', 'size'):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -434,6 +434,10 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
|
||||||
module: torch.nn.Module = model_speedup.fetch_attr(node.target)
|
module: torch.nn.Module = model_speedup.fetch_attr(node.target)
|
||||||
if isinstance(module, self.no_change_act_module):
|
if isinstance(module, self.no_change_act_module):
|
||||||
return self.direct_activation, self.indirect_activation
|
return self.direct_activation, self.indirect_activation
|
||||||
|
elif node.op == 'call_method':
|
||||||
|
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
|
||||||
|
if node.target in ('clone', 'detach'):
|
||||||
|
return self.direct_activation, self.indirect_activation
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node):
|
def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node):
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
import logging
|
import logging
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal
|
from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal
|
||||||
|
@ -438,7 +439,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
|
||||||
) -> Tuple[Dict[str, ModuleWrapper], Dict[str, Dict[str, TargetSpace]]]:
|
) -> Tuple[Dict[str, ModuleWrapper], Dict[str, Dict[str, TargetSpace]]]:
|
||||||
assert mode in ['pruning', 'quantization', 'distillation']
|
assert mode in ['pruning', 'quantization', 'distillation']
|
||||||
|
|
||||||
configured_target_spaces = {}
|
configured_target_spaces = defaultdict(dict)
|
||||||
existed_wrappers = existed_wrappers if existed_wrappers else {}
|
existed_wrappers = existed_wrappers if existed_wrappers else {}
|
||||||
module_wrappers = {k: v for k, v in existed_wrappers.items()}
|
module_wrappers = {k: v for k, v in existed_wrappers.items()}
|
||||||
identity_module_set = set()
|
identity_module_set = set()
|
||||||
|
@ -459,7 +460,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
|
||||||
wrapper, target_spaces = create_module_wrapper(model, module, module_name, mode, public_config, \
|
wrapper, target_spaces = create_module_wrapper(model, module, module_name, mode, public_config, \
|
||||||
old_wrapper, list(fused_modules_pair))
|
old_wrapper, list(fused_modules_pair))
|
||||||
module_wrappers[module_name] = wrapper
|
module_wrappers[module_name] = wrapper
|
||||||
configured_target_spaces[module_name] = target_spaces
|
configured_target_spaces[module_name].update(target_spaces)
|
||||||
if len(fuse_module_names) > 0:
|
if len(fuse_module_names) > 0:
|
||||||
raise ValueError(f'{fuse_module_names} can\'t be fused with {modules.keys()}')
|
raise ValueError(f'{fuse_module_names} can\'t be fused with {modules.keys()}')
|
||||||
if module_set.intersection(identity_module_set):
|
if module_set.intersection(identity_module_set):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче