зеркало из https://github.com/microsoft/nni.git
[common] cpu/gpu mix trace fix (#5583)
This commit is contained in:
Родитель
1605cd0c53
Коммит
aec1fec127
|
@ -19,6 +19,7 @@ from contextlib import contextmanager
|
|||
import torch
|
||||
from torch._C import ScriptObject
|
||||
from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
|
@ -86,6 +87,7 @@ from .utils import (
|
|||
)
|
||||
|
||||
|
||||
# pyright: reportGeneralTypeIssues=false
|
||||
_logger = logging.getLogger(__name__)
|
||||
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
|
||||
|
||||
|
@ -214,7 +216,7 @@ class ConcreteTracer(TracerBase):
|
|||
node_to_originating_module : Dict[torch.fx.Node, str] = {}
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def __init__(self):
|
||||
def __init__(self, cpu_offload = False):
|
||||
"""
|
||||
similar to _symbolic_trace.Tracer.__init__.
|
||||
remove the 'param_shapes_constant' because we can get real shape when executing.
|
||||
|
@ -223,6 +225,7 @@ class ConcreteTracer(TracerBase):
|
|||
self.scope = Scope("", None)
|
||||
self.module_stack = collections.OrderedDict()
|
||||
self.node_name_to_scope = {}
|
||||
self.cpu_offload = cpu_offload
|
||||
|
||||
@contextmanager
|
||||
def do_temp_disable(self, call=False, attr=False, agfunc_apply=False):
|
||||
|
@ -275,36 +278,65 @@ class ConcreteTracer(TracerBase):
|
|||
actually execute the code.
|
||||
apply the patcher, and the _autowrap_check to the target function.
|
||||
"""
|
||||
if kind == 'call_function':
|
||||
assert isinstance(target, Callable)
|
||||
fn = target
|
||||
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'):
|
||||
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
|
||||
with self.do_temp_disable(call=True):
|
||||
return OperatorPatcherContext.patch_run(fn, *args, **kwargs)
|
||||
elif kind == 'call_method':
|
||||
self_obj, *args_tail = args
|
||||
fn = _orig_getattr(self_obj, target)
|
||||
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'):
|
||||
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
|
||||
with self.do_temp_disable(call=True):
|
||||
return OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs)
|
||||
elif kind == 'call_module':
|
||||
assert isinstance(target, str)
|
||||
mod = self.fetch_attr(target)
|
||||
if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(mod, '__globals__'):
|
||||
_autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
|
||||
with self.do_temp_disable(call=True):
|
||||
return OperatorPatcherContext.patch_run(mod, *args, **kwargs)
|
||||
elif kind == 'get_attr':
|
||||
assert isinstance(target, str)
|
||||
return self.fetch_attr(target)
|
||||
elif kind == 'output':
|
||||
if kind == 'output':
|
||||
return args[0]
|
||||
elif kind == 'placeholder':
|
||||
return self.placeholder_dict[target]
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t
|
||||
to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t
|
||||
|
||||
def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]):
|
||||
if self.cpu_offload:
|
||||
args = tree_map(to_cuda, args)
|
||||
kwargs = tree_map(to_cuda, kwargs)
|
||||
|
||||
if kind == 'call_function':
|
||||
assert isinstance(target, Callable)
|
||||
fn = target
|
||||
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \
|
||||
and hasattr(fn, '__globals__'):
|
||||
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
|
||||
return OperatorPatcherContext.patch_run(fn, *args, **kwargs)
|
||||
elif kind == 'call_method':
|
||||
self_obj, *args_tail = args
|
||||
fn = _orig_getattr(self_obj, target)
|
||||
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \
|
||||
and hasattr(fn, '__globals__'):
|
||||
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
|
||||
result = fn(*args_tail, **kwargs)
|
||||
elif kind == 'call_module':
|
||||
assert isinstance(target, str)
|
||||
mod = self.fetch_attr(target)
|
||||
if self.cpu_offload:
|
||||
mod.cuda() # how it works in ddp?
|
||||
if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \
|
||||
and hasattr(mod, '__globals__'):
|
||||
_autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
|
||||
result = OperatorPatcherContext.patch_run(mod, *args, **kwargs)
|
||||
if self.cpu_offload:
|
||||
mod.cpu()
|
||||
elif kind == 'get_attr':
|
||||
assert isinstance(target, str)
|
||||
return self.fetch_attr(target)
|
||||
else:
|
||||
raise RuntimeError()
|
||||
return result
|
||||
|
||||
with self.do_temp_disable(call=True):
|
||||
result = run(kind, target, args, kwargs)
|
||||
if self.cpu_offload:
|
||||
if isinstance(result, torch.Tensor):
|
||||
result = result.cpu()
|
||||
elif isinstance(result, (list, dict, tuple)):
|
||||
result = tree_map(to_cpu, result)
|
||||
else:
|
||||
_logger.warning(f"result of target {target} is {type(result)}, which is not a common behavior.")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.temp_disable_call = False
|
||||
return result
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def proxy(self, value: Any, node: Node) -> ep.ConcreteProxy:
|
||||
|
@ -315,8 +347,8 @@ class ConcreteTracer(TracerBase):
|
|||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None, type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Optional[Callable[[Node], Any]] = None):
|
||||
name: Optional[str] = None, type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Optional[Callable[[Node], Any]] = None):
|
||||
"""
|
||||
similar to _symbolic_trace.Tracer.create_proxy.
|
||||
use the 'run_target' to actually execute the code, and store the value in 'value' field.
|
||||
|
@ -502,10 +534,10 @@ class ConcreteTracer(TracerBase):
|
|||
cnt = 0
|
||||
self.placeholder_dict = {}
|
||||
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
|
||||
diff_len = len(arg_names) - len(default_value_list)
|
||||
diff_len = _orig_len(arg_names) - _orig_len(default_value_list)
|
||||
default_args = {arg_names[idx + diff_len]: default_value_list[idx] for idx in range(len(default_value_list))}
|
||||
if isinstance(concrete_args, tuple):
|
||||
if len(arg_names) != len(concrete_args):
|
||||
if _orig_len(arg_names) != _orig_len(concrete_args):
|
||||
raise RuntimeError(f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments")
|
||||
concrete_args = {name: val for name, val in zip(arg_names, concrete_args)}
|
||||
def proxy_placeholder(name: str):
|
||||
|
@ -671,7 +703,7 @@ class ConcreteTracer(TracerBase):
|
|||
return _orig_module_getattribute(mod, attr)
|
||||
except AttributeError:
|
||||
return _orig_module_getattr(mod, attr)
|
||||
with self.do_temp_disable(call=True, attr=True):
|
||||
with self.do_temp_disable(attr=True):
|
||||
try:
|
||||
attr_val = _orig_module_getattribute(mod, attr)
|
||||
except AttributeError:
|
||||
|
@ -992,6 +1024,9 @@ class ConcreteTracer(TracerBase):
|
|||
pass
|
||||
|
||||
self.submodule_paths = None
|
||||
with MagicMethodPatcher():
|
||||
GraphModule(self.root, self.graph) # assign graph.owning_module
|
||||
self.graph.eliminate_dead_code()
|
||||
return self.graph
|
||||
|
||||
# List of pairs of (global dict, function name) functions
|
||||
|
@ -1318,13 +1353,13 @@ def _retain_weight_consistency(root: torch.nn.Module):
|
|||
for module in root.modules():
|
||||
for name, param in module.named_parameters():
|
||||
if _orig_isinstance(param, ep.ConcreteProxy):
|
||||
param: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false
|
||||
param: ep.ConcreteProxy
|
||||
_logger.warning(f'Parameter {name} of {module} is a ConcreteProxy. Some weight may be modified inplace within forward().')
|
||||
setattr(module, name, param.value)
|
||||
_flag |= 1
|
||||
for name, buffer in module.named_buffers():
|
||||
if _orig_isinstance(buffer, ep.ConcreteProxy):
|
||||
buffer: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false
|
||||
buffer: ep.ConcreteProxy
|
||||
_logger.warning(f'Buffer {name} of {module} is a ConcreteProxy. Some buffer may be modified inplace within forward().')
|
||||
setattr(module, name, buffer.value)
|
||||
_flag |= 1
|
||||
|
@ -1344,7 +1379,9 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
|||
autowrap_leaf_class = None,
|
||||
leaf_module: Tuple | None = None,
|
||||
fake_middle_class = None,
|
||||
dce = False) -> GraphModule:
|
||||
dce = False,
|
||||
cpu_offload = False,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Concrete tracing API
|
||||
|
||||
|
@ -1467,33 +1504,39 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
|||
The struct of dict is: leaf_class: ([(module_path, module_name)], is_iterator_class).
|
||||
is_iterator_class: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True.
|
||||
|
||||
cpu_offload (bool): Whether to offload the module to CPU during tracing. If set to True, the traced code will be executed on GPU,
|
||||
but is offloaded to CPU afterward. This is useful for reducing memory usage during tracing, but may cause performance issues.
|
||||
If set to False, there will be no offloading during tracing, but the traced code will be executed on default device.
|
||||
|
||||
Returns:
|
||||
fx.GraphModule: a Module created from the recorded operations from ``root``.
|
||||
"""
|
||||
tracer = ConcreteTracer()
|
||||
tracer = ConcreteTracer(cpu_offload = cpu_offload)
|
||||
is_training = root.training
|
||||
root.eval()
|
||||
|
||||
graph = tracer.trace(root,
|
||||
autowrap_leaf_function = autowrap_leaf_function,
|
||||
autowrap_leaf_class = autowrap_leaf_class,
|
||||
leaf_module = leaf_module,
|
||||
fake_middle_class = fake_middle_class,
|
||||
concrete_args=concrete_args,
|
||||
use_operator_patch=use_operator_patch,
|
||||
operator_patch_backlist=operator_patch_backlist,
|
||||
forward_function_name=forward_function_name,
|
||||
concrete_args = concrete_args,
|
||||
use_operator_patch = use_operator_patch,
|
||||
operator_patch_backlist = operator_patch_backlist,
|
||||
forward_function_name = forward_function_name,
|
||||
)
|
||||
graph_check = tracer.trace(root,
|
||||
autowrap_leaf_function = autowrap_leaf_function,
|
||||
autowrap_leaf_class = autowrap_leaf_class,
|
||||
leaf_module = leaf_module,
|
||||
fake_middle_class = fake_middle_class,
|
||||
concrete_args=concrete_args,
|
||||
use_operator_patch=use_operator_patch,
|
||||
operator_patch_backlist=operator_patch_backlist,
|
||||
forward_function_name=forward_function_name,
|
||||
concrete_args = concrete_args,
|
||||
use_operator_patch = use_operator_patch,
|
||||
operator_patch_backlist = operator_patch_backlist,
|
||||
forward_function_name = forward_function_name,
|
||||
)
|
||||
# compare to check equal
|
||||
assert len(graph.nodes) == len(graph_check.nodes)
|
||||
assert len(graph.nodes) == len(graph_check.nodes), f'number nodes: {len(graph.nodes)} vs {len(graph_check.nodes)}'
|
||||
for node_a, node_b in zip(graph.nodes, graph_check.nodes):
|
||||
node_a: Node
|
||||
node_b: Node
|
||||
|
@ -1507,14 +1550,13 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
|||
assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\
|
||||
hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function)
|
||||
else:
|
||||
assert node_a.op == node_b.op and target_a == target_b
|
||||
assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}'
|
||||
|
||||
with MagicMethodPatcher():
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
traced = GraphModule(tracer.root, graph, name)
|
||||
|
||||
# TODO: better infomation
|
||||
# # assert root(**concrete_args) == traced(**concrete_args)
|
||||
if check_args is not None:
|
||||
assert root(**check_args) == traced(**check_args)
|
||||
|
||||
|
@ -1545,4 +1587,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
|||
recursively_check_node(node)
|
||||
traced.recompile()
|
||||
|
||||
if is_training:
|
||||
root.train()
|
||||
|
||||
return traced
|
||||
|
|
|
@ -52,7 +52,7 @@ class TransformerOp(ast.NodeTransformer):
|
|||
return super().visit(node)
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'super' and len(node.args) == 0:
|
||||
if isinstance(node.func, ast.Name) and node.func.id == 'super' and _orig_len(node.args) == 0:
|
||||
return self.generic_visit(ast.Call(
|
||||
func=ast.Name(id='super', ctx=ast.Load()),
|
||||
args=[
|
||||
|
@ -173,6 +173,8 @@ class OperatorPatcher:
|
|||
self.function_cache_orig: Dict[int, Callable] = {}
|
||||
|
||||
def patch_inner(self, func):
|
||||
if _orig_isinstance(func, torch.nn.Module):
|
||||
return self.patch_inner_helper(func) # better not cache this
|
||||
if id(func) not in self.function_cache:
|
||||
self.function_cache[id(func)] = self.patch_inner_helper(func)
|
||||
self.function_cache_orig[id(func)] = func
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
from nni.common.concrete_trace_utils import concrete_trace
|
||||
|
||||
model_list = [
|
||||
models.alexnet,
|
||||
models.convnext_base,
|
||||
models.densenet121,
|
||||
models.efficientnet_b0,
|
||||
models.mobilenet_v2,
|
||||
models.resnet18,
|
||||
models.resnext50_32x4d,
|
||||
models.vit_b_16,
|
||||
]
|
||||
|
||||
|
||||
def check_equal(a, b):
|
||||
if type(a) != type(b):
|
||||
return False
|
||||
if isinstance(a, (list, tuple, set)):
|
||||
if len(a) != len(b):
|
||||
return False
|
||||
for sub_a, sub_b in zip(a, b):
|
||||
if not check_equal(sub_a, sub_b):
|
||||
return False
|
||||
return True
|
||||
elif isinstance(a, dict):
|
||||
keys_a, kes_b = set(a.keys()), set(b.keys())
|
||||
if keys_a != kes_b:
|
||||
return False
|
||||
for key in keys_a:
|
||||
if not check_equal(a[key], b[key]):
|
||||
return False
|
||||
return True
|
||||
elif isinstance(a, torch.Tensor):
|
||||
# may not euqal on gpu
|
||||
return torch.std(a - b).item() < 1e-6
|
||||
else:
|
||||
return a == b
|
||||
|
||||
@pytest.mark.parametrize('model_fn', model_list)
|
||||
def test_torchvision_models(model_fn):
|
||||
model = model_fn()
|
||||
model.eval()
|
||||
dummy_inputs = (torch.rand(2, 3, 224, 224), )
|
||||
traced = concrete_trace(model, dummy_inputs, use_operator_patch=True)
|
||||
out_orig = model.forward(*dummy_inputs)
|
||||
out_traced = traced.forward(*dummy_inputs)
|
||||
assert check_equal(out_orig, out_traced), f'{traced.code}'
|
||||
del out_orig, out_traced
|
|
@ -25,7 +25,7 @@ config_files_correct = (
|
|||
# 'ddod/ddod_r50_fpn_1x_coco',
|
||||
# 'deepfashion/mask-rcnn_r50_fpn_15e_deepfashion',
|
||||
# 'deformable_detr/deformable-detr_r50_16xb2-50e_coco',
|
||||
'detr/detr_r18_8xb2-500e_coco',
|
||||
# 'detr/detr_r18_8xb2-500e_coco',
|
||||
# 'double_heads/dh-faster-rcnn_r50_fpn_1x_coco',
|
||||
'dyhead/atss_r50-caffe_fpn_dyhead_1x_coco',
|
||||
# 'dynamic_rcnn/dynamic-rcnn_r50_fpn_1x_coco',
|
||||
|
@ -68,7 +68,7 @@ config_files_correct = (
|
|||
'ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco',
|
||||
# 'swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco',
|
||||
# 'tood/tood_r50_fpn_1x_coco',
|
||||
'vfnet/vfnet_r50_fpn_1x_coco',
|
||||
# 'vfnet/vfnet_r50_fpn_1x_coco',
|
||||
# 'yolact/yolact_r50_1xb8-55e_coco',
|
||||
'yolo/yolov3_d53_8xb8-320-273e_coco',
|
||||
'yolof/yolof_r50-c5_8xb8-1x_coco',
|
||||
|
|
Загрузка…
Ссылка в новой задаче