зеркало из https://github.com/microsoft/nni.git
[Common] Fix some discovered issues during test. (#5487)
This commit is contained in:
Родитель
dd418617b6
Коммит
ce6b2e8fc9
|
@ -53,6 +53,7 @@ from .utils import (
|
||||||
|
|
||||||
_orig_type,
|
_orig_type,
|
||||||
_orig_isinstance,
|
_orig_isinstance,
|
||||||
|
_orig_issubclass,
|
||||||
_orig_getattr,
|
_orig_getattr,
|
||||||
|
|
||||||
_orig_range,
|
_orig_range,
|
||||||
|
@ -67,6 +68,8 @@ from .utils import (
|
||||||
_orig_zip,
|
_orig_zip,
|
||||||
_orig_enumerate,
|
_orig_enumerate,
|
||||||
_orig_slice,
|
_orig_slice,
|
||||||
|
_orig_reversed,
|
||||||
|
_orig_torch_size,
|
||||||
|
|
||||||
_orig_len,
|
_orig_len,
|
||||||
_orig_not,
|
_orig_not,
|
||||||
|
@ -74,6 +77,10 @@ from .utils import (
|
||||||
_orig_is_not,
|
_orig_is_not,
|
||||||
_orig_contains,
|
_orig_contains,
|
||||||
_orig_index,
|
_orig_index,
|
||||||
|
|
||||||
|
_orig_all,
|
||||||
|
_orig_min,
|
||||||
|
_orig_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,6 +108,9 @@ class ConcreteTracer(TracerBase):
|
||||||
_orig_is_not: ([], False, None),
|
_orig_is_not: ([], False, None),
|
||||||
_orig_contains: ([], False, None),
|
_orig_contains: ([], False, None),
|
||||||
_orig_index: ([], False, None),
|
_orig_index: ([], False, None),
|
||||||
|
_orig_all: ((), False, None),
|
||||||
|
_orig_min: ((), False, None),
|
||||||
|
_orig_max: ((), False, None),
|
||||||
|
|
||||||
# force-traced function (the factory functions of tensor creation)
|
# force-traced function (the factory functions of tensor creation)
|
||||||
torch.arange: ([], True, None),
|
torch.arange: ([], True, None),
|
||||||
|
@ -191,6 +201,9 @@ class ConcreteTracer(TracerBase):
|
||||||
_orig_set: ([], True),
|
_orig_set: ([], True),
|
||||||
_orig_frozenset: ([], True),
|
_orig_frozenset: ([], True),
|
||||||
_orig_dict: ([], True),
|
_orig_dict: ([], True),
|
||||||
|
_orig_reversed: ((), False),
|
||||||
|
|
||||||
|
_orig_torch_size: ((), False),
|
||||||
}
|
}
|
||||||
|
|
||||||
# add these to record module path information during tracing
|
# add these to record module path information during tracing
|
||||||
|
@ -534,7 +547,7 @@ class ConcreteTracer(TracerBase):
|
||||||
concrete_args: Union[Dict[str, Any], Tuple],
|
concrete_args: Union[Dict[str, Any], Tuple],
|
||||||
use_operator_patch: bool = True,
|
use_operator_patch: bool = True,
|
||||||
operator_patch_backlist: List[str] | None = None,
|
operator_patch_backlist: List[str] | None = None,
|
||||||
forwrad_function_name: str = 'forward') -> Graph:
|
forward_function_name: str = 'forward') -> Graph:
|
||||||
"""
|
"""
|
||||||
similar to _symbolic_trace.Tracer.trace
|
similar to _symbolic_trace.Tracer.trace
|
||||||
different args:
|
different args:
|
||||||
|
@ -596,10 +609,10 @@ class ConcreteTracer(TracerBase):
|
||||||
|
|
||||||
# TODO: better infomation
|
# TODO: better infomation
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
root, forwrad_function_name
|
root, forward_function_name
|
||||||
), f"traced_func_name={forwrad_function_name} doesn't exist in {_orig_type(root).__name__}"
|
), f"traced_func_name={forward_function_name} doesn't exist in {_orig_type(root).__name__}"
|
||||||
|
|
||||||
fn = getattr(root, forwrad_function_name)
|
fn = getattr(root, forward_function_name)
|
||||||
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
|
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
|
||||||
else:
|
else:
|
||||||
self.root = torch.nn.Module()
|
self.root = torch.nn.Module()
|
||||||
|
@ -908,6 +921,20 @@ class ConcreteTracer(TracerBase):
|
||||||
instance = instance.value
|
instance = instance.value
|
||||||
return _orig_isinstance(instance, clz)
|
return _orig_isinstance(instance, clz)
|
||||||
|
|
||||||
|
@functools.wraps(_orig_issubclass)
|
||||||
|
def issubclass_wrapper(subclass, clz):
|
||||||
|
if _orig_type(clz) in (slice, tuple, list, _orig_slice, _orig_tuple, _orig_list):
|
||||||
|
clz_wrapped = []
|
||||||
|
for wrapped_type, orig_type in self.clz_wrapper_map.items():
|
||||||
|
if wrapped_type in clz:
|
||||||
|
clz_wrapped.append(orig_type)
|
||||||
|
clz = (*clz_wrapped, *(aclz for aclz in clz if aclz not in self.clz_wrapper_map))
|
||||||
|
return _orig_issubclass(subclass, clz)
|
||||||
|
else:
|
||||||
|
if clz in self.clz_wrapper_map:
|
||||||
|
clz = self.clz_wrapper_map[clz]
|
||||||
|
return _orig_issubclass(subclass, clz)
|
||||||
|
|
||||||
@functools.wraps(_orig_getattr)
|
@functools.wraps(_orig_getattr)
|
||||||
def getattr_wrapper(obj, *args):
|
def getattr_wrapper(obj, *args):
|
||||||
# TODO: better infomation
|
# TODO: better infomation
|
||||||
|
@ -939,6 +966,7 @@ class ConcreteTracer(TracerBase):
|
||||||
self.patcher.patch_method(builtins, "range", range_wrapper, deduplicate=False)
|
self.patcher.patch_method(builtins, "range", range_wrapper, deduplicate=False)
|
||||||
self.patcher.patch_method(builtins, "type", type_wrapper, deduplicate=False)
|
self.patcher.patch_method(builtins, "type", type_wrapper, deduplicate=False)
|
||||||
self.patcher.patch_method(builtins, "isinstance", isinstance_wrapper, deduplicate=False)
|
self.patcher.patch_method(builtins, "isinstance", isinstance_wrapper, deduplicate=False)
|
||||||
|
self.patcher.patch_method(builtins, "issubclass", issubclass_wrapper, deduplicate=False)
|
||||||
self.patcher.patch_method(builtins, "getattr", getattr_wrapper, deduplicate=False)
|
self.patcher.patch_method(builtins, "getattr", getattr_wrapper, deduplicate=False)
|
||||||
|
|
||||||
for obj, (positions, wrapped) in self.wrapped_leaf.items():
|
for obj, (positions, wrapped) in self.wrapped_leaf.items():
|
||||||
|
@ -1307,7 +1335,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
||||||
*,
|
*,
|
||||||
use_operator_patch: bool = True,
|
use_operator_patch: bool = True,
|
||||||
operator_patch_backlist: List[str] | None = None,
|
operator_patch_backlist: List[str] | None = None,
|
||||||
forwrad_function_name: str = 'forward',
|
forward_function_name: str = 'forward',
|
||||||
check_args: Optional[Dict[str, Any]] = None,
|
check_args: Optional[Dict[str, Any]] = None,
|
||||||
autowrap_leaf_function = None,
|
autowrap_leaf_function = None,
|
||||||
autowrap_leaf_class = None,
|
autowrap_leaf_class = None,
|
||||||
|
@ -1449,7 +1477,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
||||||
concrete_args=concrete_args,
|
concrete_args=concrete_args,
|
||||||
use_operator_patch=use_operator_patch,
|
use_operator_patch=use_operator_patch,
|
||||||
operator_patch_backlist=operator_patch_backlist,
|
operator_patch_backlist=operator_patch_backlist,
|
||||||
forwrad_function_name=forwrad_function_name,
|
forward_function_name=forward_function_name,
|
||||||
)
|
)
|
||||||
graph_check = tracer.trace(root,
|
graph_check = tracer.trace(root,
|
||||||
autowrap_leaf_function = autowrap_leaf_function,
|
autowrap_leaf_function = autowrap_leaf_function,
|
||||||
|
@ -1459,7 +1487,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
||||||
concrete_args=concrete_args,
|
concrete_args=concrete_args,
|
||||||
use_operator_patch=use_operator_patch,
|
use_operator_patch=use_operator_patch,
|
||||||
operator_patch_backlist=operator_patch_backlist,
|
operator_patch_backlist=operator_patch_backlist,
|
||||||
forwrad_function_name=forwrad_function_name,
|
forward_function_name=forward_function_name,
|
||||||
)
|
)
|
||||||
# compare to check equal
|
# compare to check equal
|
||||||
assert len(graph.nodes) == len(graph_check.nodes)
|
assert len(graph.nodes) == len(graph_check.nodes)
|
||||||
|
|
|
@ -18,6 +18,7 @@ _orig_torch_assert: Callable = torch._assert
|
||||||
|
|
||||||
_orig_type: Callable = builtins.type
|
_orig_type: Callable = builtins.type
|
||||||
_orig_isinstance: Callable = builtins.isinstance
|
_orig_isinstance: Callable = builtins.isinstance
|
||||||
|
_orig_issubclass: Callable = builtins.issubclass
|
||||||
_orig_getattr: Callable = builtins.getattr
|
_orig_getattr: Callable = builtins.getattr
|
||||||
|
|
||||||
_orig_range: Type[Any] = builtins.range
|
_orig_range: Type[Any] = builtins.range
|
||||||
|
@ -32,6 +33,8 @@ _orig_map: Type[Any] = builtins.map
|
||||||
_orig_zip: Type[Any] = builtins.zip
|
_orig_zip: Type[Any] = builtins.zip
|
||||||
_orig_enumerate: Type[Any] = builtins.enumerate
|
_orig_enumerate: Type[Any] = builtins.enumerate
|
||||||
_orig_slice: Type[Any] = builtins.slice
|
_orig_slice: Type[Any] = builtins.slice
|
||||||
|
_orig_reversed: Type[Any] = builtins.reversed
|
||||||
|
_orig_torch_size: Type[Any] = torch.Size
|
||||||
|
|
||||||
_orig_len: Callable = builtins.len
|
_orig_len: Callable = builtins.len
|
||||||
_orig_not: Callable = operator.not_
|
_orig_not: Callable = operator.not_
|
||||||
|
@ -40,6 +43,11 @@ _orig_is_not: Callable = operator.is_not
|
||||||
_orig_contains: Callable = operator.contains
|
_orig_contains: Callable = operator.contains
|
||||||
_orig_index: Callable = operator.index
|
_orig_index: Callable = operator.index
|
||||||
|
|
||||||
|
_orig_all: Callable = builtins.all
|
||||||
|
_orig_min: Callable = builtins.min
|
||||||
|
_orig_max: Callable = builtins.max
|
||||||
|
|
||||||
|
|
||||||
def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None):
|
def run_onlyif_instance(cond_type: Type[Any], return_orig: bool = True, return_const: Any = None):
|
||||||
def helper(fn):
|
def helper(fn):
|
||||||
if return_orig:
|
if return_orig:
|
||||||
|
|
|
@ -239,7 +239,7 @@ def test_mmdetection(config_file: str):
|
||||||
|
|
||||||
traced_model = concrete_trace(model, {'img': img_tensor},
|
traced_model = concrete_trace(model, {'img': img_tensor},
|
||||||
use_operator_patch=False,
|
use_operator_patch=False,
|
||||||
forwrad_function_name='forward_dummy',
|
forward_function_name='forward_dummy',
|
||||||
autowrap_leaf_function = {
|
autowrap_leaf_function = {
|
||||||
**ConcreteTracer.default_autowrap_leaf_function,
|
**ConcreteTracer.default_autowrap_leaf_function,
|
||||||
all: ((), False, None),
|
all: ((), False, None),
|
||||||
|
|
Загрузка…
Ссылка в новой задаче