upgrade linux build from c++11 to c++14, enable gsl, update with latest LotusIR

This commit is contained in:
liqfu 2018-07-09 13:29:47 -07:00
Родитель 55b4606b23
Коммит 3c87d2012c
22 изменённых файлов: 154 добавлений и 203 удалений

Просмотреть файл

@ -102,9 +102,9 @@ INCLUDEPATH+=$(GSL_PATH)/include
INCLUDEPATH+=$(ONNX_PATH) INCLUDEPATH+=$(ONNX_PATH)
INCLUDEPATH+=$(ONNX_REPO_PATH) INCLUDEPATH+=$(ONNX_REPO_PATH)
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers. # COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__ COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++14 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__
CPPFLAGS:= CPPFLAGS:=
CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -fopenmp -fpermissive -fPIC -Werror -fcheck-new
LIBPATH:= LIBPATH:=
LIBS_LIST:= LIBS_LIST:=
LDFLAGS:= LDFLAGS:=

Просмотреть файл

@ -178,7 +178,6 @@
<ClInclude Include="proto\onnx\CNTKToONNX.h" /> <ClInclude Include="proto\onnx\CNTKToONNX.h" />
<ClInclude Include="proto\onnx\core\common\code_location.h" /> <ClInclude Include="proto\onnx\core\common\code_location.h" />
<ClInclude Include="proto\onnx\core\common\common.h" /> <ClInclude Include="proto\onnx\core\common\common.h" />
<ClInclude Include="proto\onnx\core\common\CommonSTD.h" />
<ClInclude Include="proto\onnx\core\common\const_pointer_container.h" /> <ClInclude Include="proto\onnx\core\common\const_pointer_container.h" />
<ClInclude Include="proto\onnx\core\common\exceptions.h" /> <ClInclude Include="proto\onnx\core\common\exceptions.h" />
<ClInclude Include="proto\onnx\core\common\logging\capture.h" /> <ClInclude Include="proto\onnx\core\common\logging\capture.h" />

Просмотреть файл

@ -93,18 +93,6 @@
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\traditionalml\defs.cc"> <ClCompile Include="proto\onnx\onnx_repo\onnx\defs\traditionalml\defs.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs\traditionalml</Filter> <Filter>proto\onnx\onnx_repo\onnx\defs\traditionalml</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\traditionalml\defs.cpp">
<Filter>proto\onnx\onnx\defs\traditionalml</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\data_type_utils.cpp">
<Filter>proto\onnx\onnx\defs</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\schema.cpp">
<Filter>proto\onnx\onnx\defs</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\core\graph\graph.cpp">
<Filter>proto\onnx\core\graph</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\core\graph\model.cc"> <ClCompile Include="proto\onnx\core\graph\model.cc">
<Filter>proto\onnx\core\graph</Filter> <Filter>proto\onnx\core\graph</Filter>
</ClCompile> </ClCompile>
@ -117,26 +105,37 @@
<ClCompile Include="proto\onnx\core\graph\graph_transformer.cc"> <ClCompile Include="proto\onnx\core\graph\graph_transformer.cc">
<Filter>proto\onnx\core\graph</Filter> <Filter>proto\onnx\core\graph</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\math\old.cpp"> <ClCompile Include="proto\onnx\onnx-ml.pb.cc.VS_wrapper.cpp" />
<Filter>proto\onnx\onnx\defs\math</Filter> <ClCompile Include="proto\onnx\onnx-operators-ml.pb.cc.VS_wrapper.cpp" />
<ClCompile Include="proto\onnx\core\graph\graph.cc">
<Filter>proto\onnx\core\graph</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\nn\old.cpp"> <ClCompile Include="proto\onnx\onnx_repo\onnx\defs\logical\old.cc">
<Filter>proto\onnx\onnx\defs\nn</Filter> <Filter>proto\onnx\onnx_repo\onnx\defs\logical</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\rnn\old.cpp"> <ClCompile Include="proto\onnx\onnx_repo\onnx\defs\tensor\old.cc">
<Filter>proto\onnx\onnx\defs\rnn</Filter> <Filter>proto\onnx\onnx_repo\onnx\defs\tensor</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\tensor\old.cpp"> <ClCompile Include="proto\onnx\onnx_repo\onnx\defs\math\old.cc">
<Filter>proto\onnx\onnx\defs\tensor</Filter> <Filter>proto\onnx\onnx_repo\onnx\defs\math</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\checker.cpp"> <ClCompile Include="proto\onnx\onnx_repo\onnx\defs\rnn\old.cc">
<Filter>proto\onnx\onnx</Filter> <Filter>proto\onnx\onnx_repo\onnx\defs\rnn</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\logical\old.cpp"> <ClCompile Include="proto\onnx\onnx_repo\onnx\defs\nn\old.cc">
<Filter>proto\onnx\onnx\defs\logical</Filter> <Filter>proto\onnx\onnx_repo\onnx\defs\nn</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="proto\onnx\onnx\common\assertions.cpp"> <ClCompile Include="proto\onnx\onnx_repo\onnx\common\assertions.cc">
<Filter>proto\onnx\onnx\common</Filter> <Filter>proto\onnx\onnx_repo\onnx\common</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx_repo\onnx\checker.cc">
<Filter>proto\onnx\onnx_repo\onnx</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\data_type_utils.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\schema.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs</Filter>
</ClCompile> </ClCompile>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
@ -295,9 +294,6 @@
<ClInclude Include="proto\onnx\core\graph\utils.h"> <ClInclude Include="proto\onnx\core\graph\utils.h">
<Filter>proto\onnx\core\graph</Filter> <Filter>proto\onnx\core\graph</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="proto\onnx\core\common\CommonSTD.h">
<Filter>proto\onnx\core\common</Filter>
</ClInclude>
<ClInclude Include="proto\onnx\onnx_repo\onnx\common\assertions.h"> <ClInclude Include="proto\onnx\onnx_repo\onnx\common\assertions.h">
<Filter>proto\onnx\onnx_repo\onnx\common</Filter> <Filter>proto\onnx\onnx_repo\onnx\common</Filter>
</ClInclude> </ClInclude>

Просмотреть файл

@ -23,6 +23,7 @@ using namespace Microsoft::MSR::CNTK;
using namespace CNTK::ONNX; using namespace CNTK::ONNX;
using namespace CNTK; using namespace CNTK;
using namespace LotusIR; using namespace LotusIR;
using namespace onnx;
const int FreeSequenceLen = 0; const int FreeSequenceLen = 0;
const std::string FreeSequenceDimParam = "None"; const std::string FreeSequenceDimParam = "None";

Просмотреть файл

@ -16,6 +16,7 @@
using namespace LotusIR; using namespace LotusIR;
using namespace CNTK; using namespace CNTK;
using namespace CNTK::ONNX; using namespace CNTK::ONNX;
using namespace onnx;
using namespace Microsoft::MSR::CNTK; using namespace Microsoft::MSR::CNTK;
namespace CNTK namespace CNTK

Просмотреть файл

