onnxruntime-extensions/pyop/py_c_api.cc

198 строки
7.2 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <pybind11/iostream.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <thread>
#include "ortx_utils.h"
#include "ortx_tokenizer.h"
#include "ortx_processor.h"
#include "pykernel.h"
namespace py = pybind11;
template <typename T>
int64_t NumOfElement(const T& sp) {
size_t c = 1;
for (auto v : sp) {
c *= v;
}
return c;
}
void AddGlobalMethodsCApi(pybind11::module& m) {
m.def(
"create_processor",
[](const char* processor_def_json) {
OrtxProcessor* processor = nullptr;
auto err = OrtxCreateProcessor(&processor, processor_def_json);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to create processor") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(processor);
},
"Create a processor.");
m.def(
"load_images",
[](const std::vector<std::string>& image_paths) {
OrtxRawImages* images = nullptr;
size_t num_images = image_paths.size();
auto image_ptrs = std::make_unique<const char*[]>(num_images);
for (size_t i = 0; i < num_images; ++i) {
image_ptrs[i] = image_paths[i].c_str();
}
auto err = OrtxLoadImages(&images, image_ptrs.get(), num_images, nullptr);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to load images") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(images);
},
"Load images.");
m.def(
"image_pre_process",
[](std::uintptr_t processor_h, std::uintptr_t images_h) {
OrtxProcessor* processor = reinterpret_cast<OrtxProcessor*>(processor_h);
OrtxRawImages* images = reinterpret_cast<OrtxRawImages*>(images_h);
OrtxTensorResult* result{};
auto err = OrtxImagePreProcess(processor, images, &result);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to preprocess images") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(result);
},
"Preprocess images.");
m.def(
"tensor_result_get_at",
[](std::uintptr_t result_h, size_t index) {
OrtxTensorResult* result = reinterpret_cast<OrtxTensorResult*>(result_h);
OrtxTensor* tensor{};
auto err = OrtxTensorResultGetAt(result, index, &tensor);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get tensor") + OrtxGetLastErrorMessage());
}
extDataType_t tensor_type;
OrtxGetTensorType(tensor, &tensor_type);
const int64_t* shape{};
size_t num_dims;
const void* data{};
size_t elem_size = 1;
if (tensor_type == extDataType_t::kOrtxInt64 ||
tensor_type == extDataType_t::kOrtxFloat ||
tensor_type == extDataType_t::kOrtxUint8) {
OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&data), &shape, &num_dims);
OrtxGetTensorSizeOfElement(tensor, &elem_size);
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
throw std::runtime_error("unsupported tensor type");
}
std::vector<std::size_t> npy_dims;
for (auto n = num_dims - num_dims; n < num_dims; ++n) {
npy_dims.push_back(shape[n]);
}
py::array obj{};
if (tensor_type == extDataType_t::kOrtxFloat) {
obj = py::array_t<float>(npy_dims);
} else if (tensor_type == extDataType_t::kOrtxInt64) {
obj = py::array_t<int64_t>(npy_dims);
} else if (tensor_type == extDataType_t::kOrtxUint8) {
obj = py::array_t<uint8_t>(npy_dims);
}
void* out_ptr = obj.mutable_data();
memcpy(out_ptr, data, NumOfElement(npy_dims) * elem_size);
return obj;
},
"Get tensor at index.");
m.def(
"create_tokenizer",
[](std::string tokenizer_def_json) {
OrtxTokenizer* tokenizer = nullptr;
auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str());
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(tokenizer);
},
"Create a tokenizer.");
m.def(
"batch_tokenize",
[](std::uintptr_t h, const std::vector<std::string>& inputs) -> std::vector<std::vector<int64_t>> {
std::vector<std::vector<int64_t>> output;
OrtxTokenizer* tokenizer = reinterpret_cast<OrtxTokenizer*>(h);
OrtxTokenId2DArray* tid_output = nullptr;
std::vector<const char*> cs_inputs;
for (const auto& input : inputs) {
cs_inputs.push_back(input.c_str());
}
auto err = OrtxTokenize(tokenizer, cs_inputs.data(), inputs.size(), &tid_output);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to tokenize") + OrtxGetLastErrorMessage());
}
for (size_t i = 0; i < inputs.size(); ++i) {
const extTokenId_t* t2d{};
size_t length{};
err = OrtxTokenId2DArrayGetItem(tid_output, i, &t2d, &length);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get token id") + OrtxGetLastErrorMessage());
}
output.push_back(std::vector<int64_t>(t2d, t2d + length));
}
OrtxDisposeOnly(tid_output);
return output;
},
"Batch tokenize.");
m.def(
"batch_detokenize",
[](std::uintptr_t h, const std::vector<std::vector<int64_t>>& inputs) -> std::vector<std::string> {
std::vector<std::string> result;
OrtxTokenizer* tokenizer = reinterpret_cast<OrtxTokenizer*>(h);
OrtxStringArray* output = nullptr;
for (size_t i = 0; i < inputs.size(); ++i) {
std::vector<extTokenId_t> input;
input.reserve(inputs[i].size());
std::transform(inputs[i].begin(), inputs[i].end(), std::back_inserter(input),
[](int64_t v) { return static_cast<extTokenId_t>(v); });
auto err = OrtxDetokenize1D(tokenizer, input.data(), input.size(), &output);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to detokenize") + OrtxGetLastErrorMessage());
}
size_t length;
err = OrtxStringArrayGetBatch(output, &length);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get batch size") + OrtxGetLastErrorMessage());
}
for (size_t i = 0; i < length; ++i) {
const char* item;
err = OrtxStringArrayGetItem(output, i, &item);
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to get item") + OrtxGetLastErrorMessage());
}
result.push_back(item);
}
OrtxDisposeOnly(output);
}
return result;
},
"Batch detokenize.");
m.def(
"delete_object", [](std::uintptr_t h) { OrtxDisposeOnly(reinterpret_cast<OrtxObject*>(h)); },
"Delete the object created by C API.");
}