Test the Gpt2 Tokenizer both in two modes, PyOP and Native (#40)
* add a flag to enable pyop * test passed * a little polish.
This commit is contained in:
Родитель
4e0af5c582
Коммит
9ec6951516
|
@ -14,5 +14,6 @@ const char c_OpDomain[] = "ai.onnx.contrib";
|
|||
#if defined(PYTHON_OP_SUPPORT)
|
||||
|
||||
const OrtCustomOp* FetchPyCustomOps(size_t& count);
|
||||
bool EnablePyCustomOps(bool enable=true);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "kernels/op_equal.hpp"
|
||||
#include "kernels/op_segment_sum.hpp"
|
||||
#include "kernels/string_hash.hpp"
|
||||
|
@ -40,19 +42,12 @@ OrtCustomOp* operator_lists[] = {
|
|||
extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
|
||||
std::set<std::string> pyop_nameset;
|
||||
|
||||
if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtCustomOp** ops = operator_lists;
|
||||
while (*ops != nullptr) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
|
||||
return status;
|
||||
}
|
||||
++ops;
|
||||
}
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
size_t count = 0;
|
||||
const OrtCustomOp* c_ops = FetchPyCustomOps(count);
|
||||
|
@ -60,16 +55,31 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
|
|||
if (auto status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
|
||||
return status;
|
||||
}
|
||||
else {
|
||||
pyop_nameset.emplace(c_ops->GetName(c_ops));
|
||||
}
|
||||
++count;
|
||||
c_ops = FetchPyCustomOps(count);
|
||||
}
|
||||
#endif
|
||||
|
||||
OrtCustomOp** ops = operator_lists;
|
||||
while (*ops != nullptr) {
|
||||
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
++ops;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_TOKENIZER)
|
||||
auto** t_ops = LoadTokenizerSchemaList();
|
||||
const OrtCustomOp** t_ops = LoadTokenizerSchemaList();
|
||||
while (*t_ops != nullptr) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)){
|
||||
return status;
|
||||
if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)){
|
||||
return status;
|
||||
}
|
||||
}
|
||||
t_ops++;
|
||||
}
|
||||
|
|
|
@ -376,6 +376,11 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
|
|||
}
|
||||
|
||||
const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count) {
|
||||
if (!EnablePyCustomOps(true)) {
|
||||
EnablePyCustomOps(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// The result must stay alive
|
||||
std::vector<PyCustomOpFactory>& copy = PyCustomOpDef_python_operator_list();
|
||||
if (count < copy.size())
|
||||
|
@ -390,6 +395,13 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count) {
|
|||
return ptr;
|
||||
}
|
||||
|
||||
bool EnablePyCustomOps(bool enabled){
|
||||
static bool f_pyop_enabled = true;
|
||||
bool last = f_pyop_enabled;
|
||||
f_pyop_enabled = enabled;
|
||||
return last;
|
||||
}
|
||||
|
||||
// static std::ofstream logger;
|
||||
static int init_numpy() {
|
||||
import_array();
|
||||
|
@ -406,6 +418,7 @@ uint64_t hash_64(const std::string& str, uint64_t num_buckets, bool fast) {
|
|||
}
|
||||
|
||||
void AddGlobalMethods(pybind11::module& m) {
|
||||
m.def("enable_custom_op", &EnablePyCustomOps, "Enable or disable pyop functions.");
|
||||
m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); });
|
||||
m.def("hash_64", &hash_64, "Computes a uint64 hash for a string (from tensorflow).");
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ __author__ = "Microsoft"
|
|||
|
||||
|
||||
from ._ocos import get_library_path # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef, hash_64 # noqa
|
||||
from ._ocos import Opdef, PyCustomOpDef, hash_64, enable_custom_op # noqa
|
||||
from ._ocos import expand_onnx_inputs # noqa
|
||||
|
||||
onnx_op = Opdef.declare
|
||||
|
|
|
@ -6,9 +6,8 @@
|
|||
import sys
|
||||
import copy
|
||||
from onnx import helper
|
||||
from pathlib import Path
|
||||
from ._ortcustomops import ( # noqa
|
||||
PyCustomOpDef, add_custom_op, hash_64)
|
||||
PyCustomOpDef, enable_custom_op, add_custom_op, hash_64)
|
||||
|
||||
|
||||
def get_library_path():
|
||||
|
|
|
@ -7,6 +7,7 @@ from transformers import GPT2Tokenizer
|
|||
import onnxruntime as _ort
|
||||
from onnxruntime_customops import (
|
||||
onnx_op,
|
||||
enable_custom_op,
|
||||
PyCustomOpDef,
|
||||
expand_onnx_inputs,
|
||||
get_library_path as _get_library_path)
|
||||
|
@ -46,30 +47,41 @@ class TestGPT2Tokenizer(unittest.TestCase):
|
|||
tokjson = _get_test_data_file('data', 'gpt2.vocab')
|
||||
os.environ["GPT2TOKFILE"] = tokjson
|
||||
cls.tokenizer = GPT2Tokenizer(tokjson, tokjson.replace('.vocab', '.merges.txt'))
|
||||
cls.test_sentence = "I can feel the magic, can you?"
|
||||
cls.indexed_tokens = cls.tokenizer.encode(cls.test_sentence)
|
||||
|
||||
@onnx_op(op_type="BBPETokenizer",
|
||||
model = _create_test_model()
|
||||
cls.binded_model = _bind_tokenizer(model)
|
||||
|
||||
@onnx_op(op_type="GPT2Tokenizer",
|
||||
inputs=[PyCustomOpDef.dt_string],
|
||||
outputs=[PyCustomOpDef.dt_int64])
|
||||
def bpe_toenizer(s):
|
||||
# The user custom op implementation here.
|
||||
TestGPT2Tokenizer.pyop_invoked = True
|
||||
return np.array(
|
||||
[TestGPT2Tokenizer.tokenizer.encode(st_) for st_ in s])
|
||||
|
||||
def test_tokenizer(self):
|
||||
test_sentence = "I can feel the magic, can you?"
|
||||
tokenizer = TestGPT2Tokenizer.tokenizer
|
||||
indexed_tokens = tokenizer.encode(test_sentence)
|
||||
|
||||
model = _create_test_model()
|
||||
binded_model = _bind_tokenizer(model)
|
||||
|
||||
def _run_tokenizer(self, pyop_flag):
|
||||
so = _ort.SessionOptions()
|
||||
enable_custom_op(pyop_flag)
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(binded_model.SerializeToString(), so)
|
||||
input_text = np.array([test_sentence])
|
||||
sess = _ort.InferenceSession(TestGPT2Tokenizer.binded_model.SerializeToString(), so)
|
||||
input_text = np.array([TestGPT2Tokenizer.test_sentence])
|
||||
txout = sess.run(None, {'string_input': input_text})
|
||||
np.testing.assert_array_equal(txout[0], np.array([self.indexed_tokens]))
|
||||
del sess
|
||||
del so
|
||||
|
||||
np.testing.assert_array_equal(txout[0], np.array([indexed_tokens]))
|
||||
def test_tokenizer(self):
|
||||
TestGPT2Tokenizer.pyop_invoked = False
|
||||
self._run_tokenizer(False)
|
||||
self.assertFalse(TestGPT2Tokenizer.pyop_invoked)
|
||||
|
||||
def test_tokenizer_pyop(self):
|
||||
TestGPT2Tokenizer.pyop_invoked = False
|
||||
self._run_tokenizer(True)
|
||||
self.assertTrue(TestGPT2Tokenizer.pyop_invoked)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Загрузка…
Ссылка в новой задаче