[RUNTIME] Add function to pack arguments (#452)
This commit is contained in:
Родитель
769544ad7d
Коммит
5061a6da5e
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#include <tvm/runtime/c_runtime_api.h>
|
#include <tvm/runtime/c_runtime_api.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
@ -31,7 +32,7 @@ union ArgUnion {
|
||||||
* \brief Create a packed function from void addr types.
|
* \brief Create a packed function from void addr types.
|
||||||
*
|
*
|
||||||
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
|
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
|
||||||
* \param arg_types The arguments that wish to get from
|
* \param arg_types The arguments type information.
|
||||||
* \tparam F the function type
|
* \tparam F the function type
|
||||||
*
|
*
|
||||||
* \return The wrapped packed function.
|
* \return The wrapped packed function.
|
||||||
|
@ -42,13 +43,24 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
|
||||||
* \brief Create a packed function that from function only packs buffer arguments.
|
* \brief Create a packed function that from function only packs buffer arguments.
|
||||||
*
|
*
|
||||||
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
|
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
|
||||||
* \param arg_types The arguments that wish to get from
|
* \param arg_types The arguments type information.
|
||||||
* \tparam F the function type
|
* \tparam F the function type
|
||||||
*
|
*
|
||||||
* \return The wrapped packed function.
|
* \return The wrapped packed function.
|
||||||
*/
|
*/
|
||||||
template<typename F>
|
template<typename F>
|
||||||
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
|
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
|
||||||
|
/*!
|
||||||
|
* \brief Create a packed function that from function that takes a packed arguments.
|
||||||
|
*
|
||||||
|
* \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
|
||||||
|
* \param arg_types The arguments that wish to get from
|
||||||
|
* \tparam F the function type
|
||||||
|
*
|
||||||
|
* \return The wrapped packed function.
|
||||||
|
*/
|
||||||
|
template<typename F>
|
||||||
|
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types);
|
||||||
/*!
|
/*!
|
||||||
* \brief Extract number of buffer argument from the argument types.
|
* \brief Extract number of buffer argument from the argument types.
|
||||||
* \param arg_types The argument types.
|
* \param arg_types The argument types.
|
||||||
|
@ -179,6 +191,56 @@ inline PackedFunc PackFuncNonBufferArg_(
|
||||||
};
|
};
|
||||||
return PackedFunc(ret);
|
return PackedFunc(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int N, typename F>
|
||||||
|
inline PackedFunc PackFuncPackedArg_(
|
||||||
|
F f, const std::vector<ArgConvertCode>& codes) {
|
||||||
|
int num_args = static_cast<int>(codes.size());
|
||||||
|
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
TempArray<uint64_t, N> pack_(num_args);
|
||||||
|
int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
|
||||||
|
int32_t* ptr = pack;
|
||||||
|
static_assert(sizeof(TVMValue) == 8, "invariant");
|
||||||
|
static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
|
||||||
|
for (int i = 0; i < num_args; ++i) {
|
||||||
|
switch (codes[i]) {
|
||||||
|
case HANDLE_TO_HANDLE: {
|
||||||
|
std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
|
||||||
|
ptr += sizeof(void*) / sizeof(int32_t);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case INT64_TO_INT64:
|
||||||
|
case FLOAT64_TO_FLOAT64: {
|
||||||
|
std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
|
||||||
|
ptr += 2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case INT64_TO_INT32: {
|
||||||
|
*ptr = static_cast<int32_t>(args.values[i].v_int64);
|
||||||
|
++ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case INT64_TO_UINT32 : {
|
||||||
|
*reinterpret_cast<uint32_t*>(ptr) =
|
||||||
|
static_cast<uint32_t>(args.values[i].v_int64);
|
||||||
|
++ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case FLOAT64_TO_FLOAT32: {
|
||||||
|
*reinterpret_cast<float*>(ptr) =
|
||||||
|
static_cast<float>(args.values[i].v_float64);
|
||||||
|
++ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
LOG(FATAL) << "not reached"; break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
|
||||||
|
};
|
||||||
|
return PackedFunc(ret);
|
||||||
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template<typename F>
|
template<typename F>
|
||||||
|
@ -228,6 +290,21 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
|
||||||
return detail::PackFuncNonBufferArg_<0>(f, base, codes);
|
return detail::PackFuncNonBufferArg_<0>(f, base, codes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename F>
|
||||||
|
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) {
|
||||||
|
std::vector<detail::ArgConvertCode> codes;
|
||||||
|
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||||
|
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
|
||||||
|
}
|
||||||
|
size_t nargs = codes.size();
|
||||||
|
// specialization
|
||||||
|
if (nargs <= 4) {
|
||||||
|
return detail::PackFuncPackedArg_<4>(f, codes);
|
||||||
|
} else {
|
||||||
|
return detail::PackFuncPackedArg_<0>(f, codes);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
#endif // TVM_RUNTIME_PACK_ARGS_H_
|
#endif // TVM_RUNTIME_PACK_ARGS_H_
|
||||||
|
|
|
@ -133,7 +133,8 @@ class ROCMWrappedFunc {
|
||||||
// invoke the function with void arguments
|
// invoke the function with void arguments
|
||||||
void operator()(TVMArgs args,
|
void operator()(TVMArgs args,
|
||||||
TVMRetValue* rv,
|
TVMRetValue* rv,
|
||||||
void** void_args) const {
|
void* packed_args,
|
||||||
|
size_t packed_nbytes) const {
|
||||||
int device_id;
|
int device_id;
|
||||||
ROCM_CALL(hipGetDevice(&device_id));
|
ROCM_CALL(hipGetDevice(&device_id));
|
||||||
if (fcache_[device_id] == nullptr) {
|
if (fcache_[device_id] == nullptr) {
|
||||||
|
@ -141,6 +142,11 @@ class ROCMWrappedFunc {
|
||||||
}
|
}
|
||||||
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
|
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
|
||||||
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
|
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
|
||||||
|
void* config[] = {
|
||||||
|
HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args,
|
||||||
|
HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes,
|
||||||
|
HIP_LAUNCH_PARAM_END
|
||||||
|
};
|
||||||
// HIP supports only extra_args.
|
// HIP supports only extra_args.
|
||||||
ROCM_DRIVER_CALL(hipModuleLaunchKernel(
|
ROCM_DRIVER_CALL(hipModuleLaunchKernel(
|
||||||
fcache_[device_id],
|
fcache_[device_id],
|
||||||
|
@ -150,7 +156,8 @@ class ROCMWrappedFunc {
|
||||||
wl.block_dim(0),
|
wl.block_dim(0),
|
||||||
wl.block_dim(1),
|
wl.block_dim(1),
|
||||||
wl.block_dim(2),
|
wl.block_dim(2),
|
||||||
0, strm, void_args, 0));
|
0, strm, nullptr,
|
||||||
|
reinterpret_cast<void**>(&config)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -180,7 +187,7 @@ PackedFunc ROCMModuleNode::GetFunction(
|
||||||
const FunctionInfo& info = it->second;
|
const FunctionInfo& info = it->second;
|
||||||
ROCMWrappedFunc f;
|
ROCMWrappedFunc f;
|
||||||
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
|
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
|
||||||
return PackFuncVoidAddr(f, info.arg_types);
|
return PackFuncPackedArg(f, info.arg_types);
|
||||||
}
|
}
|
||||||
|
|
||||||
Module ROCMModuleCreate(
|
Module ROCMModuleCreate(
|
||||||
|
|
Загрузка…
Ссылка в новой задаче