@ -1,32 +0,0 @@
#pragma once
#include <memory>
#include <type_traits>
// to get make_unique definition
#include "Platform.h"
// to add what is missing in gsl
#if defined(__GNUC__)
namespace std {
template<bool _Test,
class _Ty = void>
using enable_if_t = typename enable_if<_Test, _Ty>::type;
template<class _Ty>
using remove_cv_t = typename remove_cv<_Ty>::type;
template<bool _Test,
class _Ty1,
class _Ty2>
using conditional_t = typename conditional<_Test, _Ty1, _Ty2>::type;
template<class _Ty>
using add_pointer_t = typename add_pointer<_Ty>::type;
template<class _Ty>
using remove_const_t = typename remove_const<_Ty>::type;
}
#else
using std::make_unique;
#endif

Просмотреть файл

@ -33,7 +33,6 @@
#include <vector> #include <vector>
#include <chrono> #include <chrono>
#include "core/common/CommonSTD.h"
#include "core/common/code_location.h" #include "core/common/code_location.h"
#include "core/common/exceptions.h" #include "core/common/exceptions.h"
#include "core/common/status.h" #include "core/common/status.h"
@ -62,6 +61,8 @@ using std::vector;
#define UNUSED_PARAMETER(x) #define UNUSED_PARAMETER(x)
#endif #endif
// std::vector<std::string> GetStackTrace();
// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER // __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER
// so we only define it as one for MSVC // so we only define it as one for MSVC
#if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) #if (_MSC_VER && !defined(__PRETTY_FUNCTION__))
@ -73,7 +74,7 @@ using std::vector;
Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__) Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__)
#define WHERE_WITH_STACK \ #define WHERE_WITH_STACK \
Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__) // Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__, Lotus::GetStackTrace()) Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__) // , Lotus::GetStackTrace())
// Throw an exception with optional message. // Throw an exception with optional message.
// NOTE: The arguments get streamed into a string via ostringstream::operator<< // NOTE: The arguments get streamed into a string via ostringstream::operator<<

Просмотреть файл

