This commit is contained in:
Xavier Dupré 2020-11-16 19:05:12 +01:00 коммит произвёл GitHub
Родитель fadcf2ab89
Коммит db43f413b8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 276 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,119 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_split.hpp"
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
}
void KernelStringSplit::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const std::string* X = ort_.GetTensorData<std::string>(input_X);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
throw std::runtime_error("Input 2 is the delimiter, it has 1 element.");
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
throw std::runtime_error("Input 3 is skip_empty, it has 1 element.");
OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 1)
throw std::runtime_error("Only 1D tensor are supported as input.");
std::vector<std::string> words;
std::vector<int64_t> indices;
int64_t maxc = 0;
int64_t col;
std::string delimiter = *sep;
bool keep = !(*skip_empty);
std::size_t current, previous = 0;
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[row];
if (str.empty())
continue;
previous = 0;
col = 0;
current = str.find_first_of(delimiter);
while (current != std::string::npos) {
if (keep || current > previous) {
words.push_back(str.substr(previous, current - previous));
indices.push_back(row);
indices.push_back(col);
++col;
}
previous = current + 1;
current = str.find_first_of(delimiter, previous);
}
if (keep || current > previous) {
words.push_back(str.substr(previous, current - previous));
indices.push_back(row);
indices.push_back(col);
++col;
}
maxc = col > maxc ? col : maxc;
}
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
std::vector<int64_t> shape_text(1, words.size());
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
std::vector<int64_t> shape_shape(1, 2);
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
std::string* p_text = ort_.GetTensorMutableData<std::string>(out_text);
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
p_shape[0] = dimensions[0];
p_shape[1] = maxc;
std::copy(words.begin(), words.end(), p_text);
}
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
return new KernelStringSplit(api);
};
const char* CustomOpStringSplit::GetName() const {
return "StringSplit";
};
size_t CustomOpStringSplit::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
}
};
size_t CustomOpStringSplit::GetOutputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
switch (index) {
case 0:
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
default:
throw std::runtime_error(MakeString("Unexpected output index ", index));
}
};

Просмотреть файл

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "kernels.h"
#include "utils.h"
struct KernelStringSplit : BaseKernel {
KernelStringSplit(OrtApi api);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info);
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -5,6 +5,7 @@
#include "kernels/string_hash.hpp"
#include "kernels/string_join.hpp"
#include "kernels/string_regex_replace.hpp"
#include "kernels/string_split.hpp"
#include "kernels/string_upper.hpp"
#include "kernels/test_output.hpp"
#include "utils.h"
@ -15,6 +16,7 @@ CustomOpStringHash c_CustomOpStringHash;
CustomOpStringHashFast c_CustomOpStringHashFast;
CustomOpStringJoin c_CustomOpStringJoin;
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
CustomOpStringSplit c_CustomOpStringSplit;
CustomOpStringUpper c_CustomOpStringUpper;
CustomOpOne c_CustomOpOne;
CustomOpTwo c_CustomOpTwo;
@ -26,6 +28,7 @@ OrtCustomOp* operator_lists[] = {
&c_CustomOpStringHashFast,
&c_CustomOpStringJoin,
&c_CustomOpStringRegexReplace,
&c_CustomOpStringSplit,
&c_CustomOpStringUpper,
&c_CustomOpOne,
&c_CustomOpTwo,

Просмотреть файл

@ -150,6 +150,36 @@ def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
return model
def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
nodes = []
nodes.append(helper.make_node('Identity', ['input'], ['id1']))
nodes.append(helper.make_node('Identity', ['delimiter'], ['id2']))
nodes.append(helper.make_node('Identity', ['skip_empty'], ['id3']))
nodes.append(
helper.make_node(
'%sStringSplit' % prefix, ['id1', 'id2', 'id3'],
['indices', 'values', 'shape'], domain=domain))
input0 = helper.make_tensor_value_info(
'input', onnx_proto.TensorProto.STRING, [])
input1 = helper.make_tensor_value_info(
'delimiter', onnx_proto.TensorProto.STRING, [])
input2 = helper.make_tensor_value_info(
'skip_empty', onnx_proto.TensorProto.BOOL, [])
output0 = helper.make_tensor_value_info(
'indices', onnx_proto.TensorProto.INT64, [])
output1 = helper.make_tensor_value_info(
'values', onnx_proto.TensorProto.STRING, [])
output2 = helper.make_tensor_value_info(
'shape', onnx_proto.TensorProto.INT64, [])
graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
[output0, output1, output2])
model = helper.make_model(
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
return model
class TestPythonOpString(unittest.TestCase):
_string_join = None
@ -246,6 +276,36 @@ class TestPythonOpString(unittest.TestCase):
def string_equal(x, y):
return x == y
@onnx_op(op_type="PyStringSplit",
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
PyCustomOpDef.dt_bool],
outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_string,
PyCustomOpDef.dt_int64])
def string_split(input, delimiter, skip_empty):
if delimiter.shape != (1, ):
raise RuntimeError("demiliter must a single element tensor.")
if skip_empty.shape != (1, ):
raise RuntimeError("skip_empty must a single element tensor.")
if len(input.shape) != 1:
raise RuntimeError("input must a one dimension tensor.")
delimiter = delimiter[0]
skip_empty = skip_empty[0]
texts = []
indices = []
max_split = 0
for row, text in enumerate(input):
if not text:
continue
res = text.split(delimiter)
if skip_empty:
res = [t for t in res if t]
texts.extend(res)
max_split = max(max_split, len(res))
indices.extend((row, i) for i in range(len(res)))
return (np.array(indices, dtype=np.int64),
np.array(texts),
np.array([len(input), max_split], dtype=np.int64))
cls._string_join = string_join
cls._string_to_crc32 = string_to_crc32
@ -576,6 +636,79 @@ class TestPythonOpString(unittest.TestCase):
txout = sess.run(None, {'x': y, 'y': x})
self.assertEqual(txout[0].tolist(), (y == x).tolist())
def test_string_split_python(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_split('Py')
self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
delimiter = np.array([","])
for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])
txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})
if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
def test_string_split_cc(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_split('')
self.assertIn('op_type: "StringSplit"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
input = np.array(["a,,b", "", "aa,b,c", "dddddd"])
delimiter = np.array([","])
for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])
txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})
try:
from tensorflow.raw_ops import StringSplit
dotf = True
except ImportError:
dotf = False
if dotf:
tfres = StringSplit(
input=input, delimiter=",,", skip_empty=skip)
self.assertEqual([_.decode() for _ in tfres[1].numpy().tolist()],
txout[1].tolist())
self.assertEqual(tfres[0].numpy().tolist(), txout[0].tolist())
self.assertEqual(tfres[2].numpy().tolist(), txout[2].tolist())
if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
if __name__ == "__main__":
unittest.main()