[RUNTIME] Add function to pack arguments (#452)
This commit is contained in:
Родитель
769544ad7d
Коммит
5061a6da5e
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include <tvm/runtime/c_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
namespace tvm {
|
||||
namespace runtime {
|
||||
|
@ -31,7 +32,7 @@ union ArgUnion {
|
|||
* \brief Create a packed function from void addr types.
|
||||
*
|
||||
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
|
||||
* \param arg_types The arguments that wish to get from
|
||||
* \param arg_types The arguments type information.
|
||||
* \tparam F the function type
|
||||
*
|
||||
* \return The wrapped packed function.
|
||||
|
@ -42,13 +43,24 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
|
|||
* \brief Create a packed function that from function only packs buffer arguments.
|
||||
*
|
||||
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
|
||||
* \param arg_types The arguments that wish to get from
|
||||
* \param arg_types The arguments type information.
|
||||
* \tparam F the function type
|
||||
*
|
||||
* \return The wrapped packed function.
|
||||
*/
|
||||
template<typename F>
|
||||
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
|
||||
/*!
|
||||
* \brief Create a packed function that from function that takes a packed arguments.
|
||||
*
|
||||
* \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
|
||||
* \param arg_types The arguments that wish to get from
|
||||
* \tparam F the function type
|
||||
*
|
||||
* \return The wrapped packed function.
|
||||
*/
|
||||
template<typename F>
|
||||
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types);
|
||||
/*!
|
||||
* \brief Extract number of buffer argument from the argument types.
|
||||
* \param arg_types The argument types.
|
||||
|
@ -179,6 +191,56 @@ inline PackedFunc PackFuncNonBufferArg_(
|
|||
};
|
||||
return PackedFunc(ret);
|
||||
}
|
||||
|
||||
template<int N, typename F>
|
||||
inline PackedFunc PackFuncPackedArg_(
|
||||
F f, const std::vector<ArgConvertCode>& codes) {
|
||||
int num_args = static_cast<int>(codes.size());
|
||||
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
|
||||
TempArray<uint64_t, N> pack_(num_args);
|
||||
int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
|
||||
int32_t* ptr = pack;
|
||||
static_assert(sizeof(TVMValue) == 8, "invariant");
|
||||
static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
switch (codes[i]) {
|
||||
case HANDLE_TO_HANDLE: {
|
||||
std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
|
||||
ptr += sizeof(void*) / sizeof(int32_t);
|
||||
break;
|
||||
}
|
||||
case INT64_TO_INT64:
|
||||
case FLOAT64_TO_FLOAT64: {
|
||||
std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
|
||||
ptr += 2;
|
||||
break;
|
||||
}
|
||||
case INT64_TO_INT32: {
|
||||
*ptr = static_cast<int32_t>(args.values[i].v_int64);
|
||||
++ptr;
|
||||
break;
|
||||
}
|
||||
case INT64_TO_UINT32 : {
|
||||
*reinterpret_cast<uint32_t*>(ptr) =
|
||||
static_cast<uint32_t>(args.values[i].v_int64);
|
||||
++ptr;
|
||||
break;
|
||||
}
|
||||
case FLOAT64_TO_FLOAT32: {
|
||||
*reinterpret_cast<float*>(ptr) =
|
||||
static_cast<float>(args.values[i].v_float64);
|
||||
++ptr;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "not reached"; break;
|
||||
}
|
||||
}
|
||||
}
|
||||
f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
|
||||
};
|
||||
return PackedFunc(ret);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template<typename F>
|
||||
|
@ -228,6 +290,21 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
|
|||
return detail::PackFuncNonBufferArg_<0>(f, base, codes);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) {
|
||||
std::vector<detail::ArgConvertCode> codes;
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
|
||||
}
|
||||
size_t nargs = codes.size();
|
||||
// specialization
|
||||
if (nargs <= 4) {
|
||||
return detail::PackFuncPackedArg_<4>(f, codes);
|
||||
} else {
|
||||
return detail::PackFuncPackedArg_<0>(f, codes);
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
#endif // TVM_RUNTIME_PACK_ARGS_H_
|
||||
|
|
|
@ -133,7 +133,8 @@ class ROCMWrappedFunc {
|
|||
// invoke the function with void arguments
|
||||
void operator()(TVMArgs args,
|
||||
TVMRetValue* rv,
|
||||
void** void_args) const {
|
||||
void* packed_args,
|
||||
size_t packed_nbytes) const {
|
||||
int device_id;
|
||||
ROCM_CALL(hipGetDevice(&device_id));
|
||||
if (fcache_[device_id] == nullptr) {
|
||||
|
@ -141,6 +142,11 @@ class ROCMWrappedFunc {
|
|||
}
|
||||
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
|
||||
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
|
||||
void* config[] = {
|
||||
HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args,
|
||||
HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes,
|
||||
HIP_LAUNCH_PARAM_END
|
||||
};
|
||||
// HIP supports only extra_args.
|
||||
ROCM_DRIVER_CALL(hipModuleLaunchKernel(
|
||||
fcache_[device_id],
|
||||
|
@ -150,7 +156,8 @@ class ROCMWrappedFunc {
|
|||
wl.block_dim(0),
|
||||
wl.block_dim(1),
|
||||
wl.block_dim(2),
|
||||
0, strm, void_args, 0));
|
||||
0, strm, nullptr,
|
||||
reinterpret_cast<void**>(&config)));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -180,7 +187,7 @@ PackedFunc ROCMModuleNode::GetFunction(
|
|||
const FunctionInfo& info = it->second;
|
||||
ROCMWrappedFunc f;
|
||||
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
|
||||
return PackFuncVoidAddr(f, info.arg_types);
|
||||
return PackFuncPackedArg(f, info.arg_types);
|
||||
}
|
||||
|
||||
Module ROCMModuleCreate(
|
||||
|
|
Загрузка…
Ссылка в новой задаче