Extend StringJoin to support any dimension (#18)
This commit is contained in:
Родитель
6927998670
Коммит
2ef88b0bda
|
@ -12,15 +12,32 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
|
||||
const OrtValue* input_axis = ort_.KernelContext_GetInput(context, 2);
|
||||
const int64_t* axis = ort_.GetTensorData<int64_t>(input_axis);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
throw std::runtime_error("Input 2 is the separator, it has 1 element.");
|
||||
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
||||
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
||||
throw std::runtime_error("Input 3 is the axis, it has 1 element.");
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 2)
|
||||
throw std::runtime_error(MakeString("Input 1 must have 2 dimensions but has ", dimensions.size(), "."));
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), 1);
|
||||
if (*axis < 0 || *axis >= dimensions.size())
|
||||
throw std::runtime_error(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis));
|
||||
|
||||
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
||||
if (dimensions.size() > 1) {
|
||||
for (size_t i = 0, pos = 0; i < dimensions.size(); ++i) {
|
||||
if (static_cast<int64_t>(i) == *axis)
|
||||
continue;
|
||||
dimensions_out[pos++] = dimensions[i];
|
||||
}
|
||||
} else {
|
||||
dimensions_out[0] = 1;
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions_out.data(), dimensions_out.size());
|
||||
std::string* out = ort_.GetTensorMutableData<std::string>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
|
@ -28,14 +45,25 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
int64_t index = 0;
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
std::ostringstream st;
|
||||
for (int64_t j = 0; j < dimensions[1] - 1; ++j, ++index) {
|
||||
st << X[index] << *sep;
|
||||
int64_t h = 1;
|
||||
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
|
||||
h *= dimensions[i];
|
||||
}
|
||||
int64_t left_part = size / h;
|
||||
int64_t right_part = size / left_part;
|
||||
int64_t n_red = dimensions[*axis] - 1;
|
||||
int64_t inc = right_part * (n_red + 1);
|
||||
int64_t pos = 0;
|
||||
for (int64_t li = 0; li < left_part; ++li) {
|
||||
for (int64_t ri = 0; ri < right_part; ++ri, ++pos) {
|
||||
std::ostringstream st;
|
||||
int64_t index = ri + li * inc;
|
||||
for (int64_t j = 0; j < n_red; ++j, index += h) {
|
||||
st << X[index] << *sep;
|
||||
}
|
||||
st << X[index];
|
||||
out[pos] = st.str();
|
||||
}
|
||||
st << X[index++];
|
||||
out[i] = st.str();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,11 +76,19 @@ const char* CustomOpStringJoin::GetName() const {
|
|||
};
|
||||
|
||||
size_t CustomOpStringJoin::GetInputTypeCount() const {
|
||||
return 2;
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringJoin::GetOutputTypeCount() const {
|
||||
|
|
|
@ -17,9 +17,9 @@ def _create_test_model_string_upper(prefix, domain='ai.onnx.contrib'):
|
|||
domain=domain)]
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'input_1', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
'input_1', onnx_proto.TensorProto.STRING, [None, None])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'customout', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
'customout', onnx_proto.TensorProto.STRING, [None, None])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0], [output0])
|
||||
model = helper.make_model(
|
||||
|
@ -33,19 +33,24 @@ def _create_test_model_string_join(prefix, domain='ai.onnx.contrib'):
|
|||
helper.make_node('Identity', ['text'], ['identity1']))
|
||||
nodes.append(
|
||||
helper.make_node('Identity', ['sep'], ['identity2']))
|
||||
nodes.append(
|
||||
helper.make_node('Identity', ['axis'], ['identity3']))
|
||||
nodes.append(
|
||||
helper.make_node(
|
||||
'%sStringJoin' % prefix, ['identity1', 'identity2'],
|
||||
'%sStringJoin' % prefix, ['identity1', 'identity2', 'identity3'],
|
||||
['customout'], domain=domain))
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'text', onnx_proto.TensorProto.STRING, [None, None])
|
||||
'text', onnx_proto.TensorProto.STRING, None)
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'sep', onnx_proto.TensorProto.STRING, [1])
|
||||
input2 = helper.make_tensor_value_info(
|
||||
'axis', onnx_proto.TensorProto.INT64, [1])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'customout', onnx_proto.TensorProto.STRING, [None, 1])
|
||||
'customout', onnx_proto.TensorProto.STRING, None)
|
||||
|
||||
graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
|
||||
graph = helper.make_graph(
|
||||
nodes, 'test0', [input0, input1, input2], [output0])
|
||||
model = helper.make_model(
|
||||
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
|
||||
return model
|
||||
|
@ -82,6 +87,8 @@ def _create_test_model_string_replace(prefix, domain='ai.onnx.contrib'):
|
|||
|
||||
class TestPythonOpString(unittest.TestCase):
|
||||
|
||||
_string_join = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
|
@ -93,15 +100,33 @@ class TestPythonOpString(unittest.TestCase):
|
|||
return np.array([s.upper() for s in x.ravel()]).reshape(x.shape)
|
||||
|
||||
@onnx_op(op_type="PyStringJoin",
|
||||
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string],
|
||||
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
|
||||
PyCustomOpDef.dt_int64],
|
||||
outputs=[PyCustomOpDef.dt_string])
|
||||
def string_join(x, sep):
|
||||
def string_join(x, sep, axis):
|
||||
# The user custom op implementation here.
|
||||
if sep.shape != (1, ):
|
||||
raise RuntimeError(
|
||||
"Unexpected shape {} for 'sep'.".format(sep.shape))
|
||||
if axis.shape != (1, ):
|
||||
raise RuntimeError(
|
||||
"Unexpected shape {} for 'axis'.".format(axis.shape))
|
||||
sp = sep[0]
|
||||
return np.array([sp.join(row) for row in x])
|
||||
ax = axis[0]
|
||||
if ax < 0 or ax >= len(x.shape):
|
||||
raise RuntimeError("axis must be in [%r,%r] but is" % (
|
||||
0, len(x.shape), ax))
|
||||
if len(x.shape) == 1:
|
||||
return np.array([sp.join(x)])
|
||||
dims = np.arange(len(x.shape))
|
||||
dims[ax], dims[-1] = dims[-1], dims[ax]
|
||||
x2 = np.transpose(x, dims)
|
||||
res_shape = x2.shape[:-1]
|
||||
x2 = x2.reshape((-1, x2.shape[-1]))
|
||||
res = np.empty(x2.shape[0], dtype=x.dtype)
|
||||
for i in range(x2.shape[0]):
|
||||
res[i] = sp.join(x2[i, :])
|
||||
return res.reshape(res_shape)
|
||||
|
||||
@onnx_op(op_type="PyStringRegexReplace",
|
||||
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
|
||||
|
@ -120,6 +145,8 @@ class TestPythonOpString(unittest.TestCase):
|
|||
list(map(lambda t: reg.sub(rewrite[0], t), x.ravel())))
|
||||
return res.reshape(x.shape)
|
||||
|
||||
cls._string_join = string_join
|
||||
|
||||
def test_check_types(self):
|
||||
def_list = set(dir(PyCustomOpDef))
|
||||
type_list = [
|
||||
|
@ -193,9 +220,45 @@ class TestPythonOpString(unittest.TestCase):
|
|||
np.array([["aa", "bb", ""]])])
|
||||
self.assertEqual(text.shape, (2, 3))
|
||||
sep = np.array([";"])
|
||||
txout = sess.run(None, {'text': text, 'sep': sep})
|
||||
axis = np.array([1], dtype=np.int64)
|
||||
TestPythonOpString._string_join(text, sep, axis)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
TestPythonOpString._string_join(text, sep, axis)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
|
||||
|
||||
def test_string_join_python_3d(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('Py')
|
||||
self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.vstack([np.array([["a", "b", "c"]]),
|
||||
np.array([["aa", "bb", ""]])]).reshape((2, 3, 1))
|
||||
sep = np.array([";"])
|
||||
axis = np.array([1], dtype=np.int64)
|
||||
TestPythonOpString._string_join(text, sep, axis)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array([['a;b;c'], ['aa;bb;']]).tolist())
|
||||
|
||||
def test_string_join_python_1d(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('Py')
|
||||
self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.array(["a", "b", "cc"])
|
||||
sep = np.array([";"])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(txout[0].shape, (1, ))
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(["a;b;cc"]).tolist())
|
||||
|
||||
def test_string_join_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
|
@ -206,9 +269,52 @@ class TestPythonOpString(unittest.TestCase):
|
|||
text = np.vstack([np.array([["a", "b", "c"]]),
|
||||
np.array([["aa", "bb", ""]])])
|
||||
sep = np.array([";"])
|
||||
txout = sess.run(None, {'text': text, 'sep': sep})
|
||||
axis = np.array([1], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
|
||||
|
||||
def test_string_join_cc_1d(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.array(["a", "b", "cc"])
|
||||
sep = np.array([";"])
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(), np.array(["a;b;cc"]).tolist())
|
||||
|
||||
def test_string_join_cc_3d(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = _create_test_model_string_join('')
|
||||
self.assertIn('op_type: "StringJoin"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
text = np.array(["a", "b", "c", "d", "e", "f", "g", "h"]).reshape((
|
||||
2, 2, 2))
|
||||
sep = np.array([";"])
|
||||
axis = np.array([2], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(),
|
||||
np.array([['a;b', 'c;d'], ['e;f', 'g;h']]).tolist())
|
||||
axis = np.array([1], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(),
|
||||
np.array([['a;c', 'b;d'], ['e;g', 'f;h']]).tolist())
|
||||
axis = np.array([0], dtype=np.int64)
|
||||
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
|
||||
self.assertEqual(
|
||||
txout[0].tolist(),
|
||||
np.array([['a;e', 'b;f'], ['c;g', 'd;h']]).tolist())
|
||||
|
||||
def test_string_replace_cc(self):
|
||||
so = _ort.SessionOptions()
|
||||
|
|
Загрузка…
Ссылка в новой задаче