Generate C++ entry points (#925)
This commit is contained in:
Родитель
8790b5ae72
Коммит
82aea2c764
|
@ -64,13 +64,14 @@ def vrelu3_embedded_ks_checkpointed_map():
|
||||||
|
|
||||||
|
|
||||||
embedded_cpp_entry_points = """
|
embedded_cpp_entry_points = """
|
||||||
namespace ks {
|
#include "knossos-entry-points.h"
|
||||||
ks::tensor<1, ks::Float> entry(ks::allocator * $alloc, ks::tensor<1, ks::Float> t) {
|
|
||||||
return ks::vrelu3($alloc, t);
|
ks::tensor<1, ks::Float> entry(ks::tensor<1, ks::Float> t) {
|
||||||
}
|
return ks::vrelu3(&ks::entry_points::g_alloc, t);
|
||||||
ks::tensor<1, ks::Float> entry_vjp(ks::allocator * $alloc, ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) {
|
|
||||||
return ks::sufrev_vrelu3($alloc, t, dret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ks::tensor<1, ks::Float> entry_vjp(ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) {
|
||||||
|
return ks::sufrev_vrelu3(&ks::entry_points::g_alloc, t, dret);
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
from ksc import utils
|
||||||
|
|
||||||
|
|
||||||
|
scalar_type_to_cpp_map = {
|
||||||
|
"Integer": "ks::Integer",
|
||||||
|
"Float": "ks::Float",
|
||||||
|
"Bool": "ks::Bool",
|
||||||
|
"String": "std::string",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ks_cpp_type(t):
|
||||||
|
if t.is_scalar:
|
||||||
|
return scalar_type_to_cpp_map[t.kind]
|
||||||
|
elif t.is_tuple:
|
||||||
|
return (
|
||||||
|
"ks::Tuple<"
|
||||||
|
+ ", ".join(ks_cpp_type(child) for child in t.tuple_elems())
|
||||||
|
+ ">"
|
||||||
|
)
|
||||||
|
elif t.is_tensor:
|
||||||
|
return f"ks::tensor<{t.tensor_rank}, {ks_cpp_type(t.tensor_elem_type)}>"
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unable to generate C++ type for "{t}"')
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cpp_entry_points(bindings_to_generate, decls):
|
||||||
|
decls_by_name = {decl.name: decl for decl in decls}
|
||||||
|
|
||||||
|
def lookup_decl(structured_name):
|
||||||
|
if structured_name not in decls_by_name:
|
||||||
|
raise ValueError(f"No ks definition found for binding: {structured_name}")
|
||||||
|
return decls_by_name[structured_name]
|
||||||
|
|
||||||
|
cpp_entry_points = "".join(
|
||||||
|
generate_cpp_entry_point(binding_name, lookup_decl(structured_name))
|
||||||
|
for binding_name, structured_name in bindings_to_generate
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
#include "knossos-entry-points.h"
|
||||||
|
|
||||||
|
namespace ks {{
|
||||||
|
namespace entry_points {{
|
||||||
|
namespace generated {{
|
||||||
|
{cpp_entry_points}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cpp_entry_point(cpp_function_name, decl):
|
||||||
|
return_type = decl.return_type
|
||||||
|
arg_types = [arg.type_ for arg in decl.args]
|
||||||
|
if len(arg_types) == 1 and arg_types[0].is_tuple:
|
||||||
|
arg_types = arg_types[0].children
|
||||||
|
|
||||||
|
def arg_name(i):
|
||||||
|
return f"arg{i}"
|
||||||
|
|
||||||
|
args = ", ".join(
|
||||||
|
f"{ks_cpp_type(arg_type)} {arg_name(i)}" for i, arg_type in enumerate(arg_types)
|
||||||
|
)
|
||||||
|
cpp_declaration = f"{ks_cpp_type(return_type)} {cpp_function_name}({args})"
|
||||||
|
|
||||||
|
ks_function_name = utils.encode_name(decl.name.mangled())
|
||||||
|
arg_names = [arg_name(i) for i in range(len(arg_types))]
|
||||||
|
arg_list = ", ".join(["&g_alloc"] + arg_names)
|
||||||
|
cpp_call = f"ks::{ks_function_name}({arg_list})"
|
||||||
|
args_streamed = ' << ", " '.join(f" << {arg}" for arg in arg_names)
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
{cpp_declaration} {{
|
||||||
|
if (g_logging) {{
|
||||||
|
std::cerr << "{ks_function_name}("{args_streamed} << ") =" << std::endl;
|
||||||
|
auto ret = {cpp_call};
|
||||||
|
std::cerr << ret << std::endl;
|
||||||
|
return ret;
|
||||||
|
}} else {{
|
||||||
|
return {cpp_call};
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
"""
|
|
@ -8,7 +8,8 @@ from tempfile import gettempdir
|
||||||
|
|
||||||
from torch.utils import cpp_extension
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
from ksc import utils
|
from ksc import cgen, utils
|
||||||
|
from ksc.parse_ks import parse_ks_filename
|
||||||
|
|
||||||
preserve_temporary_files = False
|
preserve_temporary_files = False
|
||||||
|
|
||||||
|
@ -52,6 +53,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
|
||||||
e = subprocess.run(ksc_command, capture_output=True, check=True,)
|
e = subprocess.run(ksc_command, capture_output=True, check=True,)
|
||||||
print(e.stdout.decode("ascii"))
|
print(e.stdout.decode("ascii"))
|
||||||
print(e.stderr.decode("ascii"))
|
print(e.stderr.decode("ascii"))
|
||||||
|
decls = list(parse_ks_filename(fkso.name))
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Command failed:\n{' '.join(ksc_command)}")
|
print(f"Command failed:\n{' '.join(ksc_command)}")
|
||||||
print(f"files {fks.name} {fkso.name} {fcpp.name}")
|
print(f"files {fks.name} {fkso.name} {fcpp.name}")
|
||||||
|
@ -62,7 +64,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
|
||||||
|
|
||||||
# Read from CPP back to string
|
# Read from CPP back to string
|
||||||
with open(fcpp.name) as f:
|
with open(fcpp.name) as f:
|
||||||
out = f.read()
|
generated_cpp = f.read()
|
||||||
|
|
||||||
# only delete these file if no error
|
# only delete these file if no error
|
||||||
if not preserve_temporary_files:
|
if not preserve_temporary_files:
|
||||||
|
@ -79,7 +81,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
|
||||||
os.unlink(fcpp.name)
|
os.unlink(fcpp.name)
|
||||||
os.unlink(fkso.name)
|
os.unlink(fkso.name)
|
||||||
|
|
||||||
return out
|
return generated_cpp, decls
|
||||||
|
|
||||||
|
|
||||||
def build_py_module_from_cpp(cpp_str, profiling=False, use_aten=False):
|
def build_py_module_from_cpp(cpp_str, profiling=False, use_aten=False):
|
||||||
|
@ -156,22 +158,23 @@ def generate_cpp_for_py_module_from_ks(
|
||||||
return structured_name.mangled()
|
return structured_name.mangled()
|
||||||
|
|
||||||
bindings = [
|
bindings = [
|
||||||
(python_name, utils.encode_name(mangled_with_type(structured_name)))
|
(python_name, "ks::entry_points::generated::" + python_name)
|
||||||
for (python_name, structured_name) in bindings_to_generate
|
for (python_name, _) in bindings_to_generate
|
||||||
]
|
]
|
||||||
|
|
||||||
cpp_ks_functions = generate_cpp_from_ks(ks_str, use_aten=use_aten)
|
cpp_ks_functions, decls = generate_cpp_from_ks(ks_str, use_aten=use_aten)
|
||||||
|
cpp_entry_points = cgen.generate_cpp_entry_points(bindings_to_generate, decls)
|
||||||
cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration(
|
cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration(
|
||||||
bindings, python_module_name
|
bindings, python_module_name
|
||||||
)
|
)
|
||||||
|
|
||||||
return cpp_ks_functions + cpp_pybind_module_declaration
|
return cpp_ks_functions + cpp_entry_points + cpp_pybind_module_declaration
|
||||||
|
|
||||||
|
|
||||||
def generate_cpp_pybind_module_declaration(bindings_to_generate, python_module_name):
|
def generate_cpp_pybind_module_declaration(bindings_to_generate, python_module_name):
|
||||||
def m_def(python_name, cpp_name):
|
def m_def(python_name, cpp_name):
|
||||||
return f"""
|
return f"""
|
||||||
m.def("{python_name}", ks::entry_points::with_ks_allocator("{cpp_name}", &ks::{cpp_name}));
|
m.def("{python_name}", &{cpp_name});
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
Загрузка…
Ссылка в новой задаче