Use torch::tensor instead of ks::tensor in entry points (#931)
This commit is contained in:
Родитель
2cd938a5e9
Коммит
1305e89be8
|
@ -64,14 +64,21 @@ def vrelu3_embedded_ks_checkpointed_map():
|
|||
|
||||
|
||||
embedded_cpp_entry_points = """
|
||||
#include "knossos-entry-points.h"
|
||||
#include "knossos-entry-points-torch.h"
|
||||
|
||||
ks::tensor<1, ks::Float> entry(ks::tensor<1, ks::Float> t) {
|
||||
return ks::vrelu3(&ks::entry_points::g_alloc, t);
|
||||
torch::Tensor entry(torch::Tensor t) {
|
||||
using namespace ks::entry_points;
|
||||
auto ks_t = convert_argument<ks::tensor<1, ks::Float>>(t);
|
||||
auto ks_ret = ks::vrelu3(&g_alloc, ks_t);
|
||||
return convert_return_value<torch::Tensor>(ks_ret);
|
||||
}
|
||||
|
||||
ks::tensor<1, ks::Float> entry_vjp(ks::tensor<1, ks::Float> t, ks::tensor<1, ks::Float> dret) {
|
||||
return ks::sufrev_vrelu3(&ks::entry_points::g_alloc, t, dret);
|
||||
torch::Tensor entry_vjp(torch::Tensor t, torch::Tensor dret) {
|
||||
using namespace ks::entry_points;
|
||||
auto ks_t = convert_argument<ks::tensor<1, ks::Float>>(t);
|
||||
auto ks_dret = convert_argument<ks::tensor<1, ks::Float>>(dret);
|
||||
auto ks_ret = ks::sufrev_vrelu3(&g_alloc, ks_t, ks_dret);
|
||||
return convert_return_value<torch::Tensor>(ks_ret);
|
||||
}
|
||||
"""
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from ksc import utils
|
||||
from ksc.type import Type
|
||||
|
||||
|
||||
scalar_type_to_cpp_map = {
|
||||
|
@ -24,7 +25,31 @@ def ks_cpp_type(t):
|
|||
raise ValueError(f'Unable to generate C++ type for "{t}"')
|
||||
|
||||
|
||||
def generate_cpp_entry_points(bindings_to_generate, decls):
|
||||
def entry_point_cpp_type(t, use_torch):
|
||||
if t.is_scalar:
|
||||
return scalar_type_to_cpp_map[t.kind]
|
||||
elif t.is_tuple:
|
||||
return (
|
||||
"ks::Tuple<"
|
||||
+ ", ".join(
|
||||
entry_point_cpp_type(child, use_torch) for child in t.tuple_elems()
|
||||
)
|
||||
+ ">"
|
||||
)
|
||||
elif t.is_tensor:
|
||||
if use_torch:
|
||||
if t.tensor_elem_type != Type.Float:
|
||||
raise ValueError(
|
||||
f'Entry point signatures may only use tensors with floating-point elements (not "{t}")'
|
||||
)
|
||||
return "torch::Tensor"
|
||||
else:
|
||||
raise ValueError(f'Tensors in entry points are not supported "{t}"')
|
||||
else:
|
||||
raise ValueError(f'Unable to generate C++ type for "{t}"')
|
||||
|
||||
|
||||
def generate_cpp_entry_points(bindings_to_generate, decls, use_torch=False):
|
||||
decls_by_name = {decl.name: decl for decl in decls}
|
||||
|
||||
def lookup_decl(structured_name):
|
||||
|
@ -33,12 +58,18 @@ def generate_cpp_entry_points(bindings_to_generate, decls):
|
|||
return decls_by_name[structured_name]
|
||||
|
||||
cpp_entry_points = "".join(
|
||||
generate_cpp_entry_point(binding_name, lookup_decl(structured_name))
|
||||
generate_cpp_entry_point(
|
||||
binding_name, lookup_decl(structured_name), use_torch=use_torch
|
||||
)
|
||||
for binding_name, structured_name in bindings_to_generate
|
||||
)
|
||||
|
||||
entry_point_header = (
|
||||
"knossos-entry-points-torch.h" if use_torch else "knossos-entry-points.h"
|
||||
)
|
||||
|
||||
return f"""
|
||||
#include "knossos-entry-points.h"
|
||||
#include "{entry_point_header}"
|
||||
|
||||
namespace ks {{
|
||||
namespace entry_points {{
|
||||
|
@ -50,7 +81,7 @@ namespace generated {{
|
|||
"""
|
||||
|
||||
|
||||
def generate_cpp_entry_point(cpp_function_name, decl):
|
||||
def generate_cpp_entry_point(cpp_function_name, decl, use_torch):
|
||||
return_type = decl.return_type
|
||||
arg_types = [arg.type_ for arg in decl.args]
|
||||
if len(arg_types) == 1 and arg_types[0].is_tuple:
|
||||
|
@ -59,26 +90,43 @@ def generate_cpp_entry_point(cpp_function_name, decl):
|
|||
def arg_name(i):
|
||||
return f"arg{i}"
|
||||
|
||||
def ks_arg_name(i):
|
||||
return f"ks_{arg_name(i)}"
|
||||
|
||||
args = ", ".join(
|
||||
f"{ks_cpp_type(arg_type)} {arg_name(i)}" for i, arg_type in enumerate(arg_types)
|
||||
f"{entry_point_cpp_type(arg_type, use_torch)} {arg_name(i)}"
|
||||
for i, arg_type in enumerate(arg_types)
|
||||
)
|
||||
cpp_declaration = (
|
||||
f"{entry_point_cpp_type(return_type, use_torch)} {cpp_function_name}({args})"
|
||||
)
|
||||
|
||||
convert_arguments = "\n".join(
|
||||
f" auto {ks_arg_name(i)} = convert_argument<{ks_cpp_type(arg_type)}>({arg_name(i)});"
|
||||
for i, arg_type in enumerate(arg_types)
|
||||
)
|
||||
cpp_declaration = f"{ks_cpp_type(return_type)} {cpp_function_name}({args})"
|
||||
|
||||
ks_function_name = utils.encode_name(decl.name.mangled())
|
||||
arg_names = [arg_name(i) for i in range(len(arg_types))]
|
||||
arg_list = ", ".join(["&g_alloc"] + arg_names)
|
||||
cpp_call = f"ks::{ks_function_name}({arg_list})"
|
||||
args_streamed = ' << ", " '.join(f" << {arg}" for arg in arg_names)
|
||||
arg_list = ", ".join(["&g_alloc"] + [ks_arg_name(i) for i in range(len(arg_types))])
|
||||
cpp_call = f"""
|
||||
auto ks_ret = ks::{ks_function_name}({arg_list});
|
||||
auto ret = convert_return_value<{entry_point_cpp_type(return_type, use_torch)}>(ks_ret);
|
||||
"""
|
||||
|
||||
args_streamed = ' << ", " '.join(
|
||||
f" << {arg_name(i)}" for i in range(len(arg_types))
|
||||
)
|
||||
|
||||
return f"""
|
||||
{cpp_declaration} {{
|
||||
if (g_logging) {{
|
||||
std::cerr << "{ks_function_name}("{args_streamed} << ") =" << std::endl;
|
||||
auto ret = {cpp_call};
|
||||
std::cerr << ret << std::endl;
|
||||
return ret;
|
||||
}} else {{
|
||||
return {cpp_call};
|
||||
}}
|
||||
{convert_arguments}
|
||||
{cpp_call}
|
||||
if (g_logging) {{
|
||||
std::cerr << ret << std::endl;
|
||||
}}
|
||||
return ret;
|
||||
}}
|
||||
"""
|
||||
|
|
|
@ -148,7 +148,7 @@ derivatives_to_generate_default = ["fwd", "rev"]
|
|||
|
||||
|
||||
def generate_cpp_for_py_module_from_ks(
|
||||
ks_str, bindings_to_generate, python_module_name, use_aten=True,
|
||||
ks_str, bindings_to_generate, python_module_name, use_aten=True, use_torch=False
|
||||
):
|
||||
def mangled_with_type(structured_name):
|
||||
if not structured_name.has_type():
|
||||
|
@ -163,7 +163,9 @@ def generate_cpp_for_py_module_from_ks(
|
|||
]
|
||||
|
||||
cpp_ks_functions, decls = generate_cpp_from_ks(ks_str, use_aten=use_aten)
|
||||
cpp_entry_points = cgen.generate_cpp_entry_points(bindings_to_generate, decls)
|
||||
cpp_entry_points = cgen.generate_cpp_entry_points(
|
||||
bindings_to_generate, decls, use_torch=use_torch
|
||||
)
|
||||
cpp_pybind_module_declaration = generate_cpp_pybind_module_declaration(
|
||||
bindings, python_module_name
|
||||
)
|
||||
|
@ -189,11 +191,6 @@ PYBIND11_MODULE("""
|
|||
m.def("allocator_top", &ks::entry_points::allocator_top);
|
||||
m.def("allocator_peak", &ks::entry_points::allocator_peak);
|
||||
m.def("logging", &ks::entry_points::logging);
|
||||
|
||||
declare_tensor_1<ks::Float>(m, "Tensor_1_Float");
|
||||
declare_tensor_2<ks::Float>(m, "Tensor_2_Float");
|
||||
declare_tensor_2<ks::Integer>(m, "Tensor_2_Integer");
|
||||
|
||||
"""
|
||||
+ "\n".join(m_def(*t) for t in bindings_to_generate)
|
||||
+ """
|
||||
|
@ -204,10 +201,16 @@ PYBIND11_MODULE("""
|
|||
)
|
||||
|
||||
|
||||
def build_py_module_from_ks(ks_str, bindings_to_generate, use_aten=False):
|
||||
def build_py_module_from_ks(
|
||||
ks_str, bindings_to_generate, use_aten=False, use_torch=False
|
||||
):
|
||||
|
||||
cpp_str = generate_cpp_for_py_module_from_ks(
|
||||
ks_str, bindings_to_generate, "PYTHON_MODULE_NAME", use_aten
|
||||
ks_str,
|
||||
bindings_to_generate,
|
||||
"PYTHON_MODULE_NAME",
|
||||
use_aten=use_aten,
|
||||
use_torch=use_torch,
|
||||
)
|
||||
|
||||
cpp_fname = (
|
||||
|
@ -237,7 +240,11 @@ def build_module_using_pytorch_from_ks(
|
|||
Each StructuredName must have a type attached
|
||||
"""
|
||||
cpp_str = generate_cpp_for_py_module_from_ks(
|
||||
ks_str, bindings_to_generate, torch_extension_name, use_aten
|
||||
ks_str,
|
||||
bindings_to_generate,
|
||||
torch_extension_name,
|
||||
use_aten=use_aten,
|
||||
use_torch=True,
|
||||
)
|
||||
|
||||
return build_module_using_pytorch_from_cpp_backend(
|
||||
|
|
|
@ -396,12 +396,14 @@ def torch_from_ks(ks_object):
|
|||
return tuple(torch_from_ks(ks) for ks in ks_object)
|
||||
|
||||
if isinstance(ks_object, float):
|
||||
return torch.tensor(ks_object)
|
||||
return torch.tensor(ks_object) # TODO: use torch::Scalar?
|
||||
|
||||
return torch.from_numpy(numpy.array(ks_object, copy=True))
|
||||
assert isinstance(ks_object, torch.Tensor) # TODO: strings, etc.
|
||||
|
||||
return ks_object
|
||||
|
||||
|
||||
def torch_to_ks(py_mod, val):
|
||||
def torch_to_ks(val):
|
||||
"""
|
||||
Return a KS-compatible version of val.
|
||||
If val is a scalar, just return it as a float.
|
||||
|
@ -419,13 +421,7 @@ def torch_to_ks(py_mod, val):
|
|||
if len(val.shape) == 0:
|
||||
return val.item()
|
||||
|
||||
val = val.contiguous() # Get data, or copy if not already contiguous
|
||||
if len(val.shape) == 1:
|
||||
ks_tensor = py_mod.Tensor_1_Float(val.data_ptr(), *val.shape)
|
||||
if len(val.shape) == 2:
|
||||
ks_tensor = py_mod.Tensor_2_Float(val.data_ptr(), *val.shape)
|
||||
ks_tensor._torch_val = val # Stash object inside return value to prevent premature garbage collection
|
||||
return ks_tensor
|
||||
return val.contiguous() # Get data, or copy if not already contiguous
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -446,7 +442,7 @@ def logging(py_mod, flag=True):
|
|||
# See https://pytorch.org/docs/stable/notes/extending.html
|
||||
def forward_template(py_mod, ctx, *args):
|
||||
py_mod.reset_allocator()
|
||||
ks_args = (torch_to_ks(py_mod, x) for x in args)
|
||||
ks_args = (torch_to_ks(x) for x in args)
|
||||
|
||||
# Call it
|
||||
outputs = py_mod.entry(*ks_args)
|
||||
|
@ -459,8 +455,8 @@ def forward_template(py_mod, ctx, *args):
|
|||
|
||||
|
||||
def backward_template(py_mod, ctx, *args):
|
||||
ks_args = make_tuple_if_many_args(torch_to_ks(py_mod, x) for x in ctx.saved_tensors)
|
||||
ks_grad_args = make_tuple_if_many_args(torch_to_ks(py_mod, x) for x in args)
|
||||
ks_args = make_tuple_if_many_args(torch_to_ks(x) for x in ctx.saved_tensors)
|
||||
ks_grad_args = make_tuple_if_many_args(torch_to_ks(x) for x in args)
|
||||
outputs = py_mod.entry_vjp(ks_args, ks_grad_args)
|
||||
return torch_from_ks(outputs)
|
||||
|
||||
|
@ -477,7 +473,7 @@ def make_KscAutogradFunction(py_mod):
|
|||
"py_mod": py_mod,
|
||||
"forward": staticmethod(forward),
|
||||
"backward": staticmethod(backward),
|
||||
"adapt": staticmethod(lambda x: torch_to_ks(py_mod, x)),
|
||||
"adapt": staticmethod(lambda x: torch_to_ks(x)),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
#pragma once
|
||||
|
||||
#include "knossos-entry-points.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace ks {
|
||||
namespace entry_points {
|
||||
|
||||
constexpr at::ScalarType scalar_type_of_Float = c10::CppTypeToScalarType<Float>::value;
|
||||
|
||||
template<>
|
||||
struct Converter<ks::tensor<1, Float>, torch::Tensor>
|
||||
{
|
||||
static ks::tensor<1, Float> to_ks(torch::Tensor arg) {
|
||||
KS_ASSERT(arg.sizes().size() == 1u);
|
||||
KS_ASSERT(arg.is_contiguous());
|
||||
KS_ASSERT(arg.scalar_type() == scalar_type_of_Float);
|
||||
return ks::tensor<1, Float>((int)arg.size(0), arg.data_ptr<Float>());
|
||||
}
|
||||
|
||||
static torch::Tensor from_ks(ks::tensor<1, Float> ret) {
|
||||
torch::Tensor torch_ret = torch::empty(ret.size(), torch::TensorOptions().dtype(scalar_type_of_Float));
|
||||
std::memcpy(torch_ret.data_ptr(), ret.data(), ret.size() * sizeof(Float));
|
||||
return torch_ret;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Converter<ks::tensor<2, Float>, torch::Tensor>
|
||||
{
|
||||
static ks::tensor<2, Float> to_ks(torch::Tensor arg) {
|
||||
KS_ASSERT(arg.sizes().size() == 2u);
|
||||
KS_ASSERT(arg.is_contiguous());
|
||||
KS_ASSERT(arg.scalar_type() == scalar_type_of_Float);
|
||||
return ks::tensor<2, Float>({(int)arg.size(0), (int)arg.size(1)}, arg.data_ptr<Float>());
|
||||
}
|
||||
|
||||
static torch::Tensor from_ks(ks::tensor<2, Float> ret) {
|
||||
auto [size0, size1] = ret.size();
|
||||
torch::Tensor torch_ret = torch::empty({size0, size1}, torch::TensorOptions().dtype(scalar_type_of_Float));
|
||||
std::memcpy(torch_ret.data_ptr(), ret.data(), size0 * size1 * sizeof(Float));
|
||||
return torch_ret;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
|
@ -33,6 +33,58 @@ auto with_ks_allocator(const char * tracingMessage, RetType(*f)(ks::allocator*,
|
|||
};
|
||||
}
|
||||
|
||||
template<typename KSType, typename EntryPointType>
|
||||
KSType convert_argument(EntryPointType arg);
|
||||
|
||||
template<typename EntryPointType, typename KSType>
|
||||
EntryPointType convert_return_value(KSType ret);
|
||||
|
||||
template<typename KSType, typename EntryPointType>
|
||||
struct Converter
|
||||
{
|
||||
static_assert(std::is_same<KSType, EntryPointType>::value, "Entry point type is not supported");
|
||||
|
||||
static KSType to_ks(EntryPointType arg) {
|
||||
return arg;
|
||||
}
|
||||
|
||||
static EntryPointType from_ks(KSType ret) {
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename ...KsElementTypes, typename ...EntryPointElementTypes>
|
||||
struct Converter<ks::Tuple<KsElementTypes...>, ks::Tuple<EntryPointElementTypes...>>
|
||||
{
|
||||
template<size_t ...Indices>
|
||||
static ks::Tuple<KsElementTypes...> to_ks_impl(ks::Tuple<EntryPointElementTypes...> arg, std::index_sequence<Indices...>) {
|
||||
return ks::make_Tuple(convert_argument<KsElementTypes>(ks::get<Indices>(arg))...);
|
||||
}
|
||||
|
||||
static ks::Tuple<KsElementTypes...> to_ks(ks::Tuple<EntryPointElementTypes...> arg) {
|
||||
return to_ks_impl(arg, std::index_sequence_for<EntryPointElementTypes...>{});
|
||||
}
|
||||
|
||||
template<size_t ...Indices>
|
||||
static ks::Tuple<EntryPointElementTypes...> from_ks_impl(ks::Tuple<KsElementTypes...> ret, std::index_sequence<Indices...>) {
|
||||
return ks::make_Tuple(convert_return_value<EntryPointElementTypes>(ks::get<Indices>(ret))...);
|
||||
}
|
||||
|
||||
static ks::Tuple<EntryPointElementTypes...> from_ks(ks::Tuple<KsElementTypes...> ret) {
|
||||
return from_ks_impl(ret, std::index_sequence_for<KsElementTypes...>{});
|
||||
}
|
||||
};
|
||||
|
||||
template<typename KSType, typename EntryPointType>
|
||||
KSType convert_argument(EntryPointType arg) {
|
||||
return Converter<KSType, EntryPointType>::to_ks(arg);
|
||||
}
|
||||
|
||||
template<typename EntryPointType, typename KSType>
|
||||
EntryPointType convert_return_value(KSType ret) {
|
||||
return Converter<KSType, EntryPointType>::from_ks(ret);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -81,59 +81,3 @@ private:
|
|||
|
||||
}}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void declare_tensor_2(py::module &m, char const* name) {
|
||||
// Wrap ks_tensor<Dim, T> to point to supplied python memory
|
||||
static constexpr size_t Dim = 2;
|
||||
py::class_<ks::tensor<2, T>>(m, name, py::buffer_protocol(), py::module_local(), py::dynamic_attr())
|
||||
.def(py::init([](std::uintptr_t v, size_t m, size_t n) {
|
||||
ks::tensor_dimension<Dim>::index_type size {int(m),int(n)};
|
||||
return ks::tensor<Dim, T>(size, reinterpret_cast<T*>(v)); // Reference to caller's data
|
||||
}))
|
||||
// And describe buffer shape to Python
|
||||
// Returned tensors will be living on g_alloc, so will become invalid after allocator_reset()
|
||||
.def_buffer([](ks::tensor<Dim, T> &t) -> py::buffer_info {
|
||||
return py::buffer_info(
|
||||
t.data(), /* Pointer to buffer */
|
||||
sizeof(T), /* Size of one scalar */
|
||||
py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
|
||||
Dim, /* Number of dimensions */
|
||||
{ ks::get_dimension<0>(t.size()), ks::get_dimension<1>(t.size()) }, /* Buffer dimensions */
|
||||
{ sizeof(T) * ks::get_dimension<1>(t.size()), /* Strides (in bytes) for each index */
|
||||
sizeof(T) }
|
||||
);
|
||||
})
|
||||
;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void declare_tensor_1(py::module &m, char const* name) {
|
||||
// Wrap ks_tensor<1, T> to point to supplied python memory
|
||||
static constexpr size_t Dim = 1;
|
||||
py::class_<ks::tensor<Dim, T>>(m, name, py::buffer_protocol(), py::module_local(), py::dynamic_attr())
|
||||
.def(py::init([](std::uintptr_t v, size_t n) {
|
||||
ks::tensor_dimension<Dim>::index_type size {int(n)};
|
||||
// Note: We are capturing a reference to the caller's data.
|
||||
// we expect the user to attach a Python object to this class
|
||||
// in order to keep that data alive. See torch_frontend.py:torch_to_ks
|
||||
// OR: of course we could just copy, but it's useful to keep track of the cost
|
||||
// so preserving an implementation where we can avoid the copy feels
|
||||
// valuable.
|
||||
return ks::tensor<Dim, T>(size, reinterpret_cast<T*>(v)); // Reference to caller's data
|
||||
}))
|
||||
// And describe buffer shape to Python
|
||||
// Returned tensors will be living on g_alloc, so will become invalid after allocator_reset()
|
||||
.def_buffer([](ks::tensor<Dim, T> &t) -> py::buffer_info {
|
||||
return py::buffer_info(
|
||||
t.data(), /* Pointer to buffer */
|
||||
sizeof(T), /* Size of one scalar */
|
||||
py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
|
||||
Dim, /* Number of dimensions */
|
||||
{ ks::get_dimension<0>(t.size()) }, /* Buffer dimensions */
|
||||
{ sizeof(T) }
|
||||
);
|
||||
})
|
||||
;
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче