Extend StringJoin to support any dimension (#18)

This commit is contained in:
Xavier Dupré 2020-11-10 18:53:24 +01:00 коммит произвёл GitHub
Родитель 6927998670
Коммит 2ef88b0bda
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 166 добавлений и 24 удалений

Просмотреть файл

@ -12,15 +12,32 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
const std::string* X = ort_.GetTensorData<std::string>(input_X);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
const OrtValue* input_axis = ort_.KernelContext_GetInput(context, 2);
const int64_t* axis = ort_.GetTensorData<int64_t>(input_axis);
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
throw std::runtime_error("Input 2 is the separator, it has 1 element.");
OrtTensorDimensions dimensions_axis(ort_, input_axis);
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
throw std::runtime_error("Input 3 is the axis, it has 1 element.");
OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 2)
throw std::runtime_error(MakeString("Input 1 must have 2 dimensions but has ", dimensions.size(), "."));
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), 1);
if (*axis < 0 || *axis >= dimensions.size())
throw std::runtime_error(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis));
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
if (dimensions.size() > 1) {
for (size_t i = 0, pos = 0; i < dimensions.size(); ++i) {
if (static_cast<int64_t>(i) == *axis)
continue;
dimensions_out[pos++] = dimensions[i];
}
} else {
dimensions_out[0] = 1;
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions_out.data(), dimensions_out.size());
std::string* out = ort_.GetTensorMutableData<std::string>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
@ -28,14 +45,25 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
int64_t index = 0;
for (int64_t i = 0; i < size; ++i) {
std::ostringstream st;
for (int64_t j = 0; j < dimensions[1] - 1; ++j, ++index) {
st << X[index] << *sep;
int64_t h = 1;
for (size_t i = *axis + 1; i < dimensions.size(); ++i) {
h *= dimensions[i];
}
int64_t left_part = size / h;
int64_t right_part = size / left_part;
int64_t n_red = dimensions[*axis] - 1;
int64_t inc = right_part * (n_red + 1);
int64_t pos = 0;
for (int64_t li = 0; li < left_part; ++li) {
for (int64_t ri = 0; ri < right_part; ++ri, ++pos) {
std::ostringstream st;
int64_t index = ri + li * inc;
for (int64_t j = 0; j < n_red; ++j, index += h) {
st << X[index] << *sep;
}
st << X[index];
out[pos] = st.str();
}
st << X[index++];
out[i] = st.str();
}
}
@ -48,11 +76,19 @@ const char* CustomOpStringJoin::GetName() const {
};
size_t CustomOpStringJoin::GetInputTypeCount() const {
return 2;
return 3;
};
ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
throw std::runtime_error(MakeString("Unexpected input index ", index));
}
};
size_t CustomOpStringJoin::GetOutputTypeCount() const {

Просмотреть файл

@ -17,9 +17,9 @@ def _create_test_model_string_upper(prefix, domain='ai.onnx.contrib'):
domain=domain)]
input0 = helper.make_tensor_value_info(
'input_1', onnx_proto.TensorProto.STRING, [None, 1])
'input_1', onnx_proto.TensorProto.STRING, [None, None])
output0 = helper.make_tensor_value_info(
'customout', onnx_proto.TensorProto.STRING, [None, 1])
'customout', onnx_proto.TensorProto.STRING, [None, None])
graph = helper.make_graph(nodes, 'test0', [input0], [output0])
model = helper.make_model(
@ -33,19 +33,24 @@ def _create_test_model_string_join(prefix, domain='ai.onnx.contrib'):
helper.make_node('Identity', ['text'], ['identity1']))
nodes.append(
helper.make_node('Identity', ['sep'], ['identity2']))
nodes.append(
helper.make_node('Identity', ['axis'], ['identity3']))
nodes.append(
helper.make_node(
'%sStringJoin' % prefix, ['identity1', 'identity2'],
'%sStringJoin' % prefix, ['identity1', 'identity2', 'identity3'],
['customout'], domain=domain))
input0 = helper.make_tensor_value_info(
'text', onnx_proto.TensorProto.STRING, [None, None])
'text', onnx_proto.TensorProto.STRING, None)
input1 = helper.make_tensor_value_info(
'sep', onnx_proto.TensorProto.STRING, [1])
input2 = helper.make_tensor_value_info(
'axis', onnx_proto.TensorProto.INT64, [1])
output0 = helper.make_tensor_value_info(
'customout', onnx_proto.TensorProto.STRING, [None, 1])
'customout', onnx_proto.TensorProto.STRING, None)
graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
graph = helper.make_graph(
nodes, 'test0', [input0, input1, input2], [output0])
model = helper.make_model(
graph, opset_imports=[helper.make_operatorsetid(domain, 1)])
return model
@ -82,6 +87,8 @@ def _create_test_model_string_replace(prefix, domain='ai.onnx.contrib'):
class TestPythonOpString(unittest.TestCase):
_string_join = None
@classmethod
def setUpClass(cls):
@ -93,15 +100,33 @@ class TestPythonOpString(unittest.TestCase):
return np.array([s.upper() for s in x.ravel()]).reshape(x.shape)
@onnx_op(op_type="PyStringJoin",
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string],
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
PyCustomOpDef.dt_int64],
outputs=[PyCustomOpDef.dt_string])
def string_join(x, sep):
def string_join(x, sep, axis):
# The user custom op implementation here.
if sep.shape != (1, ):
raise RuntimeError(
"Unexpected shape {} for 'sep'.".format(sep.shape))
if axis.shape != (1, ):
raise RuntimeError(
"Unexpected shape {} for 'axis'.".format(axis.shape))
sp = sep[0]
return np.array([sp.join(row) for row in x])
ax = axis[0]
if ax < 0 or ax >= len(x.shape):
raise RuntimeError("axis must be in [%r,%r] but is" % (
0, len(x.shape), ax))
if len(x.shape) == 1:
return np.array([sp.join(x)])
dims = np.arange(len(x.shape))
dims[ax], dims[-1] = dims[-1], dims[ax]
x2 = np.transpose(x, dims)
res_shape = x2.shape[:-1]
x2 = x2.reshape((-1, x2.shape[-1]))
res = np.empty(x2.shape[0], dtype=x.dtype)
for i in range(x2.shape[0]):
res[i] = sp.join(x2[i, :])
return res.reshape(res_shape)
@onnx_op(op_type="PyStringRegexReplace",
inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_string,
@ -120,6 +145,8 @@ class TestPythonOpString(unittest.TestCase):
list(map(lambda t: reg.sub(rewrite[0], t), x.ravel())))
return res.reshape(x.shape)
cls._string_join = string_join
def test_check_types(self):
def_list = set(dir(PyCustomOpDef))
type_list = [
@ -193,9 +220,45 @@ class TestPythonOpString(unittest.TestCase):
np.array([["aa", "bb", ""]])])
self.assertEqual(text.shape, (2, 3))
sep = np.array([";"])
txout = sess.run(None, {'text': text, 'sep': sep})
axis = np.array([1], dtype=np.int64)
TestPythonOpString._string_join(text, sep, axis)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
axis = np.array([0], dtype=np.int64)
TestPythonOpString._string_join(text, sep, axis)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
def test_string_join_python_3d(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_join('Py')
self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
text = np.vstack([np.array([["a", "b", "c"]]),
np.array([["aa", "bb", ""]])]).reshape((2, 3, 1))
sep = np.array([";"])
axis = np.array([1], dtype=np.int64)
TestPythonOpString._string_join(text, sep, axis)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(), np.array([['a;b;c'], ['aa;bb;']]).tolist())
def test_string_join_python_1d(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_join('Py')
self.assertIn('op_type: "PyStringJoin"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
text = np.array(["a", "b", "cc"])
sep = np.array([";"])
axis = np.array([0], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(txout[0].shape, (1, ))
self.assertEqual(
txout[0].tolist(), np.array(["a;b;cc"]).tolist())
def test_string_join_cc(self):
so = _ort.SessionOptions()
@ -206,9 +269,52 @@ class TestPythonOpString(unittest.TestCase):
text = np.vstack([np.array([["a", "b", "c"]]),
np.array([["aa", "bb", ""]])])
sep = np.array([";"])
txout = sess.run(None, {'text': text, 'sep': sep})
axis = np.array([1], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(), np.array(["a;b;c", "aa;bb;"]).tolist())
axis = np.array([0], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(), np.array(['a;aa', 'b;bb', 'c;']).tolist())
def test_string_join_cc_1d(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_join('')
self.assertIn('op_type: "StringJoin"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
text = np.array(["a", "b", "cc"])
sep = np.array([";"])
axis = np.array([0], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(), np.array(["a;b;cc"]).tolist())
def test_string_join_cc_3d(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = _create_test_model_string_join('')
self.assertIn('op_type: "StringJoin"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
text = np.array(["a", "b", "c", "d", "e", "f", "g", "h"]).reshape((
2, 2, 2))
sep = np.array([";"])
axis = np.array([2], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(),
np.array([['a;b', 'c;d'], ['e;f', 'g;h']]).tolist())
axis = np.array([1], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(),
np.array([['a;c', 'b;d'], ['e;g', 'f;h']]).tolist())
axis = np.array([0], dtype=np.int64)
txout = sess.run(None, {'text': text, 'sep': sep, 'axis': axis})
self.assertEqual(
txout[0].tolist(),
np.array([['a;e', 'b;f'], ['c;g', 'd;h']]).tolist())
def test_string_replace_cc(self):
so = _ort.SessionOptions()