зеркало из https://github.com/microsoft/MMdnn.git
tensorflow pytest
This commit is contained in:
Родитель
ed7947304f
Коммит
7175832333
|
@ -64,7 +64,7 @@ class TestKit(object):
|
|||
'caffe' : {
|
||||
'alexnet' : lambda path : TestKit.ZeroCenter(path, 227, True),
|
||||
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'inception_v1' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'inception_v1' : lambda path : TestKit.ZeroCenter(path, 227, True),
|
||||
'resnet152' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'squeezenet' : lambda path : TestKit.ZeroCenter(path, 227, False)
|
||||
},
|
||||
|
@ -74,7 +74,12 @@ class TestKit(object):
|
|||
'inception_v1' : lambda path : TestKit.Standard(path, 224),
|
||||
'inception_v3' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet_v1_50' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'resnet_v1_101' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'resnet_v1_152' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'resnet_v2_50' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet_v2_152' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet_v2_200' : lambda path : TestKit.Standard(path, 299),
|
||||
'resnet152' : lambda path : TestKit.Standard(path, 299),
|
||||
'mobilenet' : lambda path : TestKit.Standard(path, 224)
|
||||
},
|
||||
|
|
|
@ -43,7 +43,47 @@ class tensorflow_extractor(base_extractor):
|
|||
'arg_scope' : inception.inception_v3_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
}
|
||||
},
|
||||
'resnet_v1_50' : {
|
||||
'url' : 'http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz',
|
||||
'filename' : 'resnet_v1_50.ckpt',
|
||||
'builder' : lambda : resnet_v1.resnet_v1_50,
|
||||
'arg_scope' : resnet_v2.resnet_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
|
||||
'num_classes' : 1000,
|
||||
},
|
||||
'resnet_v1_152' : {
|
||||
'url' : 'http://download.tensorflow.org/models/resnet_v1_152_2016_08_28.tar.gz',
|
||||
'filename' : 'resnet_v1_152.ckpt',
|
||||
'builder' : lambda : resnet_v1.resnet_v1_152,
|
||||
'arg_scope' : resnet_v2.resnet_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
|
||||
'num_classes' : 1000,
|
||||
},
|
||||
'resnet_v2_50' : {
|
||||
'url' : 'http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz',
|
||||
'filename' : 'resnet_v2_50.ckpt',
|
||||
'builder' : lambda : resnet_v2.resnet_v2_50,
|
||||
'arg_scope' : resnet_v2.resnet_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'resnet_v2_152' : {
|
||||
'url' : 'http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz',
|
||||
'filename' : 'resnet_v2_152.ckpt',
|
||||
'builder' : lambda : resnet_v2.resnet_v2_152,
|
||||
'arg_scope' : resnet_v2.resnet_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
'resnet_v2_200' : {
|
||||
'url' : 'http://download.tensorflow.org/models/resnet_v2_200_2017_04_14.tar.gz',
|
||||
'filename' : 'resnet_v2_200.ckpt',
|
||||
'builder' : lambda : resnet_v2.resnet_v2_200,
|
||||
'arg_scope' : resnet_v2.resnet_arg_scope,
|
||||
'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
|
||||
'num_classes' : 1001,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -177,8 +177,6 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
@staticmethod
|
||||
def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to CNTK.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
converted_file = original_framework + '_cntk_' + architecture_name + "_converted"
|
||||
converted_file = converted_file.replace('.', '_')
|
||||
|
@ -200,8 +198,6 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
@staticmethod
|
||||
def TensorflowEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to TensorFlow.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
converted_file = original_framework + '_tensorflow_' + architecture_name + "_converted"
|
||||
converted_file = converted_file.replace('.', '_')
|
||||
|
@ -230,10 +226,9 @@ class TestModels(CorrectnessTest):
|
|||
@staticmethod
|
||||
def PytorchEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
import torch
|
||||
print("Testing {} from {} to PyTorch.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
converted_file = original_framework + '_keras_' + architecture_name + "_converted"
|
||||
converted_file = original_framework + '_pytorch_' + architecture_name + "_converted"
|
||||
converted_file = converted_file.replace('.', '_')
|
||||
emitter = PytorchEmitter((architecture_path, weight_path))
|
||||
emitter.run(converted_file + '.py', converted_file + '.npy', 'test')
|
||||
|
@ -264,8 +259,6 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
@staticmethod
|
||||
def KerasEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to Keras.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
converted_file = original_framework + '_keras_' + architecture_name + "_converted"
|
||||
converted_file = converted_file.replace('.', '_')
|
||||
|
@ -304,6 +297,7 @@ class TestModels(CorrectnessTest):
|
|||
'mobilenet' : [TensorflowEmit, KerasEmit],
|
||||
'nasnet' : [TensorflowEmit, KerasEmit],
|
||||
},
|
||||
|
||||
'mxnet' : {
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'imagenet1k-inception-bn' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
|
@ -312,6 +306,7 @@ class TestModels(CorrectnessTest):
|
|||
'imagenet1k-resnext-101-64x4d' : [CntkEmit, TensorflowEmit, PytorchEmit], # Keras is too slow
|
||||
'imagenet1k-resnext-50' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
},
|
||||
|
||||
'caffe' : {
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'alexnet' : [CntkEmit],
|
||||
|
@ -319,9 +314,15 @@ class TestModels(CorrectnessTest):
|
|||
'resnet152' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'squeezenet' : [CntkEmit, PytorchEmit]
|
||||
},
|
||||
|
||||
'tensorflow' : {
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'inception_v1' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'inception_v3' : [CntkEmit, TensorflowEmit, KerasEmit], # TODO: PytorchEmit
|
||||
'resnet_v1_50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'resnet_v1_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'resnet_v2_50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'resnet_v2_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -331,13 +332,14 @@ class TestModels(CorrectnessTest):
|
|||
ensure_dir(self.tmpdir)
|
||||
|
||||
for network_name in self.test_table[original_framework].keys():
|
||||
print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
print("Test {} from {} start.".format(network_name, original_framework), file=sys.stderr, flush=True)
|
||||
|
||||
# get original model prediction result
|
||||
original_predict = parser(network_name, self.image_path)
|
||||
|
||||
IR_file = TestModels.tmpdir + original_framework + '_' + network_name + "_converted"
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
print('Testing conversion {} from {} to {}.'.format(network_name, original_framework, emit.__func__.__name__[:-4]), file=sys.stderr, flush=True)
|
||||
converted_predict = emit.__func__(
|
||||
original_framework,
|
||||
network_name,
|
||||
|
@ -347,6 +349,8 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
print('Conversion {} from {} to {} passed.'.format(network_name, original_framework, emit.__func__.__name__[:-4]), file=sys.stderr, flush=True)
|
||||
|
||||
try:
|
||||
os.remove(IR_file + ".json")
|
||||
except OSError:
|
||||
|
@ -363,13 +367,13 @@ class TestModels(CorrectnessTest):
|
|||
self._test_function('tensorflow', self.TensorFlowParse)
|
||||
|
||||
|
||||
def test_caffe(self):
|
||||
self._test_function('caffe', self.CaffeParse)
|
||||
# def test_caffe(self):
|
||||
# self._test_function('caffe', self.CaffeParse)
|
||||
|
||||
|
||||
def test_keras(self):
|
||||
self._test_function('keras', self.KerasParse)
|
||||
# def test_keras(self):
|
||||
# self._test_function('keras', self.KerasParse)
|
||||
|
||||
|
||||
def test_mxnet(self):
|
||||
self._test_function('mxnet', self.MXNetParse)
|
||||
# def test_mxnet(self):
|
||||
# self._test_function('mxnet', self.MXNetParse)
|
||||
|
|
Загрузка…
Ссылка в новой задаче