[Common] Fix some discovered issues during test. (#5487)

This commit is contained in:
Super Daniel 2023-03-30 10:34:51 +08:00 коммит произвёл GitHub
Родитель dd418617b6
Коммит ce6b2e8fc9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 44 добавлений и 8 удалений

Просмотреть файл

@ -53,6 +53,7 @@ from .utils import (
_orig_type,
_orig_isinstance,
_orig_issubclass,
_orig_getattr,
_orig_range,
@ -67,6 +68,8 @@ from .utils import (
_orig_zip,
_orig_enumerate,
_orig_slice,
_orig_reversed,
_orig_torch_size,
_orig_len,
_orig_not,
@ -74,6 +77,10 @@ from .utils import (
_orig_is_not,
_orig_contains,
_orig_index,
_orig_all,
_orig_min,
_orig_max,
)
@ -101,6 +108,9 @@ class ConcreteTracer(TracerBase):
_orig_is_not: ([], False, None),
_orig_contains: ([], False, None),
_orig_index: ([], False, None),
_orig_all: ((), False, None),
_orig_min: ((), False, None),
_orig_max: ((), False, None),
# force-traced function (the factory functions of tensor creation)
torch.arange: ([], True, None),
@ -191,6 +201,9 @@ class ConcreteTracer(TracerBase):
_orig_set: ([], True),
_orig_frozenset: ([], True),
_orig_dict: ([], True),
_orig_reversed: ((), False),
_orig_torch_size: ((), False),
}
# add these to record module path information during tracing
@ -534,7 +547,7 @@ class ConcreteTracer(TracerBase):
concrete_args: Union[Dict[str, Any], Tuple],
use_operator_patch: bool = True,
operator_patch_backlist: List[str] | None = None,
forwrad_function_name: str = 'forward') -> Graph:
forward_function_name: str = 'forward') -> Graph:
"""
similar to _symbolic_trace.Tracer.trace
different args:
@ -596,10 +609,10 @@ class ConcreteTracer(TracerBase):
# TODO: better infomation
assert hasattr(
root, forwrad_function_name
), f"traced_func_name={forwrad_function_name} doesn't exist in {_orig_type(root).__name__}"
root, forward_function_name
), f"traced_func_name={forward_function_name} doesn't exist in {_orig_type(root).__name__}"
fn = getattr(root, forwrad_function_name)
fn = getattr(root, forward_function_name)
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
else:
self.root = torch.nn.Module()
@ -908,6 +921,20 @@ class ConcreteTracer(TracerBase):
instance = instance.value
return _orig_isinstance(instance, clz)
@functools.wraps(_orig_issubclass)
def issubclass_wrapper(subclass, clz):
if _orig_type(clz) in (slice, tuple, list, _orig_slice, _orig_tuple, _orig_list):
clz_wrapped = []
for wrapped_type, orig_type in self.clz_wrapper_map.items():
if wrapped_type in clz:
clz_wrapped.append(orig_type)
clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map))
return _orig_issubclass(subclass, clz)
else:
if clz in self.clz_wrapper_map:
clz = self.clz_wrapper_map[clz]
return _orig_issubclass(subclass, clz)
@functools.wraps(_orig_getattr)
def getattr_wrapper(obj, *args):
# TODO: better infomation
@ -939,6 +966,7 @@ class ConcreteTracer(TracerBase):
self.patcher.patch_method(builtins, "range", range_wrapper, deduplicate=False)
self.patcher.patch_method(builtins, "type", type_wrapper, deduplicate=False)
self.patcher.patch_method(builtins, "isinstance", isinstance_wrapper, deduplicate=False)
self.patcher.patch_method(builtins, "issubclass", issubclass_wrapper, deduplicate=False)
self.patcher.patch_method(builtins, "getattr", getattr_wrapper, deduplicate=False)
for obj, (positions, wrapped) in self.wrapped_leaf.items():
@ -1307,7 +1335,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
*,
use_operator_patch: bool = True,
operator_patch_backlist: List[str] | None = None,
forwrad_function_name: str = 'forward',
forward_function_name: str = 'forward',
check_args: Optional[Dict[str, Any]] = None,
autowrap_leaf_function = None,
autowrap_leaf_class = None,
@ -1449,7 +1477,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
concrete_args=concrete_args,
use_operator_patch=use_operator_patch,
operator_patch_backlist=operator_patch_backlist,
forwrad_function_name=forwrad_function_name,
forward_function_name=forward_function_name,
)
graph_check = tracer.trace(root,
autowrap_leaf_function = autowrap_leaf_function,
@ -1459,7 +1487,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
concrete_args=concrete_args,
use_operator_patch=use_operator_patch,
operator_patch_backlist=operator_patch_backlist,
forwrad_function_name=forwrad_function_name,
forward_function_name=forward_function_name,
)
# compare to check equal
assert len(graph.nodes) == len(graph_check.nodes)

Просмотреть файл

@ -18,6 +18,7 @@ _orig_torch_assert: Callable = torch._assert
_orig_type: Callable = builtins.type
_orig_isinstance: Callable = builtins.isinstance
_orig_issubclass: Callable = builtins.issubclass
_orig_getattr: Callable = builtins.getattr
_orig_range: Type[Any] = builtins.range
@ -32,6 +33,8 @@ _orig_map: Type[Any] = builtins.map
_orig_zip: Type[Any] = builtins.zip
_orig_enumerate: Type[Any] = builtins.enumerate
_orig_slice: Type[Any] = builtins.slice
_orig_reversed: Type[Any] = builtins.reversed
_orig_torch_size: Type[Any] = torch.Size
_orig_len: Callable = builtins.len
_orig_not: Callable = operator.not_
@ -40,6 +43,11 @@ _orig_is_not: Callable = operator.is_not
_orig_contains: Callable = operator.contains
_orig_index: Callable = operator.index
_orig_all: Callable = builtins.all
_orig_min: Callable = builtins.min
_orig_max: Callable = builtins.max
def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None):
def helper(fn):
if return_orig:

Просмотреть файл

@ -239,7 +239,7 @@ def test_mmdetection(config_file: str):
traced_model = concrete_trace(model, {'img': img_tensor},
use_operator_patch=False,
forwrad_function_name='forward_dummy',
forward_function_name='forward_dummy',
autowrap_leaf_function = {
**ConcreteTracer.default_autowrap_leaf_function,
all: ((), False, None),