ort-customops/operators/text/string_join.cc

79 строки
2.8 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_functions.h"
#include "string_tensor.h"
OrtStatusPtr string_join(const ortc::Tensor<std::string>& input_X,
std::string_view input_sep,
int64_t axis,
ortc::Tensor<std::string>& output) {
OrtStatusPtr status = nullptr;
// Setup inputs
auto& X = input_X.Data();
auto& dimensions = input_X.Shape();
if (dimensions.size() == 0) {
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
if (X.size() != 1) {
status = OrtW::CreateStatus(
MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()).c_str(), ORT_INVALID_ARGUMENT);
return status;
}
} else {
if (axis < 0 || axis >= static_cast<int64_t>(dimensions.size())) {
status = OrtW::CreateStatus(
MakeString("axis must be positive and smaller than the number of dimension but it is ", axis).c_str(), ORT_INVALID_ARGUMENT);
return status;
}
}
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
if (dimensions.size() > 1) {
for (size_t i = 0, pos = 0; i < dimensions.size(); ++i) {
if (static_cast<int64_t>(i) == axis)
continue;
dimensions_out[pos++] = dimensions[i];
}
} else {
dimensions_out[0] = 1;
}
int64_t size = std::accumulate(dimensions_out.begin(), dimensions_out.end(), 1ULL, std::multiplies<int64_t>());
std::vector<std::string> out(static_cast<size_t>(size));
if (dimensions.size() > 0) {
if (X.size() > 0) {
// Do computation
int64_t h = 1;
for (size_t i = static_cast<size_t>(axis + 1); i < dimensions.size(); ++i) {
h *= dimensions[i];
}
int64_t left_part = size / h;
int64_t right_part = size / left_part;
int64_t n_red = dimensions[static_cast<size_t>(axis)] - 1;
int64_t inc = right_part * (n_red + 1);
int64_t pos = 0;
for (int64_t li = 0; li < left_part; ++li) {
for (int64_t ri = 0; ri < right_part; ++ri, ++pos) {
std::ostringstream st;
int64_t index = ri + li * inc;
for (int64_t j = 0; j < n_red; ++j, index += h) {
st << X[static_cast<size_t>(index)] << input_sep;
}
st << X[static_cast<size_t>(index)];
out[static_cast<size_t>(pos)] = st.str();
}
}
} else {
// for input 1 contains 0 elements, output joined string is empty string
out[0] = "";
}
} else {
// for input 1 (scalar) which has 1 element, output joined string is input string itself. See issue: https://github.com/onnx/onnx/issues/3724
out[0] = X[0];
}
output.SetStringOutput(out, dimensions_out);
return status;
}