From 4b612fd6a26002d701deb6084c05577285b94bfe Mon Sep 17 00:00:00 2001 From: Kit Date: Wed, 7 Feb 2018 16:39:29 +0800 Subject: [PATCH] Implement MXNet -> Other pytest. --- README.md | 6 +- mmdnn/conversion/_script/extractModel.py | 9 +- mmdnn/conversion/common/utils.py | 2 - mmdnn/conversion/examples/imagenet_test.py | 14 +-- mmdnn/conversion/examples/keras/extractor.py | 2 +- mmdnn/conversion/examples/mxnet/extractor.py | 103 +++++++++++++++++++ tests/test_conversion_imagenet.py | 68 +++++++++++- 7 files changed, 188 insertions(+), 16 deletions(-) create mode 100644 mmdnn/conversion/examples/mxnet/extractor.py diff --git a/README.md b/README.md index ab759cb..7d827be 100644 --- a/README.md +++ b/README.md @@ -58,9 +58,11 @@ Models | Caffe | Keras | Tensorflow [ResNet V1 50](https://arxiv.org/abs/1512.03385) | × | √ | √ | o | √ | √ | √ [ResNet V2 152](https://arxiv.org/abs/1603.05027) | √ | √ | √ | √ | √ | √ [VGG 19](http://arxiv.org/abs/1409.1556.pdf) | √ | √ | √ | √ | √ | √ | √ -[MobileNet_v1](https://arxiv.org/pdf/1704.04861.pdf)| × | √ | √ | o (only Relu) | × | × | √ -[Xception](https://arxiv.org/pdf/1610.02357.pdf) | × | √ | √ | × | × | × +[MobileNet_v1](https://arxiv.org/pdf/1704.04861.pdf)| × | √ | √ | × (no DepthwiseConv) | × | × | √ +[Xception](https://arxiv.org/pdf/1610.02357.pdf) | × | √ | √ | × (no SeparableConv) | × | × [SqueezeNet](https://arxiv.org/pdf/1602.07360) | √ | √ | √ | √ | √ | × +DenseNet | | √ | √ | √ | | | +[NASNet](https://arxiv.org/abs/1707.07012) | | √ | √ | × (no SeparableConv) #### On-going frameworks diff --git a/mmdnn/conversion/_script/extractModel.py b/mmdnn/conversion/_script/extractModel.py index 8d4ffcc..356cc49 100644 --- a/mmdnn/conversion/_script/extractModel.py +++ b/mmdnn/conversion/_script/extractModel.py @@ -64,7 +64,8 @@ def extract_model(args): pass elif args.framework == 'mxnet': - pass + from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor + extractor = mxnet_extractor() elif args.framework == 'cntk': pass @@ -74,7 +75,7 @@ def extract_model(args): files = extractor.download(args.network,args.path) if files and args.image: - predict = extractor.inference(args.network, args.image) + predict = extractor.inference(args.network, args.path, args.image) top_indices = predict.argsort()[-5:][::-1] result = [(i, predict[i]) for i in top_indices] print(result) @@ -103,10 +104,10 @@ def _main(): type=_text_type, help='Test Image Path') parser.add_argument( - '--path', '-p', + '--path', '-p', '-o', type=_text_type, default='./', - help='Path to save the model network file (e.g keras h5') + help='Path to save the pre-trained model files (e.g keras h5)') args = parser.parse_args() extract_model(args) diff --git a/mmdnn/conversion/common/utils.py b/mmdnn/conversion/common/utils.py index 1e30fdf..5671801 100644 --- a/mmdnn/conversion/common/utils.py +++ b/mmdnn/conversion/common/utils.py @@ -131,9 +131,7 @@ def _progress_check(count, block_size, total_size): def _single_thread_download(url, file_name): from six.moves import urllib - import requests result, _ = urllib.request.urlretrieve(url, file_name, _progress_check) - print ("") return result diff --git a/mmdnn/conversion/examples/imagenet_test.py b/mmdnn/conversion/examples/imagenet_test.py index 6495f6d..8572d6c 100644 --- a/mmdnn/conversion/examples/imagenet_test.py +++ b/mmdnn/conversion/examples/imagenet_test.py @@ -91,12 +91,14 @@ class TestKit(object): }, 'mxnet' : { - 'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False), - 'resnet' : lambda path : TestKit.Identity(path, 224, True), - 'squeezenet' : lambda path : TestKit.ZeroCenter(path, 224, False), - 'inception_bn' : lambda path : TestKit.Identity(path, 224, False), - 'resnet152-11k' : lambda path : TestKit.Identity(path, 224, True), - 'resnext' : lambda path : TestKit.Identity(path, 224, False), + 'vgg16' : lambda path : TestKit.ZeroCenter(path, 224, False), + 'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False), + 'resnet' : lambda path : TestKit.Identity(path, 224, True), + 'squeezenet_v1.0' : lambda path : TestKit.ZeroCenter(path, 224, False), + 'squeezenet_v1.1' : lambda path : TestKit.ZeroCenter(path, 224, False), + 'inception_bn' : lambda path : TestKit.Identity(path, 224, False), + 'resnet152-11k' : lambda path : TestKit.Identity(path, 224, True), + 'resnext' : lambda path : TestKit.Identity(path, 224, False), 'imagenet1k-resnext-50' : lambda path : TestKit.Identity(path, 224, False) }, diff --git a/mmdnn/conversion/examples/keras/extractor.py b/mmdnn/conversion/examples/keras/extractor.py index 7867c95..261f7be 100644 --- a/mmdnn/conversion/examples/keras/extractor.py +++ b/mmdnn/conversion/examples/keras/extractor.py @@ -58,7 +58,7 @@ class keras_extractor(base_extractor): @classmethod - def inference(cls, architecture, image_path): + def inference(cls, architecture, path, image_path): if cls.sanity_check(architecture): model = cls.architecture_map[architecture]() import numpy as np diff --git a/mmdnn/conversion/examples/mxnet/extractor.py b/mmdnn/conversion/examples/mxnet/extractor.py new file mode 100644 index 0000000..2c270f1 --- /dev/null +++ b/mmdnn/conversion/examples/mxnet/extractor.py @@ -0,0 +1,103 @@ +#---------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +#---------------------------------------------------------------------------------------------- + +from __future__ import absolute_import +import os +from mmdnn.conversion.examples.imagenet_test import TestKit +from mmdnn.conversion.examples.extractor import base_extractor +from mmdnn.conversion.common.utils import download_file + + +class mxnet_extractor(base_extractor): + + _base_model_url = 'http://data.mxnet.io/models/' + + _image_size = 224 + + from collections import namedtuple + Batch = namedtuple('Batch', ['data']) + + architecture_map = { + 'imagenet1k-inception-bn' : {'symbol' : _base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json', + 'params' : _base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'}, + 'imagenet1k-resnet-18' : {'symbol' : _base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json', + 'params' : _base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'}, + 'imagenet1k-resnet-34' : {'symbol' : _base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json', + 'params' : _base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'}, + 'imagenet1k-resnet-50' : {'symbol' : _base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json', + 'params' : _base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'}, + 'imagenet1k-resnet-101' : {'symbol' : _base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json', + 'params' : _base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'}, + 'imagenet1k-resnet-152' : {'symbol' : _base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json', + 'params' : _base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'}, + 'imagenet1k-resnext-50' : {'symbol' : _base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json', + 'params' : _base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'}, + 'imagenet1k-resnext-101' : {'symbol' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json', + 'params' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'}, + 'imagenet1k-resnext-101-64x4d' : {'symbol' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json', + 'params' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'}, + 'imagenet11k-resnet-152' : {'symbol' : _base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json', + 'params' : _base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'}, + 'imagenet11k-place365ch-resnet-152' : {'symbol' : _base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json', + 'params' : _base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'}, + 'imagenet11k-place365ch-resnet-50' : {'symbol' : _base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json', + 'params' : _base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'}, + 'vgg19' : {'symbol' : _base_model_url+'imagenet/vgg/vgg19-symbol.json', + 'params' : _base_model_url+'imagenet/vgg/vgg19-0000.params'}, + 'vgg16' : {'symbol' : _base_model_url+'imagenet/vgg/vgg16-symbol.json', + 'params' : _base_model_url+'imagenet/vgg/vgg16-0000.params'}, + 'squeezenet_v1.0' : {'symbol' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.0-symbol.json', + 'params' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.0-0000.params'}, + 'squeezenet_v1.1' : {'symbol' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.1-symbol.json', + 'params' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.1-0000.params'} + } + + + @classmethod + def download(cls, architecture, path="./"): + if cls.sanity_check(architecture): + architecture_file = download_file(cls.architecture_map[architecture]['symbol'], directory=path) + if not architecture_file: + return None + + weight_file = download_file(cls.architecture_map[architecture]['params'], directory=path) + if not weight_file: + return None + + print("MXNet Model {} saved as [{}] and [{}].".format(architecture, architecture_file, weight_file)) + return (architecture_file, weight_file) + + else: + return None + + + @classmethod + def inference(cls, architecture, path, image_path): + import mxnet as mx + import numpy as np + if cls.sanity_check(architecture): + file_name = cls.architecture_map[architecture]['params'].split('/')[-1] + prefix, epoch_num = file_name[:-7].rsplit('-', 1) + + sym, arg_params, aux_params = mx.model.load_checkpoint(path + prefix, int(epoch_num)) + model = mx.mod.Module(symbol=sym) + model.bind(for_training=False, + data_shapes=[('data', (1, 3, cls._image_size, cls._image_size))]) + model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True) + + func = TestKit.preprocess_func['mxnet'][architecture] + img = func(image_path) + img = np.transpose(img, [2, 0, 1]) + img = np.expand_dims(img, axis=0) + + model.forward(cls.Batch([mx.nd.array(img)])) + predict = model.get_outputs()[0].asnumpy() + predict = np.squeeze(predict) + + del model + return predict + + else: + return None diff --git a/tests/test_conversion_imagenet.py b/tests/test_conversion_imagenet.py index 01bd5c8..37bad16 100644 --- a/tests/test_conversion_imagenet.py +++ b/tests/test_conversion_imagenet.py @@ -1,4 +1,5 @@ import os +import six import unittest import numpy as np from six.moves import reload_module @@ -6,8 +7,10 @@ import tensorflow as tf from mmdnn.conversion.examples.imagenet_test import TestKit from mmdnn.conversion.examples.keras.extractor import keras_extractor +from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor from mmdnn.conversion.keras.keras2_parser import Keras2Parser +from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter from mmdnn.conversion.tensorflow.tensorflow_emitter import TensorflowEmitter @@ -72,7 +75,7 @@ class TestModels(CorrectnessTest): @staticmethod def KerasParse(architecture_name, image_path): # get original model prediction result - original_predict = keras_extractor.inference(architecture_name, image_path) + original_predict = keras_extractor.inference(architecture_name, TestModels.cachedir, image_path) # download model model_filename = keras_extractor.download(architecture_name, TestModels.cachedir) @@ -86,6 +89,30 @@ class TestModels(CorrectnessTest): return original_predict + @staticmethod + def MXNetParse(architecture_name, image_path): + # download model + architecture_file, weight_file = mxnet_extractor.download(architecture_name, TestModels.cachedir) + + # get original model prediction result + original_predict = mxnet_extractor.inference(architecture_name, TestModels.cachedir, image_path) + + # original to IR + import re + if re.search('.', weight_file): + weight_file = weight_file[:-7] + prefix, epoch = weight_file.rsplit('-', 1) + model = (architecture_file, prefix, epoch, [3, 224, 224]) + + parser = MXNetParser(model) + parser.gen_IR() + parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb") + parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy") + del parser + + return original_predict + + @staticmethod def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path): print("Testing {} from {} to CNTK.".format(architecture_name, original_framework)) @@ -185,8 +212,13 @@ class TestModels(CorrectnessTest): predict = model_converted.predict(input_data) converted_predict = np.squeeze(predict) + del model_converted del converted_model + + import keras.backend as K + K.clear_session() + os.remove("converted_model.py") return converted_predict @@ -201,11 +233,15 @@ class TestModels(CorrectnessTest): 'xception' : [TensorflowEmit, KerasEmit], 'mobilenet' : [TensorflowEmit, KerasEmit], 'nasnet' : [TensorflowEmit, KerasEmit], + }, + 'mxnet' : { + 'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit], } } def test_keras(self): + return # keras original ensure_dir(self.cachedir) ensure_dir(self.tmpdir) @@ -233,3 +269,33 @@ class TestModels(CorrectnessTest): print("Testing {} model {} passed.".format(original_framework, network_name)) print("Testing {} model all passed.".format(original_framework)) + + + def test_mxnet(self): + # mxnet original + ensure_dir(self.cachedir) + ensure_dir(self.tmpdir) + original_framework = 'mxnet' + + for network_name in self.test_table[original_framework].keys(): + print("Testing {} model {} start.".format(original_framework, network_name)) + + # get original model prediction result + original_predict = self.MXNetParse(network_name, self.image_path) + + for emit in self.test_table[original_framework][network_name]: + converted_predict = emit.__func__( + original_framework, + network_name, + self.tmpdir + network_name + "_converted.pb", + self.tmpdir + network_name + "_converted.npy", + self.image_path) + + self._compare_outputs(original_predict, converted_predict) + + + os.remove(self.tmpdir + network_name + "_converted.pb") + os.remove(self.tmpdir + network_name + "_converted.npy") + print("Testing {} model {} passed.".format(original_framework, network_name)) + + print("Testing {} model all passed.".format(original_framework))