[Bugbash] bugfix & add support for pytorch 2.0 (#5484)

This commit is contained in:
J-shang 2023-03-29 12:17:01 +08:00 коммит произвёл GitHub
Родитель 16595a9191
Коммит 02fd013962
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 25 добавлений и 3 удалений

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import collections
import sys import sys
import inspect import inspect
import logging import logging
@ -27,6 +28,19 @@ from torch.fx.graph import Graph
from torch.fx.node import Target, Node from torch.fx.node import Target, Node
from torch.fx.proxy import TracerBase from torch.fx.proxy import TracerBase
try:
# Scope is a new class to record module path in pytorch 2.0
from torch.fx.proxy import Scope
except ImportError:
# copy from pytorch 2.0
@compatibility(is_backward_compatible=False)
class Scope:
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
from . import concrete_proxy as ep from . import concrete_proxy as ep
from .operator_patcher import OperatorPatcherContext from .operator_patcher import OperatorPatcherContext
from .utils import ( from .utils import (
@ -190,6 +204,9 @@ class ConcreteTracer(TracerBase):
remove the 'param_shapes_constant' because we can get real shape when executing. remove the 'param_shapes_constant' because we can get real shape when executing.
""" """
super().__init__() super().__init__()
self.scope = Scope("", None)
self.module_stack = collections.OrderedDict()
self.node_name_to_scope = {}
@contextmanager @contextmanager
def do_temp_disable(self, call=False, attr=False, agfunc_apply=False): def do_temp_disable(self, call=False, attr=False, agfunc_apply=False):

Просмотреть файл

@ -243,7 +243,7 @@ class NoMaskUpdater(DefaultMaskUpdater):
return True return True
elif node.op == 'call_method': elif node.op == 'call_method':
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor): if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
if node.target in ('dim', 'size', 'clone', 'detach'): if node.target in ('dim', 'size'):
return True return True
return False return False
@ -434,6 +434,10 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
module: torch.nn.Module = model_speedup.fetch_attr(node.target) module: torch.nn.Module = model_speedup.fetch_attr(node.target)
if isinstance(module, self.no_change_act_module): if isinstance(module, self.no_change_act_module):
return self.direct_activation, self.indirect_activation return self.direct_activation, self.indirect_activation
elif node.op == 'call_method':
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
if node.target in ('clone', 'detach'):
return self.direct_activation, self.indirect_activation
return None return None
def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node): def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node):

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict
import logging import logging
import inspect import inspect
from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal
@ -438,7 +439,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
) -> Tuple[Dict[str, ModuleWrapper], Dict[str, Dict[str, TargetSpace]]]: ) -> Tuple[Dict[str, ModuleWrapper], Dict[str, Dict[str, TargetSpace]]]:
assert mode in ['pruning', 'quantization', 'distillation'] assert mode in ['pruning', 'quantization', 'distillation']
configured_target_spaces = {} configured_target_spaces = defaultdict(dict)
existed_wrappers = existed_wrappers if existed_wrappers else {} existed_wrappers = existed_wrappers if existed_wrappers else {}
module_wrappers = {k: v for k, v in existed_wrappers.items()} module_wrappers = {k: v for k, v in existed_wrappers.items()}
identity_module_set = set() identity_module_set = set()
@ -459,7 +460,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
wrapper, target_spaces = create_module_wrapper(model, module, module_name, mode, public_config, \ wrapper, target_spaces = create_module_wrapper(model, module, module_name, mode, public_config, \
old_wrapper, list(fused_modules_pair)) old_wrapper, list(fused_modules_pair))
module_wrappers[module_name] = wrapper module_wrappers[module_name] = wrapper
configured_target_spaces[module_name] = target_spaces configured_target_spaces[module_name].update(target_spaces)
if len(fuse_module_names) > 0: if len(fuse_module_names) > 0:
raise ValueError(f'{fuse_module_names} can\'t be fused with {modules.keys()}') raise ValueError(f'{fuse_module_names} can\'t be fused with {modules.keys()}')
if module_set.intersection(identity_module_set): if module_set.intersection(identity_module_set):