зеркало из https://github.com/microsoft/MMdnn.git
caffe paser and extractor
This commit is contained in:
Родитель
7c1ae6c4f8
Коммит
779a252690
|
@ -63,6 +63,7 @@ Models | Caffe | Keras | Tensorflow
|
|||
[SqueezeNet](https://arxiv.org/pdf/1602.07360) | √ | √ | √ | √ | √ | ×
|
||||
DenseNet | | √ | √ | √ | | |
|
||||
[NASNet](https://arxiv.org/abs/1707.07012) | | √ | √ | × (no SeparableConv)
|
||||
[ResNext] | | √ | √ | √ | √ |
|
||||
|
||||
#### On-going frameworks
|
||||
|
||||
|
|
|
@ -28,30 +28,6 @@ def _convert(args):
|
|||
|
||||
elif args.srcFramework == 'caffe2':
|
||||
raise NotImplementedError("Caffe2 is not supported yet.")
|
||||
'''
|
||||
assert args.inputShape != None
|
||||
from dlconv.caffe2.conversion.transformer import Caffe2Transformer
|
||||
transformer = Caffe2Transformer(args.network, args.weights, args.inputShape, 'tensorflow')
|
||||
|
||||
graph = transformer.transform_graph()
|
||||
data = transformer.transform_data()
|
||||
|
||||
from dlconv.common.writer import JsonFormatter, ModelSaver, PyWriter
|
||||
JsonFormatter(graph).dump(args.dstPath + ".json")
|
||||
print ("IR saved as [{}.json].".format(args.dstPath))
|
||||
|
||||
prototxt = graph.as_graph_def().SerializeToString()
|
||||
with open(args.dstPath + ".pb", 'wb') as of:
|
||||
of.write(prototxt)
|
||||
print ("IR saved as [{}.pb].".format(args.dstPath))
|
||||
|
||||
import numpy as np
|
||||
with open(args.dstPath + ".npy", 'wb') as of:
|
||||
np.save(of, data)
|
||||
print ("IR weights saved as [{}.npy].".format(args.dstPath))
|
||||
|
||||
return 0
|
||||
'''
|
||||
|
||||
elif args.srcFramework == 'keras':
|
||||
if args.network != None:
|
||||
|
@ -66,14 +42,10 @@ def _convert(args):
|
|||
if args.dstNodeName is None:
|
||||
raise ValueError("Need to provide the output node of Tensorflow model.")
|
||||
|
||||
if args.weights is None:
|
||||
# only convert network structure
|
||||
model = args.network
|
||||
else:
|
||||
model = (args.network, args.weights)
|
||||
assert args.network or args.frozen_pb
|
||||
|
||||
from mmdnn.conversion.tensorflow.tensorflow_parser import TensorflowParser
|
||||
parser = TensorflowParser(model, args.dstNodeName)
|
||||
parser = TensorflowParser(args.network, args.weights, args.frozen_pb, args.dstNodeName)
|
||||
|
||||
elif args.srcFramework == 'mxnet':
|
||||
assert args.inputShape != None
|
||||
|
@ -97,10 +69,7 @@ def _convert(args):
|
|||
else:
|
||||
raise ValueError("Unknown framework [{}].".format(args.srcFramework))
|
||||
|
||||
parser.gen_IR()
|
||||
parser.save_to_json(args.dstPath + ".json")
|
||||
parser.save_to_proto(args.dstPath + ".pb")
|
||||
parser.save_weights(args.dstPath + ".npy")
|
||||
parser.run(args.dstPath)
|
||||
|
||||
return 0
|
||||
|
||||
|
@ -140,6 +109,13 @@ def _main():
|
|||
default=None,
|
||||
help="[Tensorflow] Output nodes' name of the graph.")
|
||||
|
||||
parser.add_argument(
|
||||
'--frozen_pb',
|
||||
type=_text_type,
|
||||
default=None,
|
||||
help="[Tensorflow] frozen model file.")
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--inputShape',
|
||||
nargs='+',
|
||||
|
|
|
@ -61,7 +61,8 @@ def extract_model(args):
|
|||
extractor = keras_extractor()
|
||||
|
||||
elif args.framework == 'tensorflow' or args.framework == 'tf':
|
||||
pass
|
||||
from mmdnn.conversion.examples.tensorflow.extractor import tensorflow_extractor
|
||||
extractor = tensorflow_extractor()
|
||||
|
||||
elif args.framework == 'mxnet':
|
||||
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
|
||||
|
@ -72,7 +73,7 @@ def extract_model(args):
|
|||
else:
|
||||
raise ValueError("Unknown framework [{}].".format(args.framework))
|
||||
|
||||
files = extractor.download(args.network,args.path)
|
||||
files = extractor.download(args.network, args.path)
|
||||
|
||||
if files and args.image:
|
||||
predict = extractor.inference(args.network, args.path, args.image)
|
||||
|
|
|
@ -110,7 +110,7 @@ if __name__=='__main__':
|
|||
with open("graph.txt", 'w') as f:
|
||||
for layer in self.IR_graph.topological_sort:
|
||||
current_node = self.IR_graph.get_node(layer)
|
||||
print("========current_node=========\n{}".format(current_node.layer), file=f)
|
||||
print("========current_node=========\n{}".format(current_node.layer))
|
||||
# test end
|
||||
|
||||
for layer in self.IR_graph.topological_sort:
|
||||
|
|
|
@ -29,9 +29,9 @@ CNTK model file is saved as [cntk_inception_v3.dnn], generated by [cntk_inceptio
|
|||
|
||||
Ubuntu 16.04 with
|
||||
|
||||
- CNTK gpu 2.3
|
||||
- CNTK CPU 2.4
|
||||
|
||||
@ 2017/12/01
|
||||
@ 2018/02/08
|
||||
|
||||
## Limitation
|
||||
|
||||
|
|
|
@ -18,6 +18,13 @@ class Parser(object):
|
|||
self.weights = dict()
|
||||
|
||||
|
||||
def run(self, dest_path):
|
||||
self.gen_IR()
|
||||
self.save_to_json(dest_path + ".json")
|
||||
self.save_to_proto(dest_path + ".pb")
|
||||
self.save_weights(dest_path + ".npy")
|
||||
|
||||
|
||||
@property
|
||||
def src_graph(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -15,6 +15,7 @@ __all__ = ["assign_IRnode_values", "convert_onnx_pad_to_tf", 'convert_tf_pad_to_
|
|||
'compute_tf_same_padding', 'is_valid_padding', 'download_file',
|
||||
'shape_to_list', 'list_to_shape']
|
||||
|
||||
|
||||
def assign_attr_value(attr, val):
|
||||
from mmdnn.conversion.common.IR.graph_pb2 import TensorShape
|
||||
'''Assign value to AttrValue proto according to data type.'''
|
||||
|
@ -172,7 +173,7 @@ def _multi_thread_download(url, file_name, file_size, thread_count):
|
|||
return file_name
|
||||
|
||||
|
||||
def download_file(url, directory='./', local_fname=None, force_write=False):
|
||||
def download_file(url, directory='./', local_fname=None, force_write=False, auto_unzip=False):
|
||||
"""Download the data from source url, unless it's already here.
|
||||
|
||||
Args:
|
||||
|
@ -187,21 +188,43 @@ def download_file(url, directory='./', local_fname=None, force_write=False):
|
|||
if not os.path.isdir(directory):
|
||||
os.mkdir(directory)
|
||||
|
||||
if local_fname is None:
|
||||
local_fname = url.split('/')[-1]
|
||||
if not local_fname:
|
||||
k = url.rfind('/')
|
||||
local_fname = url[k + 1:]
|
||||
|
||||
local_fname = os.path.join(directory, local_fname)
|
||||
|
||||
if os.path.exists(local_fname) and not force_write:
|
||||
print ("File [{}] existed!".format(local_fname))
|
||||
return local_fname
|
||||
|
||||
print ("Downloading file [{}] from [{}]".format(local_fname, url))
|
||||
else:
|
||||
print ("Downloading file [{}] from [{}]".format(local_fname, url))
|
||||
try:
|
||||
import wget
|
||||
ret = wget.download(url, local_fname)
|
||||
except:
|
||||
ret = _single_thread_download(url, local_fname)
|
||||
|
||||
try:
|
||||
import wget
|
||||
return wget.download(url, local_fname)
|
||||
except:
|
||||
return _single_thread_download(url, local_fname)
|
||||
if auto_unzip:
|
||||
if ret.endswith(".tar.gz") or ret.endswith(".tgz"):
|
||||
try:
|
||||
import tarfile
|
||||
tar = tarfile.open(ret)
|
||||
tar.extractall(directory)
|
||||
tar.close()
|
||||
except:
|
||||
print("Unzip file [{}] failed.".format(ret))
|
||||
|
||||
elif ret.endswith('.zip'):
|
||||
try:
|
||||
import zipfile
|
||||
zip_ref = zipfile.ZipFile(ret, 'r')
|
||||
zip_ref.extractall(directory)
|
||||
zip_ref.close()
|
||||
except:
|
||||
print("Unzip file [{}] failed.".format(ret))
|
||||
return ret
|
||||
"""
|
||||
r = requests.head(url)
|
||||
try:
|
||||
|
|
|
@ -60,9 +60,11 @@ def _main():
|
|||
import numpy as np
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
|
||||
net = caffe.Net(arch_fn, weight_fn, caffe.TEST)
|
||||
net = caffe.Net(arch_fn.encode("utf-8"), weight_fn.encode("utf-8"), caffe.TEST)
|
||||
# net = caffe.Net(arch_fn, weight_fn, caffe.TEST)
|
||||
func = TestKit.preprocess_func['caffe'][args.network]
|
||||
img = func(args.image)
|
||||
print(img.size)
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
img = np.expand_dims(img, 0)
|
||||
net.blobs['data'].data[...] = img
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
#----------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
#----------------------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import absolute_import
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
from mmdnn.conversion.examples.extractor import base_extractor
|
||||
from mmdnn.conversion.common.utils import download_file
|
||||
|
||||
|
||||
class caffe_extractor(base_extractor):
|
||||
|
||||
BASE_MODEL_URL = 'http://data.mxnet.io/models/imagenet/test/caffe/'
|
||||
|
||||
architecture_map = {
|
||||
'alexnet' : {'prototxt' : 'https://raw.githubusercontent.com/BVLC/caffe/master/models/bvlc_alexnet/deploy.prototxt',
|
||||
'caffemodel' : 'http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel'},
|
||||
'inception_v1' : {'prototxt' : 'https://raw.githubusercontent.com/BVLC/caffe/master/models/bvlc_googlenet/deploy.prototxt',
|
||||
'caffemodel' : 'http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel'},
|
||||
'vgg16' : {'prototxt' : 'https://gist.githubusercontent.com/ksimonyan/211839e770f7b538e2d8/raw/c3ba00e272d9f48594acef1f67e5fd12aff7a806/VGG_ILSVRC_16_layers_deploy.prototxt',
|
||||
'caffemodel' : 'http://data.mxnet.io/models/imagenet/test/caffe/VGG_ILSVRC_16_layers.caffemodel'},
|
||||
'vgg19' : {'prototxt' : 'https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/bb2b4fe0a9bb0669211cf3d0bc949dfdda173e9e/VGG_ILSVRC_19_layers_deploy.prototxt',
|
||||
'caffemodel' : 'http://data.mxnet.io/models/imagenet/test/caffe/VGG_ILSVRC_19_layers.caffemodel'},
|
||||
'resnet50' : {'prototxt' : BASE_MODEL_URL + 'ResNet-50-deploy.prototxt',
|
||||
'caffemodel' : BASE_MODEL_URL + 'ResNet-50-model.caffemodel'},
|
||||
'resnet101' : {'prototxt' : BASE_MODEL_URL + 'ResNet-101-deploy.prototxt',
|
||||
'caffemodel' : BASE_MODEL_URL + 'ResNet-101-model.caffemodel'},
|
||||
'resnet152' : {'prototxt' : BASE_MODEL_URL + 'ResNet-152-deploy.prototxt',
|
||||
'caffemodel' : BASE_MODEL_URL + 'ResNet-152-model.caffemodel'},
|
||||
'squeezenet' : {'prototxt' : "https://raw.githubusercontent.com/DeepScale/SqueezeNet/master/SqueezeNet_v1.1/deploy.prototxt",
|
||||
'caffemodel' : "https://github.com/DeepScale/SqueezeNet/raw/master/SqueezeNet_v1.1/squeezenet_v1.1.caffemodel"}
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def download(cls, architecture, path="./"):
|
||||
if cls.sanity_check(architecture):
|
||||
architecture_file = download_file(cls.architecture_map[architecture]['prototxt'], directory=path)
|
||||
if not architecture_file:
|
||||
return None
|
||||
|
||||
weight_file = download_file(cls.architecture_map[architecture]['caffemodel'], directory=path)
|
||||
if not weight_file:
|
||||
return None
|
||||
|
||||
print("Caffe Model {} saved as [{}] and [{}].".format(architecture, architecture_file, weight_file))
|
||||
return (architecture_file, weight_file)
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@classmethod
|
||||
def inference(cls, architecture_name, architecture, path, image_path):
|
||||
if cls.sanity_check(architecture_name):
|
||||
import caffe
|
||||
import numpy as np
|
||||
net = caffe.Net(architecture, path, caffe.TEST)
|
||||
func = TestKit.preprocess_func['caffe'][architecture_name]
|
||||
img = func(image_path)
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
img = np.expand_dims(img, 0)
|
||||
net.blobs['data'].data[...] = img
|
||||
predict = np.squeeze(net.forward()['prob'][0])
|
||||
predict = np.squeeze(predict)
|
||||
return predict
|
||||
|
||||
else:
|
||||
return None
|
|
@ -64,7 +64,7 @@ class TestKit(object):
|
|||
'caffe' : {
|
||||
'alexnet' : lambda path : TestKit.ZeroCenter(path, 227, True),
|
||||
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'inception_v1' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'inception_v1' : lambda path : TestKit.ZeroCenter(path, 227, True),
|
||||
'resnet152' : lambda path : TestKit.ZeroCenter(path, 224, True),
|
||||
'squeezenet' : lambda path : TestKit.ZeroCenter(path, 227, False)
|
||||
},
|
||||
|
|
|
@ -480,8 +480,17 @@ class MXNetParser(Parser):
|
|||
# print("Warning: Layer [{}] has changed model data format from [{}] to [{}]".format(source_node.name, self.data_format, layout))
|
||||
self.data_format = layout
|
||||
|
||||
# groups
|
||||
group = int(layer_attr.get("num_group", "1"))
|
||||
IR_node.attr["group"].i = group
|
||||
in_channel = self.IR_layer_map[IR_node.input[0]].attr["_output_shapes"].list.shape[0].dim[-1].size
|
||||
|
||||
if group == in_channel:
|
||||
self._copy_and_reop(source_node, IR_node, "DepthwiseConv")
|
||||
else:
|
||||
self._copy_and_reop(source_node, IR_node, "Conv")
|
||||
# in_channel = in_channel // group
|
||||
|
||||
assert "num_filter" in layer_attr
|
||||
out_channel = int(layer_attr.get("num_filter"))
|
||||
|
||||
|
@ -511,14 +520,6 @@ class MXNetParser(Parser):
|
|||
# data_format
|
||||
assign_IRnode_values(IR_node, {'data_format' : layout})
|
||||
|
||||
# groups
|
||||
group = int(layer_attr.get("num_group", "1"))
|
||||
IR_node.attr["group"].i = group
|
||||
if group == in_channel:
|
||||
self._copy_and_reop(source_node, IR_node, "DepthwiseConv")
|
||||
else:
|
||||
self._copy_and_reop(source_node, IR_node, "Conv")
|
||||
|
||||
# padding
|
||||
if "pad" in layer_attr:
|
||||
pad = MXNetParser.str2intList(layer_attr.get("pad"))
|
||||
|
|
|
@ -177,16 +177,15 @@ class TensorflowParser(Parser):
|
|||
output_node.real_name = source_node.name
|
||||
|
||||
|
||||
def __init__(self, input_args, dest_nodes = None):
|
||||
def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes = None):
|
||||
super(TensorflowParser, self).__init__()
|
||||
|
||||
# load model files into Keras graph
|
||||
from six import string_types as _string_types
|
||||
if isinstance(input_args, _string_types):
|
||||
model = TensorflowParser._load_meta(input_args)
|
||||
elif isinstance(input_args, tuple):
|
||||
model = TensorflowParser._load_meta(input_args[0])
|
||||
self.ckpt_data = TensorflowParser._load_weights(input_args[1])
|
||||
# load model files into TensorFlow graph
|
||||
if meta_file:
|
||||
model = TensorflowParser._load_meta(meta_file)
|
||||
|
||||
if checkpoint_file:
|
||||
self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
|
||||
self.weight_loaded = True
|
||||
|
||||
if dest_nodes != None:
|
||||
|
@ -194,7 +193,7 @@ class TensorflowParser(Parser):
|
|||
model = extract_sub_graph(model, dest_nodes.split(','))
|
||||
|
||||
# Build network graph
|
||||
self.tf_graph = TensorflowGraph(model)
|
||||
self.tf_graph = TensorflowGraph(model)
|
||||
self.tf_graph.build()
|
||||
|
||||
|
||||
|
|
|
@ -4,11 +4,12 @@ import unittest
|
|||
import numpy as np
|
||||
from six.moves import reload_module
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
# import torch
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
|
||||
from mmdnn.conversion.examples.keras.extractor import keras_extractor
|
||||
from mmdnn.conversion.examples.mxnet.extractor import mxnet_extractor
|
||||
from mmdnn.conversion.examples.caffe.extractor import caffe_extractor
|
||||
|
||||
from mmdnn.conversion.keras.keras2_parser import Keras2Parser
|
||||
from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
|
||||
|
@ -83,9 +84,7 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
# original to IR
|
||||
parser = Keras2Parser(model_filename)
|
||||
parser.gen_IR()
|
||||
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
|
||||
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
parser.run(TestModels.tmpdir + architecture_name + "_converted")
|
||||
del parser
|
||||
return original_predict
|
||||
|
||||
|
@ -106,13 +105,42 @@ class TestModels(CorrectnessTest):
|
|||
model = (architecture_file, prefix, epoch, [3, 224, 224])
|
||||
|
||||
parser = MXNetParser(model)
|
||||
parser.gen_IR()
|
||||
parser.save_to_proto(TestModels.tmpdir + architecture_name + "_converted.pb")
|
||||
parser.save_weights(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
parser.run(TestModels.tmpdir + architecture_name + "_converted")
|
||||
del parser
|
||||
|
||||
return original_predict
|
||||
|
||||
@staticmethod
|
||||
def CaffeParse(architecture_name, image_path):
|
||||
# download model
|
||||
architecture_file, weight_file = caffe_extractor.download(architecture_name, TestModels.cachedir)
|
||||
|
||||
# get original model prediction result
|
||||
|
||||
original_predict = caffe_extractor.inference(architecture_name,architecture_file, weight_file, image_path)
|
||||
|
||||
# original to IR
|
||||
from mmdnn.conversion.caffe.transformer import CaffeTransformer
|
||||
transformer = CaffeTransformer(architecture_file, weight_file, "tensorflow", None, phase = 'TRAIN')
|
||||
graph = transformer.transform_graph()
|
||||
data = transformer.transform_data()
|
||||
|
||||
from mmdnn.conversion.caffe.writer import ModelSaver, PyWriter
|
||||
|
||||
prototxt = graph.as_graph_def().SerializeToString()
|
||||
pb_path = TestModels.tmpdir + architecture_name + "_converted.pb"
|
||||
with open(pb_path, 'wb') as of:
|
||||
of.write(prototxt)
|
||||
print ("IR network structure is saved as [{}].".format(pb_path))
|
||||
|
||||
import numpy as np
|
||||
npy_path = TestModels.tmpdir + architecture_name + "_converted.npy"
|
||||
with open(npy_path, 'wb') as of:
|
||||
np.save(of, data)
|
||||
print ("IR weights are saved as [{}].".format(npy_path))
|
||||
|
||||
return original_predict
|
||||
|
||||
|
||||
@staticmethod
|
||||
def CntkEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
|
@ -140,7 +168,7 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
@staticmethod
|
||||
def TensorflowEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to Tensorflow.".format(architecture_name, original_framework))
|
||||
print("Testing {} from {} to TensorFlow.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
emitter = TensorflowEmitter((architecture_path, weight_path))
|
||||
|
@ -169,7 +197,7 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
@staticmethod
|
||||
def PytorchEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to Pytorch.".format(architecture_name, original_framework))
|
||||
print("Testing {} from {} to PyTorch.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
emitter = PytorchEmitter((architecture_path, weight_path))
|
||||
|
@ -249,51 +277,28 @@ class TestModels(CorrectnessTest):
|
|||
'squeezenet_v1.1' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'imagenet1k-resnext-101-64x4d' : [TensorflowEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'imagenet1k-resnext-50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
},
|
||||
'caffe' : {
|
||||
'vgg19' : [KerasEmit],
|
||||
# 'alexnet' : [KerasEmit],
|
||||
# 'inception_v1' : [CntkEmit],
|
||||
# 'resnet152' : [CntkEmit],
|
||||
# 'squeezenet' : [CntkEmit]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_keras(self):
|
||||
# keras original
|
||||
ensure_dir(self.cachedir)
|
||||
ensure_dir(self.tmpdir)
|
||||
original_framework = 'keras'
|
||||
|
||||
for network_name in self.test_table[original_framework].keys():
|
||||
print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# get original model prediction result
|
||||
original_predict = self.KerasParse(network_name, self.image_path)
|
||||
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
converted_predict = emit.__func__(
|
||||
original_framework,
|
||||
network_name,
|
||||
self.tmpdir + network_name + "_converted.pb",
|
||||
self.tmpdir + network_name + "_converted.npy",
|
||||
self.image_path)
|
||||
|
||||
self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
print("Testing {} model all passed.".format(original_framework))
|
||||
|
||||
|
||||
def test_mxnet(self):
|
||||
def test_caffe(self):
|
||||
# mxnet original
|
||||
ensure_dir(self.cachedir)
|
||||
ensure_dir(self.tmpdir)
|
||||
original_framework = 'mxnet'
|
||||
original_framework = 'caffe'
|
||||
|
||||
for network_name in self.test_table[original_framework].keys():
|
||||
print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# get original model prediction result
|
||||
original_predict = self.MXNetParse(network_name, self.image_path)
|
||||
original_predict = self.CaffeParse(network_name, self.image_path)
|
||||
# print(original_predict)
|
||||
|
||||
for emit in self.test_table[original_framework][network_name]:
|
||||
converted_predict = emit.__func__(
|
||||
|
@ -311,3 +316,69 @@ class TestModels(CorrectnessTest):
|
|||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
print("Testing {} model all passed.".format(original_framework))
|
||||
|
||||
# def test_keras(self):
|
||||
# # keras original
|
||||
# ensure_dir(self.cachedir)
|
||||
# ensure_dir(self.tmpdir)
|
||||
# original_framework = 'keras'
|
||||
|
||||
# for network_name in self.test_table[original_framework].keys():
|
||||
# print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# # get original model prediction result
|
||||
# original_predict = self.KerasParse(network_name, self.image_path)
|
||||
|
||||
# for emit in self.test_table[original_framework][network_name]:
|
||||
# converted_predict = emit.__func__(
|
||||
# original_framework,
|
||||
# network_name,
|
||||
# self.tmpdir + network_name + "_converted.pb",
|
||||
# self.tmpdir + network_name + "_converted.npy",
|
||||
# self.image_path)
|
||||
|
||||
# self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
# os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
# os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
# print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
<<<<<<< HEAD
|
||||
# print("Testing {} model all passed.".format(original_framework))
|
||||
=======
|
||||
os.remove(self.tmpdir + network_name + "_converted.json")
|
||||
os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
>>>>>>> 4c794431b8c07d1be1dffc1a9c705cb57831b31e
|
||||
|
||||
|
||||
# def test_mxnet(self):
|
||||
# # mxnet original
|
||||
# ensure_dir(self.cachedir)
|
||||
# ensure_dir(self.tmpdir)
|
||||
# original_framework = 'mxnet'
|
||||
|
||||
# for network_name in self.test_table[original_framework].keys():
|
||||
# print("Testing {} model {} start.".format(original_framework, network_name))
|
||||
|
||||
# # get original model prediction result
|
||||
# original_predict = self.MXNetParse(network_name, self.image_path)
|
||||
|
||||
# for emit in self.test_table[original_framework][network_name]:
|
||||
# converted_predict = emit.__func__(
|
||||
# original_framework,
|
||||
# network_name,
|
||||
# self.tmpdir + network_name + "_converted.pb",
|
||||
# self.tmpdir + network_name + "_converted.npy",
|
||||
# self.image_path)
|
||||
|
||||
# self._compare_outputs(original_predict, converted_predict)
|
||||
|
||||
|
||||
# os.remove(self.tmpdir + network_name + "_converted.pb")
|
||||
# os.remove(self.tmpdir + network_name + "_converted.npy")
|
||||
# print("Testing {} model {} passed.".format(original_framework, network_name))
|
||||
|
||||
# print("Testing {} model all passed.".format(original_framework))
|
||||
|
|
Загрузка…
Ссылка в новой задаче