Add check for empty input in StringJoin operator and fix empty string input error in BlingFire sentence breaker. (#175)
* Add test cases and fix empty string error in BlingFire sentence breaker. * Throw error if input text to join is empty array. * Fix scalar support and access violation. * Resolve comments. * Resolve comments. Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
This commit is contained in:
Родитель
64c972fb02
Коммит
05f7ded825
|
@ -17,16 +17,22 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
|
||||
|
||||
// Setup output
|
||||
// Check input
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
ORT_CXX_API_THROW("Input 2 is the separator, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
ORT_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
||||
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
||||
ORT_CXX_API_THROW("Input 3 is the axis, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
ORT_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (*axis < 0 || *axis >= dimensions.size())
|
||||
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
||||
if (dimensions.size() == 0) {
|
||||
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
|
||||
if (X.size() != 1)
|
||||
ORT_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
|
||||
} else {
|
||||
if (*axis < 0 || *axis >= dimensions.size())
|
||||
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
||||
if (dimensions.size() > 1) {
|
||||
|
@ -45,27 +51,38 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
std::vector<std::string> out(size);
|
||||
|
||||
// Do computation
|
||||
int64_t h = 1;
|
||||
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
|
||||
h *= dimensions[i];
|
||||
}
|
||||
int64_t left_part = size / h;
|
||||
int64_t right_part = size / left_part;
|
||||
int64_t n_red = dimensions[*axis] - 1;
|
||||
int64_t inc = right_part * (n_red + 1);
|
||||
int64_t pos = 0;
|
||||
for (int64_t li = 0; li < left_part; ++li) {
|
||||
for (int64_t ri = 0; ri < right_part; ++ri, ++pos) {
|
||||
std::ostringstream st;
|
||||
int64_t index = ri + li * inc;
|
||||
for (int64_t j = 0; j < n_red; ++j, index += h) {
|
||||
st << X[index] << sep[0];
|
||||
if (dimensions.size() > 0) {
|
||||
if (X.size() > 0) {
|
||||
// Do computation
|
||||
int64_t h = 1;
|
||||
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
|
||||
h *= dimensions[i];
|
||||
}
|
||||
st << X[index];
|
||||
out[pos] = st.str();
|
||||
int64_t left_part = size / h;
|
||||
int64_t right_part = size / left_part;
|
||||
int64_t n_red = dimensions[*axis] - 1;
|
||||
int64_t inc = right_part * (n_red + 1);
|
||||
int64_t pos = 0;
|
||||
for (int64_t li = 0; li < left_part; ++li) {
|
||||
for (int64_t ri = 0; ri < right_part; ++ri, ++pos) {
|
||||
std::ostringstream st;
|
||||
int64_t index = ri + li * inc;
|
||||
for (int64_t j = 0; j < n_red; ++j, index += h) {
|
||||
st << X[index] << sep[0];
|
||||
}
|
||||
st << X[index];
|
||||
out[pos] = st.str();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// for input 1 contains 0 elements, output joined string is empty string
|
||||
out[0] = "";
|
||||
}
|
||||
} else {
|
||||
// for input 1 (scalar) which has 1 element, output joined string is input string itself. See issue: https://github.com/onnx/onnx/issues/3724
|
||||
out[0] = X[0];
|
||||
}
|
||||
|
||||
FillTensorDataString(api_, ort_, context, out, output);
|
||||
}
|
||||
|
||||
|
|
|
@ -52,16 +52,22 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
|
||||
// inline split output_str by newline '\n'
|
||||
std::vector<char*> output_sentences;
|
||||
bool head_flag = true;
|
||||
for (int i = 0; i < output_length; i++) {
|
||||
if (head_flag) {
|
||||
output_sentences.push_back(&output_str[i]);
|
||||
head_flag = false;
|
||||
}
|
||||
|
||||
if (output_str[i] == '\n') {
|
||||
head_flag = true;
|
||||
output_str[i] = '\0';
|
||||
if (output_length == 0) {
|
||||
// put one empty string if output_length is 0
|
||||
output_sentences.push_back("");
|
||||
} else {
|
||||
bool head_flag = true;
|
||||
for (int i = 0; i < output_length; i++) {
|
||||
if (head_flag) {
|
||||
output_sentences.push_back(&output_str[i]);
|
||||
head_flag = false;
|
||||
}
|
||||
|
||||
if (output_str[i] == '\n') {
|
||||
head_flag = true;
|
||||
output_str[i] = '\0';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Двоичный файл не отображается.
|
@ -265,3 +265,140 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
outputs[0].values_string = {"Test+ test", "tEsT+ Test", " TEST+ test"};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
|
||||
TEST(utils, test_string_join) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(3);
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1,3};
|
||||
inputs[0].values_string = {"abc","zzz","efg"};
|
||||
|
||||
inputs[1].name = "sep";
|
||||
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[1].dims = {1};
|
||||
inputs[1].values_string = {"-"};
|
||||
|
||||
inputs[2].name = "axis";
|
||||
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
inputs[2].dims = {1};
|
||||
inputs[2].values_int64 = {1};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "out";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
outputs[0].values_string = {"abc-zzz-efg"};
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "custom_op_string_join.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
TEST(utils, test_string_join_values_empty_string) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(3);
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1};
|
||||
inputs[0].values_string = {""};
|
||||
|
||||
inputs[1].name = "sep";
|
||||
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[1].dims = {1};
|
||||
inputs[1].values_string = {" "};
|
||||
|
||||
inputs[2].name = "axis";
|
||||
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
inputs[2].dims = {1};
|
||||
inputs[2].values_int64 = {0};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "out";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
outputs[0].values_string = {""};
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "custom_op_string_join.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
TEST(utils, test_string_join_dims_zero_values_empty) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(3);
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
// If dims contains value 0, values must have length 0. See issue: https://github.com/onnx/onnx/issues/3724
|
||||
inputs[0].dims = {0};
|
||||
inputs[0].values_string = {};
|
||||
|
||||
inputs[1].name = "sep";
|
||||
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[1].dims = {1};
|
||||
inputs[1].values_string = {" "};
|
||||
|
||||
inputs[2].name = "axis";
|
||||
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
inputs[2].dims = {1};
|
||||
inputs[2].values_int64 = {0};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "out";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
outputs[0].values_string = {""};
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "custom_op_string_join.onnx";
|
||||
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
TEST(utils, test_string_join_dims_empty_values_scalar) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(3);
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
// dims size 0 means it's scalar, and values must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
|
||||
inputs[0].dims = {};
|
||||
inputs[0].values_string = {"abc"};
|
||||
|
||||
inputs[1].name = "sep";
|
||||
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[1].dims = {1};
|
||||
inputs[1].values_string = {" "};
|
||||
|
||||
inputs[2].name = "axis";
|
||||
inputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
inputs[2].dims = {1};
|
||||
inputs[2].values_int64 = {0};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "out";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
outputs[0].values_string = {"abc"};
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "custom_op_string_join.onnx";
|
||||
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
|
|
@ -24,6 +24,19 @@ class TestBlingFireSentenceBreaker(unittest.TestCase):
|
|||
"Я увидел девушку с телескопом."])
|
||||
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
|
||||
|
||||
def test_text_to_case2(self):
|
||||
# input is empty
|
||||
inputs = np.array([""])
|
||||
outputs = np.array([""])
|
||||
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
|
||||
|
||||
def test_text_to_case3(self):
|
||||
# input is whitespace
|
||||
inputs = np.array([" "])
|
||||
# output of blingfire sbd.bin model
|
||||
outputs = np.array([""])
|
||||
_run_blingfire_sentencebreaker(input=inputs, output=outputs, model_path=_get_test_data_file('data', 'default_sentence_break_model.bin'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -616,6 +616,32 @@ class TestPythonOpString(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(["a;b;cc"]).tolist())
|
||||
|
||||
def test_string_join_empty(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.array([""])
|
||||
sep = np.array([" "])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array([""]).tolist())
|
||||
|
||||
def test_string_join_scalar(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.array("a scalar string")
|
||||
sep = np.array([" "])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
txt_out = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txt_out[0].tolist(), np.array(["a scalar string"]).tolist())
|
||||
|
||||
def test_string_join_cc_3d(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
|
|
Загрузка…
Ссылка в новой задаче