Add check for empty input in StringJoin operator and fix empty string input error in BlingFire sentence breaker. (#175)

* Add test cases and fix empty string error in BlingFire sentence breaker.

* Throw error if input text to join is empty array.

* Fix scalar support and access violation.

* Resolve comments.

* Resolve comments.

Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
This commit is contained in:
Zuwei Zhao 2021-10-27 20:21:16 +08:00 коммит произвёл GitHub
Родитель 64c972fb02
Коммит 05f7ded825
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 231 добавлений и 32 удалений

Просмотреть файл

@ -17,16 +17,22 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
GetTensorMutableDataString(api_, ort_, context, input_X, X);
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
// Setup output
// Check input
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
ORT_CXX_API_THROW("Input 2 is the separator, it has 1 element.", ORT_INVALID_ARGUMENT);
ORT_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions_axis(ort_, input_axis);
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
ORT_CXX_API_THROW("Input 3 is the axis, it has 1 element.", ORT_INVALID_ARGUMENT);
ORT_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input_X);
if (*axis < 0 || *axis >= dimensions.size())
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
if (dimensions.size() == 0) {
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
if (X.size() != 1)
ORT_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
} else {
if (*axis < 0 || *axis >= dimensions.size())
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
}
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
if (dimensions.size() > 1) {
@ -45,27 +51,38 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
std::vector<std::string> out(size);
// Do computation
int64_t h = 1;
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
h *= dimensions[i];
}
int64_t left_part = size / h;
int64_t right_part = size / left_part;
int64_t n_red = dimensions[*axis] - 1;
int64_t inc = right_part * (n_red + 1);
int64_t pos = 0;
for (int64_t li = 0; li < left_part; ++li) {
for (int64_t ri = 0; ri < right_part; ++ri, ++pos) {
std::ostringstream st;
int64_t index = ri + li * inc;
for (int64_t j = 0; j < n_red; ++j, index += h) {
st << X[index] << sep[0];
if (dimensions.size() > 0) {
if (X.size() > 0) {
// Do computation
int64_t h = 1;
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
h *= dimensions[i];
}
st << X[index];
out[pos] = st.str();
int64_t left_part = size / h;
int64_t right_part = size / left_part;
int64_t n_red = dimensions[*axis] - 1;
int64_t inc = right_part * (n_red + 1);
int64_t pos = 0;
for (int64_t li = 0; li < left_part; ++li) {
for (int64_t ri = 0; ri < right_part; ++ri, ++pos) {
std::ostringstream st;
int64_t index = ri + li * inc;
for (int64_t j = 0; j < n_red; ++j, index += h) {
st << X[index] << sep[0];
}
st << X[index];
out[pos] = st.str();
}
}
} else {
// for input 1 contains 0 elements, output joined string is empty string
out[0] = "";
}
} else {
// for input 1 (scalar) which has 1 element, output joined string is input string itself. See issue: https://github.com/onnx/onnx/issues/3724
out[0] = X[0];
}
FillTensorDataString(api_, ort_, context, out, output);
}

Просмотреть файл

@ -52,16 +52,22 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
// inline split output_str by newline '\n'
std::vector<char*> output_sentences;
bool head_flag = true;
for (int i = 0; i < output_length; i++) {
if (head_flag) {
output_sentences.push_back(&output_str[i]);
head_flag = false;
}
if (output_str[i] == '\n') {
head_flag = true;
output_str[i] = '\0';
if (output_length == 0) {
// put one empty string if output_length is 0
output_sentences.push_back("");
} else {
bool head_flag = true;
for (int i = 0; i < output_length; i++) {
if (head_flag) {
output_sentences.push_back(&output_str[i]);
head_flag = false;
}
if (output_str[i] == '\n') {
head_flag = true;
output_str[i] = '\0';
}
}
}

Двоичные данные
test/data/custom_op_string_join.onnx Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -265,3 +265,140 @@ TEST(string_operator, test_string_ecmaregex_replace) {
outputs[0].values_string = {"Test+ test", "tEsT+ Test", " TEST+ test"};
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
TEST(utils, test_string_join) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
std::vector<TestValue> inputs(3);
inputs[0].name = "text";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[0].dims = {1,3};
inputs[0].values_string = {"abc","zzz","efg"};
inputs[1].name = "sep";
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[1].dims = {1};
inputs[1].values_string = {"-"};
inputs[2].name = "axis";
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
inputs[2].dims = {1};
inputs[2].values_int64 = {1};
std::vector<TestValue> outputs(1);
outputs[0].name = "out";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
outputs[0].dims = {1};
outputs[0].values_string = {"abc-zzz-efg"};
std::filesystem::path model_path = __FILE__;
model_path = model_path.parent_path();
model_path /= "..";
model_path /= "data";
model_path /= "custom_op_string_join.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
TEST(utils, test_string_join_values_empty_string) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
std::vector<TestValue> inputs(3);
inputs[0].name = "text";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[0].dims = {1};
inputs[0].values_string = {""};
inputs[1].name = "sep";
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[1].dims = {1};
inputs[1].values_string = {" "};
inputs[2].name = "axis";
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
inputs[2].dims = {1};
inputs[2].values_int64 = {0};
std::vector<TestValue> outputs(1);
outputs[0].name = "out";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
outputs[0].dims = {1};
outputs[0].values_string = {""};
std::filesystem::path model_path = __FILE__;
model_path = model_path.parent_path();
model_path /= "..";
model_path /= "data";
model_path /= "custom_op_string_join.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
TEST(utils, test_string_join_dims_zero_values_empty) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
std::vector<TestValue> inputs(3);
inputs[0].name = "text";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
// If dims contains value 0, values must have length 0. See issue: https://github.com/onnx/onnx/issues/3724
inputs[0].dims = {0};
inputs[0].values_string = {};
inputs[1].name = "sep";
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[1].dims = {1};
inputs[1].values_string = {" "};
inputs[2].name = "axis";
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
inputs[2].dims = {1};
inputs[2].values_int64 = {0};
std::vector<TestValue> outputs(1);
outputs[0].name = "out";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
outputs[0].dims = {1};
outputs[0].values_string = {""};
std::filesystem::path model_path = __FILE__;
model_path = model_path.parent_path();
model_path /= "..";
model_path /= "data";
model_path /= "custom_op_string_join.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
TEST(utils, test_string_join_dims_empty_values_scalar) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
std::vector<TestValue> inputs(3);
inputs[0].name = "text";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
// dims size 0 means it's scalar, and values must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
inputs[0].dims = {};
inputs[0].values_string = {"abc"};
inputs[1].name = "sep";
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[1].dims = {1};
inputs[1].values_string = {" "};
inputs[2].name = "axis";
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
inputs[2].dims = {1};
inputs[2].values_int64 = {0};
std::vector<TestValue> outputs(1);
outputs[0].name = "out";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
outputs[0].dims = {1};
outputs[0].values_string = {"abc"};
std::filesystem::path model_path = __FILE__;
model_path = model_path.parent_path();
model_path /= "..";
model_path /= "data";
model_path /= "custom_op_string_join.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}

Просмотреть файл

@ -24,6 +24,19 @@ class TestBlingFireSentenceBreaker(unittest.TestCase):
"Я увидел девушку с телескопом."])
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
def test_text_to_case2(self):
# input is empty
inputs = np.array([""])
outputs = np.array([""])
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
def test_text_to_case3(self):
# input is whitespace
inputs = np.array([" "])
# output of blingfire sbd.bin model
outputs = np.array([""])
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -616,6 +616,32 @@ class TestPythonOpString(unittest.TestCase):
self.assertEqual(
txout[0].tolist(), np.array(["a;b;cc"]).tolist())
def test_string_join_empty(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_join('')
self.assertIn('op_type: "StringJoin"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
text = np.array([""])
sep = np.array([" "])
axis = np.array([0], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(), np.array([""]).tolist())
def test_string_join_scalar(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_join('')
self.assertIn('op_type: "StringJoin"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
text = np.array("a scalar string")
sep = np.array([" "])
axis = np.array([0], dtype=np.int64)
txt_out = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txt_out[0].tolist(), np.array(["a scalar string"]).tolist())
def test_string_join_cc_3d(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())