This commit is contained in:
David Collier 2021-07-15 10:54:08 +01:00 коммит произвёл GitHub
Родитель 8790b5ae72
Коммит 82aea2c764
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 102 добавлений и 14 удалений

Просмотреть файл

@ -64,13 +64,14 @@ def vrelu3_embedded_ks_checkpointed_map():
embedded_cpp_entry_points = """
namespace ks {
ks::tensor<1, ks::Float> entry(ks::allocator * $alloc, ks::tensor<1, ks::Float> t) {
return ks::vrelu3($alloc, t);
}
ks::tensor<1, ks::Float> entry_vjp(ks::allocator * $alloc, ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) {
return ks::sufrev_vrelu3($alloc, t, dret);
#include "knossos-entry-points.h"
ks::tensor<1, ks::Float> entry(ks::tensor<1, ks::Float> t) {
return ks::vrelu3(&ks::entry_points::g_alloc, t);
}
ks::tensor<1, ks::Float> entry_vjp(ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) {
return ks::sufrev_vrelu3(&ks::entry_points::g_alloc, t, dret);
}
"""

84
src/python/ksc/cgen.py Normal file
Просмотреть файл

@ -0,0 +1,84 @@
from ksc import utils
scalar_type_to_cpp_map = {
"Integer": "ks::Integer",
"Float": "ks::Float",
"Bool": "ks::Bool",
"String": "std::string",
}
def ks_cpp_type(t):
if t.is_scalar:
return scalar_type_to_cpp_map[t.kind]
elif t.is_tuple:
return (
"ks::Tuple<"
+ ", ".join(ks_cpp_type(child) for child in t.tuple_elems())
+ ">"
)
elif t.is_tensor:
return f"ks::tensor<{t.tensor_rank}, {ks_cpp_type(t.tensor_elem_type)}>"
else:
raise ValueError(f'Unable to generate C++ type for "{t}"')
def generate_cpp_entry_points(bindings_to_generate, decls):
decls_by_name = {decl.name: decl for decl in decls}
def lookup_decl(structured_name):
if structured_name not in decls_by_name:
raise ValueError(f"No ks definition found for binding: {structured_name}")
return decls_by_name[structured_name]
cpp_entry_points = "".join(
generate_cpp_entry_point(binding_name, lookup_decl(structured_name))
for binding_name, structured_name in bindings_to_generate
)
return f"""
#include "knossos-entry-points.h"
namespace ks {{
namespace entry_points {{
namespace generated {{
{cpp_entry_points}
}}
}}
}}
"""
def generate_cpp_entry_point(cpp_function_name, decl):
return_type = decl.return_type
arg_types = [arg.type_ for arg in decl.args]
if len(arg_types) == 1 and arg_types[0].is_tuple:
arg_types = arg_types[0].children
def arg_name(i):
return f"arg{i}"
args = ", ".join(
f"{ks_cpp_type(arg_type)} {arg_name(i)}" for i, arg_type in enumerate(arg_types)
)
cpp_declaration = f"{ks_cpp_type(return_type)} {cpp_function_name}({args})"
ks_function_name = utils.encode_name(decl.name.mangled())
arg_names = [arg_name(i) for i in range(len(arg_types))]
arg_list = ", ".join(["&g_alloc"] + arg_names)
cpp_call = f"ks::{ks_function_name}({arg_list})"
args_streamed = ' << ", " '.join(f" << {arg}" for arg in arg_names)
return f"""
{cpp_declaration} {{
if (g_logging) {{
std::cerr << "{ks_function_name}("{args_streamed} << ") =" << std::endl;
auto ret = {cpp_call};
std::cerr << ret << std::endl;
return ret;
}} else {{
return {cpp_call};
}}
}}
"""

Просмотреть файл

@ -8,7 +8,8 @@ from tempfile import gettempdir
from torch.utils import cpp_extension
from ksc import utils
from ksc import cgen, utils
from ksc.parse_ks import parse_ks_filename
preserve_temporary_files = False
@ -52,6 +53,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
e = subprocess.run(ksc_command, capture_output=True, check=True,)
print(e.stdout.decode("ascii"))
print(e.stderr.decode("ascii"))
decls = list(parse_ks_filename(fkso.name))
except subprocess.CalledProcessError as e:
print(f"Command failed:\n{' '.join(ksc_command)}")
print(f"files {fks.name} {fkso.name} {fcpp.name}")
@ -62,7 +64,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
# Read from CPP back to string
with open(fcpp.name) as f:
out = f.read()
generated_cpp = f.read()
# only delete these file if no error
if not preserve_temporary_files:
@ -79,7 +81,7 @@ def generate_cpp_from_ks(ks_str, use_aten=False):
os.unlink(fcpp.name)
os.unlink(fkso.name)
return out
return generated_cpp, decls
def build_py_module_from_cpp(cpp_str, profiling=False, use_aten=False):
@ -156,22 +158,23 @@ def generate_cpp_for_py_module_from_ks(
return structured_name.mangled()
bindings = [
(python_name, utils.encode_name(mangled_with_type(structured_name)))
for (python_name, structured_name) in bindings_to_generate
(python_name, "ks::entry_points::generated::" + python_name)
for (python_name, _) in bindings_to_generate
]
cpp_ks_functions = generate_cpp_from_ks(ks_str, use_aten=use_aten)
cpp_ks_functions, decls = generate_cpp_from_ks(ks_str, use_aten=use_aten)
cpp_entry_points = cgen.generate_cpp_entry_points(bindings_to_generate, decls)
cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration(
bindings, python_module_name
)
return cpp_ks_functions + cpp_pybind_module_declaration
return cpp_ks_functions + cpp_entry_points + cpp_pybind_module_declaration
def generate_cpp_pybind_module_declaration(bindings_to_generate, python_module_name):
def m_def(python_name, cpp_name):
return f"""
m.def("{python_name}", ks::entry_points::with_ks_allocator("{cpp_name}", &ks::{cpp_name}));
m.def("{python_name}", &{cpp_name});
"""
return (