[common] cpu/gpu mix trace fix (#5583)

This commit is contained in:
super-dainiu 2023-05-29 21:46:36 +08:00 коммит произвёл GitHub
Родитель 1605cd0c53
Коммит aec1fec127
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 154 добавлений и 51 удалений

Просмотреть файл

@ -19,6 +19,7 @@ from contextlib import contextmanager
import torch
from torch._C import ScriptObject
from torch.nn.modules.container import Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
from torch.utils._pytree import tree_map
import torch.fx
from torch.fx import GraphModule
@ -86,6 +87,7 @@ from .utils import (
)
# pyright: reportGeneralTypeIssues=false
_logger = logging.getLogger(__name__)
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
@ -214,7 +216,7 @@ class ConcreteTracer(TracerBase):
node_to_originating_module : Dict[torch.fx.Node, str] = {}
@compatibility(is_backward_compatible=True)
def __init__(self):
def __init__(self, cpu_offload = False):
"""
similar to _symbolic_trace.Tracer.__init__.
remove the 'param_shapes_constant' because we can get real shape when executing.
@ -223,6 +225,7 @@ class ConcreteTracer(TracerBase):
self.scope = Scope("", None)
self.module_stack = collections.OrderedDict()
self.node_name_to_scope = {}
self.cpu_offload = cpu_offload
@contextmanager
def do_temp_disable(self, call=False, attr=False, agfunc_apply=False):
@ -275,36 +278,65 @@ class ConcreteTracer(TracerBase):
actually execute the code.
apply the patcher, and the _autowrap_check to the target function.
"""
if kind == 'call_function':
assert isinstance(target, Callable)
fn = target
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'):
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
with self.do_temp_disable(call=True):
return OperatorPatcherContext.patch_run(fn, *args, **kwargs)
elif kind == 'call_method':
self_obj, *args_tail = args
fn = _orig_getattr(self_obj, target)
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(fn, '__globals__'):
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
with self.do_temp_disable(call=True):
return OperatorPatcherContext.patch_run(fn, *args_tail, **kwargs)
elif kind == 'call_module':
assert isinstance(target, str)
mod = self.fetch_attr(target)
if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' and hasattr(mod, '__globals__'):
_autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
with self.do_temp_disable(call=True):
return OperatorPatcherContext.patch_run(mod, *args, **kwargs)
elif kind == 'get_attr':
assert isinstance(target, str)
return self.fetch_attr(target)
elif kind == 'output':
if kind == 'output':
return args[0]
elif kind == 'placeholder':
return self.placeholder_dict[target]
else:
raise RuntimeError()
to_cpu = lambda t: t.cpu() if _orig_isinstance(t, torch.Tensor) else t
to_cuda = lambda t: t.cuda() if _orig_isinstance(t, torch.Tensor) else t
def run(kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]):
if self.cpu_offload:
args = tree_map(to_cuda, args)
kwargs = tree_map(to_cuda, kwargs)
if kind == 'call_function':
assert isinstance(target, Callable)
fn = target
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \
and hasattr(fn, '__globals__'):
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
return OperatorPatcherContext.patch_run(fn, *args, **kwargs)
elif kind == 'call_method':
self_obj, *args_tail = args
fn = _orig_getattr(self_obj, target)
if _orig_getattr(fn, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \
and hasattr(fn, '__globals__'):
_autowrap_check(self, fn.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
result = fn(*args_tail, **kwargs)
elif kind == 'call_module':
assert isinstance(target, str)
mod = self.fetch_attr(target)
if self.cpu_offload:
mod.cuda() # how it works in ddp?
if _orig_getattr(mod, '__module__', None) != 'nni.common.concrete_trace_utils.concrete_tracer' \
and hasattr(mod, '__globals__'):
_autowrap_check(self, mod.__globals__, self._autowrap_function_ids, self.autowrap_leaf_pairs, self.agfunc_dict)
result = OperatorPatcherContext.patch_run(mod, *args, **kwargs)
if self.cpu_offload:
mod.cpu()
elif kind == 'get_attr':
assert isinstance(target, str)
return self.fetch_attr(target)
else:
raise RuntimeError()
return result
with self.do_temp_disable(call=True):
result = run(kind, target, args, kwargs)
if self.cpu_offload:
if isinstance(result, torch.Tensor):
result = result.cpu()
elif isinstance(result, (list, dict, tuple)):
result = tree_map(to_cpu, result)
else:
_logger.warning(f"result of target {target} is {type(result)}, which is not a common behavior.")
torch.cuda.empty_cache()
self.temp_disable_call = False
return result
@compatibility(is_backward_compatible=True)
def proxy(self, value: Any, node: Node) -> ep.ConcreteProxy:
@ -315,8 +347,8 @@ class ConcreteTracer(TracerBase):
@compatibility(is_backward_compatible=True)
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr: Optional[Any] = None,
proxy_factory_fn: Optional[Callable[[Node], Any]] = None):
name: Optional[str] = None, type_expr: Optional[Any] = None,
proxy_factory_fn: Optional[Callable[[Node], Any]] = None):
"""
similar to _symbolic_trace.Tracer.create_proxy.
use the 'run_target' to actually execute the code, and store the value in 'value' field.
@ -502,10 +534,10 @@ class ConcreteTracer(TracerBase):
cnt = 0
self.placeholder_dict = {}
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
diff_len = len(arg_names) - len(default_value_list)
diff_len = _orig_len(arg_names) - _orig_len(default_value_list)
default_args = {arg_names[idx + diff_len]: default_value_list[idx] for idx in range(len(default_value_list))}
if isinstance(concrete_args, tuple):
if len(arg_names) != len(concrete_args):
if _orig_len(arg_names) != _orig_len(concrete_args):
raise RuntimeError(f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments")
concrete_args = {name: val for name, val in zip(arg_names, concrete_args)}
def proxy_placeholder(name: str):
@ -671,7 +703,7 @@ class ConcreteTracer(TracerBase):
return _orig_module_getattribute(mod, attr)
except AttributeError:
return _orig_module_getattr(mod, attr)
with self.do_temp_disable(call=True, attr=True):
with self.do_temp_disable(attr=True):
try:
attr_val = _orig_module_getattribute(mod, attr)
except AttributeError:
@ -992,6 +1024,9 @@ class ConcreteTracer(TracerBase):
pass
self.submodule_paths = None
with MagicMethodPatcher():
GraphModule(self.root, self.graph) # assign graph.owning_module
self.graph.eliminate_dead_code()
return self.graph
# List of pairs of (global dict, function name) functions
@ -1318,13 +1353,13 @@ def _retain_weight_consistency(root: torch.nn.Module):
for module in root.modules():
for name, param in module.named_parameters():
if _orig_isinstance(param, ep.ConcreteProxy):
param: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false
param: ep.ConcreteProxy
_logger.warning(f'Parameter {name} of {module} is a ConcreteProxy. Some weight may be modified inplace within forward().')
setattr(module, name, param.value)
_flag |= 1
for name, buffer in module.named_buffers():
if _orig_isinstance(buffer, ep.ConcreteProxy):
buffer: ep.ConcreteProxy # pyright: reportGeneralTypeIssues=false
buffer: ep.ConcreteProxy
_logger.warning(f'Buffer {name} of {module} is a ConcreteProxy. Some buffer may be modified inplace within forward().')
setattr(module, name, buffer.value)
_flag |= 1
@ -1344,7 +1379,9 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
autowrap_leaf_class = None,
leaf_module: Tuple | None = None,
fake_middle_class = None,
dce = False) -> GraphModule:
dce = False,
cpu_offload = False,
) -> GraphModule:
"""
Concrete tracing API
@ -1467,33 +1504,39 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
The struct of dict is: leaf_class: ([(module_path, module_name)], is_iterator_class).
is_iterator_class: Is the class init from an iterator. Only 'tuple', 'list', 'set' or 'dict' needs to set it to True.
cpu_offload (bool): Whether to offload the module to CPU during tracing. If set to True, the traced code will be executed on GPU,
but is offloaded to CPU afterward. This is useful for reducing memory usage during tracing, but may cause performance issues.
If set to False, there will be no offloading during tracing, but the traced code will be executed on default device.
Returns:
fx.GraphModule: a Module created from the recorded operations from ``root``.
"""
tracer = ConcreteTracer()
tracer = ConcreteTracer(cpu_offload = cpu_offload)
is_training = root.training
root.eval()
graph = tracer.trace(root,
autowrap_leaf_function = autowrap_leaf_function,
autowrap_leaf_class = autowrap_leaf_class,
leaf_module = leaf_module,
fake_middle_class = fake_middle_class,
concrete_args=concrete_args,
use_operator_patch=use_operator_patch,
operator_patch_backlist=operator_patch_backlist,
forward_function_name=forward_function_name,
concrete_args = concrete_args,
use_operator_patch = use_operator_patch,
operator_patch_backlist = operator_patch_backlist,
forward_function_name = forward_function_name,
)
graph_check = tracer.trace(root,
autowrap_leaf_function = autowrap_leaf_function,
autowrap_leaf_class = autowrap_leaf_class,
leaf_module = leaf_module,
fake_middle_class = fake_middle_class,
concrete_args=concrete_args,
use_operator_patch=use_operator_patch,
operator_patch_backlist=operator_patch_backlist,
forward_function_name=forward_function_name,
concrete_args = concrete_args,
use_operator_patch = use_operator_patch,
operator_patch_backlist = operator_patch_backlist,
forward_function_name = forward_function_name,
)
# compare to check equal
assert len(graph.nodes) == len(graph_check.nodes)
assert len(graph.nodes) == len(graph_check.nodes), f'number nodes: {len(graph.nodes)} vs {len(graph_check.nodes)}'
for node_a, node_b in zip(graph.nodes, graph_check.nodes):
node_a: Node
node_b: Node
@ -1507,14 +1550,13 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
assert node_b.op == 'call_function' and isinstance(target_b, Callable) and target_b.__name__ == 'apply' and\
hasattr(target_b, '__self__') and issubclass(target_b.__self__, torch.autograd.Function)
else:
assert node_a.op == node_b.op and target_a == target_b
assert node_a.op == node_b.op and target_a == target_b, f'op: {node_a.op} vs {node_b.op}, target: {target_a} vs {target_b}'
with MagicMethodPatcher():
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
traced = GraphModule(tracer.root, graph, name)
# TODO: better infomation
# # assert root(**concrete_args) == traced(**concrete_args)
if check_args is not None:
assert root(**check_args) == traced(**check_args)
@ -1545,4 +1587,7 @@ def concrete_trace(root : Union[torch.nn.Module, Callable[..., Any]],
recursively_check_node(node)
traced.recompile()
if is_training:
root.train()
return traced

Просмотреть файл

@ -52,7 +52,7 @@ class TransformerOp(ast.NodeTransformer):
return super().visit(node)
def visit_Call(self, node: ast.Call):
if isinstance(node.func, ast.Name) and node.func.id == 'super' and len(node.args) == 0:
if isinstance(node.func, ast.Name) and node.func.id == 'super' and _orig_len(node.args) == 0:
return self.generic_visit(ast.Call(
func=ast.Name(id='super', ctx=ast.Load()),
args=[
@ -173,6 +173,8 @@ class OperatorPatcher:
self.function_cache_orig: Dict[int, Callable] = {}
def patch_inner(self, func):
if _orig_isinstance(func, torch.nn.Module):
return self.patch_inner_helper(func) # better not cache this
if id(func) not in self.function_cache:
self.function_cache[id(func)] = self.patch_inner_helper(func)
self.function_cache_orig[id(func)] = func

Просмотреть файл

@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytest
import torch
import torchvision.models as models
from nni.common.concrete_trace_utils import concrete_trace
model_list = [
models.alexnet,
models.convnext_base,
models.densenet121,
models.efficientnet_b0,
models.mobilenet_v2,
models.resnet18,
models.resnext50_32x4d,
models.vit_b_16,
]
def check_equal(a, b):
if type(a) != type(b):
return False
if isinstance(a, (list, tuple, set)):
if len(a) != len(b):
return False
for sub_a, sub_b in zip(a, b):
if not check_equal(sub_a, sub_b):
return False
return True
elif isinstance(a, dict):
keys_a, kes_b = set(a.keys()), set(b.keys())
if keys_a != kes_b:
return False
for key in keys_a:
if not check_equal(a[key], b[key]):
return False
return True
elif isinstance(a, torch.Tensor):
# may not euqal on gpu
return torch.std(a - b).item() < 1e-6
else:
return a == b
@pytest.mark.parametrize('model_fn', model_list)
def test_torchvision_models(model_fn):
model = model_fn()
model.eval()
dummy_inputs = (torch.rand(2, 3, 224, 224), )
traced = concrete_trace(model, dummy_inputs, use_operator_patch=True)
out_orig = model.forward(*dummy_inputs)
out_traced = traced.forward(*dummy_inputs)
assert check_equal(out_orig, out_traced), f'{traced.code}'
del out_orig, out_traced

Просмотреть файл

@ -25,7 +25,7 @@ config_files_correct = (
# 'ddod/ddod_r50_fpn_1x_coco',
# 'deepfashion/mask-rcnn_r50_fpn_15e_deepfashion',
# 'deformable_detr/deformable-detr_r50_16xb2-50e_coco',
'detr/detr_r18_8xb2-500e_coco',
# 'detr/detr_r18_8xb2-500e_coco',
# 'double_heads/dh-faster-rcnn_r50_fpn_1x_coco',
'dyhead/atss_r50-caffe_fpn_dyhead_1x_coco',
# 'dynamic_rcnn/dynamic-rcnn_r50_fpn_1x_coco',
@ -68,7 +68,7 @@ config_files_correct = (
'ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco',
# 'swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco',
# 'tood/tood_r50_fpn_1x_coco',
'vfnet/vfnet_r50_fpn_1x_coco',
# 'vfnet/vfnet_r50_fpn_1x_coco',
# 'yolact/yolact_r50_1xb8-55e_coco',
'yolo/yolov3_d53_8xb8-320-273e_coco',
'yolof/yolof_r50-c5_8xb8-1x_coco',