Use torch::tensor instead of ks::tensor in entry points (#931)

This commit is contained in:
David Collier 2021-07-16 14:13:37 +01:00 коммит произвёл GitHub
Родитель 2cd938a5e9
Коммит 1305e89be8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 202 добавлений и 100 удалений

Просмотреть файл

@ -64,14 +64,21 @@ def vrelu3_embedded_ks_checkpointed_map():
embedded_cpp_entry_points = """ embedded_cpp_entry_points = """
#include "knossos-entry-points.h" #include "knossos-entry-points-torch.h"
ks::tensor<1, ks::Float> entry(ks::tensor<1, ks::Float> t) { torch::Tensor entry(torch::Tensor t) {
return ks::vrelu3(&ks::entry_points::g_alloc, t); using namespace ks::entry_points;
auto ks_t = convert_argument<ks::tensor<1, ks::Float>>(t);
auto ks_ret = ks::vrelu3(&g_alloc, ks_t);
return convert_return_value<torch::Tensor>(ks_ret);
} }
ks::tensor<1, ks::Float> entry_vjp(ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) { torch::Tensor entry_vjp(torch::Tensor t, torch::Tensor dret) {
return ks::sufrev_vrelu3(&ks::entry_points::g_alloc, t, dret); using namespace ks::entry_points;
auto ks_t = convert_argument<ks::tensor<1, ks::Float>>(t);
auto ks_dret = convert_argument<ks::tensor<1, ks::Float>>(dret);
auto ks_ret = ks::sufrev_vrelu3(&g_alloc, ks_t, ks_dret);
return convert_return_value<torch::Tensor>(ks_ret);
} }
""" """

Просмотреть файл

@ -1,4 +1,5 @@
from ksc import utils from ksc import utils
from ksc.type import Type
scalar_type_to_cpp_map = { scalar_type_to_cpp_map = {
@ -24,7 +25,31 @@ def ks_cpp_type(t):
raise ValueError(f'Unable to generate C++ type for "{t}"') raise ValueError(f'Unable to generate C++ type for "{t}"')
def generate_cpp_entry_points(bindings_to_generate, decls): def entry_point_cpp_type(t, use_torch):
if t.is_scalar:
return scalar_type_to_cpp_map[t.kind]
elif t.is_tuple:
return (
"ks::Tuple<"
+ ", ".join(
entry_point_cpp_type(child, use_torch) for child in t.tuple_elems()
)
+ ">"
)
elif t.is_tensor:
if use_torch:
if t.tensor_elem_type != Type.Float:
raise ValueError(
f'Entry point signatures may only use tensors with floating-point elements (not "{t}")'
)
return "torch::Tensor"
else:
raise ValueError(f'Tensors in entry points are not supported "{t}"')
else:
raise ValueError(f'Unable to generate C++ type for "{t}"')
def generate_cpp_entry_points(bindings_to_generate, decls, use_torch=False):
decls_by_name = {decl.name: decl for decl in decls} decls_by_name = {decl.name: decl for decl in decls}
def lookup_decl(structured_name): def lookup_decl(structured_name):
@ -33,12 +58,18 @@ def generate_cpp_entry_points(bindings_to_generate, decls):
return decls_by_name[structured_name] return decls_by_name[structured_name]
cpp_entry_points = "".join( cpp_entry_points = "".join(
generate_cpp_entry_point(binding_name, lookup_decl(structured_name)) generate_cpp_entry_point(
binding_name, lookup_decl(structured_name), use_torch=use_torch
)
for binding_name, structured_name in bindings_to_generate for binding_name, structured_name in bindings_to_generate
) )
entry_point_header = (
"knossos-entry-points-torch.h" if use_torch else "knossos-entry-points.h"
)
return f""" return f"""
#include "knossos-entry-points.h" #include "{entry_point_header}"
namespace ks {{ namespace ks {{
namespace entry_points {{ namespace entry_points {{
@ -50,7 +81,7 @@ namespace generated {{
""" """
def generate_cpp_entry_point(cpp_function_name, decl): def generate_cpp_entry_point(cpp_function_name, decl, use_torch):
return_type = decl.return_type return_type = decl.return_type
arg_types = [arg.type_ for arg in decl.args] arg_types = [arg.type_ for arg in decl.args]
if len(arg_types) == 1 and arg_types[0].is_tuple: if len(arg_types) == 1 and arg_types[0].is_tuple:
@ -59,26 +90,43 @@ def generate_cpp_entry_point(cpp_function_name, decl):
def arg_name(i): def arg_name(i):
return f"arg{i}" return f"arg{i}"
def ks_arg_name(i):
return f"ks_{arg_name(i)}"
args = ", ".join( args = ", ".join(
f"{ks_cpp_type(arg_type)} {arg_name(i)}" for i, arg_type in enumerate(arg_types) f"{entry_point_cpp_type(arg_type, use_torch)} {arg_name(i)}"
for i, arg_type in enumerate(arg_types)
)
cpp_declaration = (
f"{entry_point_cpp_type(return_type, use_torch)} {cpp_function_name}({args})"
)
convert_arguments = "\n".join(
f" auto {ks_arg_name(i)} = convert_argument<{ks_cpp_type(arg_type)}>({arg_name(i)});"
for i, arg_type in enumerate(arg_types)
) )
cpp_declaration = f"{ks_cpp_type(return_type)} {cpp_function_name}({args})"
ks_function_name = utils.encode_name(decl.name.mangled()) ks_function_name = utils.encode_name(decl.name.mangled())
arg_names = [arg_name(i) for i in range(len(arg_types))] arg_list = ", ".join(["&g_alloc"] + [ks_arg_name(i) for i in range(len(arg_types))])
arg_list = ", ".join(["&g_alloc"] + arg_names) cpp_call = f"""
cpp_call = f"ks::{ks_function_name}({arg_list})" auto ks_ret = ks::{ks_function_name}({arg_list});
args_streamed = ' << ", " '.join(f" << {arg}" for arg in arg_names) auto ret = convert_return_value<{entry_point_cpp_type(return_type, use_torch)}>(ks_ret);
"""
args_streamed = ' << ", " '.join(
f" << {arg_name(i)}" for i in range(len(arg_types))
)
return f""" return f"""
{cpp_declaration} {{ {cpp_declaration} {{
if (g_logging) {{ if (g_logging) {{
std::cerr << "{ks_function_name}("{args_streamed} << ") =" << std::endl; std::cerr << "{ks_function_name}("{args_streamed} << ") =" << std::endl;
auto ret = {cpp_call};
std::cerr << ret << std::endl;
return ret;
}} else {{
return {cpp_call};
}} }}
{convert_arguments}
{cpp_call}
if (g_logging) {{
std::cerr << ret << std::endl;
}}
return ret;
}} }}
""" """

Просмотреть файл

@ -148,7 +148,7 @@ derivatives_to_generate_default = ["fwd", "rev"]
def generate_cpp_for_py_module_from_ks( def generate_cpp_for_py_module_from_ks(
ks_str, bindings_to_generate, python_module_name, use_aten=True, ks_str, bindings_to_generate, python_module_name, use_aten=True, use_torch=False
): ):
def mangled_with_type(structured_name): def mangled_with_type(structured_name):
if not structured_name.has_type(): if not structured_name.has_type():
@ -163,7 +163,9 @@ def generate_cpp_for_py_module_from_ks(
] ]
cpp_ks_functions, decls = generate_cpp_from_ks(ks_str, use_aten=use_aten) cpp_ks_functions, decls = generate_cpp_from_ks(ks_str, use_aten=use_aten)
cpp_entry_points = cgen.generate_cpp_entry_points(bindings_to_generate, decls) cpp_entry_points = cgen.generate_cpp_entry_points(
bindings_to_generate, decls, use_torch=use_torch
)
cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration( cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration(
bindings, python_module_name bindings, python_module_name
) )
@ -189,11 +191,6 @@ PYBIND11_MODULE("""
m.def("allocator_top", &ks::entry_points::allocator_top); m.def("allocator_top", &ks::entry_points::allocator_top);
m.def("allocator_peak", &ks::entry_points::allocator_peak); m.def("allocator_peak", &ks::entry_points::allocator_peak);
m.def("logging", &ks::entry_points::logging); m.def("logging", &ks::entry_points::logging);
declare_tensor_1<ks::Float>(m, "Tensor_1_Float");
declare_tensor_2<ks::Float>(m, "Tensor_2_Float");
declare_tensor_2<ks::Integer>(m, "Tensor_2_Integer");
""" """
+ "\n".join(m_def(*t) for t in bindings_to_generate) + "\n".join(m_def(*t) for t in bindings_to_generate)
+ """ + """
@ -204,10 +201,16 @@ PYBIND11_MODULE("""
) )
def build_py_module_from_ks(ks_str, bindings_to_generate, use_aten=False): def build_py_module_from_ks(
ks_str, bindings_to_generate, use_aten=False, use_torch=False
):
cpp_str = generate_cpp_for_py_module_from_ks( cpp_str = generate_cpp_for_py_module_from_ks(
ks_str, bindings_to_generate, "PYTHON_MODULE_NAME", use_aten ks_str,
bindings_to_generate,
"PYTHON_MODULE_NAME",
use_aten=use_aten,
use_torch=use_torch,
) )
cpp_fname = ( cpp_fname = (
@ -237,7 +240,11 @@ def build_module_using_pytorch_from_ks(
Each StructuredName must have a type attached Each StructuredName must have a type attached
""" """
cpp_str = generate_cpp_for_py_module_from_ks( cpp_str = generate_cpp_for_py_module_from_ks(
ks_str, bindings_to_generate, torch_extension_name, use_aten ks_str,
bindings_to_generate,
torch_extension_name,
use_aten=use_aten,
use_torch=True,
) )
return build_module_using_pytorch_from_cpp_backend( return build_module_using_pytorch_from_cpp_backend(

Просмотреть файл

@ -396,12 +396,14 @@ def torch_from_ks(ks_object):
return tuple(torch_from_ks(ks) for ks in ks_object) return tuple(torch_from_ks(ks) for ks in ks_object)
if isinstance(ks_object, float): if isinstance(ks_object, float):
return torch.tensor(ks_object) return torch.tensor(ks_object) # TODO: use torch::Scalar?
return torch.from_numpy(numpy.array(ks_object, copy=True)) assert isinstance(ks_object, torch.Tensor) # TODO: strings, etc.
return ks_object
def torch_to_ks(py_mod, val): def torch_to_ks(val):
""" """
Return a KS-compatible version of val. Return a KS-compatible version of val.
If val is a scalar, just return it as a float. If val is a scalar, just return it as a float.
@ -419,13 +421,7 @@ def torch_to_ks(py_mod, val):
if len(val.shape) == 0: if len(val.shape) == 0:
return val.item() return val.item()
val = val.contiguous() # Get data, or copy if not already contiguous return val.contiguous() # Get data, or copy if not already contiguous
if len(val.shape) == 1:
ks_tensor = py_mod.Tensor_1_Float(val.data_ptr(), *val.shape)
if len(val.shape) == 2:
ks_tensor = py_mod.Tensor_2_Float(val.data_ptr(), *val.shape)
ks_tensor._torch_val = val # Stash object inside return value to prevent premature garbage collection
return ks_tensor
raise NotImplementedError() raise NotImplementedError()
@ -446,7 +442,7 @@ def logging(py_mod, flag=True):
# See https://pytorch.org/docs/stable/notes/extending.html # See https://pytorch.org/docs/stable/notes/extending.html
def forward_template(py_mod, ctx, *args): def forward_template(py_mod, ctx, *args):
py_mod.reset_allocator() py_mod.reset_allocator()
ks_args = (torch_to_ks(py_mod, x) for x in args) ks_args = (torch_to_ks(x) for x in args)
# Call it # Call it
outputs = py_mod.entry(*ks_args) outputs = py_mod.entry(*ks_args)
@ -459,8 +455,8 @@ def forward_template(py_mod, ctx, *args):
def backward_template(py_mod, ctx, *args): def backward_template(py_mod, ctx, *args):
ks_args = make_tuple_if_many_args(torch_to_ks(py_mod, x) for x in ctx.saved_tensors) ks_args = make_tuple_if_many_args(torch_to_ks(x) for x in ctx.saved_tensors)
ks_grad_args = make_tuple_if_many_args(torch_to_ks(py_mod, x) for x in args) ks_grad_args = make_tuple_if_many_args(torch_to_ks(x) for x in args)
outputs = py_mod.entry_vjp(ks_args, ks_grad_args) outputs = py_mod.entry_vjp(ks_args, ks_grad_args)
return torch_from_ks(outputs) return torch_from_ks(outputs)
@ -477,7 +473,7 @@ def make_KscAutogradFunction(py_mod):
"py_mod": py_mod, "py_mod": py_mod,
"forward": staticmethod(forward), "forward": staticmethod(forward),
"backward": staticmethod(backward), "backward": staticmethod(backward),
"adapt": staticmethod(lambda x: torch_to_ks(py_mod, x)), "adapt": staticmethod(lambda x: torch_to_ks(x)),
}, },
) )

Просмотреть файл

@ -0,0 +1,48 @@
#pragma once
#include "knossos-entry-points.h"
#include <torch/extension.h>
namespace ks {
namespace entry_points {
constexpr at::ScalarType scalar_type_of_Float = c10::CppTypeToScalarType<Float>::value;
template<>
struct Converter<ks::tensor<1, Float>, torch::Tensor>
{
static ks::tensor<1, Float> to_ks(torch::Tensor arg) {
KS_ASSERT(arg.sizes().size() == 1u);
KS_ASSERT(arg.is_contiguous());
KS_ASSERT(arg.scalar_type() == scalar_type_of_Float);
return ks::tensor<1, Float>((int)arg.size(0), arg.data_ptr<Float>());
}
static torch::Tensor from_ks(ks::tensor<1, Float> ret) {
torch::Tensor torch_ret = torch::empty(ret.size(), torch::TensorOptions().dtype(scalar_type_of_Float));
std::memcpy(torch_ret.data_ptr(), ret.data(), ret.size() * sizeof(Float));
return torch_ret;
}
};
template<>
struct Converter<ks::tensor<2, Float>, torch::Tensor>
{
static ks::tensor<2, Float> to_ks(torch::Tensor arg) {
KS_ASSERT(arg.sizes().size() == 2u);
KS_ASSERT(arg.is_contiguous());
KS_ASSERT(arg.scalar_type() == scalar_type_of_Float);
return ks::tensor<2, Float>({(int)arg.size(0), (int)arg.size(1)}, arg.data_ptr<Float>());
}
static torch::Tensor from_ks(ks::tensor<2, Float> ret) {
auto [size0, size1] = ret.size();
torch::Tensor torch_ret = torch::empty({size0, size1}, torch::TensorOptions().dtype(scalar_type_of_Float));
std::memcpy(torch_ret.data_ptr(), ret.data(), size0 * size1 * sizeof(Float));
return torch_ret;
}
};
}
}

Просмотреть файл

@ -33,6 +33,58 @@ auto with_ks_allocator(const char * tracingMessage, RetType(*f)(ks::allocator*,
}; };
} }
template<typename KSType, typename EntryPointType>
KSType convert_argument(EntryPointType arg);
template<typename EntryPointType, typename KSType>
EntryPointType convert_return_value(KSType ret);
template<typename KSType, typename EntryPointType>
struct Converter
{
static_assert(std::is_same<KSType, EntryPointType>::value, "Entry point type is not supported");
static KSType to_ks(EntryPointType arg) {
return arg;
}
static EntryPointType from_ks(KSType ret) {
return ret;
}
};
template<typename ...KsElementTypes, typename ...EntryPointElementTypes>
struct Converter<ks::Tuple<KsElementTypes...>, ks::Tuple<EntryPointElementTypes...>>
{
template<size_t ...Indices>
static ks::Tuple<KsElementTypes...> to_ks_impl(ks::Tuple<EntryPointElementTypes...> arg, std::index_sequence<Indices...>) {
return ks::make_Tuple(convert_argument<KsElementTypes>(ks::get<Indices>(arg))...);
}
static ks::Tuple<KsElementTypes...> to_ks(ks::Tuple<EntryPointElementTypes...> arg) {
return to_ks_impl(arg, std::index_sequence_for<EntryPointElementTypes...>{});
}
template<size_t ...Indices>
static ks::Tuple<EntryPointElementTypes...> from_ks_impl(ks::Tuple<KsElementTypes...> ret, std::index_sequence<Indices...>) {
return ks::make_Tuple(convert_return_value<EntryPointElementTypes>(ks::get<Indices>(ret))...);
}
static ks::Tuple<EntryPointElementTypes...> from_ks(ks::Tuple<KsElementTypes...> ret) {
return from_ks_impl(ret, std::index_sequence_for<KsElementTypes...>{});
}
};
template<typename KSType, typename EntryPointType>
KSType convert_argument(EntryPointType arg) {
return Converter<KSType, EntryPointType>::to_ks(arg);
}
template<typename EntryPointType, typename KSType>
EntryPointType convert_return_value(KSType ret) {
return Converter<KSType, EntryPointType>::from_ks(ret);
}
} }
} }

Просмотреть файл

@ -81,59 +81,3 @@ private:
}} }}
template<typename T>
void declare_tensor_2(py::module &m, char const* name) {
// Wrap ks_tensor<Dim, T> to point to supplied python memory
static constexpr size_t Dim = 2;
py::class_<ks::tensor<2, T>>(m, name, py::buffer_protocol(), py::module_local(), py::dynamic_attr())
.def(py::init([](std::uintptr_t v, size_t m, size_t n) {
ks::tensor_dimension<Dim>::index_type size {int(m),int(n)};
return ks::tensor<Dim, T>(size, reinterpret_cast<T*>(v)); // Reference to caller's data
}))
// And describe buffer shape to Python
// Returned tensors will be living on g_alloc, so will become invalid after allocator_reset()
.def_buffer([](ks::tensor<Dim, T> &t) -> py::buffer_info {
return py::buffer_info(
t.data(), /* Pointer to buffer */
sizeof(T), /* Size of one scalar */
py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
Dim, /* Number of dimensions */
{ ks::get_dimension<0>(t.size()), ks::get_dimension<1>(t.size()) }, /* Buffer dimensions */
{ sizeof(T) * ks::get_dimension<1>(t.size()), /* Strides (in bytes) for each index */
sizeof(T) }
);
})
;
}
template<typename T>
void declare_tensor_1(py::module &m, char const* name) {
// Wrap ks_tensor<1, T> to point to supplied python memory
static constexpr size_t Dim = 1;
py::class_<ks::tensor<Dim, T>>(m, name, py::buffer_protocol(), py::module_local(), py::dynamic_attr())
.def(py::init([](std::uintptr_t v, size_t n) {
ks::tensor_dimension<Dim>::index_type size {int(n)};
// Note: We are capturing a reference to the caller's data.
// we expect the user to attach a Python object to this class
// in order to keep that data alive. See torch_frontend.py:torch_to_ks
// OR: of course we could just copy, but it's useful to keep track of the cost
// so preserving an implementation where we can avoid the copy feels
// valuable.
return ks::tensor<Dim, T>(size, reinterpret_cast<T*>(v)); // Reference to caller's data
}))
// And describe buffer shape to Python
// Returned tensors will be living on g_alloc, so will become invalid after allocator_reset()
.def_buffer([](ks::tensor<Dim, T> &t) -> py::buffer_info {
return py::buffer_info(
t.data(), /* Pointer to buffer */
sizeof(T), /* Size of one scalar */
py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
Dim, /* Number of dimensions */
{ ks::get_dimension<0>(t.size()) }, /* Buffer dimensions */
{ sizeof(T) }
);
})
;
}