Add operator StringSplit (#24)
This commit is contained in:
Родитель
fadcf2ab89
Коммит
db43f413b8
|
@ -0,0 +1,119 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_split.hpp"
|
||||
|
||||
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
|
||||
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
|
||||
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
|
||||
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 1)
|
||||
throw std::runtime_error("Only 1D tensor are supported as input.");
|
||||
|
||||
std::vector<std::string> words;
|
||||
std::vector<int64_t> indices;
|
||||
int64_t maxc = 0;
|
||||
int64_t col;
|
||||
std::string delimiter = *sep;
|
||||
bool keep = !(*skip_empty);
|
||||
std::size_t current, previous = 0;
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[row];
|
||||
if (str.empty())
|
||||
continue;
|
||||
previous = 0;
|
||||
col = 0;
|
||||
current = str.find_first_of(delimiter);
|
||||
while (current != std::string::npos) {
|
||||
if (keep || current > previous) {
|
||||
words.push_back(str.substr(previous, current - previous));
|
||||
indices.push_back(row);
|
||||
indices.push_back(col);
|
||||
++col;
|
||||
}
|
||||
previous = current + 1;
|
||||
current = str.find_first_of(delimiter, previous);
|
||||
}
|
||||
if (keep || current > previous) {
|
||||
words.push_back(str.substr(previous, current - previous));
|
||||
indices.push_back(row);
|
||||
indices.push_back(col);
|
||||
++col;
|
||||
}
|
||||
maxc = col > maxc ? col : maxc;
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
|
||||
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
|
||||
|
||||
std::vector<int64_t> shape_text(1, words.size());
|
||||
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
|
||||
|
||||
std::vector<int64_t> shape_shape(1, 2);
|
||||
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
|
||||
|
||||
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
|
||||
std::string* p_text = ort_.GetTensorMutableData<std::string>(out_text);
|
||||
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
|
||||
|
||||
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
p_shape[0] = dimensions[0];
|
||||
p_shape[1] = maxc;
|
||||
std::copy(words.begin(), words.end(), p_text);
|
||||
}
|
||||
|
||||
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
||||
return new KernelStringSplit(api);
|
||||
};
|
||||
|
||||
const char* CustomOpStringSplit::GetName() const {
|
||||
return "StringSplit";
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected output index ", index));
|
||||
}
|
||||
};
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils.h"
|
||||
|
||||
struct KernelStringSplit : BaseKernel {
|
||||
KernelStringSplit(OrtApi api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info);
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
|
@ -5,6 +5,7 @@
|
|||
#include "kernels/string_hash.hpp"
|
||||
#include "kernels/string_join.hpp"
|
||||
#include "kernels/string_regex_replace.hpp"
|
||||
#include "kernels/string_split.hpp"
|
||||
#include "kernels/string_upper.hpp"
|
||||
#include "kernels/test_output.hpp"
|
||||
#include "utils.h"
|
||||
|
@ -15,6 +16,7 @@ CustomOpStringHash c_CustomOpStringHash;
|
|||
CustomOpStringHashFast c_CustomOpStringHashFast;
|
||||
CustomOpStringJoin c_CustomOpStringJoin;
|
||||
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
|
||||
CustomOpStringSplit c_CustomOpStringSplit;
|
||||
CustomOpStringUpper c_CustomOpStringUpper;
|
||||
CustomOpOne c_CustomOpOne;
|
||||
CustomOpTwo c_CustomOpTwo;
|
||||
|
@ -26,6 +28,7 @@ OrtCustomOp* operator_lists[] = {
|
|||
&c_CustomOpStringHashFast,
|
||||
&c_CustomOpStringJoin,
|
||||
&c_CustomOpStringRegexReplace,
|
||||
&c_CustomOpStringSplit,
|
||||
&c_CustomOpStringUpper,
|
||||
&c_CustomOpOne,
|
||||
&c_CustomOpTwo,
|
||||
|
|
|
@ -150,6 +150,36 @@ def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
|
|||
return model
|
||||
|
||||
|
||||
def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
|
||||
nodes = []
|
||||
nodes.append(helper.make_node('Identity', ['input'], ['id1']))
|
||||
nodes.append(helper.make_node('Identity', ['delimiter'], ['id2']))
|
||||
nodes.append(helper.make_node('Identity', ['skip_empty'], ['id3']))
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
'%sStringSplit' % prefix, ['id1', 'id2', 'id3'],
|
||||
['indices', 'values', 'shape'], domain=domain))
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'input', onnx_proto.TensorProto.STRING, [])
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'delimiter', onnx_proto.TensorProto.STRING, [])
|
||||
input2 = helper.make_tensor_value_info(
|
||||
'skip_empty', onnx_proto.TensorProto.BOOL, [])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'indices', onnx_proto.TensorProto.INT64, [])
|
||||
output1 = helper.make_tensor_value_info(
|
||||
'values', onnx_proto.TensorProto.STRING, [])
|
||||
output2 = helper.make_tensor_value_info(
|
||||
'shape', onnx_proto.TensorProto.INT64, [])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
|
||||
[output0, output1, output2])
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
|
||||
return model
|
||||
|
||||
|
||||
class TestPythonOpString(unittest.TestCase):
|
||||
|
||||
_string_join = None
|
||||
|
@ -246,6 +276,36 @@ class TestPythonOpString(unittest.TestCase):
|
|||
def string_equal(x, y):
|
||||
return x == y
|
||||
|
||||
@onnx_op(op_type="PyStringSplit",
|
||||
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
|
||||
PyCustomOpDef.dt_bool],
|
||||
outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_string,
|
||||
PyCustomOpDef.dt_int64])
|
||||
def string_split(input, delimiter, skip_empty):
|
||||
if delimiter.shape != (1, ):
|
||||
raise RuntimeError("demiliter must a single element tensor.")
|
||||
if skip_empty.shape != (1, ):
|
||||
raise RuntimeError("skip_empty must a single element tensor.")
|
||||
if len(input.shape) != 1:
|
||||
raise RuntimeError("input must a one dimension tensor.")
|
||||
delimiter = delimiter[0]
|
||||
skip_empty = skip_empty[0]
|
||||
texts = []
|
||||
indices = []
|
||||
max_split = 0
|
||||
for row, text in enumerate(input):
|
||||
if not text:
|
||||
continue
|
||||
res = text.split(delimiter)
|
||||
if skip_empty:
|
||||
res = [t for t in res if t]
|
||||
texts.extend(res)
|
||||
max_split = max(max_split, len(res))
|
||||
indices.extend((row, i) for i in range(len(res)))
|
||||
return (np.array(indices, dtype=np.int64),
|
||||
np.array(texts),
|
||||
np.array([len(input), max_split], dtype=np.int64))
|
||||
|
||||
cls._string_join = string_join
|
||||
cls._string_to_crc32 = string_to_crc32
|
||||
|
||||
|
@ -576,6 +636,79 @@ class TestPythonOpString(unittest.TestCase):
|
|||
txout = sess.run(None, {'x': y, 'y': x})
|
||||
self.assertEqual(txout[0].tolist(), (y == x).tolist())
|
||||
|
||||
def test_string_split_python(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_split('Py')
|
||||
self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
|
||||
delimiter = np.array([","])
|
||||
|
||||
for skip in [True, False]:
|
||||
with self.subTest(skip=skip):
|
||||
skip_empty = np.array([skip])
|
||||
|
||||
txout = sess.run(
|
||||
None, {'input': input, 'delimiter': delimiter,
|
||||
'skip_empty': skip_empty})
|
||||
|
||||
if skip_empty:
|
||||
exp_indices = np.array(
|
||||
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||
else:
|
||||
exp_indices = np.array(
|
||||
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||
exp_text = np.array(['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||
exp_shape = np.array([4, 3])
|
||||
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
|
||||
self.assertEqual(exp_text.tolist(), txout[1].tolist())
|
||||
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
|
||||
|
||||
def test_string_split_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_split('')
|
||||
self.assertIn('op_type: "StringSplit"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
|
||||
delimiter = np.array([","])
|
||||
|
||||
for skip in [True, False]:
|
||||
with self.subTest(skip=skip):
|
||||
skip_empty = np.array([skip])
|
||||
|
||||
txout = sess.run(
|
||||
None, {'input': input, 'delimiter': delimiter,
|
||||
'skip_empty': skip_empty})
|
||||
|
||||
try:
|
||||
from tensorflow.raw_ops import StringSplit
|
||||
dotf = True
|
||||
except ImportError:
|
||||
dotf = False
|
||||
if dotf:
|
||||
tfres = StringSplit(
|
||||
input=input, delimiter=",,", skip_empty=skip)
|
||||
self.assertEqual([_.decode() for _ in tfres[1].numpy().tolist()],
|
||||
txout[1].tolist())
|
||||
self.assertEqual(tfres[0].numpy().tolist(), txout[0].tolist())
|
||||
self.assertEqual(tfres[2].numpy().tolist(), txout[2].tolist())
|
||||
|
||||
if skip_empty:
|
||||
exp_indices = np.array(
|
||||
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||
else:
|
||||
exp_indices = np.array(
|
||||
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2], [3, 0]])
|
||||
exp_text = np.array(['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
|
||||
exp_shape = np.array([4, 3])
|
||||
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
|
||||
self.assertEqual(exp_text.tolist(), txout[1].tolist())
|
||||
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче