Merge branch 'bowbao/onnxruntime_ci_stage2'

This commit is contained in:
Bowen Bao 2018-11-01 06:58:24 +00:00
Родитель 0cd2faec1e fca139674c
Коммит 29818ffd05
4 изменённых файлов: 162 добавлений и 11 удалений

Просмотреть файл

@ -4435,8 +4435,8 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
input = Utils::ConvertVariableType<float, float16>(input, true);
//// This is a workaround allowing CNTK V1 pretrained models to continue running after removal of sequence axis from input
//if (input.Shape().Rank() > 1)
// inputArgType = ToTypeProto(input.Shape().SubShape(0, 1), input.HasBatchAxis(), input.HasSequenceAxis());
if ((src->Attributes()[L"spatial"].Value<bool>() ? 1 : 0) && input.Shape().Rank() > 1)
inputArgType = ToTypeProto(input.Shape().SubShape(0, 1), input.HasBatchAxis(), input.HasSequenceAxis());
}
}

Просмотреть файл

@ -8,10 +8,16 @@ import numpy as np
import pytest
import os
import re
import shutil
import time
import tempfile
onnx = pytest.importorskip("onnx")
from onnx import numpy_helper
from .onnx_test_helper import find_onnx_value_info_proto_with_matching_name, save_cntk_data_as_onnx_tensor
from .onnx_verify_helper import verify_model
# To test models locally, create folder 'onnx_models' and put in model folders.
# For example.
# .
@ -199,6 +205,78 @@ def test_onnx_model(model_name, round_trip):
rtol=1e-3,
atol=1e-4)
# Helper for exporting test data.
model_file = 'model.onnx'
data_dir = 'test_data_set_0'
def SaveData(test_data_dir, prefix, onnx_variables, variables, data_list, names, batch_size=1):
if isinstance(data_list, np.ndarray):
data_list = [data_list]
for (i, d), v, n in zip(enumerate(data_list), variables, names):
onnx_value_info_proto = find_onnx_value_info_proto_with_matching_name(onnx_variables, n, onnx_variables[0])
save_cntk_data_as_onnx_tensor(os.path.join(test_data_dir, '{0}_{1}.pb'.format(prefix, i)), v, d, onnx_value_info_proto)
def Save(dir, func, inputs, outputs, batch_size=1):
if not os.path.exists(dir):
os.makedirs(dir)
model_file_path = os.path.join(dir, model_file)
func.save(model_file_path, C.ModelFormat.ONNX)
onnx_model = onnx.load(model_file_path)
onnx_model_description = onnx_model.graph.doc_string
uid_name_map = dict(tuple(x[3:-3].split(', ')) for x in re.findall(r'<<<[^>]*>>>', onnx_model_description)[1:])
input_names = [uid_name_map[x.uid] for x in func.arguments]
# handle block outputs
output_names = []
block_uid_count = {}
# when block are exported as a single onnx node, the onnx node output takes name from block node output.
# when block are exported by exporting nodes within that block, the onnx node output takes name from inner node output.
# the cntk node that provides the name will have its uid stored in the uid_name_map.
# this function tries to find the deepest inner output node whose uid is in uid_name_map.
def find_deepest_inner_block_output(output):
# might be a placeholder
if not output.is_output:
return False, output
if output.owner and output.owner.is_block:
block_uid_count[output.owner.uid] = block_uid_count[output.owner.uid] + 1 if output.owner.uid in block_uid_count else 0
found, inner_output = find_deepest_inner_block_output(output.owner.block_root.outputs[block_uid_count[output.owner.uid]])
if found:
return True, inner_output
return output.uid in uid_name_map, output
for output in func.outputs:
_, output = find_deepest_inner_block_output(output)
output_names.append(uid_name_map[output.uid])
test_data_dir = os.path.join(dir, data_dir)
if not os.path.exists(test_data_dir):
os.makedirs(test_data_dir)
SaveData(test_data_dir, 'input', onnx_model.graph.input, func.arguments, inputs, input_names, batch_size)
SaveData(test_data_dir, 'output', onnx_model.graph.output, func.outputs, outputs, output_names, batch_size)
# Initialize tmp-directory for exporting cntk models
tmpdir = 'tmp_exported_models'
if os.path.isdir(tmpdir):
# os.mkdir might get called before shutil.rmtree complete. So rename the current tmpdir to avoid collision.
tmp = tempfile.mktemp(dir=os.path.dirname(tmpdir))
shutil.move(tmpdir, tmp)
shutil.rmtree(tmp)
os.mkdir(tmpdir)
# test_cntk_model will create exported onnx model with test data in the following tmp folder:
# .
# +-- tmp_exported_models # models exported in 'model.onnx' onnx format.
# | +-- test_model1
# | | +-- model.onnx
# | | +-- test_data_set_0
# | | | +-- input_0.pb
# | | | +-- input_1.pb
# | | | +-- output_0.pb
# | | +-- test_data_set_1
# | | | +-- input_0.pb
# | | | +-- input_1.pb
# | | | +-- output_0.pb
# | +-- test_model2
# ...
@pytest.mark.parametrize('model_name',
[model_name for model_name in cntk_model_names],
ids=[model_name for model_name in cntk_model_names])
@ -208,14 +286,17 @@ def test_cntk_model(model_name):
model_dir = os.path.join(cntk_base_dir, model_name)
model = C.Function.load(model_dir, format=C.ModelFormat.CNTKv2)
resave_model_path = 'model_resave.onnx'
model.save(resave_model_path, format=C.ModelFormat.ONNX)
reloaded_model = C.Function.load(resave_model_path, format=C.ModelFormat.ONNX)
resave_model_dir = os.path.join(tmpdir, 'test_' + model_name)
resave_model_path = os.path.join(resave_model_dir, model_file)
np.random.seed(3)
input_shape = (1,) + model.arguments[0].shape
data_x = np.asarray(np.random.uniform(-1, 1, input_shape), dtype=np.float32)
data_y = model.eval({model.arguments[0]:data_x})
Save(resave_model_dir, model, data_x, data_y)
reloaded_model = C.Function.load(resave_model_path, format=C.ModelFormat.ONNX)
data_y_ = reloaded_model.eval({reloaded_model.arguments[0]:data_x})
np.testing.assert_equal(len(data_y), len(data_y_))
@ -225,4 +306,6 @@ def test_cntk_model(model_name):
data_y[i],
data_y_[i],
rtol=1e-3,
atol=1e-4)
atol=1e-4)
verify_model(model_name, str(os.path.abspath(tmpdir)))

