Test the Gpt2 Tokenizer both in two modes, PyOP and Native (#40)

* add a flag to enable pyop

* test passed

* a little polish.
This commit is contained in:
Wenbing Li 2021-01-12 16:11:02 -08:00 коммит произвёл GitHub
Родитель 4e0af5c582
Коммит 9ec6951516
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 61 добавлений и 26 удалений

Просмотреть файл

@ -14,5 +14,6 @@ const char c_OpDomain[] = "ai.onnx.contrib";
#if defined(PYTHON_OP_SUPPORT)
const OrtCustomOp* FetchPyCustomOps(size_t& count);
bool EnablePyCustomOps(bool enable=true);
#endif

Просмотреть файл

@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <set>
#include "kernels/op_equal.hpp"
#include "kernels/op_segment_sum.hpp"
#include "kernels/string_hash.hpp"
@ -40,19 +42,12 @@ OrtCustomOp* operator_lists[] = {
extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
OrtCustomOpDomain* domain = nullptr;
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
std::set<std::string> pyop_nameset;
if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
return status;
}
OrtCustomOp** ops = operator_lists;
while (*ops != nullptr) {
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
return status;
}
++ops;
}
#if defined(PYTHON_OP_SUPPORT)
size_t count = 0;
const OrtCustomOp* c_ops = FetchPyCustomOps(count);
@ -60,16 +55,31 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
if (auto status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
return status;
}
else {
pyop_nameset.emplace(c_ops->GetName(c_ops));
}
++count;
c_ops = FetchPyCustomOps(count);
}
#endif
OrtCustomOp** ops = operator_lists;
while (*ops != nullptr) {
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
return status;
}
}
++ops;
}
#if defined(ENABLE_TOKENIZER)
auto** t_ops = LoadTokenizerSchemaList();
const OrtCustomOp** t_ops = LoadTokenizerSchemaList();
while (*t_ops != nullptr) {
if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)){
return status;
if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) {
if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)){
return status;
}
}
t_ops++;
}

Просмотреть файл

@ -376,6 +376,11 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
}
const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count) {
if (!EnablePyCustomOps(true)) {
EnablePyCustomOps(false);
return nullptr;
}
// The result must stay alive
std::vector<PyCustomOpFactory>& copy = PyCustomOpDef_python_operator_list();
if (count < copy.size())
@ -390,6 +395,13 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count) {
return ptr;
}
bool EnablePyCustomOps(bool enabled){
static bool f_pyop_enabled = true;
bool last = f_pyop_enabled;
f_pyop_enabled = enabled;
return last;
}
// static std::ofstream logger;
static int init_numpy() {
import_array();
@ -406,6 +418,7 @@ uint64_t hash_64(const std::string& str, uint64_t num_buckets, bool fast) {
}
void AddGlobalMethods(pybind11::module& m) {
m.def("enable_custom_op", &EnablePyCustomOps, "Enable or disable pyop functions.");
m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); });
m.def("hash_64", &hash_64, "Computes a uint64 hash for a string (from tensorflow).");
}

Просмотреть файл

@ -12,7 +12,7 @@ __author__ = "Microsoft"
from ._ocos import get_library_path # noqa
from ._ocos import Opdef, PyCustomOpDef, hash_64 # noqa
from ._ocos import Opdef, PyCustomOpDef, hash_64, enable_custom_op # noqa
from ._ocos import expand_onnx_inputs # noqa
onnx_op = Opdef.declare

Просмотреть файл

@ -6,9 +6,8 @@
import sys
import copy
from onnx import helper
from pathlib import Path
from ._ortcustomops import ( # noqa
PyCustomOpDef, add_custom_op, hash_64)
PyCustomOpDef, enable_custom_op, add_custom_op, hash_64)
def get_library_path():

Просмотреть файл

@ -7,6 +7,7 @@ from transformers import GPT2Tokenizer
import onnxruntime as _ort
from onnxruntime_customops import (
onnx_op,
enable_custom_op,
PyCustomOpDef,
expand_onnx_inputs,
get_library_path as _get_library_path)
@ -46,30 +47,41 @@ class TestGPT2Tokenizer(unittest.TestCase):
tokjson = _get_test_data_file('data', 'gpt2.vocab')
os.environ["GPT2TOKFILE"] = tokjson
cls.tokenizer = GPT2Tokenizer(tokjson, tokjson.replace('.vocab', '.merges.txt'))
cls.test_sentence = "I can feel the magic, can you?"
cls.indexed_tokens = cls.tokenizer.encode(cls.test_sentence)
@onnx_op(op_type="BBPETokenizer",
model = _create_test_model()
cls.binded_model = _bind_tokenizer(model)
@onnx_op(op_type="GPT2Tokenizer",
inputs=[PyCustomOpDef.dt_string],
outputs=[PyCustomOpDef.dt_int64])
def bpe_toenizer(s):
# The user custom op implementation here.
TestGPT2Tokenizer.pyop_invoked = True
return np.array(
[TestGPT2Tokenizer.tokenizer.encode(st_) for st_ in s])
def test_tokenizer(self):
test_sentence = "I can feel the magic, can you?"
tokenizer = TestGPT2Tokenizer.tokenizer
indexed_tokens = tokenizer.encode(test_sentence)
model = _create_test_model()
binded_model = _bind_tokenizer(model)
def _run_tokenizer(self, pyop_flag):
so = _ort.SessionOptions()
enable_custom_op(pyop_flag)
so.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(binded_model.SerializeToString(), so)
input_text = np.array([test_sentence])
sess = _ort.InferenceSession(TestGPT2Tokenizer.binded_model.SerializeToString(), so)
input_text = np.array([TestGPT2Tokenizer.test_sentence])
txout = sess.run(None, {'string_input': input_text})
np.testing.assert_array_equal(txout[0], np.array([self.indexed_tokens]))
del sess
del so
np.testing.assert_array_equal(txout[0], np.array([indexed_tokens]))
def test_tokenizer(self):
TestGPT2Tokenizer.pyop_invoked = False
self._run_tokenizer(False)
self.assertFalse(TestGPT2Tokenizer.pyop_invoked)
def test_tokenizer_pyop(self):
TestGPT2Tokenizer.pyop_invoked = False
self._run_tokenizer(True)
self.assertTrue(TestGPT2Tokenizer.pyop_invoked)
if __name__ == "__main__":