Add string strip text operator (#460)
* add string strip text operator --------- Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Родитель
93f239c143
Коммит
30eb7afcfa
|
@ -0,0 +1,54 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_strip.hpp"
|
||||
#include "string_tensor.h"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
const char* WHITE_SPACE_CHARS = " \t\n\r\f\v";
|
||||
|
||||
KernelStringStrip::KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringStrip::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> X;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
|
||||
// For each string in input, replace with whitespace-trimmed version.
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
size_t nonWhitespaceBegin = X[i].find_first_not_of(WHITE_SPACE_CHARS);
|
||||
if (nonWhitespaceBegin != std::string::npos) {
|
||||
size_t nonWhitespaceEnd = X[i].find_last_not_of(WHITE_SPACE_CHARS);
|
||||
size_t nonWhitespaceRange = nonWhitespaceEnd - nonWhitespaceBegin + 1;
|
||||
|
||||
X[i] = X[i].substr(nonWhitespaceBegin, nonWhitespaceRange);
|
||||
}
|
||||
}
|
||||
|
||||
// Fills the output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
FillTensorDataString(api_, ort_, context, X, output);
|
||||
}
|
||||
|
||||
const char* CustomOpStringStrip::GetName() const { return "StringStrip"; };
|
||||
|
||||
size_t CustomOpStringStrip::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringStrip::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringStrip::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringStrip::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringStrip : BaseKernel {
|
||||
KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringStrip : OrtW::CustomOpBase<CustomOpStringStrip, KernelStringStrip> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
|
@ -4,6 +4,7 @@
|
|||
#include "text/string_join.hpp"
|
||||
#include "text/string_lower.hpp"
|
||||
#include "text/string_split.hpp"
|
||||
#include "text/string_strip.hpp"
|
||||
#include "text/string_to_vector.hpp"
|
||||
#include "text/string_upper.hpp"
|
||||
#include "text/vector_to_string.hpp"
|
||||
|
@ -17,15 +18,14 @@
|
|||
#if defined(ENABLE_RE2_REGEX)
|
||||
#include "text/re2_strings/string_regex_replace.hpp"
|
||||
#include "text/re2_strings/string_regex_split.hpp"
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
#if defined(ENABLE_RE2_REGEX)
|
||||
CustomOpStringRegexReplace,
|
||||
CustomOpStringRegexSplitWithOffsets,
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
CustomOpRaggedTensorToDense,
|
||||
CustomOpRaggedTensorToSparse,
|
||||
CustomOpStringRaggedTensorToDense,
|
||||
|
@ -38,10 +38,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Text =
|
|||
CustomOpStringMapping,
|
||||
CustomOpMaskedFill,
|
||||
CustomOpStringSplit,
|
||||
CustomOpStringStrip,
|
||||
CustomOpStringToVector,
|
||||
CustomOpVectorToString,
|
||||
CustomOpStringLength,
|
||||
CustomOpStringConcat,
|
||||
CustomOpStringECMARegexReplace,
|
||||
CustomOpStringECMARegexSplitWithOffsets
|
||||
>;
|
||||
CustomOpStringECMARegexSplitWithOffsets>;
|
||||
|
|
|
@ -173,6 +173,22 @@ def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
|
|||
return model
|
||||
|
||||
|
||||
def _create_test_model_string_strip(prefix, domain='ai.onnx.contrib'):
|
||||
nodes = []
|
||||
nodes[0:] = [helper.make_node('Identity', ['input_1'], ['identity1'])]
|
||||
nodes[1:] = [helper.make_node('%sStringStrip' % prefix,
|
||||
['identity1'], ['customout'],
|
||||
domain=domain)]
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'input_1', onnx_proto.TensorProto.STRING, [None, None])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'customout', onnx_proto.TensorProto.STRING, [None, None])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0], [output0])
|
||||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
|
||||
nodes = []
|
||||
nodes.append(helper.make_node('Identity', ['input'], ['id1']))
|
||||
|
@ -436,6 +452,26 @@ class TestPythonOpString(unittest.TestCase):
|
|||
for t in type_list:
|
||||
self.assertIn(t, def_list)
|
||||
|
||||
def test_string_strip_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_strip('')
|
||||
self.assertIn('op_type: "StringStrip"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([[" a b c "]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["a b c"]]).tolist())
|
||||
|
||||
def test_string_strip_cc_empty(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_strip('')
|
||||
self.assertIn('op_type: "StringStrip"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([[""]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([[""]]).tolist())
|
||||
|
||||
def test_string_upper_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
|
@ -1151,7 +1187,6 @@ class TestPythonOpString(unittest.TestCase):
|
|||
res.__len__ = lambda self: len(vocab)
|
||||
|
||||
vocab_table = _CreateTable(["want", "##want", "##ed", "wa", "un", "runn", "##ing"])
|
||||
|
||||
text = tf.convert_to_tensor(["unwanted running", "unwantedX running"], dtype=tf.string)
|
||||
try:
|
||||
tf_tokens, tf_rows, tf_begins, tf_ends = (
|
||||
|
|
Загрузка…
Ссылка в новой задаче