@ -1,7 +1,6 @@
#include "core/common/CommonSTD.h"
#include "core/common/logging/capture.h" #include "core/common/logging/capture.h"
#include "core/common/logging/logging.h" #include "core/common/logging/logging.h"
// #include "gsl/span" #include "gsl/span"
#include "gsl/gsl_util" #include "gsl/gsl_util"
namespace Lotus { namespace Lotus {
@ -21,27 +20,22 @@ void Capture::CapturePrintf(msvc_printf_check const char *format, ...) {
void Capture::ProcessPrintf(msvc_printf_check const char *format, va_list args) { void Capture::ProcessPrintf(msvc_printf_check const char *format, va_list args) {
static constexpr auto kTruncatedWarningText = "[...truncated...]"; static constexpr auto kTruncatedWarningText = "[...truncated...]";
static const int kMaxMessageSize = 2048; static const int kMaxMessageSize = 2048;
char finished_message[kMaxMessageSize]; char message_buffer[kMaxMessageSize];
const auto message = gsl::make_span(message_buffer);
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__)) #if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
const auto finished_message_len = _countof(finished_message); const int nbrcharacters = vsnprintf_s(message.data(), message.size(), _TRUNCATE, format, args);
#else #else
int finished_message_len = sizeof(finished_message); const int nbrcharacters = vsnprintf(message.data(), message.size(), format, args);
#endif
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
const int nbrcharacters = vsnprintf_s(finished_message, finished_message_len, _TRUNCATE, format, args);
#else
const int nbrcharacters = vsnprintf(finished_message, finished_message_len, format, args);
#endif #endif
if (nbrcharacters <= 0) { if (nbrcharacters <= 0) {
stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message"; stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message";
stream_ << '"' << format << '"' << std::endl; stream_ << '"' << format << '"' << std::endl;
} else if (static_cast<uint32_t>(nbrcharacters) > finished_message_len) { } else if (nbrcharacters > message.size()) {
stream_ << finished_message << kTruncatedWarningText; stream_ << message.data() << kTruncatedWarningText;
} else { } else {
stream_ << finished_message; stream_ << message.data();
} }
} }

Просмотреть файл

@ -1,6 +1,6 @@
#include <exception> #include <exception>
#include <ctime> #include <ctime>
#include "core/common/CommonSTD.h"
#include "core/common/exceptions.h" #include "core/common/exceptions.h"
#include "core/common/logging/isink.h" #include "core/common/logging/isink.h"
#include "core/common/logging/logging.h" #include "core/common/logging/logging.h"

Просмотреть файл

@ -1,4 +1,3 @@
#include "core/common/CommonSTD.h"
#include "core/common/status.h" #include "core/common/status.h"
#include "core/common/common.h" #include "core/common/common.h"

Просмотреть файл

@ -8,15 +8,14 @@
#include <numeric> #include <numeric>
#include <stack> #include <stack>
#include "core/common/CommonSTD.h" #include "gsl/pointers"
// #include "gsl/pointers"
#include "core/graph/graph.h" #include "core/graph/graph.h"
#include "core/graph/op.h" #include "core/graph/op.h"
#include "core/graph/utils.h" #include "core/graph/utils.h"
#include "core/common/logging/logging.h" #include "core/common/logging/logging.h"
#include "onnx/checker.h" #include "onnx/checker.h"
#include "core/graph/schema_registry.h" #include "core/graph/schema_registry.h"
using namespace onnx;
using namespace onnx::Utils; using namespace onnx::Utils;
using namespace onnx::checker; using namespace onnx::checker;
@ -172,7 +171,7 @@ void Node::ToProto(NodeProto& proto) const {
// Set attributes. // Set attributes.
proto.clear_attribute(); proto.clear_attribute();
for (auto attribute : attributes_) { for (auto attribute : attributes_) {
auto attr = proto.add_attribute(); const gsl::not_null<AttributeProto*> attr = proto.add_attribute();
*attr = attribute.second; *attr = attribute.second;
} }
@ -346,12 +345,12 @@ const NodeAttributes& Node::GetAttributes() const noexcept {
} }
void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const { void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const {
for (const LotusIR::NodeArg* arg : InputDefs()) { for (const gsl::not_null<const LotusIR::NodeArg*> arg : InputDefs()) {
if (!arg->Exists()) if (!arg->Exists())
continue; continue;
func(&*arg, true); func(&*arg, true);
} }
for (const LotusIR::NodeArg* arg : OutputDefs()) { for (const gsl::not_null<const LotusIR::NodeArg*> arg : OutputDefs()) {
if (!arg->Exists()) if (!arg->Exists())
continue; continue;
func(&*arg, false); func(&*arg, false);
@ -359,7 +358,7 @@ void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)
}; };
void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) const { void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
for (const LotusIR::NodeArg* arg : InputDefs()) { for (const gsl::not_null<const LotusIR::NodeArg*> arg : InputDefs()) {
if (!arg->Exists()) if (!arg->Exists())
continue; continue;
func(&*arg); func(&*arg);
@ -367,7 +366,7 @@ void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) co
}; };
void Node::ForEachOutputDef(std::function<void(const LotusIR::NodeArg*)> func) const { void Node::ForEachOutputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
for (const LotusIR::NodeArg* arg : OutputDefs()) { for (const gsl::not_null<const LotusIR::NodeArg*> arg : OutputDefs()) {
if (!arg->Exists()) if (!arg->Exists())
continue; continue;
func(&*arg); func(&*arg);
@ -378,7 +377,7 @@ void Node::ReplaceDefs(const std::map<const LotusIR::NodeArg*, LotusIR::NodeArg*
std::vector<std::vector<NodeArg*>*> all_defs = {&definitions_.input_defs, &definitions_.output_defs}; std::vector<std::vector<NodeArg*>*> all_defs = {&definitions_.input_defs, &definitions_.output_defs};
for (auto pair : replacements) for (auto pair : replacements)
for (auto defs : all_defs) for (const gsl::not_null<std::vector<LotusIR::NodeArg*>*> defs : all_defs)
for (auto& def : *defs) for (auto& def : *defs)
if (def == pair.first) if (def == pair.first)
def = pair.second; def = pair.second;
@ -418,14 +417,14 @@ Graph::Graph(GraphProto* graph_proto,
// Copy constant nodes _value to name_to_initial_tensor_ // Copy constant nodes _value to name_to_initial_tensor_
for (auto& node : graph_proto_->node()) { for (auto& node : graph_proto_->node()) {
if (node.op_type() == kConstant) { if (node.op_type() == kConstant) {
auto tensor = graph_proto_->add_initializer(); const gsl::not_null<TensorProto*> tensor = graph_proto_->add_initializer();
*tensor = node.attribute(0).t(); *tensor = node.attribute(0).t();
*(tensor->mutable_name()) = node.output(0); *(tensor->mutable_name()) = node.output(0);
} }
} }
// remove constant nodes // remove constant nodes
auto graph_mutable_nodes = graph_proto_->mutable_node(); const gsl::not_null<RepeatedPtrField<NodeProto>*> graph_mutable_nodes = graph_proto_->mutable_node();
graph_mutable_nodes->erase( graph_mutable_nodes->erase(
std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(), std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(),
[](NodeProto& p) { [](NodeProto& p) {
@ -493,7 +492,7 @@ Status GraphBase::VerifyNoDuplicateName(/*in*/ const std::unordered_set<std::str
node_name_to_index[node_name] = node.Index(); node_name_to_index[node_name] = node.Index();
// Verify node outputs' name should be unique. // Verify node outputs' name should be unique.
for (const NodeArg* output_def : node.OutputDefs()) { for (const gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
if (output_def->Exists()) { if (output_def->Exists()) {
auto& output_arg_name = output_def->Name(); auto& output_arg_name = output_def->Name();
if (inputs_and_initializers.count(output_arg_name)) { if (inputs_and_initializers.count(output_arg_name)) {
@ -546,7 +545,7 @@ Status GraphBase::BuildConnections(const std::unordered_map<std::string, Node*>&
if (input_args.size() > 0) { if (input_args.size() > 0) {
// This node needs inputs. // This node needs inputs.
for (const NodeArg* input_arg : input_args) { for (const gsl::not_null<const NodeArg*> input_arg : input_args) {
if (!input_arg->Exists()) { if (!input_arg->Exists()) {
// This input could be optional and it does not exist in this case. // This input could be optional and it does not exist in this case.
continue; continue;
@ -655,7 +654,7 @@ void GraphBase::ReverseDFSFrom(const std::vector<const Node*>& from,
sorted_nodes.push_back((*iter)); sorted_nodes.push_back((*iter));
} }
std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp); std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
for (const LotusIR::Node* in : sorted_nodes) { for (gsl::not_null<const LotusIR::Node*> in : sorted_nodes) {
const NodeIndex idx = in->Index(); const NodeIndex idx = in->Index();
if (!visited[idx]) { if (!visited[idx]) {
stack.emplace_back(in, false); stack.emplace_back(in, false);
@ -1131,7 +1130,7 @@ Status Graph::VerifyNodeAndOpMatch(const std::vector<NodeIndex>& nodes_in_topolo
// currently an Op is required by ValidateVersion, so we use gsl::not_null. // currently an Op is required by ValidateVersion, so we use gsl::not_null.
// This may change in the future to allow a null Op // This may change in the future to allow a null Op
const OpSchema* p_op = node.Op(); const gsl::not_null<const OpSchema*> p_op = node.Op();
// Attribute verification and fill node attribute with // Attribute verification and fill node attribute with
// default value defined in operator definition if needed. // default value defined in operator definition if needed.
@ -1218,7 +1217,7 @@ Status Graph::Resolve(bool no_proto_sync_required) {
return Status::OK(); return Status::OK();
} }
Status GraphBase::GetNodesInTopologicalOrder(const std::vector<NodeIndex>** pp_nodes) const { Status GraphBase::GetNodesInTopologicalOrder(gsl::not_null<const std::vector<NodeIndex>**> pp_nodes) const {
if (graph_resolve_needed_) { if (graph_resolve_needed_) {
return Status{StatusCategory::LOTUS, StatusCode::FAIL, return Status{StatusCategory::LOTUS, StatusCode::FAIL,
"Resolve() must be called before using the graph as modifications have been made to it."}; "Resolve() must be called before using the graph as modifications have been made to it."};
@ -1263,7 +1262,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) {
return; return;
} }
auto tensorAdded = graph_proto_->add_initializer(); const gsl::not_null<TensorProto*> tensorAdded = graph_proto_->add_initializer();
*(tensorAdded) = tensor; *(tensorAdded) = tensor;
name_to_initial_tensorIndex_[tensor.name()] = graph_proto_->initializer_size() - 1; name_to_initial_tensorIndex_[tensor.name()] = graph_proto_->initializer_size() - 1;
name_to_initial_tensor_[tensor.name()] = tensorAdded; name_to_initial_tensor_[tensor.name()] = tensorAdded;
@ -1283,7 +1282,7 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) {
} }
} }
bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const { bool Graph::GetInitializedTensor(const std::string& tensor_name, gsl::not_null<const TensorProto**> value) const {
auto iter = name_to_initial_tensor_.find(tensor_name); auto iter = name_to_initial_tensor_.find(tensor_name);
if (name_to_initial_tensor_.end() == iter) { if (name_to_initial_tensor_.end() == iter) {
return false; return false;
@ -1318,7 +1317,7 @@ const std::vector<const NodeArg*>& Graph::GetValueInfo() const noexcept {
// Ensure the NodeArgs in the input are created and in this Graph's node arg map // Ensure the NodeArgs in the input are created and in this Graph's node arg map
static void AddNodeArgs(const std::vector<NodeArg*>& input_args, static void AddNodeArgs(const std::vector<NodeArg*>& input_args,
std::unordered_map<std::string, NodeArg*>& node_arg_map) { std::unordered_map<std::string, NodeArg*>& node_arg_map) {
for (auto input_arg : input_args) { for (const gsl::not_null<NodeArg*> input_arg : input_args) {
if (!input_arg->Exists()) continue; if (!input_arg->Exists()) continue;
auto& key = input_arg->Name(); auto& key = input_arg->Name();
auto existing_entry = node_arg_map.find(key); auto existing_entry = node_arg_map.find(key);
@ -1383,7 +1382,7 @@ Node* GraphBase::AddNode(const Node& other) {
Node* GraphBase::AddNode(const NodeProto& node_proto, Node* GraphBase::AddNode(const NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type_map) { const ArgNameToTypeMap& name_to_type_map) {
auto node = AllocateNode(); const gsl::not_null<Node*> node = AllocateNode();
auto input_defs = CreateNodeArgs(node_proto.input(), name_to_type_map, node_args_, owned_node_args_); auto input_defs = CreateNodeArgs(node_proto.input(), name_to_type_map, node_args_, owned_node_args_);
auto output_defs = CreateNodeArgs(node_proto.output(), name_to_type_map, node_args_, owned_node_args_); auto output_defs = CreateNodeArgs(node_proto.output(), name_to_type_map, node_args_, owned_node_args_);
@ -1460,7 +1459,7 @@ Node* GraphBase::AddNode(const std::string& name,
AddNodeArgs(input_args, node_args_); AddNodeArgs(input_args, node_args_);
AddNodeArgs(output_args, node_args_); AddNodeArgs(output_args, node_args_);
auto node = AllocateNode(); const gsl::not_null<Node*> node = AllocateNode();
node->Init(name, op_type, description, input_args, output_args, attributes, domain); node->Init(name, op_type, description, input_args, output_args, attributes, domain);
if (0 != op_type.compare(kNoOp)) { if (0 != op_type.compare(kNoOp)) {
graph_proto_sync_needed_ = true; graph_proto_sync_needed_ = true;
@ -1505,8 +1504,8 @@ const GraphProto& Graph::ToGraphProto() {
continue; continue;
} }
auto node_proto = graph_proto_->add_node(); const gsl::not_null<NodeProto*> node_proto = graph_proto_->add_node();
auto p_node = GetNode(node_idx); const gsl::not_null<Node*> p_node = GetNode(node_idx);
p_node->ToProto(*node_proto); p_node->ToProto(*node_proto);
} }
@ -1551,15 +1550,15 @@ void Graph::SyncGraphInputsOutputs() {
graph_proto_->clear_output(); graph_proto_->clear_output();
graph_proto_->clear_value_info(); graph_proto_->clear_value_info();
for (const LotusIR::NodeArg* input_arg : GetInputs()) { for (const gsl::not_null<const LotusIR::NodeArg*> input_arg : GetInputs()) {
*(graph_proto_->mutable_input()->Add()) = input_arg->ToProto(); *(graph_proto_->mutable_input()->Add()) = input_arg->ToProto();
} }
for (const LotusIR::NodeArg* output_arg : GetOutputs()) { for (const gsl::not_null<const LotusIR::NodeArg*> output_arg : GetOutputs()) {
*(graph_proto_->mutable_output()->Add()) = output_arg->ToProto(); *(graph_proto_->mutable_output()->Add()) = output_arg->ToProto();
} }
for (const LotusIR::NodeArg* value_info : value_info_) { for (const gsl::not_null<const LotusIR::NodeArg*> value_info : value_info_) {
*(graph_proto_->mutable_value_info()->Add()) = value_info->ToProto(); *(graph_proto_->mutable_value_info()->Add()) = value_info->ToProto();
} }
} }
@ -1573,11 +1572,11 @@ void Graph::CleanUnusedInitializers() {
for (const auto& pv : name_to_initial_tensor_) { for (const auto& pv : name_to_initial_tensor_) {
const std::string& s = pv.first; const std::string& s = pv.first;
const bool used_as_input = std::any_of(input_args.begin(), input_args.end(), const bool used_as_input = std::any_of(input_args.begin(), input_args.end(),
[&s](const NodeArg* input) noexcept { [&s](const gsl::not_null<const NodeArg*> input) noexcept {
return s == input->Name(); return s == input->Name();
}); });
const bool used_as_output = std::any_of(GetOutputs().begin(), GetOutputs().end(), const bool used_as_output = std::any_of(GetOutputs().begin(), GetOutputs().end(),
[&s](const NodeArg* output) noexcept { [&s](const gsl::not_null<const NodeArg*> output) noexcept {
return s == output->Name(); return s == output->Name();
}); });
@ -1639,7 +1638,7 @@ Status Graph::SetGraphInputsOutputs() {
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg; std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
for (const auto& node : Nodes()) { for (const auto& node : Nodes()) {
for (const NodeArg* output_def : node.OutputDefs()) { for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
if (specified_graph_outputs.erase(output_def->Name()) >= 1) { if (specified_graph_outputs.erase(output_def->Name()) >= 1) {
graph_outputs.push_back(output_def); graph_outputs.push_back(output_def);
} }
@ -1664,7 +1663,7 @@ Status Graph::SetGraphInputsOutputs() {
for (const auto& node : Nodes()) { for (const auto& node : Nodes()) {
// Go thru all node's inputs. // Go thru all node's inputs.
for (const NodeArg* input_arg : node.InputDefs()) { for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
if (!input_arg->Exists()) { if (!input_arg->Exists()) {
// It's an optional input and does not exist in this case. // It's an optional input and does not exist in this case.
continue; continue;
@ -1694,7 +1693,7 @@ Status Graph::SetGraphInputsOutputs() {
} else { } else {
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg; std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
for (const auto& node : Nodes()) { for (const auto& node : Nodes()) {
for (const NodeArg* output_def : node.OutputDefs()) { for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
if (output_def->Exists()) if (output_def->Exists())
output_name_to_node_arg.insert({output_def->Name(), output_def}); output_name_to_node_arg.insert({output_def->Name(), output_def});
} }
@ -1706,7 +1705,7 @@ Status Graph::SetGraphInputsOutputs() {
std::unordered_set<Node*> inner_nodes; std::unordered_set<Node*> inner_nodes;
for (const auto& node : Nodes()) { for (const auto& node : Nodes()) {
// Go thru all node's inputs. // Go thru all node's inputs.
for (const NodeArg* input_arg : node.InputDefs()) { for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
if (!input_arg->Exists()) { if (!input_arg->Exists()) {
// It's an optional input and does not exist in this case. // It's an optional input and does not exist in this case.
continue; continue;
@ -1719,8 +1718,8 @@ Status Graph::SetGraphInputsOutputs() {
const std::string& name = input_arg->Name(); const std::string& name = input_arg->Name();
if (added_input_names.end() == added_input_names.find(name)) { if (added_input_names.end() == added_input_names.find(name)) {
// This graph input has not been added into <graph_inputs_>. // This graph input has not been added into <graph_inputs_>.
// if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) //// if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end())
graph_inputs.push_back(input_arg); graph_inputs.push_back(input_arg);
added_input_names.insert(input_arg->Name()); added_input_names.insert(input_arg->Name());
} }
} else if (graph_output_args.erase(output_arg_iter->first) >= 1) { } else if (graph_output_args.erase(output_arg_iter->first) >= 1) {
@ -1759,7 +1758,7 @@ const Node* GraphBase::SinkNode() const {
// calling private ctor // calling private ctor
GSL_SUPPRESS(r .11) GSL_SUPPRESS(r .11)
Node* GraphBase::AllocateNode() { gsl::not_null<Node*> GraphBase::AllocateNode() {
std::unique_ptr<Node> new_node(new Node(nodes_.size(), *this)); std::unique_ptr<Node> new_node(new Node(nodes_.size(), *this));
Node* node{new_node.get()}; Node* node{new_node.get()};

Просмотреть файл

@ -7,9 +7,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "core/common/CommonSTD.h" #include "gsl/pointers"
// #include "gsl/pointers"
#include "gsl/gsl_util" #include "gsl/gsl_util"
#include "core/common/common.h" #include "core/common/common.h"
@ -21,21 +19,19 @@
#include "core/graph/utils.h" #include "core/graph/utils.h"
#include "onnx/onnx_pb.h" #include "onnx/onnx_pb.h"
using namespace onnx;
// TODO - Evaluate switching the types below to support transparent comparators and enable // TODO - Evaluate switching the types below to support transparent comparators and enable
// lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations // lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations
// converting to std::string, but requires conversion to std::map<std::string, foo, std::less<>> // converting to std::string, but requires conversion to std::map<std::string, foo, std::less<>>
// instead of std::unordered_map<std::string, foo, [std::less<foo>]>. // instead of std::unordered_map<std::string, foo, [std::less<foo>]>.
using NodeAttributes = std::unordered_map<std::string, AttributeProto>; using NodeAttributes = std::unordered_map<std::string, onnx::AttributeProto>;
namespace LotusIR { namespace LotusIR {
using NodeIndex = size_t; using NodeIndex = size_t;
using Version = int64_t; using Version = int64_t;
using NodeArgInfo = ValueInfoProto; using NodeArgInfo = onnx::ValueInfoProto;
using InitializedTensorSet = std::unordered_map<std::string, const TensorProto*>; using InitializedTensorSet = std::unordered_map<std::string, const onnx::TensorProto*>;
using ArgNameToTypeMap = std::unordered_map<std::string, TypeProto>; using ArgNameToTypeMap = std::unordered_map<std::string, onnx::TypeProto>;
using ProviderType = const std::string&; using ProviderType = const std::string&;
class Graph; class Graph;
@ -70,7 +66,7 @@ class NodeArg {
// optional. This is called when loading a <Graph> from <GraphProto> // optional. This is called when loading a <Graph> from <GraphProto>
// normally. // normally.
NodeArg(const std::string& name, NodeArg(const std::string& name,
const TypeProto* p_arg_type); const onnx::TypeProto* p_arg_type);
NodeArg(NodeArg&& other) = default; NodeArg(NodeArg&& other) = default;
@ -78,17 +74,17 @@ class NodeArg {
const std::string& Name() const noexcept; const std::string& Name() const noexcept;
// Get node arg type. // Get node arg type.
DataType Type() const noexcept; onnx::DataType Type() const noexcept;
const TypeProto* TypeAsProto() const noexcept; const onnx::TypeProto* TypeAsProto() const noexcept;
// Get node arg shape. // Get node arg shape.
// Return null pointer if there's no shape specified. // Return null pointer if there's no shape specified.
const TensorShapeProto* Shape() const; const onnx::TensorShapeProto* Shape() const;
// Set node arg shape. // Set node arg shape.
// Shape could only be set after setting type since shape information // Shape could only be set after setting type since shape information
// now is part of TypeProto. // now is part of TypeProto.
void SetShape(const TensorShapeProto& shape); void SetShape(const onnx::TensorShapeProto& shape);
// Get node arg info proto. // Get node arg info proto.
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
@ -102,13 +98,13 @@ class NodeArg {
LOTUS_DISALLOW_COPY_AND_ASSIGN(NodeArg); LOTUS_DISALLOW_COPY_AND_ASSIGN(NodeArg);
friend class Graph; friend class Graph;
void SetType(DataType p_type); void SetType(onnx::DataType p_type);
void SetType(const TypeProto& type_proto); void SetType(const onnx::TypeProto& type_proto);
NodeArg& operator=(NodeArg&& other) = delete; NodeArg& operator=(NodeArg&& other) = delete;
// Node arg PType. // Node arg PType.
DataType type_; onnx::DataType type_;
// Node arg name, type and shape. // Node arg name, type and shape.
NodeArgInfo node_arg_info_; NodeArgInfo node_arg_info_;
@ -159,7 +155,7 @@ class Node {
// Get the OperatorSchema this node refers to. ValidateOpType() must have been called previously. // Get the OperatorSchema this node refers to. ValidateOpType() must have been called previously.
// May be null in the future. // May be null in the future.
const OpSchema* Op() const noexcept; const onnx::OpSchema* Op() const noexcept;
// Get node description. // Get node description.
const std::string& Description() const noexcept; const std::string& Description() const noexcept;
@ -167,8 +163,8 @@ class Node {
// Iterate through Input/OutputDefs() with index, note the loop early terminates with error // Iterate through Input/OutputDefs() with index, note the loop early terminates with error
static Lotus::Common::Status ForEachWithIndex( static Lotus::Common::Status ForEachWithIndex(
const ConstPointerContainer<std::vector<NodeArg*>>& nodeArgVec, const ConstPointerContainer<std::vector<NodeArg*>>& nodeArgVec,
std::function<Lotus::Common::Status(const NodeArg& arg, int index)> func) { std::function<Lotus::Common::Status(const NodeArg& arg, size_t index)> func) {
for (int index = 0; index < nodeArgVec.size(); ++index) { for (size_t index = 0; index < nodeArgVec.size(); ++index) {
auto arg = nodeArgVec[index]; auto arg = nodeArgVec[index];
if (!arg->Exists()) if (!arg->Exists())
continue; continue;
@ -207,7 +203,7 @@ class Node {
const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; } const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; }
// Add a node attribute with specified attribute name and value. // Add a node attribute with specified attribute name and value.
void AddAttribute(const std::string& attr_name, const AttributeProto& value); void AddAttribute(const std::string& attr_name, const onnx::AttributeProto& value);
#define ADD_ATTR_INTERFACES(TypeName) \ #define ADD_ATTR_INTERFACES(TypeName) \
void AddAttribute(const std::string& attr_name, \ void AddAttribute(const std::string& attr_name, \
@ -218,8 +214,8 @@ class Node {
ADD_ATTR_INTERFACES(int64_t) ADD_ATTR_INTERFACES(int64_t)
ADD_ATTR_INTERFACES(float) ADD_ATTR_INTERFACES(float)
ADD_ATTR_INTERFACES(std::string) ADD_ATTR_INTERFACES(std::string)
ADD_ATTR_INTERFACES(TensorProto) ADD_ATTR_INTERFACES(onnx::TensorProto)
ADD_ATTR_INTERFACES(GraphProto) ADD_ATTR_INTERFACES(onnx::GraphProto)
// Clear specified node attribute. // Clear specified node attribute.
bool ClearAttribute(const std::string& attr_name); bool ClearAttribute(const std::string& attr_name);
@ -235,7 +231,7 @@ class Node {
void SetExecutionProviderType(ProviderType execution_provider_type); void SetExecutionProviderType(ProviderType execution_provider_type);
// Get the corresponding <NodeProto>. // Get the corresponding <NodeProto>.
void ToProto(NodeProto& proto) const; void ToProto(onnx::NodeProto& proto) const;
// iterate through all input/output defs // iterate through all input/output defs
void ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const; void ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const;
@ -355,7 +351,7 @@ class Node {
std::string domain_; std::string domain_;
// OperatorSchema that <*this> node refers to. // OperatorSchema that <*this> node refers to.
const OpSchema* op_ = nullptr; const onnx::OpSchema* op_ = nullptr;
// Node doc string. // Node doc string.
std::string description_; std::string description_;
@ -428,7 +424,7 @@ class GraphBase {
int NumberOfNodes() const noexcept { return num_of_nodes_; } int NumberOfNodes() const noexcept { return num_of_nodes_; }
// Get NodeArg by name, or create NodeArg owned by the graph if not found // Get NodeArg by name, or create NodeArg owned by the graph if not found
NodeArg& GetOrCreateNodeArg(const std::string& name, const TypeProto* p_arg_type) { NodeArg& GetOrCreateNodeArg(const std::string& name, const onnx::TypeProto* p_arg_type) {
auto iter = node_args_.find(name); auto iter = node_args_.find(name);
if (iter != node_args_.end()) if (iter != node_args_.end())
return *(iter->second); return *(iter->second);
@ -492,7 +488,7 @@ class GraphBase {
// TODO(Task:135) See if GraphBase::GetNodesInTopologicalOrder can be made more correctly const // TODO(Task:135) See if GraphBase::GetNodesInTopologicalOrder can be made more correctly const
// by forcing Resolve to have been called directly previously. Simple change is to return error if // by forcing Resolve to have been called directly previously. Simple change is to return error if
// GraphResolveNeeded is true. // GraphResolveNeeded is true.
Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ const std::vector<NodeIndex>** pp_nodes) const; Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ gsl::not_null<const std::vector<NodeIndex>**> pp_nodes) const;
// Mark Graph as needing Resolve() to be called // Mark Graph as needing Resolve() to be called
GraphBase& SetGraphResolveNeeded() noexcept { GraphBase& SetGraphResolveNeeded() noexcept {
@ -545,7 +541,7 @@ class GraphBase {
void AddSourceSinkNodes(); void AddSourceSinkNodes();
// Add node with specified <node_proto>. // Add node with specified <node_proto>.
Node* AddNode(const NodeProto& node_proto, Node* AddNode(const onnx::NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type); const ArgNameToTypeMap& name_to_type);
NodeIndex SourceNodeIndex() const noexcept { return source_node_index_; } NodeIndex SourceNodeIndex() const noexcept { return source_node_index_; }
@ -613,13 +609,13 @@ class GraphBase {
// Returns the inferred shape+type for every output of the node in // Returns the inferred shape+type for every output of the node in
// output parameter inferredShapes. // output parameter inferredShapes.
Lotus::Common::Status InferOutputTypesAndShapes(LotusIR::Node& node, Lotus::Common::Status InferOutputTypesAndShapes(LotusIR::Node& node,
/*out*/ std::vector<TypeProto>& inferred_shapes); /*out*/ std::vector<onnx::TypeProto>& inferred_shapes);
private: private:
// need custom versions to handle the unique_ptr's in nodes_ // need custom versions to handle the unique_ptr's in nodes_
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(GraphBase); LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(GraphBase);
Node* AllocateNode(); gsl::not_null<Node*> AllocateNode();
/** /**
Release the node. Release the node.
@ -709,9 +705,9 @@ class Graph : public GraphBase {
void SetDescription(const std::string& description) override; void SetDescription(const std::string& description) override;
// Add/Remove/Get initial tensors for some graph inputs. // Add/Remove/Get initial tensors for some graph inputs.
void AddInitializedTensor(const TensorProto& tensor_proto); void AddInitializedTensor(const onnx::TensorProto& tensor_proto);
void RemoveInitializedTensor(const std::string& tensor_name); void RemoveInitializedTensor(const std::string& tensor_name);
bool GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const; bool GetInitializedTensor(const std::string& tensor_name, gsl::not_null<const onnx::TensorProto**> value) const;
const InitializedTensorSet& GetAllInitializedTensors() const noexcept; const InitializedTensorSet& GetAllInitializedTensors() const noexcept;
void CleanAllInitializedTensors() noexcept; void CleanAllInitializedTensors() noexcept;
@ -719,7 +715,7 @@ class Graph : public GraphBase {
const std::vector<const NodeArg*>& GetValueInfo() const noexcept; const std::vector<const NodeArg*>& GetValueInfo() const noexcept;
// Serialize the <Graph> into <GraphProto>. // Serialize the <Graph> into <GraphProto>.
const GraphProto& ToGraphProto(); const onnx::GraphProto& ToGraphProto();
private: private:
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(Graph); LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(Graph);
@ -730,7 +726,7 @@ class Graph : public GraphBase {
// Constructor: Given a <GraphProto> loaded from model file, construct // Constructor: Given a <GraphProto> loaded from model file, construct
// a <Graph> object. // a <Graph> object.
Graph(GraphProto* graph_proto, Graph(onnx::GraphProto* graph_proto,
const std::unordered_map<std::string, int>& domain_to_version, const std::unordered_map<std::string, int>& domain_to_version,
Version ir_version, Version ir_version,
const ILotusOpSchemaCollection* local_registry = nullptr); const ILotusOpSchemaCollection* local_registry = nullptr);
@ -754,7 +750,7 @@ class Graph : public GraphBase {
Lotus::Common::Status Resolve(bool no_proto_sync_required); Lotus::Common::Status Resolve(bool no_proto_sync_required);
Lotus::Common::Status InferAndVerifyTypeMatch(Node& node, Lotus::Common::Status InferAndVerifyTypeMatch(Node& node,
const OpSchema& op); const onnx::OpSchema& op);
// Apply type-inference and type-checking to all inputs and initializers: // Apply type-inference and type-checking to all inputs and initializers:
Lotus::Common::Status TypeCheckInputsAndInitializers(); Lotus::Common::Status TypeCheckInputsAndInitializers();
@ -783,7 +779,7 @@ class Graph : public GraphBase {
// functions in <Graph> will also be fed into <graph_proto_> so that // functions in <Graph> will also be fed into <graph_proto_> so that
// it's consistent with <*this> graph. // it's consistent with <*this> graph.
// This pointer is owned by parent model. // This pointer is owned by parent model.
GraphProto* graph_proto_; onnx::GraphProto* graph_proto_;
// The node which refers to <*this> graph (Function). // The node which refers to <*this> graph (Function).
// Node* node_; // Node* node_;

Просмотреть файл

@ -1,4 +1,3 @@
#include "core/common/CommonSTD.h"
#include "core/graph/graph_transformer.h" #include "core/graph/graph_transformer.h"
using namespace Lotus; using namespace Lotus;
using namespace Lotus::Common; using namespace Lotus::Common;

Просмотреть файл

@ -53,7 +53,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
// should be stored globally. Otherwise, there will be multiple addresses/pointers // should be stored globally. Otherwise, there will be multiple addresses/pointers
// for the same operator or function. To avoid this, we may use OpSignature ID // for the same operator or function. To avoid this, we may use OpSignature ID
// as the key, which should be name_domain_version. // as the key, which should be name_domain_version.
Lotus::Common::Status Register(const OpSchema* op, std::unique_ptr<RewriteRule> rule) { Lotus::Common::Status Register(const onnx::OpSchema* op, std::unique_ptr<RewriteRule> rule) {
op_to_rules_[op].push_back(std::move(rule)); op_to_rules_[op].push_back(std::move(rule));
return Lotus::Common::Status::OK(); return Lotus::Common::Status::OK();
} }
@ -66,7 +66,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
} }
private: private:
typedef std::unordered_map<const OpSchema*, std::vector<std::unique_ptr<RewriteRule>>> typedef std::unordered_map<const onnx::OpSchema*, std::vector<std::unique_ptr<RewriteRule>>>
RewriteRuleSet; RewriteRuleSet;
RewriteRuleSet op_to_rules_; RewriteRuleSet op_to_rules_;

Просмотреть файл

@ -7,16 +7,15 @@
#ifdef _MSC_VER #ifdef _MSC_VER
#pragma warning(pop) #pragma warning(pop)
#endif #endif
#include "core/common/CommonSTD.h"
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
#include <memory> #include <memory>
#include "core/graph/model.h" #include "core/graph/model.h"
#include "core/graph/utils.h" #include "core/graph/utils.h"
#include "core/graph/schema_registry.h" #include "core/graph/schema_registry.h"
// #include "gsl/pointers" #include "gsl/pointers"
#include "gsl/gsl_util" #include "gsl/gsl_util"
using namespace onnx;
using namespace Lotus; using namespace Lotus;
using namespace Lotus::Common; using namespace Lotus::Common;
@ -27,7 +26,7 @@ Model::Model(const std::string& graph_name, bool is_onnx_domain_only, const Mode
model_proto_->mutable_graph()->set_name(graph_name); model_proto_->mutable_graph()->set_name(graph_name);
model_metadata_ = model_metadata; model_metadata_ = model_metadata;
for (auto& metadata : model_metadata_) { for (auto& metadata : model_metadata_) {
auto prop = model_proto_->add_metadata_props(); const gsl::not_null<StringStringEntryProto*> prop = model_proto_->add_metadata_props();
prop->set_key(metadata.first); prop->set_key(metadata.first);
prop->set_value(metadata.second); prop->set_value(metadata.second);
} }
@ -138,7 +137,7 @@ ModelProto Model::ToProto() {
} }
void Model::AddImportOpSets(bool is_onnx_domain_only, void Model::AddImportOpSets(bool is_onnx_domain_only,
/*out*/ std::unordered_map<std::string, int>* domain_to_version, /*out*/ gsl::not_null<std::unordered_map<std::string, int>*> domain_to_version,
const ILotusOpSchemaCollection* local_registry) { const ILotusOpSchemaCollection* local_registry) {
auto& domain_to_version_range_map = OpSchemaRegistry::DomainToVersionRange::Instance().Map(); auto& domain_to_version_range_map = OpSchemaRegistry::DomainToVersionRange::Instance().Map();
Domain_To_Version_Map local_domain_to_version_map = local_registry ? local_registry->DomainToVersionMap() : Domain_To_Version_Map(); Domain_To_Version_Map local_domain_to_version_map = local_registry ? local_registry->DomainToVersionMap() : Domain_To_Version_Map();
@ -158,7 +157,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only,
} }
auto ignored = domain_to_version->insert({domainToVersionRange.first, max}); auto ignored = domain_to_version->insert({domainToVersionRange.first, max});
auto opset_id_proto = model_proto_->add_opset_import(); const gsl::not_null<OperatorSetIdProto*> opset_id_proto = model_proto_->add_opset_import();
opset_id_proto->set_domain(domainToVersionRange.first); opset_id_proto->set_domain(domainToVersionRange.first);
opset_id_proto->set_version(domainToVersionRange.second.second); opset_id_proto->set_version(domainToVersionRange.second.second);
} }
@ -174,7 +173,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only,
if (domain_to_version_range_map.end() != domain_to_version_range_map.find(local_domain.first)) { if (domain_to_version_range_map.end() != domain_to_version_range_map.find(local_domain.first)) {
auto ignored = domain_to_version->insert({local_domain.first, local_domain.second.second}); auto ignored = domain_to_version->insert({local_domain.first, local_domain.second.second});
auto opset_id_proto = model_proto_->add_opset_import(); const gsl::not_null<OperatorSetIdProto*> opset_id_proto = model_proto_->add_opset_import();
opset_id_proto->set_domain(local_domain.first); opset_id_proto->set_domain(local_domain.first);
opset_id_proto->set_version(local_domain.second.second); opset_id_proto->set_version(local_domain.second.second);
} }

Просмотреть файл

@ -2,7 +2,7 @@
#include "core/graph/graph.h" #include "core/graph/graph.h"
// #include "gsl/pointers" #include "gsl/pointers"
namespace LotusIR { namespace LotusIR {
typedef std::unordered_map<std::string, std::string> ModelMetaData; typedef std::unordered_map<std::string, std::string> ModelMetaData;
@ -22,11 +22,11 @@ class Model {
// NOTE: after calling this constructor, <*this> model will // NOTE: after calling this constructor, <*this> model will
// hold a copy of <model_proto>. // hold a copy of <model_proto>.
explicit Model(const ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr); explicit Model(const onnx::ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
// NOTE: after calling this constructor, <*this> model will // NOTE: after calling this constructor, <*this> model will
// own the <model_proto>. // own the <model_proto>.
explicit Model(std::unique_ptr<ModelProto> model_proto, const ILotusOpSchemaCollection* local_registry = nullptr); explicit Model(std::unique_ptr<onnx::ModelProto> model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
// Get model's IR version. // Get model's IR version.
// Return <kNoVersion> if not specified. // Return <kNoVersion> if not specified.
@ -71,7 +71,7 @@ class Model {
const Graph* MainGraph() const noexcept; const Graph* MainGraph() const noexcept;
// Get model's serialization proto data. // Get model's serialization proto data.
ModelProto ToProto(); onnx::ModelProto ToProto();
#ifdef _WIN32 #ifdef _WIN32
static Lotus::Common::Status Save(Model& model, const std::wstring& file_path); static Lotus::Common::Status Save(Model& model, const std::wstring& file_path);
@ -84,7 +84,7 @@ class Model {
static Lotus::Common::Status Save(Model& model, int fd); static Lotus::Common::Status Save(Model& model, int fd);
static Lotus::Common::Status Load(std::istream& model_istream, ModelProto* p_model_proto); static Lotus::Common::Status Load(std::istream& model_istream, onnx::ModelProto* p_model_proto);
static Lotus::Common::Status Load(const std::string& file_path, static Lotus::Common::Status Load(const std::string& file_path,
/*out*/ std::shared_ptr<Model>& p_model, /*out*/ std::shared_ptr<Model>& p_model,
@ -97,7 +97,7 @@ class Model {
static Lotus::Common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model, static Lotus::Common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
const ILotusOpSchemaCollection* local_registry = nullptr); const ILotusOpSchemaCollection* local_registry = nullptr);
static Lotus::Common::Status Load(const ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model, static Lotus::Common::Status Load(const onnx::ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model,
const ILotusOpSchemaCollection* local_registry = nullptr); const ILotusOpSchemaCollection* local_registry = nullptr);
private: private:
@ -106,11 +106,11 @@ class Model {
// if <is_onnx_domain_only> is true, then only onnx domain will be contained. // if <is_onnx_domain_only> is true, then only onnx domain will be contained.
// otherwise, ml domain will also be contained. // otherwise, ml domain will also be contained.
void AddImportOpSets(bool is_onnx_domain_only, void AddImportOpSets(bool is_onnx_domain_only,
/*out*/ std::unordered_map<std::string, int>* domain_to_version, /*out*/ gsl::not_null<std::unordered_map<std::string, int>*> domain_to_version,
const ILotusOpSchemaCollection* local_registry); const ILotusOpSchemaCollection* local_registry);
// Model data. // Model data.
std::unique_ptr<ModelProto> model_proto_; std::unique_ptr<onnx::ModelProto> model_proto_;
// This is a duplication of <model_proto_.metadata_props()>. // This is a duplication of <model_proto_.metadata_props()>.
// It gives better accessibility. // It gives better accessibility.

Просмотреть файл

@ -1,9 +1,9 @@
#include <cstring> #include <cstring>
#include "core/common/CommonSTD.h"
#include "core/graph/constants.h" #include "core/graph/constants.h"
#include "core/graph/op.h" #include "core/graph/op.h"
#include "core/graph/utils.h" #include "core/graph/utils.h"
using namespace onnx;
namespace LotusIR { namespace LotusIR {
bool TypeUtils::IsValidAttribute(const AttributeProto& attr) { bool TypeUtils::IsValidAttribute(const AttributeProto& attr) {

Просмотреть файл

@ -6,12 +6,11 @@
#include "core/common/status.h" #include "core/common/status.h"
#include "core/graph/constants.h" #include "core/graph/constants.h"
using namespace onnx;
using namespace Lotus::Common; using namespace Lotus::Common;
namespace LotusIR { namespace LotusIR {
using AttrType = AttributeProto_AttributeType; using AttrType = onnx::AttributeProto_AttributeType;
using NodeAttributes = std::unordered_map<std::string, AttributeProto>; using NodeAttributes = std::unordered_map<std::string, onnx::AttributeProto>;
// This string array should exactly match the AttrType defined above. // This string array should exactly match the AttrType defined above.
/* /*
@ -44,8 +43,8 @@ static constexpr const char* kAttrTypeStrings[] =
class TypeUtils { class TypeUtils {
public: public:
// Get attribute type given attribute proto data. // Get attribute type given attribute proto data.
static Status GetType(const AttributeProto& attr, AttrType& type); static Status GetType(const onnx::AttributeProto& attr, AttrType& type);
static bool IsValidAttribute(const AttributeProto& attribute); static bool IsValidAttribute(const onnx::AttributeProto& attribute);
}; };
class MsOpRegistry { class MsOpRegistry {

Просмотреть файл

@ -1,8 +1,7 @@
#include "core/common/CommonSTD.h"
#include "core/graph/tensorutils.h" #include "core/graph/tensorutils.h"
#include <algorithm> #include <algorithm>
// #include "gsl/span" #include "gsl/span"
namespace Lotus { namespace Lotus {
namespace Utils { namespace Utils {
@ -23,9 +22,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
return Status(StatusCategory::LOTUS, StatusCode::FAIL, return Status(StatusCategory::LOTUS, StatusCode::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto"); "UnpackTensor: the pre-allocate size does not match the size in proto");
for (auto& elem : tensor.string_data()) { const auto data = gsl::make_span(p_data, expected_size);
*p_data++ = elem;
} auto& string_data = tensor.string_data();
std::copy(string_data.cbegin(), string_data.cend(), data.begin());
return Status::OK(); return Status::OK();
} }
@ -57,9 +57,8 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
return Status(StatusCategory::LOTUS, StatusCode::FAIL, return Status(StatusCategory::LOTUS, StatusCode::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto"); "UnpackTensor: the pre-allocate size does not match the size in proto");
for (auto& elem : tensor.int32_data()) { const auto data = gsl::make_span(p_data, expected_size);
*p_data++ = elem; std::copy(tensor.int32_data().cbegin(), tensor.int32_data().cend(), data.begin());
}
return Status::OK(); return Status::OK();
} }
@ -90,9 +89,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
if (tensor.int32_data_size() != expected_size) if (tensor.int32_data_size() != expected_size)
return Status(StatusCategory::LOTUS, StatusCode::FAIL, return Status(StatusCategory::LOTUS, StatusCode::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto"); "UnpackTensor: the pre-allocate size does not match the size in proto");
for (auto& elem : tensor.int32_data()) {
*p_data++ = (uint16_t)elem; const auto data = gsl::make_span(p_data, expected_size);
} for (int i = 0; i < expected_size; i++)
data[i] = gsl::narrow_cast<uint16_t>(tensor.int32_data()[i]);
return Status::OK(); return Status::OK();
} }

Просмотреть файл

@ -3,8 +3,8 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
// #include "gsl/pointers" #include "gsl/pointers"
// #include "gsl/span" #include "gsl/span"
#include "core/common/common.h" #include "core/common/common.h"
#include "core/common/status.h" #include "core/common/status.h"
@ -36,9 +36,9 @@ class TensorUtils {
if (tensor.field_size() != expected_size) \ if (tensor.field_size() != expected_size) \
return Status(StatusCategory::LOTUS, StatusCode::FAIL, \ return Status(StatusCategory::LOTUS, StatusCode::FAIL, \
"UnpackTensor: the pre-allocated size does not match the size in proto"); \ "UnpackTensor: the pre-allocated size does not match the size in proto"); \
for (auto elem : tensor.field_name()) { \ const auto span = gsl::make_span(p_data, expected_size); \
*p_data++ = static_cast<T>(elem); \ auto& data = tensor.field_name(); \
} \ std::copy(data.cbegin(), data.cend(), span.begin()); \
return Status::OK(); \ return Status::OK(); \
} }

Просмотреть файл

@ -16,12 +16,12 @@
#include "core/common/status.h" #include "core/common/status.h"
#include "onnx/onnx_pb.h" #include "onnx/onnx_pb.h"
// #include "gsl/pointers" #include "gsl/pointers"
namespace Lotus{ namespace Lotus{
using namespace ::Lotus::Common; using namespace ::Lotus::Common;
#ifdef _WIN32 #ifdef _WIN32
inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) { inline Status FileOpenRd(const std::wstring& path, /*out*/ gsl::not_null<int*> p_fd) {
_wsopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE); _wsopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
if (0 > *p_fd) { if (0 > *p_fd) {
return Status(SYSTEM, errno); return Status(SYSTEM, errno);
@ -29,7 +29,7 @@ inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) {
return Status::OK(); return Status::OK();
} }
inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) { inline Status FileOpenWr(const std::wstring& path, /*out*/ gsl::not_null<int*> p_fd) {
_wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE); _wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
if (0 > *p_fd) { if (0 > *p_fd) {
return Status(SYSTEM, errno); return Status(SYSTEM, errno);
@ -38,7 +38,7 @@ inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) {
} }
#endif #endif
inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) { inline Status FileOpenRd(const std::string& path, /*out*/ gsl::not_null<int*> p_fd) {
#ifdef _WIN32 #ifdef _WIN32
_sopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE); _sopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
#else #else
@ -50,7 +50,7 @@ inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) {
return Status::OK(); return Status::OK();
} }
inline Status FileOpenWr(const std::string& path, /*out*/ int* p_fd) { inline Status FileOpenWr(const std::string& path, /*out*/ gsl::not_null<int*> p_fd) {
#ifdef _WIN32 #ifdef _WIN32
_sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE); _sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
#else #else

Просмотреть файл

@ -71,20 +71,20 @@ enum class MLTensorDataType : uint32_t {
}; };
union MLFloat16 { union MLFloat16 {
uint16_t val; uint16_t val;
MLFloat16(uint16_t x) : val(x) {} MLFloat16(uint16_t x) : val(x) {}
MLFloat16() : val(0) {} MLFloat16() : val(0) {}
}; };
inline bool operator==(const MLFloat16& left, const MLFloat16& right) inline bool operator==(const MLFloat16& left, const MLFloat16& right)
{ {
return left.val == right.val; return left.val == right.val;
} }
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) inline bool operator!=(const MLFloat16& left, const MLFloat16& right)
{ {
return left.val != right.val; return left.val != right.val;
} }
struct MLMapType { struct MLMapType {