Fixing package path in tflite test (#3427)

This commit is contained in:
Sammy 2019-06-24 23:55:55 -04:00 коммит произвёл Tianqi Chen
Родитель 25bad4402f
Коммит e97c01012d
1 изменённых файлов: 5 добавлений и 2 удалений

Просмотреть файл

@ -32,7 +32,10 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow import lite as interpreter_wrapper
try:
from tensorflow import lite as interpreter_wrapper
except ImportError:
from tensorflow.contrib import lite as interpreter_wrapper
import tvm.relay.testing.tf as tf_testing
@ -131,7 +134,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
if init_global_variables:
sess.run(variables.global_variables_initializer())
# convert to tflite model
converter = tf.contrib.lite.TFLiteConverter.from_session(
converter = interpreter_wrapper.TFLiteConverter.from_session(
sess, input_tensors, output_tensors)
tflite_model_buffer = converter.convert()
tflite_output = run_tflite_graph(tflite_model_buffer, in_data)