Fix trace changes after torch 1.1.0
This commit is contained in:
Родитель
30f770cc97
Коммит
dc0f2e01de
|
@ -83,11 +83,21 @@ class SummaryGraph(object):
|
|||
model_clone, converted_module_names_map = _to_distiller_modulelist(model_clone)
|
||||
|
||||
with torch.onnx.set_training(model_clone, False):
|
||||
|
||||
|
||||
device = distiller.model_device(model_clone)
|
||||
dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
|
||||
self.dummy_input = dummy_input
|
||||
trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
|
||||
|
||||
if hasattr(jit, 'get_trace_graph'): # torch 1.1.0 or before
|
||||
trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
|
||||
graph = trace.graph()
|
||||
nodes = graph.nodes()
|
||||
elif hasattr(jit, '_get_trace_graph'):
|
||||
trace, _ = jit._get_trace_graph(model_clone, dummy_input, _force_outplace=True)
|
||||
graph = trace
|
||||
nodes = graph.nodes()
|
||||
else:
|
||||
raise RuntimeError('torch version {} has internal changes that are not supported yet'.format(torch.__version__))
|
||||
|
||||
# As of PyTorch 1.1.0, ONNX trace optimization has two issues that result in incorrect scope names
|
||||
# of nodes in the trace graph.
|
||||
|
@ -109,7 +119,7 @@ class SummaryGraph(object):
|
|||
pre_dropout_nodes_scope_names = OrderedDict()
|
||||
|
||||
prev_non_dropout_op = None
|
||||
for node in trace.graph().nodes():
|
||||
for node in nodes:
|
||||
kind = node.kind()
|
||||
if 'aten' not in kind:
|
||||
continue
|
||||
|
@ -125,7 +135,6 @@ class SummaryGraph(object):
|
|||
# composing a GEMM operation; etc.
|
||||
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
|
||||
|
||||
graph = trace.graph()
|
||||
self.ops = OrderedDict()
|
||||
self.module_ops_map = defaultdict(list)
|
||||
self.params = OrderedDict()
|
||||
|
@ -351,7 +360,7 @@ class SummaryGraph(object):
|
|||
else:
|
||||
kernel_size, group = 1, 1
|
||||
n_ifm = self.param_shape(conv_in)[1]
|
||||
n_ofm = self.param_shape(conv_out)[1]
|
||||
n_ofm = self.param_shape(conv_out)[1]
|
||||
weights_vol = kernel_size * n_ifm * n_ofm / group
|
||||
op['attrs']['n_ifm'] = n_ifm
|
||||
op['attrs']['n_ofm'] = n_ofm
|
||||
|
|
Загрузка…
Ссылка в новой задаче