From 717583233367248995cacce5113e30eb0df0b4d3 Mon Sep 17 00:00:00 2001 From: Kit Date: Fri, 9 Feb 2018 14:57:39 +0800 Subject: [PATCH] tensorflow pytest --- mmdnn/conversion/examples/imagenet_test.py | 7 +++- .../examples/tensorflow/extractor.py | 42 ++++++++++++++++++- tests/test_conversion_imagenet.py | 34 ++++++++------- 3 files changed, 66 insertions(+), 17 deletions(-) diff --git a/mmdnn/conversion/examples/imagenet_test.py b/mmdnn/conversion/examples/imagenet_test.py index c0cebf4..4698e9a 100644 --- a/mmdnn/conversion/examples/imagenet_test.py +++ b/mmdnn/conversion/examples/imagenet_test.py @@ -64,7 +64,7 @@ class TestKit(object): 'caffe' : { 'alexnet' : lambda path : TestKit.ZeroCenter(path, 227, True), 'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, True), - 'inception_v1' : lambda path : TestKit.ZeroCenter(path, 224, True), + 'inception_v1' : lambda path : TestKit.ZeroCenter(path, 227, True), 'resnet152' : lambda path : TestKit.ZeroCenter(path, 224, True), 'squeezenet' : lambda path : TestKit.ZeroCenter(path, 227, False) }, @@ -74,7 +74,12 @@ class TestKit(object): 'inception_v1' : lambda path : TestKit.Standard(path, 224), 'inception_v3' : lambda path : TestKit.Standard(path, 299), 'resnet' : lambda path : TestKit.Standard(path, 299), + 'resnet_v1_50' : lambda path : TestKit.ZeroCenter(path, 224, False), 'resnet_v1_101' : lambda path : TestKit.ZeroCenter(path, 224, False), + 'resnet_v1_152' : lambda path : TestKit.ZeroCenter(path, 224, False), + 'resnet_v2_50' : lambda path : TestKit.Standard(path, 299), + 'resnet_v2_152' : lambda path : TestKit.Standard(path, 299), + 'resnet_v2_200' : lambda path : TestKit.Standard(path, 299), 'resnet152' : lambda path : TestKit.Standard(path, 299), 'mobilenet' : lambda path : TestKit.Standard(path, 224) }, diff --git a/mmdnn/conversion/examples/tensorflow/extractor.py b/mmdnn/conversion/examples/tensorflow/extractor.py index d86b257..6da9e4a 100644 --- a/mmdnn/conversion/examples/tensorflow/extractor.py +++ b/mmdnn/conversion/examples/tensorflow/extractor.py @@ -43,7 +43,47 @@ class tensorflow_extractor(base_extractor): 'arg_scope' : inception.inception_v3_arg_scope, 'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]), 'num_classes' : 1001, - } + }, + 'resnet_v1_50' : { + 'url' : 'http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz', + 'filename' : 'resnet_v1_50.ckpt', + 'builder' : lambda : resnet_v1.resnet_v1_50, + 'arg_scope' : resnet_v2.resnet_arg_scope, + 'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]), + 'num_classes' : 1000, + }, + 'resnet_v1_152' : { + 'url' : 'http://download.tensorflow.org/models/resnet_v1_152_2016_08_28.tar.gz', + 'filename' : 'resnet_v1_152.ckpt', + 'builder' : lambda : resnet_v1.resnet_v1_152, + 'arg_scope' : resnet_v2.resnet_arg_scope, + 'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]), + 'num_classes' : 1000, + }, + 'resnet_v2_50' : { + 'url' : 'http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz', + 'filename' : 'resnet_v2_50.ckpt', + 'builder' : lambda : resnet_v2.resnet_v2_50, + 'arg_scope' : resnet_v2.resnet_arg_scope, + 'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]), + 'num_classes' : 1001, + }, + 'resnet_v2_152' : { + 'url' : 'http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz', + 'filename' : 'resnet_v2_152.ckpt', + 'builder' : lambda : resnet_v2.resnet_v2_152, + 'arg_scope' : resnet_v2.resnet_arg_scope, + 'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]), + 'num_classes' : 1001, + }, + 'resnet_v2_200' : { + 'url' : 'http://download.tensorflow.org/models/resnet_v2_200_2017_04_14.tar.gz', + 'filename' : 'resnet_v2_200.ckpt', + 'builder' : lambda : resnet_v2.resnet_v2_200, + 'arg_scope' : resnet_v2.resnet_arg_scope, + 'input' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]), + 'num_classes' : 1001, + }, } diff --git a/tests/test_conversion_imagenet.py b/tests/test_conversion_imagenet.py index c1dc996..920e8f0 100644 --- a/tests/test_conversion_imagenet.py +++ b/tests/test_conversion_imagenet.py @@ -177,8 +177,6 @@ class TestModels(CorrectnessTest): @staticmethod def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path): - print("Testing {} from {} to CNTK.".format(architecture_name, original_framework)) - # IR to code converted_file = original_framework + '_cntk_' + architecture_name + "_converted" converted_file = converted_file.replace('.', '_') @@ -200,8 +198,6 @@ class TestModels(CorrectnessTest): @staticmethod def TensorflowEmit(original_framework, architecture_name, architecture_path, weight_path, image_path): - print("Testing {} from {} to TensorFlow.".format(architecture_name, original_framework)) - # IR to code converted_file = original_framework + '_tensorflow_' + architecture_name + "_converted" converted_file = converted_file.replace('.', '_') @@ -230,10 +226,9 @@ class TestModels(CorrectnessTest): @staticmethod def PytorchEmit(original_framework, architecture_name, architecture_path, weight_path, image_path): import torch - print("Testing {} from {} to PyTorch.".format(architecture_name, original_framework)) # IR to code - converted_file = original_framework + '_keras_' + architecture_name + "_converted" + converted_file = original_framework + '_pytorch_' + architecture_name + "_converted" converted_file = converted_file.replace('.', '_') emitter = PytorchEmitter((architecture_path, weight_path)) emitter.run(converted_file + '.py', converted_file + '.npy', 'test') @@ -264,8 +259,6 @@ class TestModels(CorrectnessTest): @staticmethod def KerasEmit(original_framework, architecture_name, architecture_path, weight_path, image_path): - print("Testing {} from {} to Keras.".format(architecture_name, original_framework)) - # IR to code converted_file = original_framework + '_keras_' + architecture_name + "_converted" converted_file = converted_file.replace('.', '_') @@ -304,6 +297,7 @@ class TestModels(CorrectnessTest): 'mobilenet' : [TensorflowEmit, KerasEmit], 'nasnet' : [TensorflowEmit, KerasEmit], }, + 'mxnet' : { 'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], 'imagenet1k-inception-bn' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], @@ -312,6 +306,7 @@ class TestModels(CorrectnessTest): 'imagenet1k-resnext-101-64x4d' : [CntkEmit, TensorflowEmit, PytorchEmit], # Keras is too slow 'imagenet1k-resnext-50' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], }, + 'caffe' : { 'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], 'alexnet' : [CntkEmit], @@ -319,9 +314,15 @@ class TestModels(CorrectnessTest): 'resnet152' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], 'squeezenet' : [CntkEmit, PytorchEmit] }, + 'tensorflow' : { 'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit], 'inception_v1' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit + 'inception_v3' : [CntkEmit, TensorflowEmit, KerasEmit], # TODO: PytorchEmit + 'resnet_v1_50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit + 'resnet_v1_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit + 'resnet_v2_50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit + 'resnet_v2_152' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit }, } @@ -331,13 +332,14 @@ class TestModels(CorrectnessTest): ensure_dir(self.tmpdir) for network_name in self.test_table[original_framework].keys(): - print("Testing {} model {} start.".format(original_framework, network_name)) + print("Test {} from {} start.".format(network_name, original_framework), file=sys.stderr, flush=True) # get original model prediction result original_predict = parser(network_name, self.image_path) IR_file = TestModels.tmpdir + original_framework + '_' + network_name + "_converted" for emit in self.test_table[original_framework][network_name]: + print('Testing conversion {} from {} to {}.'.format(network_name, original_framework, emit.__func__.__name__[:-4]), file=sys.stderr, flush=True) converted_predict = emit.__func__( original_framework, network_name, @@ -347,6 +349,8 @@ class TestModels(CorrectnessTest): self._compare_outputs(original_predict, converted_predict) + print('Conversion {} from {} to {} passed.'.format(network_name, original_framework, emit.__func__.__name__[:-4]), file=sys.stderr, flush=True) + try: os.remove(IR_file + ".json") except OSError: @@ -363,13 +367,13 @@ class TestModels(CorrectnessTest): self._test_function('tensorflow', self.TensorFlowParse) - def test_caffe(self): - self._test_function('caffe', self.CaffeParse) + # def test_caffe(self): + # self._test_function('caffe', self.CaffeParse) - def test_keras(self): - self._test_function('keras', self.KerasParse) + # def test_keras(self): + # self._test_function('keras', self.KerasParse) - def test_mxnet(self): - self._test_function('mxnet', self.MXNetParse) + # def test_mxnet(self): + # self._test_function('mxnet', self.MXNetParse)