change to use six reload module function.

This commit is contained in:
Kit 2018-02-07 11:34:23 +08:00
Родитель bee331c341
Коммит 82a53c5779
1 изменённых файлов: 11 добавлений и 8 удалений

Просмотреть файл

@ -1,7 +1,7 @@
import os
import unittest
import numpy as np
from imp import reload
from six.moves import reload_module
import tensorflow as tf
from mmdnn.conversion.examples.imagenet_test import TestKit
@ -25,9 +25,10 @@ def _compute_SNR(x,y):
def _compute_max_relative_error(x,y):
from six.moves import xrange
rerror = 0
index = 0
for i in range(len(x)):
for i in xrange(len(x)):
den = max(1.0, np.abs(x[i]), np.abs(y[i]))
if np.abs(x[i]/den - y[i]/den) > rerror:
rerror = np.abs(x[i]/den - y[i]/den)
@ -97,7 +98,7 @@ class TestModels(CorrectnessTest):
# import converted model
import converted_model
reload (converted_model)
reload_module (converted_model)
model_converted = converted_model.KitModel(TestModels.tmpdir + architecture_name + "_converted.npy")
func = TestKit.preprocess_func[original_framework][architecture_name]
@ -121,7 +122,7 @@ class TestModels(CorrectnessTest):
# import converted model
import converted_model
reload (converted_model)
reload_module (converted_model)
model_converted = converted_model.KitModel(TestModels.tmpdir + architecture_name + "_converted.npy")
input_tf, model_tf = model_converted
@ -150,7 +151,7 @@ class TestModels(CorrectnessTest):
# import converted model
import converted_model
reload (converted_model)
reload_module (converted_model)
model_converted = converted_model.KitModel(TestModels.tmpdir + architecture_name + "_converted.npy")
# input_tf, model_tf = model_converted
@ -176,7 +177,7 @@ class TestModels(CorrectnessTest):
# import converted model
import converted_model
reload (converted_model)
reload_module (converted_model)
model_converted = converted_model.KitModel(TestModels.tmpdir + architecture_name + "_converted.npy")
func = TestKit.preprocess_func[original_framework][architecture_name]
@ -190,19 +191,21 @@ class TestModels(CorrectnessTest):
os.remove("converted_model.py")
return converted_predict
test_table = {
'keras': {
'keras' : {
'vgg16' : [CntkEmit, TensorflowEmit, KerasEmit],
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit],
'inception_v3' : [CntkEmit, TensorflowEmit, KerasEmit],
'resnet50' : [CntkEmit, TensorflowEmit, KerasEmit],
'densenet' : [CntkEmit, KerasEmit],
'densenet' : [CntkEmit, TensorflowEmit, KerasEmit],
'xception' : [TensorflowEmit, KerasEmit],
'mobilenet' : [TensorflowEmit, KerasEmit],
'nasnet' : [KerasEmit],
}
}
def test_keras(self):
# keras original
ensure_dir(self.cachedir)