Fix trace changes after torch 1.1.0
This commit is contained in:
Родитель
30f770cc97
Коммит
dc0f2e01de
|
@ -87,7 +87,17 @@ class SummaryGraph(object):
|
|||
device = distiller.model_device(model_clone)
|
||||
dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
|
||||
self.dummy_input = dummy_input
|
||||
|
||||
if hasattr(jit, 'get_trace_graph'): # torch 1.1.0 or before
|
||||
trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
|
||||
graph = trace.graph()
|
||||
nodes = graph.nodes()
|
||||
elif hasattr(jit, '_get_trace_graph'):
|
||||
trace, _ = jit._get_trace_graph(model_clone, dummy_input, _force_outplace=True)
|
||||
graph = trace
|
||||
nodes = graph.nodes()
|
||||
else:
|
||||
raise RuntimeError('torch version {} has internal changes that are not supported yet'.format(torch.__version__))
|
||||
|
||||
# As of PyTorch 1.1.0, ONNX trace optimization has two issues that result in incorrect scope names
|
||||
# of nodes in the trace graph.
|
||||
|
@ -109,7 +119,7 @@ class SummaryGraph(object):
|
|||
pre_dropout_nodes_scope_names = OrderedDict()
|
||||
|
||||
prev_non_dropout_op = None
|
||||
for node in trace.graph().nodes():
|
||||
for node in nodes:
|
||||
kind = node.kind()
|
||||
if 'aten' not in kind:
|
||||
continue
|
||||
|
@ -125,7 +135,6 @@ class SummaryGraph(object):
|
|||
# composing a GEMM operation; etc.
|
||||
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
|
||||
|
||||
graph = trace.graph()
|
||||
self.ops = OrderedDict()
|
||||
self.module_ops_map = defaultdict(list)
|
||||
self.params = OrderedDict()
|
||||
|
|
Загрузка…
Ссылка в новой задаче