support the operator list for build flags (#122)

* support the operator list for build flags

* revert the flag

* update the file name

* little refinement
This commit is contained in:
Wenbing Li 2021-07-30 12:43:47 -07:00 коммит произвёл GitHub
Родитель a428be447c
Коммит 983de7c0fe
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 73 добавлений и 1 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -14,6 +14,7 @@ dist/
cmake_build
.cmake_build
cmake-build*
_selectedoplist.cmake
gen
.DS_Store
*~

Просмотреть файл

@ -31,7 +31,7 @@ option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
option(OCOS_ENABLE_BLINGFIRE "Enable the Blingfire building" ON)
option(OCOS_ENABLE_MATH "Enable the math tensor operators building" ON)
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
option(OCOS_ENABLE_OPLIST_FILE "Enable including the selected_ops tool file" OFF)
if(NOT CC_OPTIMIZE)
message("!!!THE COMPILER OPTIMIZATION HAS BEEN DISABLED, DEBUG-ONLY!!!")
@ -59,6 +59,10 @@ endif()
# External dependencies
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals)
if (OCOS_ENABLE_OPLIST_FILE)
include(_selectedoplist)
endif()
include(FetchContent)
if (OCOS_ENABLE_TF_STRING)
if (NOT TARGET re2::re2)

Просмотреть файл

@ -0,0 +1,35 @@
import inspect
from ._ocos import default_opset_domain
from . import _cuops
ALL_CUSTOM_OPS = {_name: _obj for _name, _obj in inspect.getmembers(_cuops)
if (inspect.isclass(_obj) and issubclass(_obj, _cuops.CustomOp))}
OPMAP_TO_CMAKE_FLAGS = {'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE'
}
def gen_cmake_oplist(opconfig_file, oplist_cmake_file = '_selectedoplist.cmake'):
ext_domain = default_opset_domain()
with open(oplist_cmake_file, 'w') as f:
print("# Auto-Generated File, not edited!!!", file=f)
with open(opconfig_file, 'r') as opfile:
for _ln in opfile:
if _ln.startswith(ext_domain):
items = _ln.strip().split(';')
if len(items) < 3:
raise RuntimeError("The malformated operator config file.")
for _op in items[2].split(','):
if not _op:
continue # is None or ""
if _op not in OPMAP_TO_CMAKE_FLAGS:
raise RuntimeError("Cannot find the custom operator({})\'s build flags, "
+ "Please update the OPMAP_TO_CMAKE_FLAGS dictionary.".format(_op))
print("set({} ON CACHE INTERNAL \"\")".format(OPMAP_TO_CMAKE_FLAGS[_op]), file=f)
print("# End of Building the Operator CMake variables", file=f)
print('The cmake tool file has been generated successfully.')

3
test/data/test.op.config Normal file
Просмотреть файл

@ -0,0 +1,3 @@
#domain;opset;op1,op2...
ai.onnx;12;Add,Cast,Concat,Squeeze
ai.onnx.contrib;1;GPT2Tokenizer,

29
test/test_cmake_helper.py Normal file
Просмотреть файл

@ -0,0 +1,29 @@
import os
import unittest
from pathlib import Path
from onnxruntime_extensions import cmake_helper
def _get_test_data_file(*sub_dirs):
test_dir = Path(__file__).parent
return str(test_dir.joinpath(*sub_dirs))
class TestCMakeHelper(unittest.TestCase):
def test_cmake_file_gen(self):
cfgfile = _get_test_data_file('data', 'test.op.config')
cfile = '_selectedoplist.cmake'
cmake_helper.gen_cmake_oplist(cfgfile, cfile)
found = False
with open(cfile, 'r') as f:
for _ln in f:
if _ln.strip() == "set(OCOS_ENABLE_GPT2_TOKENIZER ON CACHE INTERNAL \"\")":
found = True
break
os.remove(cfile)
self.assertTrue(found)
if __name__ == "__main__":
unittest.main()