Add operator StringSplit (#24)
This commit is contained in:
Родитель
fadcf2ab89
Коммит
db43f413b8
|
@ -0,0 +1,119 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#include "string_split.hpp"
|
||||||
|
|
||||||
|
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||||
|
// Setup inputs
|
||||||
|
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||||
|
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||||
|
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||||
|
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
|
||||||
|
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
|
||||||
|
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
|
||||||
|
|
||||||
|
// Setup output
|
||||||
|
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||||
|
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||||
|
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
|
||||||
|
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||||
|
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||||
|
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
|
||||||
|
OrtTensorDimensions dimensions(ort_, input_X);
|
||||||
|
if (dimensions.size() != 1)
|
||||||
|
throw std::runtime_error("Only 1D tensor are supported as input.");
|
||||||
|
|
||||||
|
std::vector<std::string> words;
|
||||||
|
std::vector<int64_t> indices;
|
||||||
|
int64_t maxc = 0;
|
||||||
|
int64_t col;
|
||||||
|
std::string delimiter = *sep;
|
||||||
|
bool keep = !(*skip_empty);
|
||||||
|
std::size_t current, previous = 0;
|
||||||
|
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||||
|
const std::string& str = X[row];
|
||||||
|
if (str.empty())
|
||||||
|
continue;
|
||||||
|
previous = 0;
|
||||||
|
col = 0;
|
||||||
|
current = str.find_first_of(delimiter);
|
||||||
|
while (current != std::string::npos) {
|
||||||
|
if (keep || current > previous) {
|
||||||
|
words.push_back(str.substr(previous, current - previous));
|
||||||
|
indices.push_back(row);
|
||||||
|
indices.push_back(col);
|
||||||
|
++col;
|
||||||
|
}
|
||||||
|
previous = current + 1;
|
||||||
|
current = str.find_first_of(delimiter, previous);
|
||||||
|
}
|
||||||
|
if (keep || current > previous) {
|
||||||
|
words.push_back(str.substr(previous, current - previous));
|
||||||
|
indices.push_back(row);
|
||||||
|
indices.push_back(col);
|
||||||
|
++col;
|
||||||
|
}
|
||||||
|
maxc = col > maxc ? col : maxc;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
|
||||||
|
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
|
||||||
|
|
||||||
|
std::vector<int64_t> shape_text(1, words.size());
|
||||||
|
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
|
||||||
|
|
||||||
|
std::vector<int64_t> shape_shape(1, 2);
|
||||||
|
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
|
||||||
|
|
||||||
|
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
|
||||||
|
std::string* p_text = ort_.GetTensorMutableData<std::string>(out_text);
|
||||||
|
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
|
||||||
|
|
||||||
|
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||||
|
p_shape[0] = dimensions[0];
|
||||||
|
p_shape[1] = maxc;
|
||||||
|
std::copy(words.begin(), words.end(), p_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
||||||
|
return new KernelStringSplit(api);
|
||||||
|
};
|
||||||
|
|
||||||
|
const char* CustomOpStringSplit::GetName() const {
|
||||||
|
return "StringSplit";
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t CustomOpStringSplit::GetInputTypeCount() const {
|
||||||
|
return 3;
|
||||||
|
};
|
||||||
|
|
||||||
|
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
|
||||||
|
switch (index) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||||
|
case 2:
|
||||||
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t CustomOpStringSplit::GetOutputTypeCount() const {
|
||||||
|
return 3;
|
||||||
|
};
|
||||||
|
|
||||||
|
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
|
||||||
|
switch (index) {
|
||||||
|
case 0:
|
||||||
|
case 2:
|
||||||
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
|
case 1:
|
||||||
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(MakeString("Unexpected output index ", index));
|
||||||
|
}
|
||||||
|
};
|
|
@ -0,0 +1,21 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "kernels.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
struct KernelStringSplit : BaseKernel {
|
||||||
|
KernelStringSplit(OrtApi api);
|
||||||
|
void Compute(OrtKernelContext* context);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
||||||
|
void* CreateKernel(OrtApi api, const OrtKernelInfo* info);
|
||||||
|
const char* GetName() const;
|
||||||
|
size_t GetInputTypeCount() const;
|
||||||
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
size_t GetOutputTypeCount() const;
|
||||||
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
|
};
|
|
@ -5,6 +5,7 @@
|
||||||
#include "kernels/string_hash.hpp"
|
#include "kernels/string_hash.hpp"
|
||||||
#include "kernels/string_join.hpp"
|
#include "kernels/string_join.hpp"
|
||||||
#include "kernels/string_regex_replace.hpp"
|
#include "kernels/string_regex_replace.hpp"
|
||||||
|
#include "kernels/string_split.hpp"
|
||||||
#include "kernels/string_upper.hpp"
|
#include "kernels/string_upper.hpp"
|
||||||
#include "kernels/test_output.hpp"
|
#include "kernels/test_output.hpp"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
@ -15,6 +16,7 @@ CustomOpStringHash c_CustomOpStringHash;
|
||||||
CustomOpStringHashFast c_CustomOpStringHashFast;
|
CustomOpStringHashFast c_CustomOpStringHashFast;
|
||||||
CustomOpStringJoin c_CustomOpStringJoin;
|
CustomOpStringJoin c_CustomOpStringJoin;
|
||||||
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
|
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
|
||||||
|
CustomOpStringSplit c_CustomOpStringSplit;
|
||||||
CustomOpStringUpper c_CustomOpStringUpper;
|
CustomOpStringUpper c_CustomOpStringUpper;
|
||||||
CustomOpOne c_CustomOpOne;
|
CustomOpOne c_CustomOpOne;
|
||||||
CustomOpTwo c_CustomOpTwo;
|
CustomOpTwo c_CustomOpTwo;
|
||||||
|
@ -26,6 +28,7 @@ OrtCustomOp* operator_lists[] = {
|
||||||
&c_CustomOpStringHashFast,
|
&c_CustomOpStringHashFast,
|
||||||
&c_CustomOpStringJoin,
|
&c_CustomOpStringJoin,
|
||||||
&c_CustomOpStringRegexReplace,
|
&c_CustomOpStringRegexReplace,
|
||||||
|
&c_CustomOpStringSplit,
|
||||||
&c_CustomOpStringUpper,
|
&c_CustomOpStringUpper,
|
||||||
&c_CustomOpOne,
|
&c_CustomOpOne,
|
||||||
&c_CustomOpTwo,
|
&c_CustomOpTwo,
|
||||||
|
|
|
@ -150,6 +150,36 @@ def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
|
||||||
|
nodes = []
|
||||||
|
nodes.append(helper.make_node('Identity', ['input'], ['id1']))
|
||||||
|
nodes.append(helper.make_node('Identity', ['delimiter'], ['id2']))
|
||||||
|
nodes.append(helper.make_node('Identity', ['skip_empty'], ['id3']))
|
||||||
|
nodes.append(
|
||||||
|
helper.make_node(
|
||||||
|
'%sStringSplit' % prefix, ['id1', 'id2', 'id3'],
|
||||||
|
['indices', 'values', 'shape'], domain=domain))
|
||||||
|
|
||||||
|
input0 = helper.make_tensor_value_info(
|
||||||
|
'input', onnx_proto.TensorProto.STRING, [])
|
||||||
|
input1 = helper.make_tensor_value_info(
|
||||||
|
'delimiter', onnx_proto.TensorProto.STRING, [])
|
||||||
|
input2 = helper.make_tensor_value_info(
|
||||||
|
'skip_empty', onnx_proto.TensorProto.BOOL, [])
|
||||||
|
output0 = helper.make_tensor_value_info(
|
||||||
|
'indices', onnx_proto.TensorProto.INT64, [])
|
||||||
|
output1 = helper.make_tensor_value_info(
|
||||||
|
'values', onnx_proto.TensorProto.STRING, [])
|
||||||
|
output2 = helper.make_tensor_value_info(
|
||||||
|
'shape', onnx_proto.TensorProto.INT64, [])
|
||||||
|
|
||||||
|
graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
|
||||||
|
[output0, output1, output2])
|
||||||
|
model = helper.make_model(
|
||||||
|
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class TestPythonOpString(unittest.TestCase):
|
class TestPythonOpString(unittest.TestCase):
|
||||||
|
|
||||||
_string_join = None
|
_string_join = None
|
||||||
|
@ -246,6 +276,36 @@ class TestPythonOpString(unittest.TestCase):
|
||||||
def string_equal(x, y):
|
def string_equal(x, y):
|
||||||
return x == y
|
return x == y
|
||||||
|
|
||||||
|
@onnx_op(op_type="PyStringSplit",
|
||||||
|
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
|
||||||
|
PyCustomOpDef.dt_bool],
|
||||||
|
outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_string,
|
||||||
|
PyCustomOpDef.dt_int64])
|
||||||
|
def string_split(input, delimiter, skip_empty):
|
||||||
|
if delimiter.shape != (1, ):
|
||||||
|
raise RuntimeError("demiliter must a single element tensor.")
|
||||||
|
if skip_empty.shape != (1, ):
|
||||||
|
raise RuntimeError("skip_empty must a single element tensor.")
|
||||||
|
if len(input.shape) != 1:
|
||||||
|
raise RuntimeError("input must a one dimension tensor.")
|
||||||
|
delimiter = delimiter[0]
|
||||||
|
skip_empty = skip_empty[0]
|
||||||
|
texts = []
|
||||||
|
indices = []
|
||||||
|
max_split = 0
|
||||||
|
for row, text in enumerate(input):
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
res = text.split(delimiter)
|
||||||
|
if skip_empty:
|
||||||
|
res = [t for t in res if t]
|
||||||
|
texts.extend(res)
|
||||||
|
max_split = max(max_split, len(res))
|
||||||
|
indices.extend((row, i) for i in range(len(res)))
|
||||||
|
return (np.array(indices, dtype=np.int64),
|
||||||
|
np.array(texts),
|
||||||
|
np.array([len(input), max_split], dtype=np.int64))
|
||||||
|
|
||||||
cls._string_join = string_join
|
cls._string_join = string_join
|
||||||
cls._string_to_crc32 = string_to_crc32
|
cls._string_to_crc32 = string_to_crc32
|
||||||
|
|
||||||
|
@ -576,6 +636,79 @@ class TestPythonOpString(unittest.TestCase):
|
||||||
txout = sess.run(None, {'x': y, 'y': x})
|
txout = sess.run(None, {'x': y, 'y': x})
|
||||||
self.assertEqual(txout[0].tolist(), (y == x).tolist())
|
self.assertEqual(txout[0].tolist(), (y == x).tolist())
|
||||||
|
|
||||||
|
def test_string_split_python(self):
|
||||||
|
so = _ort.SessionOptions()
|
||||||
|
so.register_custom_ops_library(_get_library_path())
|
||||||
|
onnx_model = _create_test_model_string_split('Py')
|
||||||
|
self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
|
||||||
|
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||||
|
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
|
||||||
|
delimiter = np.array([","])
|
||||||
|
|
||||||
|
for skip in [True, False]:
|
||||||
|
with self.subTest(skip=skip):
|
||||||
|
skip_empty = np.array([skip])
|
||||||
|
|
||||||
|
txout = sess.run(
|
||||||
|
None, {'input': input, 'delimiter': delimiter,
|
||||||
|
'skip_empty': skip_empty})
|
||||||
|
|
||||||
|
if skip_empty:
|
||||||
|
exp_indices = np.array(
|
||||||
|
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||||
|
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||||
|
else:
|
||||||
|
exp_indices = np.array(
|
||||||
|
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||||
|
exp_text = np.array(['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||||
|
exp_shape = np.array([4, 3])
|
||||||
|
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
|
||||||
|
self.assertEqual(exp_text.tolist(), txout[1].tolist())
|
||||||
|
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
|
||||||
|
|
||||||
|
def test_string_split_cc(self):
|
||||||
|
so = _ort.SessionOptions()
|
||||||
|
so.register_custom_ops_library(_get_library_path())
|
||||||
|
onnx_model = _create_test_model_string_split('')
|
||||||
|
self.assertIn('op_type: "StringSplit"', str(onnx_model))
|
||||||
|
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||||
|
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
|
||||||
|
delimiter = np.array([","])
|
||||||
|
|
||||||
|
for skip in [True, False]:
|
||||||
|
with self.subTest(skip=skip):
|
||||||
|
skip_empty = np.array([skip])
|
||||||
|
|
||||||
|
txout = sess.run(
|
||||||
|
None, {'input': input, 'delimiter': delimiter,
|
||||||
|
'skip_empty': skip_empty})
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensorflow.raw_ops import StringSplit
|
||||||
|
dotf = True
|
||||||
|
except ImportError:
|
||||||
|
dotf = False
|
||||||
|
if dotf:
|
||||||
|
tfres = StringSplit(
|
||||||
|
input=input, delimiter=",,", skip_empty=skip)
|
||||||
|
self.assertEqual([_.decode() for _ in tfres[1].numpy().tolist()],
|
||||||
|
txout[1].tolist())
|
||||||
|
self.assertEqual(tfres[0].numpy().tolist(), txout[0].tolist())
|
||||||
|
self.assertEqual(tfres[2].numpy().tolist(), txout[2].tolist())
|
||||||
|
|
||||||
|
if skip_empty:
|
||||||
|
exp_indices = np.array(
|
||||||
|
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||||
|
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||||
|
else:
|
||||||
|
exp_indices = np.array(
|
||||||
|
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||||
|
exp_text = np.array(['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||||
|
exp_shape = np.array([4, 3])
|
||||||
|
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
|
||||||
|
self.assertEqual(exp_text.tolist(), txout[1].tolist())
|
||||||
|
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче