зеркало из https://github.com/microsoft/MMdnn.git
to pytorch test
This commit is contained in:
Родитель
3a81b7adce
Коммит
7c1ae6c4f8
|
@ -4,6 +4,7 @@ import unittest
|
|||
import numpy as np
|
||||
from six.moves import reload_module
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from mmdnn.conversion.examples.imagenet_test import TestKit
|
||||
|
||||
from mmdnn.conversion.examples.keras.extractor import keras_extractor
|
||||
|
@ -15,7 +16,7 @@ from mmdnn.conversion.mxnet.mxnet_parser import MXNetParser
|
|||
from mmdnn.conversion.cntk.cntk_emitter import CntkEmitter
|
||||
from mmdnn.conversion.tensorflow.tensorflow_emitter import TensorflowEmitter
|
||||
from mmdnn.conversion.keras.keras2_emitter import Keras2Emitter
|
||||
from mmdnn.conversion.caffe.caffe_emitter import CaffeEmitter
|
||||
from mmdnn.conversion.pytorch.pytorch_emitter import PytorchEmitter
|
||||
|
||||
def _compute_SNR(x,y):
|
||||
noise = x - y
|
||||
|
@ -167,29 +168,36 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def CaffeEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to Caffe.".format(architecture_name, original_framework))
|
||||
def PytorchEmit(original_framework, architecture_name, architecture_path, weight_path, image_path):
|
||||
print("Testing {} from {} to Pytorch.".format(architecture_name, original_framework))
|
||||
|
||||
# IR to code
|
||||
emitter = CaffeEmitter((architecture_path, weight_path))
|
||||
emitter.run("converted_model.py", None, 'test')
|
||||
emitter = PytorchEmitter((architecture_path, weight_path))
|
||||
emitter.run("converted_model.py", "pytorch_weight.npy", 'test')
|
||||
del emitter
|
||||
|
||||
# import converted model
|
||||
import converted_model
|
||||
reload_module (converted_model)
|
||||
model_converted = converted_model.KitModel(TestModels.tmpdir + architecture_name + "_converted.npy")
|
||||
# input_tf, model_tf = model_converted
|
||||
model_converted = converted_model.KitModel("pytorch_weight.npy")
|
||||
model_converted.eval()
|
||||
|
||||
func = TestKit.preprocess_func[original_framework][architecture_name]
|
||||
img = func(image_path)
|
||||
# input_data = np.expand_dims(img, 0)
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
img = np.expand_dims(img, 0).copy()
|
||||
input_data = torch.from_numpy(img)
|
||||
input_data = torch.autograd.Variable(input_data, requires_grad = False)
|
||||
|
||||
# del model_converted
|
||||
# del converted_model
|
||||
# os.remove("converted_model.py")
|
||||
# converted_predict = np.squeeze(predict)
|
||||
# return converted_predict
|
||||
predict = model_converted(input_data)
|
||||
predict = predict.data.numpy()
|
||||
|
||||
del model_converted
|
||||
del converted_model
|
||||
os.remove("converted_model.py")
|
||||
os.remove("pytorch_weight.npy")
|
||||
converted_predict = np.squeeze(predict)
|
||||
return converted_predict
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
@ -225,22 +233,22 @@ class TestModels(CorrectnessTest):
|
|||
|
||||
test_table = {
|
||||
'keras' : {
|
||||
'vgg16' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'inception_v3' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'resnet50' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'densenet' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'vgg16' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'inception_v3' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'resnet50' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'densenet' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'xception' : [TensorflowEmit, KerasEmit],
|
||||
'mobilenet' : [TensorflowEmit, KerasEmit],
|
||||
'nasnet' : [TensorflowEmit, KerasEmit],
|
||||
},
|
||||
'mxnet' : {
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'imagenet1k-inception-bn' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'imagenet1k-resnet-152' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'squeezenet_v1.1' : [CntkEmit, TensorflowEmit, KerasEmit],
|
||||
'imagenet1k-resnext-101-64x4d' : [TensorflowEmit], # TODO: CntkEmit
|
||||
'imagenet1k-resnext-50' : [TensorflowEmit, KerasEmit], # TODO: CntkEmit
|
||||
'vgg19' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'imagenet1k-inception-bn' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'imagenet1k-resnet-152' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'squeezenet_v1.1' : [CntkEmit, TensorflowEmit, KerasEmit, PytorchEmit],
|
||||
'imagenet1k-resnext-101-64x4d' : [TensorflowEmit, PytorchEmit], # TODO: CntkEmit
|
||||
'imagenet1k-resnext-50' : [TensorflowEmit, KerasEmit, PytorchEmit], # TODO: CntkEmit
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче