Fix trace changes after torch 1.1.0

This commit is contained in:
Shital Shah 2020-03-03 23:20:56 -08:00
Родитель 30f770cc97
Коммит dc0f2e01de
1 изменённых файлов: 14 добавлений и 5 удалений

Просмотреть файл

@ -87,7 +87,17 @@ class SummaryGraph(object):
device = distiller.model_device(model_clone)
dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
self.dummy_input = dummy_input
if hasattr(jit, 'get_trace_graph'): # torch 1.1.0 or before
trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
graph = trace.graph()
nodes = graph.nodes()
elif hasattr(jit, '_get_trace_graph'):
trace, _ = jit._get_trace_graph(model_clone, dummy_input, _force_outplace=True)
graph = trace
nodes = graph.nodes()
else:
raise RuntimeError('torch version {} has internal changes that are not supported yet'.format(torch.__version__))
# As of PyTorch 1.1.0, ONNX trace optimization has two issues that result in incorrect scope names
# of nodes in the trace graph.
@ -109,7 +119,7 @@ class SummaryGraph(object):
pre_dropout_nodes_scope_names = OrderedDict()
prev_non_dropout_op = None
for node in trace.graph().nodes():
for node in nodes:
kind = node.kind()
if 'aten' not in kind:
continue
@ -125,7 +135,6 @@ class SummaryGraph(object):
# composing a GEMM operation; etc.
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
graph = trace.graph()
self.ops = OrderedDict()
self.module_ops_map = defaultdict(list)
self.params = OrderedDict()