Просмотреть файл

@ -10,6 +10,7 @@ import scipy
import cntk as C
import pytest
onnx = pytest.importorskip("onnx")
from .onnx_verify_helper import verify_model, get_onnx_test_runner_callscript
CNTK_FREEDIM_AXIS_DENOTATION = -3
DIM_SIZE_FOR_NON_BATCH_OPS = 1
@ -249,8 +250,7 @@ def save_test_data(model, onnx_model, test_data_path, input_data, output_data, n
model.outputs[i], output_data_i, onnx_value_info_proto)
# print out command line for onnx test runner
verify_filename = os.path.join(str(tmpdir), '../verify.bat')
append_write = 'a' if os.path.exists(verify_filename) else 'w'
with open(verify_filename, append_write) as file:
file.write(R'onnx_test_runner.exe -n ' + name + ' ' + str(tmpdir) + '\n')
print(R'onnx_test_runner.exe -n ' + name + ' ' + str(tmpdir))
print(get_onnx_test_runner_callscript(name, tmpdir))
failed_cases_count = verify_model(name, tmpdir)
assert failed_cases_count == 0

Просмотреть файл

@ -0,0 +1,68 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
from __future__ import print_function
import os, re, sys, subprocess
windows = os.getenv("OS")=="Windows_NT"
known_issues = [
'BatchNormalization_float160',
'SpatialBatchNormalization_float160',
'DepthToSpace',
'RNN',
'test_sequence_slice_',
'test_sequence_slice_0',
'test_sequence_slice_1',
'SequenceSoftmax',
'SpaceToDepth',
'top_k',
'ConvTranspose_with_OutputShape_0',
'Flatten_1',
'Gather_1',
# Not in onnxruntime
'LayerNorm_0',
'MVN_0',
'MVN_1',
'MVN_2',
'MVN_3',
]
def parse_single_result_case(case_str):
fails = re.search(r'Failed Test Cases:\w+', case_str)
if fails:
failed_case = fails.group().split(':')[1]
if not failed_case in known_issues:
print(case_str, file=sys.stderr)
return 1
return 0
def parse_verify_out_str(content):
total_failed_cases = 0
case_list = re.findall(r'result:[\s\S]*?Failed Test Cases:[^\n]*\n', content)
for case_str in case_list:
total_failed_cases += parse_single_result_case(case_str)
if total_failed_cases:
print('ERROR: onnx_test_runner produced ' + str(total_failed_cases) + ' failed cases.', file=sys.stderr)
sys.exit(1)
return total_failed_cases
def verify_model(model_name, model_dir):
path_prefix = os.path.join(os.environ['CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY'], 'ONNXRuntime') if 'CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY' in os.environ else ''
onnx_test_runner_path_str = str(os.path.join(path_prefix, 'onnx_test_runner.exe'))
# run only on windows.
if not os.path.exists(onnx_test_runner_path_str) or not windows:
return 0
callargs = [onnx_test_runner_path_str, '-n', model_name, str(model_dir)]
process = subprocess.run(callargs, stdout=subprocess.PIPE)
return parse_verify_out_str(process.stdout.decode('utf-8'))
def get_onnx_test_runner_callscript(model_name, model_dir):
return R'onnx_test_runner.exe -n ' + model_name + ' ' + str(model_dir)