support the operator list for build flags (#122)
* support the operator list for build flags * revert the flag * update the file name * little refinement
This commit is contained in:
Родитель
a428be447c
Коммит
983de7c0fe
|
@ -14,6 +14,7 @@ dist/
|
|||
cmake_build
|
||||
.cmake_build
|
||||
cmake-build*
|
||||
_selectedoplist.cmake
|
||||
gen
|
||||
.DS_Store
|
||||
*~
|
||||
|
|
|
@ -31,7 +31,7 @@ option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
|
|||
option(OCOS_ENABLE_BLINGFIRE "Enable the Blingfire building" ON)
|
||||
option(OCOS_ENABLE_MATH "Enable the math tensor operators building" ON)
|
||||
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
|
||||
|
||||
option(OCOS_ENABLE_OPLIST_FILE "Enable including the selected_ops tool file" OFF)
|
||||
|
||||
if(NOT CC_OPTIMIZE)
|
||||
message("!!!THE COMPILER OPTIMIZATION HAS BEEN DISABLED, DEBUG-ONLY!!!")
|
||||
|
@ -59,6 +59,10 @@ endif()
|
|||
|
||||
# External dependencies
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals)
|
||||
if (OCOS_ENABLE_OPLIST_FILE)
|
||||
include(_selectedoplist)
|
||||
endif()
|
||||
|
||||
include(FetchContent)
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
if (NOT TARGET re2::re2)
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import inspect
|
||||
from ._ocos import default_opset_domain
|
||||
from . import _cuops
|
||||
|
||||
|
||||
ALL_CUSTOM_OPS = {_name: _obj for _name, _obj in inspect.getmembers(_cuops)
|
||||
if (inspect.isclass(_obj) and issubclass(_obj, _cuops.CustomOp))}
|
||||
|
||||
|
||||
OPMAP_TO_CMAKE_FLAGS = {'GPT2Tokenizer': 'OCOS_ENABLE_GPT2_TOKENIZER',
|
||||
'BlingFireSentenceBreaker': 'OCOS_ENABLE_BLINGFIRE'
|
||||
}
|
||||
|
||||
|
||||
def gen_cmake_oplist(opconfig_file, oplist_cmake_file = '_selectedoplist.cmake'):
|
||||
|
||||
ext_domain = default_opset_domain()
|
||||
with open(oplist_cmake_file, 'w') as f:
|
||||
print("# Auto-Generated File, not edited!!!", file=f)
|
||||
with open(opconfig_file, 'r') as opfile:
|
||||
for _ln in opfile:
|
||||
if _ln.startswith(ext_domain):
|
||||
items = _ln.strip().split(';')
|
||||
if len(items) < 3:
|
||||
raise RuntimeError("The malformated operator config file.")
|
||||
for _op in items[2].split(','):
|
||||
if not _op:
|
||||
continue # is None or ""
|
||||
if _op not in OPMAP_TO_CMAKE_FLAGS:
|
||||
raise RuntimeError("Cannot find the custom operator({})\'s build flags, "
|
||||
+ "Please update the OPMAP_TO_CMAKE_FLAGS dictionary.".format(_op))
|
||||
print("set({} ON CACHE INTERNAL \"\")".format(OPMAP_TO_CMAKE_FLAGS[_op]), file=f)
|
||||
print("# End of Building the Operator CMake variables", file=f)
|
||||
|
||||
print('The cmake tool file has been generated successfully.')
|
|
@ -0,0 +1,3 @@
|
|||
#domain;opset;op1,op2...
|
||||
ai.onnx;12;Add,Cast,Concat,Squeeze
|
||||
ai.onnx.contrib;1;GPT2Tokenizer,
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from onnxruntime_extensions import cmake_helper
|
||||
|
||||
|
||||
def _get_test_data_file(*sub_dirs):
|
||||
test_dir = Path(__file__).parent
|
||||
return str(test_dir.joinpath(*sub_dirs))
|
||||
|
||||
|
||||
class TestCMakeHelper(unittest.TestCase):
|
||||
def test_cmake_file_gen(self):
|
||||
cfgfile = _get_test_data_file('data', 'test.op.config')
|
||||
cfile = '_selectedoplist.cmake'
|
||||
cmake_helper.gen_cmake_oplist(cfgfile, cfile)
|
||||
found = False
|
||||
with open(cfile, 'r') as f:
|
||||
for _ln in f:
|
||||
if _ln.strip() == "set(OCOS_ENABLE_GPT2_TOKENIZER ON CACHE INTERNAL \"\")":
|
||||
found = True
|
||||
break
|
||||
|
||||
os.remove(cfile)
|
||||
self.assertTrue(found)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче