[EXECUTOR] Split graph_executor to header file and (runtime) source file (#300)
* [EXECUTOR] Split graph_executor to header file and (runtime) source file * Fix
This commit is contained in:
Родитель
41768cf918
Коммит
1389d20888
|
@ -1,99 +1,23 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2017 by Contributors
|
* Copyright (c) 2017 by Contributors
|
||||||
* \file NNVM Graph executor.
|
* \file graph_executor.cc
|
||||||
*/
|
*/
|
||||||
#include <dmlc/io.h>
|
#include "./graph_executor.h"
|
||||||
#include <dmlc/memory_io.h>
|
|
||||||
#include <tvm/runtime/registry.h>
|
|
||||||
#include <tvm/runtime/packed_func.h>
|
|
||||||
#include <tvm/runtime/module.h>
|
|
||||||
#include <nnvm/graph.h>
|
|
||||||
#include <nnvm/graph_attr_types.h>
|
|
||||||
#include <nnvm/tuple.h>
|
|
||||||
#include <nnvm/pass.h>
|
|
||||||
#include <numeric>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace contrib {
|
namespace contrib {
|
||||||
|
|
||||||
using tvm::runtime::TVMArgs;
|
|
||||||
using tvm::runtime::TVMRetValue;
|
|
||||||
using tvm::runtime::PackedFunc;
|
|
||||||
using nnvm::StorageVector;
|
|
||||||
using nnvm::ShapeVector;
|
|
||||||
using nnvm::TShape;
|
|
||||||
using nnvm::NodeAttrs;
|
|
||||||
|
|
||||||
/*! \brief DLPack compatible data types */
|
|
||||||
using DLTypeVector = std::vector<DLDataType>;
|
|
||||||
|
|
||||||
/*! \brief The executor function */
|
|
||||||
using FOpExec = std::function<void()>;
|
|
||||||
|
|
||||||
/*! \brief macro to do C API call */
|
|
||||||
#define TVM_CCALL(func) \
|
|
||||||
{ \
|
|
||||||
int ret = (func); \
|
|
||||||
CHECK_EQ(ret, 0) \
|
|
||||||
<< TVMGetLastError(); \
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief Graph Executor with TVM runtime */
|
|
||||||
class GraphExecutor : public runtime::ModuleNode {
|
|
||||||
public:
|
|
||||||
const char* type_key() const {
|
|
||||||
return "GraphExecutor";
|
|
||||||
}
|
|
||||||
PackedFunc GetFunction(
|
|
||||||
const std::string& name,
|
|
||||||
const std::shared_ptr<ModuleNode>& sptr_to_self);
|
|
||||||
// Destructor
|
|
||||||
~GraphExecutor();
|
|
||||||
// Setup with a given graph
|
|
||||||
void Init(const nnvm::Graph& g, TVMContext ctx);
|
|
||||||
// Copy data to index-th input
|
|
||||||
void SetInput(int index, DLTensor* data_in);
|
|
||||||
// Copy index-th output to data_out
|
|
||||||
void GetOutput(int index, DLTensor* data_out);
|
|
||||||
// Load parameters from stream
|
|
||||||
void LoadParams(dmlc::Stream* strm);
|
|
||||||
// Load parameters from binary file blob
|
|
||||||
void LoadParamsFromBlob(std::string param_blob);
|
|
||||||
// Execute the graph.
|
|
||||||
void Run();
|
|
||||||
|
|
||||||
private:
|
|
||||||
// functions
|
|
||||||
void SetupStorage();
|
|
||||||
void SetupOpExecs();
|
|
||||||
// Constructor to create TVM op
|
|
||||||
FOpExec CreateTVMOp(const nnvm::NodeAttrs& attrs,
|
|
||||||
std::vector<DLTensor> inputs,
|
|
||||||
size_t num_inputs);
|
|
||||||
// The graph to be executed.
|
|
||||||
nnvm::Graph graph_;
|
|
||||||
// The execution context
|
|
||||||
TVMContext ctx_;
|
|
||||||
// Common storage pool
|
|
||||||
std::vector<DLTensor*> storage_pool_;
|
|
||||||
// The data shape
|
|
||||||
std::vector<TShape> data_shape_;
|
|
||||||
// The data entry
|
|
||||||
std::vector<DLTensor> data_entry_;
|
|
||||||
// The operation lambda on each node
|
|
||||||
std::vector<FOpExec> op_execs_;
|
|
||||||
// The code module.
|
|
||||||
tvm::runtime::Module module_;
|
|
||||||
};
|
|
||||||
|
|
||||||
PackedFunc GraphExecutor::GetFunction(
|
PackedFunc GraphExecutor::GetFunction(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::shared_ptr<ModuleNode>& sptr_to_self) {
|
const std::shared_ptr<ModuleNode>& sptr_to_self) {
|
||||||
// return member functions during query.
|
// return member functions during query.
|
||||||
if (name == "set_input") {
|
if (name == "set_input") {
|
||||||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
|
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
|
||||||
this->SetInput(args[0], args[1]);
|
if (args[0].type_code() == kStr) {
|
||||||
|
this->SetInput(this->GetIndex(args[0]), args[1]);
|
||||||
|
} else {
|
||||||
|
this->SetInput(args[0], args[1]);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
} else if (name == "get_output") {
|
} else if (name == "get_output") {
|
||||||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
|
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
|
||||||
|
@ -129,10 +53,17 @@ void GraphExecutor::Init(const nnvm::Graph& g, TVMContext ctx) {
|
||||||
graph_ = g;
|
graph_ = g;
|
||||||
ctx_ = ctx;
|
ctx_ = ctx;
|
||||||
module_ = g.GetAttr<tvm::runtime::Module>("module");
|
module_ = g.GetAttr<tvm::runtime::Module>("module");
|
||||||
|
this->SetupNameIndex();
|
||||||
this->SetupStorage();
|
this->SetupStorage();
|
||||||
this->SetupOpExecs();
|
this->SetupOpExecs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GraphExecutor::GetIndex(std::string name) {
|
||||||
|
CHECK(name_idx_.count(name))
|
||||||
|
<< name << " is not in the graph.";
|
||||||
|
return name_idx_.at(name);
|
||||||
|
}
|
||||||
|
|
||||||
void GraphExecutor::SetInput(int index, DLTensor* data_in) {
|
void GraphExecutor::SetInput(int index, DLTensor* data_in) {
|
||||||
const auto& idx = graph_.indexed_graph();
|
const auto& idx = graph_.indexed_graph();
|
||||||
CHECK_LT(static_cast<size_t>(index), idx.input_nodes().size());
|
CHECK_LT(static_cast<size_t>(index), idx.input_nodes().size());
|
||||||
|
@ -147,33 +78,6 @@ void GraphExecutor::GetOutput(int index, DLTensor* data_out) {
|
||||||
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
|
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
|
|
||||||
|
|
||||||
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
|
|
||||||
uint64_t header = kTVMNDArrayMagic, reserved = 0;
|
|
||||||
strm->Write(&header, sizeof(header));
|
|
||||||
strm->Write(&reserved, sizeof(reserved));
|
|
||||||
|
|
||||||
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
|
|
||||||
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
|
|
||||||
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
|
|
||||||
|
|
||||||
int ndim = tensor->ndim;
|
|
||||||
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
|
|
||||||
|
|
||||||
int type_size = tensor->dtype.bits / 8;
|
|
||||||
int64_t size = 1;
|
|
||||||
for (int i = 0; i < ndim; ++i) {
|
|
||||||
size *= tensor->shape[i];
|
|
||||||
}
|
|
||||||
int64_t data_byte_size = type_size * size;
|
|
||||||
strm->Write(&data_byte_size, sizeof(data_byte_size));
|
|
||||||
strm->Write(tensor->data, data_byte_size);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
|
bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
|
||||||
uint64_t header, reserved;
|
uint64_t header, reserved;
|
||||||
CHECK(strm->Read(&header, sizeof(header)))
|
CHECK(strm->Read(&header, sizeof(header)))
|
||||||
|
@ -209,37 +113,6 @@ bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
|
|
||||||
|
|
||||||
TVM_REGISTER_GLOBAL("tvm_graph._save_param_dict")
|
|
||||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
|
||||||
std::string fname = args[0];
|
|
||||||
int num_params = args[1];
|
|
||||||
std::vector<std::string> names;
|
|
||||||
names.reserve(num_params);
|
|
||||||
std::vector<DLTensor*> arrays;
|
|
||||||
arrays.reserve(num_params);
|
|
||||||
for (int i = 2; i < (2 + 2*num_params); i += 2) {
|
|
||||||
names.emplace_back(args[i].operator std::string());
|
|
||||||
arrays.emplace_back(args[i+1].operator DLTensor*());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
|
||||||
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
|
|
||||||
fo->Write(&header, sizeof(header));
|
|
||||||
fo->Write(&reserved, sizeof(reserved));
|
|
||||||
|
|
||||||
fo->Write(names);
|
|
||||||
{
|
|
||||||
uint64_t sz = static_cast<uint64_t>(arrays.size());
|
|
||||||
fo->Write(&sz, sizeof(sz));
|
|
||||||
for (size_t i = 0; i < sz; ++i) {
|
|
||||||
SaveDLTensor(fo.get(), arrays[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
void GraphExecutor::LoadParams(dmlc::Stream *strm) {
|
void GraphExecutor::LoadParams(dmlc::Stream *strm) {
|
||||||
uint64_t header, reserved;
|
uint64_t header, reserved;
|
||||||
CHECK(strm->Read(&header))
|
CHECK(strm->Read(&header))
|
||||||
|
@ -277,6 +150,15 @@ void GraphExecutor::LoadParamsFromBlob(std::string param_blob) {
|
||||||
this->LoadParams(&strm);
|
this->LoadParams(&strm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GraphExecutor::SetupNameIndex() {
|
||||||
|
nnvm::Symbol s;
|
||||||
|
s.outputs = graph_.outputs;
|
||||||
|
std::vector<std::string> input_names = s.ListInputNames(nnvm::Symbol::kAll);
|
||||||
|
for (size_t i = 0; i < input_names.size(); ++i) {
|
||||||
|
name_idx_[input_names[i]] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void GraphExecutor::SetupStorage() {
|
void GraphExecutor::SetupStorage() {
|
||||||
const auto& idx = graph_.indexed_graph();
|
const auto& idx = graph_.indexed_graph();
|
||||||
// Grab saved optimization plan from graph.
|
// Grab saved optimization plan from graph.
|
||||||
|
@ -399,23 +281,6 @@ FOpExec GraphExecutor::CreateTVMOp(const nnvm::NodeAttrs& attrs,
|
||||||
return fexec;
|
return fexec;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
|
|
||||||
std::string func_name;
|
|
||||||
uint32_t num_inputs;
|
|
||||||
uint32_t num_outputs;
|
|
||||||
bool flatten_data;
|
|
||||||
DMLC_DECLARE_PARAMETER(TVMOpParam) {
|
|
||||||
DMLC_DECLARE_FIELD(func_name);
|
|
||||||
DMLC_DECLARE_FIELD(num_inputs)
|
|
||||||
.set_default(1);
|
|
||||||
DMLC_DECLARE_FIELD(num_outputs)
|
|
||||||
.set_default(1);
|
|
||||||
DMLC_DECLARE_FIELD(flatten_data)
|
|
||||||
.set_default(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
DMLC_REGISTER_PARAMETER(TVMOpParam);
|
|
||||||
|
|
||||||
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
|
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
|
||||||
template<typename PType>
|
template<typename PType>
|
||||||
inline void ParamParser(nnvm::NodeAttrs* attrs) {
|
inline void ParamParser(nnvm::NodeAttrs* attrs) {
|
||||||
|
@ -436,6 +301,8 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) {
|
||||||
attrs->parsed = std::move(param);
|
attrs->parsed = std::move(param);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(TVMOpParam);
|
||||||
|
|
||||||
// ewise tvm op
|
// ewise tvm op
|
||||||
NNVM_REGISTER_OP(tvm_op)
|
NNVM_REGISTER_OP(tvm_op)
|
||||||
.set_attr_parser(ParamParser<TVMOpParam>)
|
.set_attr_parser(ParamParser<TVMOpParam>)
|
||||||
|
@ -448,33 +315,6 @@ NNVM_REGISTER_OP(tvm_op)
|
||||||
return param.num_outputs;
|
return param.num_outputs;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create executor
|
|
||||||
tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) {
|
|
||||||
std::shared_ptr<GraphExecutor> exec =
|
|
||||||
std::make_shared<GraphExecutor>();
|
|
||||||
exec->Init(g, ctx);
|
|
||||||
return tvm::runtime::Module(exec);
|
|
||||||
}
|
|
||||||
|
|
||||||
TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
|
|
||||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
|
||||||
void* graph_handle = args[0];
|
|
||||||
int device_type = args[1];
|
|
||||||
int device_id = args[2];
|
|
||||||
TVMContext ctx{static_cast<DLDeviceType>(device_type), device_id};
|
|
||||||
nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
|
|
||||||
*rv = CreateExecutor(g, ctx);
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
TVM_REGISTER_GLOBAL("tvm_graph._get_module_from_graph")
|
|
||||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
|
||||||
void* graph_handle = args[0];
|
|
||||||
nnvm::Graph* g = static_cast<nnvm::Graph*>(graph_handle);
|
|
||||||
*rv = g->MoveCopyAttr<tvm::runtime::Module>("module");
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
TVM_REGISTER_GLOBAL("tvm_graph._load_executor")
|
TVM_REGISTER_GLOBAL("tvm_graph._load_executor")
|
||||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||||
std::string sym_json = args[0];
|
std::string sym_json = args[0];
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file graph_executor.h
|
||||||
|
*/
|
||||||
|
#ifndef TVM_GRAPH_EXECUTOR_H_
|
||||||
|
#define TVM_GRAPH_EXECUTOR_H_
|
||||||
|
|
||||||
|
#include <dmlc/io.h>
|
||||||
|
#include <dmlc/memory_io.h>
|
||||||
|
#include <tvm/runtime/registry.h>
|
||||||
|
#include <tvm/runtime/packed_func.h>
|
||||||
|
#include <tvm/runtime/module.h>
|
||||||
|
#include <nnvm/graph.h>
|
||||||
|
#include <nnvm/graph_attr_types.h>
|
||||||
|
#include <nnvm/tuple.h>
|
||||||
|
#include <nnvm/pass.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace contrib {
|
||||||
|
|
||||||
|
using tvm::runtime::TVMArgs;
|
||||||
|
using tvm::runtime::TVMRetValue;
|
||||||
|
using tvm::runtime::PackedFunc;
|
||||||
|
using nnvm::StorageVector;
|
||||||
|
using nnvm::ShapeVector;
|
||||||
|
using nnvm::TShape;
|
||||||
|
using nnvm::NodeAttrs;
|
||||||
|
|
||||||
|
/*! \brief DLPack compatible data types */
|
||||||
|
using DLTypeVector = std::vector<DLDataType>;
|
||||||
|
|
||||||
|
/*! \brief The executor function */
|
||||||
|
using FOpExec = std::function<void()>;
|
||||||
|
|
||||||
|
/*! \brief macro to do C API call */
|
||||||
|
#define TVM_CCALL(func) \
|
||||||
|
{ \
|
||||||
|
int ret = (func); \
|
||||||
|
CHECK_EQ(ret, 0) \
|
||||||
|
<< TVMGetLastError(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
|
||||||
|
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
|
||||||
|
|
||||||
|
/*! \brief Graph Executor with TVM runtime */
|
||||||
|
class GraphExecutor : public runtime::ModuleNode {
|
||||||
|
public:
|
||||||
|
const char* type_key() const {
|
||||||
|
return "GraphExecutor";
|
||||||
|
}
|
||||||
|
PackedFunc GetFunction(
|
||||||
|
const std::string& name,
|
||||||
|
const std::shared_ptr<ModuleNode>& sptr_to_self);
|
||||||
|
// Destructor
|
||||||
|
~GraphExecutor();
|
||||||
|
// Setup with a given graph
|
||||||
|
void Init(const nnvm::Graph& g, TVMContext ctx);
|
||||||
|
// Get index of variable
|
||||||
|
int GetIndex(std::string name);
|
||||||
|
// Copy data to index-th input
|
||||||
|
void SetInput(int index, DLTensor* data_in);
|
||||||
|
// Copy index-th output to data_out
|
||||||
|
void GetOutput(int index, DLTensor* data_out);
|
||||||
|
// Load parameters from stream
|
||||||
|
void LoadParams(dmlc::Stream* strm);
|
||||||
|
// Load parameters from binary file blob
|
||||||
|
void LoadParamsFromBlob(std::string param_blob);
|
||||||
|
// Execute the graph.
|
||||||
|
void Run();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// functions
|
||||||
|
void SetupNameIndex();
|
||||||
|
void SetupStorage();
|
||||||
|
void SetupOpExecs();
|
||||||
|
// Constructor to create TVM op
|
||||||
|
FOpExec CreateTVMOp(const nnvm::NodeAttrs& attrs,
|
||||||
|
std::vector<DLTensor> inputs,
|
||||||
|
size_t num_inputs);
|
||||||
|
// The graph to be executed.
|
||||||
|
nnvm::Graph graph_;
|
||||||
|
// The execution context
|
||||||
|
TVMContext ctx_;
|
||||||
|
// Common storage pool
|
||||||
|
std::vector<DLTensor*> storage_pool_;
|
||||||
|
// The data shape
|
||||||
|
std::vector<TShape> data_shape_;
|
||||||
|
// The data entry
|
||||||
|
std::vector<DLTensor> data_entry_;
|
||||||
|
// The operation lambda on each node
|
||||||
|
std::vector<FOpExec> op_execs_;
|
||||||
|
// The code module.
|
||||||
|
tvm::runtime::Module module_;
|
||||||
|
std::unordered_map<std::string, size_t> name_idx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
|
||||||
|
std::string func_name;
|
||||||
|
uint32_t num_inputs;
|
||||||
|
uint32_t num_outputs;
|
||||||
|
bool flatten_data;
|
||||||
|
DMLC_DECLARE_PARAMETER(TVMOpParam) {
|
||||||
|
DMLC_DECLARE_FIELD(func_name);
|
||||||
|
DMLC_DECLARE_FIELD(num_inputs)
|
||||||
|
.set_default(1);
|
||||||
|
DMLC_DECLARE_FIELD(num_outputs)
|
||||||
|
.set_default(1);
|
||||||
|
DMLC_DECLARE_FIELD(flatten_data)
|
||||||
|
.set_default(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace contrib
|
||||||
|
} // namespace tvm
|
||||||
|
|
||||||
|
#endif // TVM_GRAPH_EXECUTOR_H_
|
|
@ -0,0 +1,87 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2017 by Contributors
|
||||||
|
* \file graph_executor_ext.cc
|
||||||
|
*/
|
||||||
|
#include "./graph_executor.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace contrib {
|
||||||
|
|
||||||
|
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
|
||||||
|
uint64_t header = kTVMNDArrayMagic, reserved = 0;
|
||||||
|
strm->Write(&header, sizeof(header));
|
||||||
|
strm->Write(&reserved, sizeof(reserved));
|
||||||
|
|
||||||
|
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
|
||||||
|
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
|
||||||
|
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
|
||||||
|
|
||||||
|
int ndim = tensor->ndim;
|
||||||
|
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
|
||||||
|
|
||||||
|
int type_size = tensor->dtype.bits / 8;
|
||||||
|
int64_t size = 1;
|
||||||
|
for (int i = 0; i < ndim; ++i) {
|
||||||
|
size *= tensor->shape[i];
|
||||||
|
}
|
||||||
|
int64_t data_byte_size = type_size * size;
|
||||||
|
strm->Write(&data_byte_size, sizeof(data_byte_size));
|
||||||
|
strm->Write(tensor->data, data_byte_size);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_REGISTER_GLOBAL("tvm_graph._save_param_dict")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||||
|
std::string fname = args[0];
|
||||||
|
int num_params = args[1];
|
||||||
|
std::vector<std::string> names;
|
||||||
|
names.reserve(num_params);
|
||||||
|
std::vector<DLTensor*> arrays;
|
||||||
|
arrays.reserve(num_params);
|
||||||
|
for (int i = 2; i < (2 + 2*num_params); i += 2) {
|
||||||
|
names.emplace_back(args[i].operator std::string());
|
||||||
|
arrays.emplace_back(args[i+1].operator DLTensor*());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||||
|
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
|
||||||
|
fo->Write(&header, sizeof(header));
|
||||||
|
fo->Write(&reserved, sizeof(reserved));
|
||||||
|
|
||||||
|
fo->Write(names);
|
||||||
|
{
|
||||||
|
uint64_t sz = static_cast<uint64_t>(arrays.size());
|
||||||
|
fo->Write(&sz, sizeof(sz));
|
||||||
|
for (size_t i = 0; i < sz; ++i) {
|
||||||
|
SaveDLTensor(fo.get(), arrays[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create executor
|
||||||
|
tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) {
|
||||||
|
std::shared_ptr<GraphExecutor> exec =
|
||||||
|
std::make_shared<GraphExecutor>();
|
||||||
|
exec->Init(g, ctx);
|
||||||
|
return tvm::runtime::Module(exec);
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||||
|
void* graph_handle = args[0];
|
||||||
|
int device_type = args[1];
|
||||||
|
int device_id = args[2];
|
||||||
|
TVMContext ctx{static_cast<DLDeviceType>(device_type), device_id};
|
||||||
|
nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
|
||||||
|
*rv = CreateExecutor(g, ctx);
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
TVM_REGISTER_GLOBAL("tvm_graph._get_module_from_graph")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||||
|
void* graph_handle = args[0];
|
||||||
|
nnvm::Graph* g = static_cast<nnvm::Graph*>(graph_handle);
|
||||||
|
*rv = g->MoveCopyAttr<tvm::runtime::Module>("module");
|
||||||
|
});
|
||||||
|
} // namespace contrib
|
||||||
|
} // namespace tvm
|
|
@ -17,8 +17,8 @@ def test_compile():
|
||||||
na = tvm.nd.array(np.ones(shape).astype(dtype))
|
na = tvm.nd.array(np.ones(shape).astype(dtype))
|
||||||
nb = tvm.nd.array(np.ones(shape).astype(dtype))
|
nb = tvm.nd.array(np.ones(shape).astype(dtype))
|
||||||
# set inputs
|
# set inputs
|
||||||
set_input(0, na)
|
set_input('x', na)
|
||||||
set_input(1, nb)
|
set_input('y', nb)
|
||||||
# execute
|
# execute
|
||||||
run()
|
run()
|
||||||
# get outputs
|
# get outputs
|
||||||
|
|
Загрузка…
Ссылка в новой задаче