зеркало из https://github.com/microsoft/MMdnn.git
Implement MXNet -> Other pytest.
This commit is contained in:
Родитель
b415ffc034
Коммит
4b612fd6a2
|
@ -58,9 +58,11 @@ Models | Caffe | Keras | Tensorflow
|
|||
[ResNet V1 50](https://arxiv.org/abs/1512.03385) | × | √ | √ | o | √ | √ | √
|
||||
[ResNet V2 152](https://arxiv.org/abs/1603.05027) | √ | √ | √ | √ | √ | √
|
||||
[VGG 19](http://arxiv.org/abs/1409.1556.pdf) | √ | √ | √ | √ | √ | √ | √
|
||||
[MobileNet_v1](https://arxiv.org/pdf/1704.04861.pdf)| × | √ | √ | o (only Relu) | × | × | √
|
||||
[Xception](https://arxiv.org/pdf/1610.02357.pdf) | × | √ | √ | × | × | ×
|
||||
[MobileNet_v1](https://arxiv.org/pdf/1704.04861.pdf)| × | √ | √ | × (no DepthwiseConv) | × | × | √
|
||||
[Xception](https://arxiv.org/pdf/1610.02357.pdf) | × | √ | √ | × (no SeparableConv) | × | ×
|
||||
[SqueezeNet](https://arxiv.org/pdf/1602.07360) | √ | √ | √ | √ | √ | ×
|
||||
DenseNet | | √ | √ | √ | | |
|
||||
[NASNet](https://arxiv.org/abs/1707.07012) | | √ | √ | × (no SeparableConv)
|
||||
|
||||
#### On-going frameworks
|
||||
|
||||
|
|
|
@ -64,7 +64,8 @@ def extract_model(args):
|
|||
pass
|
||||
|
||||
elif args.framework == 'mxnet':
|
||||
pass
|
||||
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
|
||||
extractor = mxnet_extractor()
|
||||
|
||||
elif args.framework == 'cntk':
|
||||
pass
|
||||
|
@ -74,7 +75,7 @@ def extract_model(args):
|
|||
files = extractor.download(args.network,args.path)
|
||||
|
||||
if files and args.image:
|
||||
predict = extractor.inference(args.network, args.image)
|
||||
predict = extractor.inference(args.network, args.path, args.image)
|
||||
top_indices = predict.argsort()[-5:][::-1]
|
||||
result = [(i, predict[i]) for i in top_indices]
|
||||
print(result)
|
||||
|
@ -103,10 +104,10 @@ def _main():
|
|||
type=_text_type, help='Test Image Path')
|
||||
|
||||
parser.add_argument(
|
||||
'--path', '-p',
|
||||
'--path', '-p', '-o',
|
||||
type=_text_type,
|
||||
default='./',
|
||||
help='Path to save the model network file (e.g keras h5')
|
||||
help='Path to save the pre-trained model files (e.g keras h5)')
|
||||
|
||||
args = parser.parse_args()
|
||||
extract_model(args)
|
||||
|
|
|
@ -131,9 +131,7 @@ def _progress_check(count, block_size, total_size):
|
|||
|
||||
def _single_thread_download(url, file_name):
|
||||
from six.moves import urllib
|
||||
import requests
|
||||
result, _ = urllib.request.urlretrieve(url, file_name, _progress_check)
|
||||
print ("")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -91,12 +91,14 @@ class TestKit(object):
|
|||
},
|
||||
|
||||
'mxnet' : {
|
||||
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'resnet' : lambda path : TestKit.Identity(path, 224, True),
|
||||
'squeezenet' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'inception_bn' : lambda path : TestKit.Identity(path, 224, False),
|
||||
'resnet152-11k' : lambda path : TestKit.Identity(path, 224, True),
|
||||
'resnext' : lambda path : TestKit.Identity(path, 224, False),
|
||||
'vgg16' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'resnet' : lambda path : TestKit.Identity(path, 224, True),
|
||||
'squeezenet_v1.0' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'squeezenet_v1.1' : lambda path : TestKit.ZeroCenter(path, 224, False),
|
||||
'inception_bn' : lambda path : TestKit.Identity(path, 224, False),
|
||||
'resnet152-11k' : lambda path : TestKit.Identity(path, 224, True),
|
||||
'resnext' : lambda path : TestKit.Identity(path, 224, False),
|
||||
'imagenet1k-resnext-50' : lambda path : TestKit.Identity(path, 224, False)
|
||||
},
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ class keras_extractor(base_extractor):
|
|||
|
||||
|
||||
@classmethod
|
||||
def inference(cls, architecture, image_path):
|
||||
def inference(cls, architecture, path, image_path):
|
||||
if cls.sanity_check(architecture):
|
||||
model = cls.architecture_map[architecture]()
|
||||
import numpy as np
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
#----------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
#----------------------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import absolute_import
|
||||
import os
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
from mmdnn.conversion.examples.extractor import base_extractor
|
||||
from mmdnn.conversion.common.utils import download_file
|
||||
|
||||
|
||||
class mxnet_extractor(base_extractor):
|
||||
|
||||
_base_model_url = 'http://data.mxnet.io/models/'
|
||||
|
||||
_image_size = 224
|
||||
|
||||
from collections import namedtuple
|
||||
Batch = namedtuple('Batch', ['data'])
|
||||
|
||||
architecture_map = {
|
||||
'imagenet1k-inception-bn' : {'symbol' : _base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'},
|
||||
'imagenet1k-resnet-18' : {'symbol' : _base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'},
|
||||
'imagenet1k-resnet-34' : {'symbol' : _base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'},
|
||||
'imagenet1k-resnet-50' : {'symbol' : _base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'},
|
||||
'imagenet1k-resnet-101' : {'symbol' : _base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'},
|
||||
'imagenet1k-resnet-152' : {'symbol' : _base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'},
|
||||
'imagenet1k-resnext-50' : {'symbol' : _base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'},
|
||||
'imagenet1k-resnext-101' : {'symbol' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'},
|
||||
'imagenet1k-resnext-101-64x4d' : {'symbol' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'},
|
||||
'imagenet11k-resnet-152' : {'symbol' : _base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json',
|
||||
'params' : _base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'},
|
||||
'imagenet11k-place365ch-resnet-152' : {'symbol' : _base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json',
|
||||
'params' : _base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'},
|
||||
'imagenet11k-place365ch-resnet-50' : {'symbol' : _base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json',
|
||||
'params' : _base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'},
|
||||
'vgg19' : {'symbol' : _base_model_url+'imagenet/vgg/vgg19-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/vgg/vgg19-0000.params'},
|
||||
'vgg16' : {'symbol' : _base_model_url+'imagenet/vgg/vgg16-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/vgg/vgg16-0000.params'},
|
||||
'squeezenet_v1.0' : {'symbol' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.0-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.0-0000.params'},
|
||||
'squeezenet_v1.1' : {'symbol' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.1-symbol.json',
|
||||
'params' : _base_model_url+'imagenet/squeezenet/squeezenet_v1.1-0000.params'}
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def download(cls, architecture, path="./"):
|
||||
if cls.sanity_check(architecture):
|
||||
architecture_file = download_file(cls.architecture_map[architecture]['symbol'], directory=path)
|
||||
if not architecture_file:
|
||||
return None
|
||||
|
||||
weight_file = download_file(cls.architecture_map[architecture]['params'], directory=path)
|
||||
if not weight_file:
|
||||
return None
|
||||
|
||||
print("MXNet Model {} saved as [{}] and [{}].".format(architecture, architecture_file, weight_file))
|
||||
return (architecture_file, weight_file)
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def inference(cls, architecture, path, image_path):
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
if cls.sanity_check(architecture):
|
||||
file_name = cls.architecture_map[architecture]['params'].split('/')[-1]
|
||||
prefix, epoch_num = file_name[:-7].rsplit('-', 1)
|
||||
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(path + prefix, int(epoch_num))
|
||||
model = mx.mod.Module(symbol=sym)
|
||||
model.bind(for_training=False,
|
||||
data_shapes=[('data', (1, 3, cls._image_size, cls._image_size))])
|
||||
model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
|
||||
|
||||
func = TestKit.preprocess_func['mxnet'][architecture]
|
||||
img = func(image_path)
|
||||
img = np.transpose(img, [2, 0, 1])
|
||||
img = np.expand_dims(img, axis=0)
|
||||
|
||||
model.forward(cls.Batch([mx.nd.array(img)]))
|
||||
predict = model.get_outputs()[0].asnumpy()
|
||||
predict = np.squeeze(predict)
|
||||
|
||||
del model
|
||||
return predict
|
||||
|
||||
else:
|
||||
return None
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import six
|
||||
import unittest
|
||||
import numpy as np
|
||||
from six.moves import reload_module
|
||||
|
@ -6,8 +7,10 @@ import tensorflow as tf
|
|||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
|
||||
from mmdnn.conversion.examples.keras.extractor import keras_extractor
|
||||
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
|
||||
|
||||
from mmdnn.conversion.keras.keras2_parser import Keras2Parser
|
||||
from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
|
||||
|
||||
from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter
|
||||
from mmdnn.conversion.tensorflow.tensorflow_emitter import TensorflowEmitter
|
||||
|
@ -72,7 +75,7 @@ class TestModels(CorrectnessTest):
|
|||
@staticmethod
|
||||
def KerasParse(architecture_name, image_path):
|
||||
# get original model prediction result
|
||||
original_predict = keras_extractor.inference(architecture_name, image_path)
|
||||
original_predict = keras_extractor.inference(architecture_name, TestModels.cachedir, image_path)
|
||||
|
||||
# download model
|
||||
model_filename = keras_extractor.download(architecture_name, TestModels.cachedir)
|
||||
|
@ -86,6 +89,30 @@ class TestModels(CorrectnessTest):
|
|||
return original_predict
|
||||
|
||||
|
||||
@staticmethod
|
||||
def MXNetParse(architecture_name, image_path):
|
||||
# download model
|
||||
architecture_file, weight_file = mxnet_extractor.download(architecture_name, TestModels.cachedir)
|
||||
|
||||
# get original model prediction result
|
||||
original_predict = mxnet_extractor.inference(architecture_name, TestModels.cachedir, image_path)
|
||||
|
||||
# original to IR
|
||||
import re
|
||||
if re.search('.', weight_file):
|
||||
weight_file = weight_file[:-7]
|
||||
prefix, epoch = weight_file.rsplit('-', 1)
|
||||
model = (architecture_file, prefix, epoch, [3, 224, 224])
|
||||
|
||||
parser = MXNetParser(model)
|
||||
parser.gen_IR()
|
||||
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
|
||||
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
del parser
|
||||
|
||||
return original_predict
|
||||
|
||||
|
||||
@staticmethod
|
||||
def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to CNTK.".format(architecture_name, original_framework))
|
||||
|
@ -185,8 +212,13 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
predict = model_converted.predict(input_data)
|
||||
converted_predict = np.squeeze(predict)
|
||||
|
||||
del model_converted
|
||||
del converted_model
|
||||
|
||||
import keras.backend as K
|
||||
K.clear_session()
|
||||
|
||||
os.remove("converted_model.py")
|
||||
return converted_predict
|
||||
|
||||
|
@ -201,11 +233,15 @@ class TestModels(CorrectnessTest):
|
|||
'xception' : [TensorflowEmit, KerasEmit],
|
||||
'mobilenet' : [TensorflowEmit, KerasEmit],
|
||||
'nasnet' : [TensorflowEmit, KerasEmit],
|
||||
},
|
||||
'mxnet' : {
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_keras(self):
|
||||
return
|
||||
# keras original
|
||||
ensure_dir(self.cachedir)
|
||||
ensure_dir(self.tmpdir)
|
||||
|
@ -233,3 +269,33 @@ class TestModels(CorrectnessTest):
|
|||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
print("Testing {} model all passed.".format(original_framework))
|
||||
|
||||
|
||||
def test_mxnet(self):
|
||||
# mxnet original
|
||||
ensure_dir(self.cachedir)
|
||||
ensure_dir(self.tmpdir)
|
||||
original_framework = 'mxnet'
|
||||
|
||||
for network_name in self.test_table[original_framework].keys():
|
||||
print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# get original model prediction result
|
||||
original_predict = self.MXNetParse(network_name, self.image_path)
|
||||
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
converted_predict = emit.__func__(
|
||||
original_framework,
|
||||
network_name,
|
||||
self.tmpdir + network_name + "_converted.pb",
|
||||
self.tmpdir + network_name + "_converted.npy",
|
||||
self.image_path)
|
||||
|
||||
self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
print("Testing {} model all passed.".format(original_framework))
|
||||
|
|
Загрузка…
Ссылка в новой задаче