Implement MXNet -> Other pytest.

This commit is contained in:
Kit 2018-02-07 16:39:29 +08:00
Родитель b415ffc034
Коммит 4b612fd6a2
7 изменённых файлов: 188 добавлений и 16 удалений

Просмотреть файл

@ -58,9 +58,11 @@ Models | Caffe | Keras | Tensorflow
[ResNet V1 50](https://arxiv.org/abs/1512.03385) | × | √ | √ | o | √ | √ | √
[ResNet V2 152](https://arxiv.org/abs/1603.05027) | √ | √ | √ | √ | √ | √
[VGG 19](http://arxiv.org/abs/1409.1556.pdf) | √ | √ | √ | √ | √ | √ | √
[MobileNet_v1](https://arxiv.org/pdf/1704.04861.pdf)| × | √ | √ | o (only Relu) | × | × |
[Xception](https://arxiv.org/pdf/1610.02357.pdf) | × | √ | √ | × | × | ×
[MobileNet_v1](https://arxiv.org/pdf/1704.04861.pdf)| × | √ | √ | × (no DepthwiseConv) | × | × |
[Xception](https://arxiv.org/pdf/1610.02357.pdf) | × | √ | √ | × (no SeparableConv) | × | ×
[SqueezeNet](https://arxiv.org/pdf/1602.07360) | √ | √ | √ | √ | √ | ×
DenseNet | | √ | √ | √ | | |
[NASNet](https://arxiv.org/abs/1707.07012) | | √ | √ | × (no SeparableConv)
#### On-going frameworks

Просмотреть файл

@ -64,7 +64,8 @@ def extract_model(args):
pass
elif args.framework == 'mxnet':
pass
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
extractor = mxnet_extractor()
elif args.framework == 'cntk':
pass
@ -74,7 +75,7 @@ def extract_model(args):
files = extractor.download(args.network,args.path)
if files and args.image:
predict = extractor.inference(args.network, args.image)
predict = extractor.inference(args.network, args.path, args.image)
top_indices = predict.argsort()[-5:][::-1]
result = [(i, predict[i]) for i in top_indices]
print(result)
@ -103,10 +104,10 @@ def _main():
type=_text_type, help='Test Image Path')
parser.add_argument(
'--path', '-p',
'--path', '-p', '-o',
type=_text_type,
default='./',
help='Path to save the model network file (e.g keras h5')
help='Path to save the pre-trained model files (e.g keras h5)')
args = parser.parse_args()
extract_model(args)

Просмотреть файл

@ -131,9 +131,7 @@ def _progress_check(count, block_size, total_size):
def _single_thread_download(url, file_name):
from six.moves import urllib
import requests
result, _ = urllib.request.urlretrieve(url, file_name, _progress_check)
print ("")
return result

Просмотреть файл

@ -91,12 +91,14 @@ class TestKit(object):
},
'mxnet' : {
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False),
'resnet' : lambda path : TestKit.Identity(path, 224, True),
'squeezenet' : lambda path : TestKit.ZeroCenter(path, 224, False),
'inception_bn' : lambda path : TestKit.Identity(path, 224, False),
'resnet152-11k' : lambda path : TestKit.Identity(path, 224, True),
'resnext' : lambda path : TestKit.Identity(path, 224, False),
'vgg16' : lambda path : TestKit.ZeroCenter(path, 224, False),
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False),
'resnet' : lambda path : TestKit.Identity(path, 224, True),
'squeezenet_v1.0' : lambda path : TestKit.ZeroCenter(path, 224, False),
'squeezenet_v1.1' : lambda path : TestKit.ZeroCenter(path, 224, False),
'inception_bn' : lambda path : TestKit.Identity(path, 224, False),
'resnet152-11k' : lambda path : TestKit.Identity(path, 224, True),
'resnext' : lambda path : TestKit.Identity(path, 224, False),
'imagenet1k-resnext-50' : lambda path : TestKit.Identity(path, 224, False)
},

Просмотреть файл

@ -58,7 +58,7 @@ class keras_extractor(base_extractor):
@classmethod
def inference(cls, architecture, image_path):
def inference(cls, architecture, path, image_path):
if cls.sanity_check(architecture):
model = cls.architecture_map[architecture]()
import numpy as np

Просмотреть файл

@ -0,0 +1,103 @@
#----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
#----------------------------------------------------------------------------------------------
from __future__ import absolute_import
import os
from mmdnn.conversion.examples.imagenet_test import TestKit
from mmdnn.conversion.examples.extractor import base_extractor
from mmdnn.conversion.common.utils import download_file
class mxnet_extractor(base_extractor):
_base_model_url = 'http://data.mxnet.io/models/'
_image_size = 224
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
architecture_map = {
'imagenet1k-inception-bn' : {'symbol' : _base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json',
'params' : _base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'},
'imagenet1k-resnet-18' : {'symbol' : _base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json',
'params' : _base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'},
'imagenet1k-resnet-34' : {'symbol' : _base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json',
'params' : _base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'},
'imagenet1k-resnet-50' : {'symbol' : _base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json',
'params' : _base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'},
'imagenet1k-resnet-101' : {'symbol' : _base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json',
'params' : _base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'},
'imagenet1k-resnet-152' : {'symbol' : _base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json',
'params' : _base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'},
'imagenet1k-resnext-50' : {'symbol' : _base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json',
'params' : _base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'},
'imagenet1k-resnext-101' : {'symbol' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json',
'params' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'},
'imagenet1k-resnext-101-64x4d' : {'symbol' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json',
'params' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'},
'imagenet11k-resnet-152' : {'symbol' : _base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json',
'params' : _base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'},
'imagenet11k-place365ch-resnet-152' : {'symbol' : _base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json',
'params' : _base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'},
'imagenet11k-place365ch-resnet-50' : {'symbol' : _base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json',
'params' : _base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'},
'vgg19' : {'symbol' : _base_model_url+'imagenet/vgg/vgg19-symbol.json',
'params' : _base_model_url+'imagenet/vgg/vgg19-0000.params'},
'vgg16' : {'symbol' : _base_model_url+'imagenet/vgg/vgg16-symbol.json',
'params' : _base_model_url+'imagenet/vgg/vgg16-0000.params'},
'squeezenet_v1.0' : {'symbol' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.0-symbol.json',
'params' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.0-0000.params'},
'squeezenet_v1.1' : {'symbol' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.1-symbol.json',
'params' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.1-0000.params'}
}
@classmethod
def download(cls, architecture, path="./"):
if cls.sanity_check(architecture):
architecture_file = download_file(cls.architecture_map[architecture]['symbol'], directory=path)
if not architecture_file:
return None
weight_file = download_file(cls.architecture_map[architecture]['params'], directory=path)
if not weight_file:
return None
print("MXNet Model {} saved as [{}] and [{}].".format(architecture, architecture_file, weight_file))
return (architecture_file, weight_file)
else:
return None
@classmethod
def inference(cls, architecture, path, image_path):
import mxnet as mx
import numpy as np
if cls.sanity_check(architecture):
file_name = cls.architecture_map[architecture]['params'].split('/')[-1]
prefix, epoch_num = file_name[:-7].rsplit('-', 1)
sym, arg_params, aux_params = mx.model.load_checkpoint(path + prefix, int(epoch_num))
model = mx.mod.Module(symbol=sym)
model.bind(for_training=False,
data_shapes=[('data', (1, 3, cls._image_size, cls._image_size))])
model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
func = TestKit.preprocess_func['mxnet'][architecture]
img = func(image_path)
img = np.transpose(img, [2, 0, 1])
img = np.expand_dims(img, axis=0)
model.forward(cls.Batch([mx.nd.array(img)]))
predict = model.get_outputs()[0].asnumpy()
predict = np.squeeze(predict)
del model
return predict
else:
return None

Просмотреть файл

@ -1,4 +1,5 @@
import os
import six
import unittest
import numpy as np
from six.moves import reload_module
@ -6,8 +7,10 @@ import tensorflow as tf
from mmdnn.conversion.examples.imagenet_test import TestKit
from mmdnn.conversion.examples.keras.extractor import keras_extractor
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
from mmdnn.conversion.keras.keras2_parser import Keras2Parser
from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter
from mmdnn.conversion.tensorflow.tensorflow_emitter import TensorflowEmitter
@ -72,7 +75,7 @@ class TestModels(CorrectnessTest):
@staticmethod
def KerasParse(architecture_name, image_path):
# get original model prediction result
original_predict = keras_extractor.inference(architecture_name, image_path)
original_predict = keras_extractor.inference(architecture_name, TestModels.cachedir, image_path)
# download model
model_filename = keras_extractor.download(architecture_name, TestModels.cachedir)
@ -86,6 +89,30 @@ class TestModels(CorrectnessTest):
return original_predict
@staticmethod
def MXNetParse(architecture_name, image_path):
# download model
architecture_file, weight_file = mxnet_extractor.download(architecture_name, TestModels.cachedir)
# get original model prediction result
original_predict = mxnet_extractor.inference(architecture_name, TestModels.cachedir, image_path)
# original to IR
import re
if re.search('.', weight_file):
weight_file = weight_file[:-7]
prefix, epoch = weight_file.rsplit('-', 1)
model = (architecture_file, prefix, epoch, [3, 224, 224])
parser = MXNetParser(model)
parser.gen_IR()
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
del parser
return original_predict
@staticmethod
def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
print("Testing {} from {} to CNTK.".format(architecture_name, original_framework))
@ -185,8 +212,13 @@ class TestModels(CorrectnessTest):
predict = model_converted.predict(input_data)
converted_predict = np.squeeze(predict)
del model_converted
del converted_model
import keras.backend as K
K.clear_session()
os.remove("converted_model.py")
return converted_predict
@ -201,11 +233,15 @@ class TestModels(CorrectnessTest):
'xception' : [TensorflowEmit, KerasEmit],
'mobilenet' : [TensorflowEmit, KerasEmit],
'nasnet' : [TensorflowEmit, KerasEmit],
},
'mxnet' : {
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit],
}
}
def test_keras(self):
return
# keras original
ensure_dir(self.cachedir)
ensure_dir(self.tmpdir)
@ -233,3 +269,33 @@ class TestModels(CorrectnessTest):
print("Testing {} model {} passed.".format(original_framework, network_name))
print("Testing {} model all passed.".format(original_framework))
def test_mxnet(self):
# mxnet original
ensure_dir(self.cachedir)
ensure_dir(self.tmpdir)
original_framework = 'mxnet'
for network_name in self.test_table[original_framework].keys():
print("Testing {} model {} start.".format(original_framework, network_name))
# get original model prediction result
original_predict = self.MXNetParse(network_name, self.image_path)
for emit in self.test_table[original_framework][network_name]:
converted_predict = emit.__func__(
original_framework,
network_name,
self.tmpdir + network_name + "_converted.pb",
self.tmpdir + network_name + "_converted.npy",
self.image_path)
self._compare_outputs(original_predict, converted_predict)
os.remove(self.tmpdir + network_name + "_converted.pb")
os.remove(self.tmpdir + network_name + "_converted.npy")
print("Testing {} model {} passed.".format(original_framework, network_name))
print("Testing {} model all passed.".format(original_framework))