Generate C++ entry points (#925)
This commit is contained in:
Родитель
8790b5ae72
Коммит
82aea2c764
|
@ -64,13 +64,14 @@ def vrelu3_embedded_ks_checkpointed_map():
|
|||
|
||||
|
||||
embedded_cpp_entry_points = """
|
||||
namespace ks {
|
||||
ks::tensor<1, ks::Float> entry(ks::allocator * $alloc, ks::tensor<1, ks::Float> t) {
|
||||
return ks::vrelu3($alloc, t);
|
||||
}
|
||||
ks::tensor<1, ks::Float> entry_vjp(ks::allocator * $alloc, ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) {
|
||||
return ks::sufrev_vrelu3($alloc, t, dret);
|
||||
#include "knossos-entry-points.h"
|
||||
|
||||
ks::tensor<1, ks::Float> entry(ks::tensor<1, ks::Float> t) {
|
||||
return ks::vrelu3(&ks::entry_points::g_alloc, t);
|
||||
}
|
||||
|
||||
ks::tensor<1, ks::Float> entry_vjp(ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) {
|
||||
return ks::sufrev_vrelu3(&ks::entry_points::g_alloc, t, dret);
|
||||
}
|
||||
"""
|
||||
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
from ksc import utils
|
||||
|
||||
|
||||
scalar_type_to_cpp_map = {
|
||||
"Integer": "ks::Integer",
|
||||
"Float": "ks::Float",
|
||||
"Bool": "ks::Bool",
|
||||
"String": "std::string",
|
||||
}
|
||||
|
||||
|
||||
def ks_cpp_type(t):
|
||||
if t.is_scalar:
|
||||
return scalar_type_to_cpp_map[t.kind]
|
||||
elif t.is_tuple:
|
||||
return (
|
||||
"ks::Tuple<"
|
||||
+ ", ".join(ks_cpp_type(child) for child in t.tuple_elems())
|
||||
+ ">"
|
||||
)
|
||||
elif t.is_tensor:
|
||||
return f"ks::tensor<{t.tensor_rank}, {ks_cpp_type(t.tensor_elem_type)}>"
|
||||
else:
|
||||
raise ValueError(f'Unable to generate C++ type for "{t}"')
|
||||
|
||||
|
||||
def generate_cpp_entry_points(bindings_to_generate, decls):
|
||||
decls_by_name = {decl.name: decl for decl in decls}
|
||||
|
||||
def lookup_decl(structured_name):
|
||||
if structured_name not in decls_by_name:
|
||||
raise ValueError(f"No ks definition found for binding: {structured_name}")
|
||||
return decls_by_name[structured_name]
|
||||
|
||||
cpp_entry_points = "".join(
|
||||
generate_cpp_entry_point(binding_name, lookup_decl(structured_name))
|
||||
for binding_name, structured_name in bindings_to_generate
|
||||
)
|
||||
|
||||
return f"""
|
||||
#include "knossos-entry-points.h"
|
||||
|
||||
namespace ks {{
|
||||
namespace entry_points {{
|
||||
namespace generated {{
|
||||
{cpp_entry_points}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def generate_cpp_entry_point(cpp_function_name, decl):
|
||||
return_type = decl.return_type
|
||||
arg_types = [arg.type_ for arg in decl.args]
|
||||
if len(arg_types) == 1 and arg_types[0].is_tuple:
|
||||
arg_types = arg_types[0].children
|
||||
|
||||
def arg_name(i):
|
||||
return f"arg{i}"
|
||||
|
||||
args = ", ".join(
|
||||
f"{ks_cpp_type(arg_type)} {arg_name(i)}" for i, arg_type in enumerate(arg_types)
|
||||
)
|
||||
cpp_declaration = f"{ks_cpp_type(return_type)} {cpp_function_name}({args})"
|
||||
|
||||
ks_function_name = utils.encode_name(decl.name.mangled())
|
||||
arg_names = [arg_name(i) for i in range(len(arg_types))]
|
||||
arg_list = ", ".join(["&g_alloc"] + arg_names)
|
||||
cpp_call = f"ks::{ks_function_name}({arg_list})"
|
||||
args_streamed = ' << ", " '.join(f" << {arg}" for arg in arg_names)
|
||||
|
||||
return f"""
|
||||
{cpp_declaration} {{
|
||||
if (g_logging) {{
|
||||
std::cerr << "{ks_function_name}("{args_streamed} << ") =" << std::endl;
|
||||
auto ret = {cpp_call};
|
||||
std::cerr << ret << std::endl;
|
||||
return ret;
|
||||
}} else {{
|
||||
return {cpp_call};
|
||||
}}
|
||||
}}
|
||||
"""
|
|
@ -8,7 +8,8 @@ from tempfile import gettempdir
|
|||
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
from ksc import utils
|
||||
from ksc import cgen, utils
|
||||
from ksc.parse_ks import parse_ks_filename
|
||||
|
||||
preserve_temporary_files = False
|
||||
|
||||
|
@ -52,6 +53,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
|
|||
e = subprocess.run(ksc_command, capture_output=True, check=True,)
|
||||
print(e.stdout.decode("ascii"))
|
||||
print(e.stderr.decode("ascii"))
|
||||
decls = list(parse_ks_filename(fkso.name))
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Command failed:\n{' '.join(ksc_command)}")
|
||||
print(f"files {fks.name} {fkso.name} {fcpp.name}")
|
||||
|
@ -62,7 +64,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
|
|||
|
||||
# Read from CPP back to string
|
||||
with open(fcpp.name) as f:
|
||||
out = f.read()
|
||||
generated_cpp = f.read()
|
||||
|
||||
# only delete these file if no error
|
||||
if not preserve_temporary_files:
|
||||
|
@ -79,7 +81,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
|
|||
os.unlink(fcpp.name)
|
||||
os.unlink(fkso.name)
|
||||
|
||||
return out
|
||||
return generated_cpp, decls
|
||||
|
||||
|
||||
def build_py_module_from_cpp(cpp_str, profiling=False, use_aten=False):
|
||||
|
@ -156,22 +158,23 @@ def generate_cpp_for_py_module_from_ks(
|
|||
return structured_name.mangled()
|
||||
|
||||
bindings = [
|
||||
(python_name, utils.encode_name(mangled_with_type(structured_name)))
|
||||
for (python_name, structured_name) in bindings_to_generate
|
||||
(python_name, "ks::entry_points::generated::" + python_name)
|
||||
for (python_name, _) in bindings_to_generate
|
||||
]
|
||||
|
||||
cpp_ks_functions = generate_cpp_from_ks(ks_str, use_aten=use_aten)
|
||||
cpp_ks_functions, decls = generate_cpp_from_ks(ks_str, use_aten=use_aten)
|
||||
cpp_entry_points = cgen.generate_cpp_entry_points(bindings_to_generate, decls)
|
||||
cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration(
|
||||
bindings, python_module_name
|
||||
)
|
||||
|
||||
return cpp_ks_functions + cpp_pybind_module_declaration
|
||||
return cpp_ks_functions + cpp_entry_points + cpp_pybind_module_declaration
|
||||
|
||||
|
||||
def generate_cpp_pybind_module_declaration(bindings_to_generate, python_module_name):
|
||||
def m_def(python_name, cpp_name):
|
||||
return f"""
|
||||
m.def("{python_name}", ks::entry_points::with_ks_allocator("{cpp_name}", &ks::{cpp_name}));
|
||||
m.def("{python_name}", &{cpp_name});
|
||||
"""
|
||||
|
||||
return (
|
||||
|
|
Загрузка…
Ссылка в новой задаче