[Bugbash] bugfix & add support for pytorch 2.0 (#5484)

This commit is contained in:
J-shang 2023-03-29 12:17:01 +08:00 коммит произвёл GitHub
Родитель 16595a9191
Коммит 02fd013962
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 25 добавлений и 3 удалений

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import annotations
import collections
import sys
import inspect
import logging
@ -27,6 +28,19 @@ from torch.fx.graph import Graph
from torch.fx.node import Target, Node
from torch.fx.proxy import TracerBase
try:
# Scope is a new class to record module path in pytorch 2.0
from torch.fx.proxy import Scope
except ImportError:
# copy from pytorch 2.0
@compatibility(is_backward_compatible=False)
class Scope:
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
from . import concrete_proxy as ep
from .operator_patcher import OperatorPatcherContext
from .utils import (
@ -190,6 +204,9 @@ class ConcreteTracer(TracerBase):
remove the 'param_shapes_constant' because we can get real shape when executing.
"""
super().__init__()
self.scope = Scope("", None)
self.module_stack = collections.OrderedDict()
self.node_name_to_scope = {}
@contextmanager
def do_temp_disable(self, call=False, attr=False, agfunc_apply=False):

Просмотреть файл

@ -243,7 +243,7 @@ class NoMaskUpdater(DefaultMaskUpdater):
return True
elif node.op == 'call_method':
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
if node.target in ('dim', 'size', 'clone', 'detach'):
if node.target in ('dim', 'size'):
return True
return False
@ -434,6 +434,10 @@ class NoChangeMaskUpdater(DefaultMaskUpdater):
module: torch.nn.Module = model_speedup.fetch_attr(node.target)
if isinstance(module, self.no_change_act_module):
return self.direct_activation, self.indirect_activation
elif node.op == 'call_method':
if isinstance(node.args[0], Node) and isinstance(model_speedup.node_infos[node.args[0]].output_origin, torch.Tensor):
if node.target in ('clone', 'detach'):
return self.direct_activation, self.indirect_activation
return None
def direct_update_process(self, model_speedup: 'ModelSpeedup', node: Node):

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import annotations
from collections import defaultdict
import logging
import inspect
from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal
@ -438,7 +439,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
) -> Tuple[Dict[str, ModuleWrapper], Dict[str, Dict[str, TargetSpace]]]:
assert mode in ['pruning', 'quantization', 'distillation']
configured_target_spaces = {}
configured_target_spaces = defaultdict(dict)
existed_wrappers = existed_wrappers if existed_wrappers else {}
module_wrappers = {k: v for k, v in existed_wrappers.items()}
identity_module_set = set()
@ -459,7 +460,7 @@ def register_wrappers(model: torch.nn.Module, config_list: List[Dict[str, Any]],
wrapper, target_spaces = create_module_wrapper(model, module, module_name, mode, public_config, \
old_wrapper, list(fused_modules_pair))
module_wrappers[module_name] = wrapper
configured_target_spaces[module_name] = target_spaces
configured_target_spaces[module_name].update(target_spaces)
if len(fuse_module_names) > 0:
raise ValueError(f'{fuse_module_names} can\'t be fused with {modules.keys()}')
if module_set.intersection(identity_module_set):