upgrade linux build from c++11 to c++14, enable gsl, update with latest LotusIR
This commit is contained in:
Родитель
55b4606b23
Коммит
3c87d2012c
4
Makefile
4
Makefile
|
@ -102,9 +102,9 @@ INCLUDEPATH+=$(GSL_PATH)/include
|
||||||
INCLUDEPATH+=$(ONNX_PATH)
|
INCLUDEPATH+=$(ONNX_PATH)
|
||||||
INCLUDEPATH+=$(ONNX_REPO_PATH)
|
INCLUDEPATH+=$(ONNX_REPO_PATH)
|
||||||
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
|
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
|
||||||
COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__
|
COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++14 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__
|
||||||
CPPFLAGS:=
|
CPPFLAGS:=
|
||||||
CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new
|
CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -fopenmp -fpermissive -fPIC -Werror -fcheck-new
|
||||||
LIBPATH:=
|
LIBPATH:=
|
||||||
LIBS_LIST:=
|
LIBS_LIST:=
|
||||||
LDFLAGS:=
|
LDFLAGS:=
|
||||||
|
|
|
@ -178,7 +178,6 @@
|
||||||
<ClInclude Include="proto\onnx\CNTKToONNX.h" />
|
<ClInclude Include="proto\onnx\CNTKToONNX.h" />
|
||||||
<ClInclude Include="proto\onnx\core\common\code_location.h" />
|
<ClInclude Include="proto\onnx\core\common\code_location.h" />
|
||||||
<ClInclude Include="proto\onnx\core\common\common.h" />
|
<ClInclude Include="proto\onnx\core\common\common.h" />
|
||||||
<ClInclude Include="proto\onnx\core\common\CommonSTD.h" />
|
|
||||||
<ClInclude Include="proto\onnx\core\common\const_pointer_container.h" />
|
<ClInclude Include="proto\onnx\core\common\const_pointer_container.h" />
|
||||||
<ClInclude Include="proto\onnx\core\common\exceptions.h" />
|
<ClInclude Include="proto\onnx\core\common\exceptions.h" />
|
||||||
<ClInclude Include="proto\onnx\core\common\logging\capture.h" />
|
<ClInclude Include="proto\onnx\core\common\logging\capture.h" />
|
||||||
|
|
|
@ -93,18 +93,6 @@
|
||||||
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\traditionalml\defs.cc">
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\traditionalml\defs.cc">
|
||||||
<Filter>proto\onnx\onnx_repo\onnx\defs\traditionalml</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\defs\traditionalml</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\traditionalml\defs.cpp">
|
|
||||||
<Filter>proto\onnx\onnx\defs\traditionalml</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\data_type_utils.cpp">
|
|
||||||
<Filter>proto\onnx\onnx\defs</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\schema.cpp">
|
|
||||||
<Filter>proto\onnx\onnx\defs</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="proto\onnx\core\graph\graph.cpp">
|
|
||||||
<Filter>proto\onnx\core\graph</Filter>
|
|
||||||
</ClCompile>
|
|
||||||
<ClCompile Include="proto\onnx\core\graph\model.cc">
|
<ClCompile Include="proto\onnx\core\graph\model.cc">
|
||||||
<Filter>proto\onnx\core\graph</Filter>
|
<Filter>proto\onnx\core\graph</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
|
@ -117,26 +105,37 @@
|
||||||
<ClCompile Include="proto\onnx\core\graph\graph_transformer.cc">
|
<ClCompile Include="proto\onnx\core\graph\graph_transformer.cc">
|
||||||
<Filter>proto\onnx\core\graph</Filter>
|
<Filter>proto\onnx\core\graph</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\math\old.cpp">
|
<ClCompile Include="proto\onnx\onnx-ml.pb.cc.VS_wrapper.cpp" />
|
||||||
<Filter>proto\onnx\onnx\defs\math</Filter>
|
<ClCompile Include="proto\onnx\onnx-operators-ml.pb.cc.VS_wrapper.cpp" />
|
||||||
|
<ClCompile Include="proto\onnx\core\graph\graph.cc">
|
||||||
|
<Filter>proto\onnx\core\graph</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\nn\old.cpp">
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\logical\old.cc">
|
||||||
<Filter>proto\onnx\onnx\defs\nn</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\defs\logical</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\rnn\old.cpp">
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\tensor\old.cc">
|
||||||
<Filter>proto\onnx\onnx\defs\rnn</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\defs\tensor</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\tensor\old.cpp">
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\math\old.cc">
|
||||||
<Filter>proto\onnx\onnx\defs\tensor</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\defs\math</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\checker.cpp">
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\rnn\old.cc">
|
||||||
<Filter>proto\onnx\onnx</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\defs\rnn</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\defs\logical\old.cpp">
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\nn\old.cc">
|
||||||
<Filter>proto\onnx\onnx\defs\logical</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\defs\nn</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="proto\onnx\onnx\common\assertions.cpp">
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\common\assertions.cc">
|
||||||
<Filter>proto\onnx\onnx\common</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\common</Filter>
|
||||||
|
</ClCompile>
|
||||||
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\checker.cc">
|
||||||
|
<Filter>proto\onnx\onnx_repo\onnx</Filter>
|
||||||
|
</ClCompile>
|
||||||
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\data_type_utils.cc">
|
||||||
|
<Filter>proto\onnx\onnx_repo\onnx\defs</Filter>
|
||||||
|
</ClCompile>
|
||||||
|
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\schema.cc">
|
||||||
|
<Filter>proto\onnx\onnx_repo\onnx\defs</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
@ -295,9 +294,6 @@
|
||||||
<ClInclude Include="proto\onnx\core\graph\utils.h">
|
<ClInclude Include="proto\onnx\core\graph\utils.h">
|
||||||
<Filter>proto\onnx\core\graph</Filter>
|
<Filter>proto\onnx\core\graph</Filter>
|
||||||
</ClInclude>
|
</ClInclude>
|
||||||
<ClInclude Include="proto\onnx\core\common\CommonSTD.h">
|
|
||||||
<Filter>proto\onnx\core\common</Filter>
|
|
||||||
</ClInclude>
|
|
||||||
<ClInclude Include="proto\onnx\onnx_repo\onnx\common\assertions.h">
|
<ClInclude Include="proto\onnx\onnx_repo\onnx\common\assertions.h">
|
||||||
<Filter>proto\onnx\onnx_repo\onnx\common</Filter>
|
<Filter>proto\onnx\onnx_repo\onnx\common</Filter>
|
||||||
</ClInclude>
|
</ClInclude>
|
||||||
|
|
|
@ -23,6 +23,7 @@ using namespace Microsoft::MSR::CNTK;
|
||||||
using namespace CNTK::ONNX;
|
using namespace CNTK::ONNX;
|
||||||
using namespace CNTK;
|
using namespace CNTK;
|
||||||
using namespace LotusIR;
|
using namespace LotusIR;
|
||||||
|
using namespace onnx;
|
||||||
|
|
||||||
const int FreeSequenceLen = 0;
|
const int FreeSequenceLen = 0;
|
||||||
const std::string FreeSequenceDimParam = "None";
|
const std::string FreeSequenceDimParam = "None";
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
using namespace LotusIR;
|
using namespace LotusIR;
|
||||||
using namespace CNTK;
|
using namespace CNTK;
|
||||||
using namespace CNTK::ONNX;
|
using namespace CNTK::ONNX;
|
||||||
|
using namespace onnx;
|
||||||
using namespace Microsoft::MSR::CNTK;
|
using namespace Microsoft::MSR::CNTK;
|
||||||
|
|
||||||
namespace CNTK
|
namespace CNTK
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
#pragma once
|
|
||||||
#include <memory>
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
// to get make_unique definition
|
|
||||||
#include "Platform.h"
|
|
||||||
|
|
||||||
// to add what is missing in gsl
|
|
||||||
#if defined(__GNUC__)
|
|
||||||
namespace std {
|
|
||||||
template<bool _Test,
|
|
||||||
class _Ty = void>
|
|
||||||
using enable_if_t = typename enable_if<_Test, _Ty>::type;
|
|
||||||
|
|
||||||
template<class _Ty>
|
|
||||||
using remove_cv_t = typename remove_cv<_Ty>::type;
|
|
||||||
|
|
||||||
template<bool _Test,
|
|
||||||
class _Ty1,
|
|
||||||
class _Ty2>
|
|
||||||
using conditional_t = typename conditional<_Test, _Ty1, _Ty2>::type;
|
|
||||||
|
|
||||||
template<class _Ty>
|
|
||||||
using add_pointer_t = typename add_pointer<_Ty>::type;
|
|
||||||
|
|
||||||
template<class _Ty>
|
|
||||||
using remove_const_t = typename remove_const<_Ty>::type;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
using std::make_unique;
|
|
||||||
#endif
|
|
||||||
|
|
|
@ -33,7 +33,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include "core/common/code_location.h"
|
#include "core/common/code_location.h"
|
||||||
#include "core/common/exceptions.h"
|
#include "core/common/exceptions.h"
|
||||||
#include "core/common/status.h"
|
#include "core/common/status.h"
|
||||||
|
@ -62,6 +61,8 @@ using std::vector;
|
||||||
#define UNUSED_PARAMETER(x)
|
#define UNUSED_PARAMETER(x)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// std::vector<std::string> GetStackTrace();
|
||||||
|
|
||||||
// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER
|
// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER
|
||||||
// so we only define it as one for MSVC
|
// so we only define it as one for MSVC
|
||||||
#if (_MSC_VER && !defined(__PRETTY_FUNCTION__))
|
#if (_MSC_VER && !defined(__PRETTY_FUNCTION__))
|
||||||
|
@ -73,7 +74,7 @@ using std::vector;
|
||||||
Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__)
|
Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__)
|
||||||
|
|
||||||
#define WHERE_WITH_STACK \
|
#define WHERE_WITH_STACK \
|
||||||
Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__) // Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__, Lotus::GetStackTrace())
|
Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__) // , Lotus::GetStackTrace())
|
||||||
|
|
||||||
// Throw an exception with optional message.
|
// Throw an exception with optional message.
|
||||||
// NOTE: The arguments get streamed into a string via ostringstream::operator<<
|
// NOTE: The arguments get streamed into a string via ostringstream::operator<<
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include "core/common/logging/capture.h"
|
#include "core/common/logging/capture.h"
|
||||||
#include "core/common/logging/logging.h"
|
#include "core/common/logging/logging.h"
|
||||||
// #include "gsl/span"
|
#include "gsl/span"
|
||||||
#include "gsl/gsl_util"
|
#include "gsl/gsl_util"
|
||||||
|
|
||||||
namespace Lotus {
|
namespace Lotus {
|
||||||
|
@ -21,27 +20,22 @@ void Capture::CapturePrintf(msvc_printf_check const char *format, ...) {
|
||||||
void Capture::ProcessPrintf(msvc_printf_check const char *format, va_list args) {
|
void Capture::ProcessPrintf(msvc_printf_check const char *format, va_list args) {
|
||||||
static constexpr auto kTruncatedWarningText = "[...truncated...]";
|
static constexpr auto kTruncatedWarningText = "[...truncated...]";
|
||||||
static const int kMaxMessageSize = 2048;
|
static const int kMaxMessageSize = 2048;
|
||||||
char finished_message[kMaxMessageSize];
|
char message_buffer[kMaxMessageSize];
|
||||||
|
const auto message = gsl::make_span(message_buffer);
|
||||||
|
|
||||||
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
|
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
|
||||||
const auto finished_message_len = _countof(finished_message);
|
const int nbrcharacters = vsnprintf_s(message.data(), message.size(), _TRUNCATE, format, args);
|
||||||
#else
|
#else
|
||||||
int finished_message_len = sizeof(finished_message);
|
const int nbrcharacters = vsnprintf(message.data(), message.size(), format, args);
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
|
|
||||||
const int nbrcharacters = vsnprintf_s(finished_message, finished_message_len, _TRUNCATE, format, args);
|
|
||||||
#else
|
|
||||||
const int nbrcharacters = vsnprintf(finished_message, finished_message_len, format, args);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (nbrcharacters <= 0) {
|
if (nbrcharacters <= 0) {
|
||||||
stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message";
|
stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message";
|
||||||
stream_ << '"' << format << '"' << std::endl;
|
stream_ << '"' << format << '"' << std::endl;
|
||||||
} else if (static_cast<uint32_t>(nbrcharacters) > finished_message_len) {
|
} else if (nbrcharacters > message.size()) {
|
||||||
stream_ << finished_message << kTruncatedWarningText;
|
stream_ << message.data() << kTruncatedWarningText;
|
||||||
} else {
|
} else {
|
||||||
stream_ << finished_message;
|
stream_ << message.data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include "core/common/exceptions.h"
|
#include "core/common/exceptions.h"
|
||||||
#include "core/common/logging/isink.h"
|
#include "core/common/logging/isink.h"
|
||||||
#include "core/common/logging/logging.h"
|
#include "core/common/logging/logging.h"
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include "core/common/status.h"
|
#include "core/common/status.h"
|
||||||
#include "core/common/common.h"
|
#include "core/common/common.h"
|
||||||
|
|
||||||
|
|
|
@ -8,15 +8,14 @@
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <stack>
|
#include <stack>
|
||||||
|
|
||||||
#include "core/common/CommonSTD.h"
|
#include "gsl/pointers"
|
||||||
// #include "gsl/pointers"
|
|
||||||
#include "core/graph/graph.h"
|
#include "core/graph/graph.h"
|
||||||
#include "core/graph/op.h"
|
#include "core/graph/op.h"
|
||||||
#include "core/graph/utils.h"
|
#include "core/graph/utils.h"
|
||||||
#include "core/common/logging/logging.h"
|
#include "core/common/logging/logging.h"
|
||||||
#include "onnx/checker.h"
|
#include "onnx/checker.h"
|
||||||
#include "core/graph/schema_registry.h"
|
#include "core/graph/schema_registry.h"
|
||||||
|
using namespace onnx;
|
||||||
using namespace onnx::Utils;
|
using namespace onnx::Utils;
|
||||||
using namespace onnx::checker;
|
using namespace onnx::checker;
|
||||||
|
|
||||||
|
@ -172,7 +171,7 @@ void Node::ToProto(NodeProto& proto) const {
|
||||||
// Set attributes.
|
// Set attributes.
|
||||||
proto.clear_attribute();
|
proto.clear_attribute();
|
||||||
for (auto attribute : attributes_) {
|
for (auto attribute : attributes_) {
|
||||||
auto attr = proto.add_attribute();
|
const gsl::not_null<AttributeProto*> attr = proto.add_attribute();
|
||||||
*attr = attribute.second;
|
*attr = attribute.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,12 +345,12 @@ const NodeAttributes& Node::GetAttributes() const noexcept {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const {
|
void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const {
|
||||||
for (const LotusIR::NodeArg* arg : InputDefs()) {
|
for (const gsl::not_null<const LotusIR::NodeArg*> arg : InputDefs()) {
|
||||||
if (!arg->Exists())
|
if (!arg->Exists())
|
||||||
continue;
|
continue;
|
||||||
func(&*arg, true);
|
func(&*arg, true);
|
||||||
}
|
}
|
||||||
for (const LotusIR::NodeArg* arg : OutputDefs()) {
|
for (const gsl::not_null<const LotusIR::NodeArg*> arg : OutputDefs()) {
|
||||||
if (!arg->Exists())
|
if (!arg->Exists())
|
||||||
continue;
|
continue;
|
||||||
func(&*arg, false);
|
func(&*arg, false);
|
||||||
|
@ -359,7 +358,7 @@ void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)
|
||||||
};
|
};
|
||||||
|
|
||||||
void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
|
void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
|
||||||
for (const LotusIR::NodeArg* arg : InputDefs()) {
|
for (const gsl::not_null<const LotusIR::NodeArg*> arg : InputDefs()) {
|
||||||
if (!arg->Exists())
|
if (!arg->Exists())
|
||||||
continue;
|
continue;
|
||||||
func(&*arg);
|
func(&*arg);
|
||||||
|
@ -367,7 +366,7 @@ void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) co
|
||||||
};
|
};
|
||||||
|
|
||||||
void Node::ForEachOutputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
|
void Node::ForEachOutputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
|
||||||
for (const LotusIR::NodeArg* arg : OutputDefs()) {
|
for (const gsl::not_null<const LotusIR::NodeArg*> arg : OutputDefs()) {
|
||||||
if (!arg->Exists())
|
if (!arg->Exists())
|
||||||
continue;
|
continue;
|
||||||
func(&*arg);
|
func(&*arg);
|
||||||
|
@ -378,7 +377,7 @@ void Node::ReplaceDefs(const std::map<const LotusIR::NodeArg*, LotusIR::NodeArg*
|
||||||
std::vector<std::vector<NodeArg*>*> all_defs = {&definitions_.input_defs, &definitions_.output_defs};
|
std::vector<std::vector<NodeArg*>*> all_defs = {&definitions_.input_defs, &definitions_.output_defs};
|
||||||
|
|
||||||
for (auto pair : replacements)
|
for (auto pair : replacements)
|
||||||
for (auto defs : all_defs)
|
for (const gsl::not_null<std::vector<LotusIR::NodeArg*>*> defs : all_defs)
|
||||||
for (auto& def : *defs)
|
for (auto& def : *defs)
|
||||||
if (def == pair.first)
|
if (def == pair.first)
|
||||||
def = pair.second;
|
def = pair.second;
|
||||||
|
@ -418,14 +417,14 @@ Graph::Graph(GraphProto* graph_proto,
|
||||||
// Copy constant nodes _value to name_to_initial_tensor_
|
// Copy constant nodes _value to name_to_initial_tensor_
|
||||||
for (auto& node : graph_proto_->node()) {
|
for (auto& node : graph_proto_->node()) {
|
||||||
if (node.op_type() == kConstant) {
|
if (node.op_type() == kConstant) {
|
||||||
auto tensor = graph_proto_->add_initializer();
|
const gsl::not_null<TensorProto*> tensor = graph_proto_->add_initializer();
|
||||||
*tensor = node.attribute(0).t();
|
*tensor = node.attribute(0).t();
|
||||||
*(tensor->mutable_name()) = node.output(0);
|
*(tensor->mutable_name()) = node.output(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove constant nodes
|
// remove constant nodes
|
||||||
auto graph_mutable_nodes = graph_proto_->mutable_node();
|
const gsl::not_null<RepeatedPtrField<NodeProto>*> graph_mutable_nodes = graph_proto_->mutable_node();
|
||||||
graph_mutable_nodes->erase(
|
graph_mutable_nodes->erase(
|
||||||
std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(),
|
std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(),
|
||||||
[](NodeProto& p) {
|
[](NodeProto& p) {
|
||||||
|
@ -493,7 +492,7 @@ Status GraphBase::VerifyNoDuplicateName(/*in*/ const std::unordered_set<std::str
|
||||||
node_name_to_index[node_name] = node.Index();
|
node_name_to_index[node_name] = node.Index();
|
||||||
|
|
||||||
// Verify node outputs' name should be unique.
|
// Verify node outputs' name should be unique.
|
||||||
for (const NodeArg* output_def : node.OutputDefs()) {
|
for (const gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||||
if (output_def->Exists()) {
|
if (output_def->Exists()) {
|
||||||
auto& output_arg_name = output_def->Name();
|
auto& output_arg_name = output_def->Name();
|
||||||
if (inputs_and_initializers.count(output_arg_name)) {
|
if (inputs_and_initializers.count(output_arg_name)) {
|
||||||
|
@ -546,7 +545,7 @@ Status GraphBase::BuildConnections(const std::unordered_map<std::string, Node*>&
|
||||||
if (input_args.size() > 0) {
|
if (input_args.size() > 0) {
|
||||||
// This node needs inputs.
|
// This node needs inputs.
|
||||||
|
|
||||||
for (const NodeArg* input_arg : input_args) {
|
for (const gsl::not_null<const NodeArg*> input_arg : input_args) {
|
||||||
if (!input_arg->Exists()) {
|
if (!input_arg->Exists()) {
|
||||||
// This input could be optional and it does not exist in this case.
|
// This input could be optional and it does not exist in this case.
|
||||||
continue;
|
continue;
|
||||||
|
@ -655,7 +654,7 @@ void GraphBase::ReverseDFSFrom(const std::vector<const Node*>& from,
|
||||||
sorted_nodes.push_back((*iter));
|
sorted_nodes.push_back((*iter));
|
||||||
}
|
}
|
||||||
std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
|
std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
|
||||||
for (const LotusIR::Node* in : sorted_nodes) {
|
for (gsl::not_null<const LotusIR::Node*> in : sorted_nodes) {
|
||||||
const NodeIndex idx = in->Index();
|
const NodeIndex idx = in->Index();
|
||||||
if (!visited[idx]) {
|
if (!visited[idx]) {
|
||||||
stack.emplace_back(in, false);
|
stack.emplace_back(in, false);
|
||||||
|
@ -1131,7 +1130,7 @@ Status Graph::VerifyNodeAndOpMatch(const std::vector<NodeIndex>& nodes_in_topolo
|
||||||
|
|
||||||
// currently an Op is required by ValidateVersion, so we use gsl::not_null.
|
// currently an Op is required by ValidateVersion, so we use gsl::not_null.
|
||||||
// This may change in the future to allow a null Op
|
// This may change in the future to allow a null Op
|
||||||
const OpSchema* p_op = node.Op();
|
const gsl::not_null<const OpSchema*> p_op = node.Op();
|
||||||
|
|
||||||
// Attribute verification and fill node attribute with
|
// Attribute verification and fill node attribute with
|
||||||
// default value defined in operator definition if needed.
|
// default value defined in operator definition if needed.
|
||||||
|
@ -1218,7 +1217,7 @@ Status Graph::Resolve(bool no_proto_sync_required) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphBase::GetNodesInTopologicalOrder(const std::vector<NodeIndex>** pp_nodes) const {
|
Status GraphBase::GetNodesInTopologicalOrder(gsl::not_null<const std::vector<NodeIndex>**> pp_nodes) const {
|
||||||
if (graph_resolve_needed_) {
|
if (graph_resolve_needed_) {
|
||||||
return Status{StatusCategory::LOTUS, StatusCode::FAIL,
|
return Status{StatusCategory::LOTUS, StatusCode::FAIL,
|
||||||
"Resolve() must be called before using the graph as modifications have been made to it."};
|
"Resolve() must be called before using the graph as modifications have been made to it."};
|
||||||
|
@ -1263,7 +1262,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tensorAdded = graph_proto_->add_initializer();
|
const gsl::not_null<TensorProto*> tensorAdded = graph_proto_->add_initializer();
|
||||||
*(tensorAdded) = tensor;
|
*(tensorAdded) = tensor;
|
||||||
name_to_initial_tensorIndex_[tensor.name()] = graph_proto_->initializer_size() - 1;
|
name_to_initial_tensorIndex_[tensor.name()] = graph_proto_->initializer_size() - 1;
|
||||||
name_to_initial_tensor_[tensor.name()] = tensorAdded;
|
name_to_initial_tensor_[tensor.name()] = tensorAdded;
|
||||||
|
@ -1283,7 +1282,7 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const {
|
bool Graph::GetInitializedTensor(const std::string& tensor_name, gsl::not_null<const TensorProto**> value) const {
|
||||||
auto iter = name_to_initial_tensor_.find(tensor_name);
|
auto iter = name_to_initial_tensor_.find(tensor_name);
|
||||||
if (name_to_initial_tensor_.end() == iter) {
|
if (name_to_initial_tensor_.end() == iter) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -1318,7 +1317,7 @@ const std::vector<const NodeArg*>& Graph::GetValueInfo() const noexcept {
|
||||||
// Ensure the NodeArgs in the input are created and in this Graph's node arg map
|
// Ensure the NodeArgs in the input are created and in this Graph's node arg map
|
||||||
static void AddNodeArgs(const std::vector<NodeArg*>& input_args,
|
static void AddNodeArgs(const std::vector<NodeArg*>& input_args,
|
||||||
std::unordered_map<std::string, NodeArg*>& node_arg_map) {
|
std::unordered_map<std::string, NodeArg*>& node_arg_map) {
|
||||||
for (auto input_arg : input_args) {
|
for (const gsl::not_null<NodeArg*> input_arg : input_args) {
|
||||||
if (!input_arg->Exists()) continue;
|
if (!input_arg->Exists()) continue;
|
||||||
auto& key = input_arg->Name();
|
auto& key = input_arg->Name();
|
||||||
auto existing_entry = node_arg_map.find(key);
|
auto existing_entry = node_arg_map.find(key);
|
||||||
|
@ -1383,7 +1382,7 @@ Node* GraphBase::AddNode(const Node& other) {
|
||||||
|
|
||||||
Node* GraphBase::AddNode(const NodeProto& node_proto,
|
Node* GraphBase::AddNode(const NodeProto& node_proto,
|
||||||
const ArgNameToTypeMap& name_to_type_map) {
|
const ArgNameToTypeMap& name_to_type_map) {
|
||||||
auto node = AllocateNode();
|
const gsl::not_null<Node*> node = AllocateNode();
|
||||||
|
|
||||||
auto input_defs = CreateNodeArgs(node_proto.input(), name_to_type_map, node_args_, owned_node_args_);
|
auto input_defs = CreateNodeArgs(node_proto.input(), name_to_type_map, node_args_, owned_node_args_);
|
||||||
auto output_defs = CreateNodeArgs(node_proto.output(), name_to_type_map, node_args_, owned_node_args_);
|
auto output_defs = CreateNodeArgs(node_proto.output(), name_to_type_map, node_args_, owned_node_args_);
|
||||||
|
@ -1460,7 +1459,7 @@ Node* GraphBase::AddNode(const std::string& name,
|
||||||
AddNodeArgs(input_args, node_args_);
|
AddNodeArgs(input_args, node_args_);
|
||||||
AddNodeArgs(output_args, node_args_);
|
AddNodeArgs(output_args, node_args_);
|
||||||
|
|
||||||
auto node = AllocateNode();
|
const gsl::not_null<Node*> node = AllocateNode();
|
||||||
node->Init(name, op_type, description, input_args, output_args, attributes, domain);
|
node->Init(name, op_type, description, input_args, output_args, attributes, domain);
|
||||||
if (0 != op_type.compare(kNoOp)) {
|
if (0 != op_type.compare(kNoOp)) {
|
||||||
graph_proto_sync_needed_ = true;
|
graph_proto_sync_needed_ = true;
|
||||||
|
@ -1505,8 +1504,8 @@ const GraphProto& Graph::ToGraphProto() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto node_proto = graph_proto_->add_node();
|
const gsl::not_null<NodeProto*> node_proto = graph_proto_->add_node();
|
||||||
auto p_node = GetNode(node_idx);
|
const gsl::not_null<Node*> p_node = GetNode(node_idx);
|
||||||
p_node->ToProto(*node_proto);
|
p_node->ToProto(*node_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1551,15 +1550,15 @@ void Graph::SyncGraphInputsOutputs() {
|
||||||
graph_proto_->clear_output();
|
graph_proto_->clear_output();
|
||||||
graph_proto_->clear_value_info();
|
graph_proto_->clear_value_info();
|
||||||
|
|
||||||
for (const LotusIR::NodeArg* input_arg : GetInputs()) {
|
for (const gsl::not_null<const LotusIR::NodeArg*> input_arg : GetInputs()) {
|
||||||
*(graph_proto_->mutable_input()->Add()) = input_arg->ToProto();
|
*(graph_proto_->mutable_input()->Add()) = input_arg->ToProto();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const LotusIR::NodeArg* output_arg : GetOutputs()) {
|
for (const gsl::not_null<const LotusIR::NodeArg*> output_arg : GetOutputs()) {
|
||||||
*(graph_proto_->mutable_output()->Add()) = output_arg->ToProto();
|
*(graph_proto_->mutable_output()->Add()) = output_arg->ToProto();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const LotusIR::NodeArg* value_info : value_info_) {
|
for (const gsl::not_null<const LotusIR::NodeArg*> value_info : value_info_) {
|
||||||
*(graph_proto_->mutable_value_info()->Add()) = value_info->ToProto();
|
*(graph_proto_->mutable_value_info()->Add()) = value_info->ToProto();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1573,11 +1572,11 @@ void Graph::CleanUnusedInitializers() {
|
||||||
for (const auto& pv : name_to_initial_tensor_) {
|
for (const auto& pv : name_to_initial_tensor_) {
|
||||||
const std::string& s = pv.first;
|
const std::string& s = pv.first;
|
||||||
const bool used_as_input = std::any_of(input_args.begin(), input_args.end(),
|
const bool used_as_input = std::any_of(input_args.begin(), input_args.end(),
|
||||||
[&s](const NodeArg* input) noexcept {
|
[&s](const gsl::not_null<const NodeArg*> input) noexcept {
|
||||||
return s == input->Name();
|
return s == input->Name();
|
||||||
});
|
});
|
||||||
const bool used_as_output = std::any_of(GetOutputs().begin(), GetOutputs().end(),
|
const bool used_as_output = std::any_of(GetOutputs().begin(), GetOutputs().end(),
|
||||||
[&s](const NodeArg* output) noexcept {
|
[&s](const gsl::not_null<const NodeArg*> output) noexcept {
|
||||||
return s == output->Name();
|
return s == output->Name();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1639,7 +1638,7 @@ Status Graph::SetGraphInputsOutputs() {
|
||||||
|
|
||||||
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
||||||
for (const auto& node : Nodes()) {
|
for (const auto& node : Nodes()) {
|
||||||
for (const NodeArg* output_def : node.OutputDefs()) {
|
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||||
if (specified_graph_outputs.erase(output_def->Name()) >= 1) {
|
if (specified_graph_outputs.erase(output_def->Name()) >= 1) {
|
||||||
graph_outputs.push_back(output_def);
|
graph_outputs.push_back(output_def);
|
||||||
}
|
}
|
||||||
|
@ -1664,7 +1663,7 @@ Status Graph::SetGraphInputsOutputs() {
|
||||||
|
|
||||||
for (const auto& node : Nodes()) {
|
for (const auto& node : Nodes()) {
|
||||||
// Go thru all node's inputs.
|
// Go thru all node's inputs.
|
||||||
for (const NodeArg* input_arg : node.InputDefs()) {
|
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
|
||||||
if (!input_arg->Exists()) {
|
if (!input_arg->Exists()) {
|
||||||
// It's an optional input and does not exist in this case.
|
// It's an optional input and does not exist in this case.
|
||||||
continue;
|
continue;
|
||||||
|
@ -1694,7 +1693,7 @@ Status Graph::SetGraphInputsOutputs() {
|
||||||
} else {
|
} else {
|
||||||
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
|
||||||
for (const auto& node : Nodes()) {
|
for (const auto& node : Nodes()) {
|
||||||
for (const NodeArg* output_def : node.OutputDefs()) {
|
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
|
||||||
if (output_def->Exists())
|
if (output_def->Exists())
|
||||||
output_name_to_node_arg.insert({output_def->Name(), output_def});
|
output_name_to_node_arg.insert({output_def->Name(), output_def});
|
||||||
}
|
}
|
||||||
|
@ -1706,7 +1705,7 @@ Status Graph::SetGraphInputsOutputs() {
|
||||||
std::unordered_set<Node*> inner_nodes;
|
std::unordered_set<Node*> inner_nodes;
|
||||||
for (const auto& node : Nodes()) {
|
for (const auto& node : Nodes()) {
|
||||||
// Go thru all node's inputs.
|
// Go thru all node's inputs.
|
||||||
for (const NodeArg* input_arg : node.InputDefs()) {
|
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
|
||||||
if (!input_arg->Exists()) {
|
if (!input_arg->Exists()) {
|
||||||
// It's an optional input and does not exist in this case.
|
// It's an optional input and does not exist in this case.
|
||||||
continue;
|
continue;
|
||||||
|
@ -1719,8 +1718,8 @@ Status Graph::SetGraphInputsOutputs() {
|
||||||
const std::string& name = input_arg->Name();
|
const std::string& name = input_arg->Name();
|
||||||
if (added_input_names.end() == added_input_names.find(name)) {
|
if (added_input_names.end() == added_input_names.find(name)) {
|
||||||
// This graph input has not been added into <graph_inputs_>.
|
// This graph input has not been added into <graph_inputs_>.
|
||||||
// if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end())
|
//// if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end())
|
||||||
graph_inputs.push_back(input_arg);
|
graph_inputs.push_back(input_arg);
|
||||||
added_input_names.insert(input_arg->Name());
|
added_input_names.insert(input_arg->Name());
|
||||||
}
|
}
|
||||||
} else if (graph_output_args.erase(output_arg_iter->first) >= 1) {
|
} else if (graph_output_args.erase(output_arg_iter->first) >= 1) {
|
||||||
|
@ -1759,7 +1758,7 @@ const Node* GraphBase::SinkNode() const {
|
||||||
|
|
||||||
// calling private ctor
|
// calling private ctor
|
||||||
GSL_SUPPRESS(r .11)
|
GSL_SUPPRESS(r .11)
|
||||||
Node* GraphBase::AllocateNode() {
|
gsl::not_null<Node*> GraphBase::AllocateNode() {
|
||||||
std::unique_ptr<Node> new_node(new Node(nodes_.size(), *this));
|
std::unique_ptr<Node> new_node(new Node(nodes_.size(), *this));
|
||||||
Node* node{new_node.get()};
|
Node* node{new_node.get()};
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "core/common/CommonSTD.h"
|
#include "gsl/pointers"
|
||||||
|
|
||||||
// #include "gsl/pointers"
|
|
||||||
#include "gsl/gsl_util"
|
#include "gsl/gsl_util"
|
||||||
|
|
||||||
#include "core/common/common.h"
|
#include "core/common/common.h"
|
||||||
|
@ -21,21 +19,19 @@
|
||||||
#include "core/graph/utils.h"
|
#include "core/graph/utils.h"
|
||||||
#include "onnx/onnx_pb.h"
|
#include "onnx/onnx_pb.h"
|
||||||
|
|
||||||
using namespace onnx;
|
|
||||||
|
|
||||||
// TODO - Evaluate switching the types below to support transparent comparators and enable
|
// TODO - Evaluate switching the types below to support transparent comparators and enable
|
||||||
// lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations
|
// lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations
|
||||||
// converting to std::string, but requires conversion to std::map<std::string, foo, std::less<>>
|
// converting to std::string, but requires conversion to std::map<std::string, foo, std::less<>>
|
||||||
// instead of std::unordered_map<std::string, foo, [std::less<foo>]>.
|
// instead of std::unordered_map<std::string, foo, [std::less<foo>]>.
|
||||||
|
|
||||||
using NodeAttributes = std::unordered_map<std::string, AttributeProto>;
|
using NodeAttributes = std::unordered_map<std::string, onnx::AttributeProto>;
|
||||||
|
|
||||||
namespace LotusIR {
|
namespace LotusIR {
|
||||||
using NodeIndex = size_t;
|
using NodeIndex = size_t;
|
||||||
using Version = int64_t;
|
using Version = int64_t;
|
||||||
using NodeArgInfo = ValueInfoProto;
|
using NodeArgInfo = onnx::ValueInfoProto;
|
||||||
using InitializedTensorSet = std::unordered_map<std::string, const TensorProto*>;
|
using InitializedTensorSet = std::unordered_map<std::string, const onnx::TensorProto*>;
|
||||||
using ArgNameToTypeMap = std::unordered_map<std::string, TypeProto>;
|
using ArgNameToTypeMap = std::unordered_map<std::string, onnx::TypeProto>;
|
||||||
using ProviderType = const std::string&;
|
using ProviderType = const std::string&;
|
||||||
|
|
||||||
class Graph;
|
class Graph;
|
||||||
|
@ -70,7 +66,7 @@ class NodeArg {
|
||||||
// optional. This is called when loading a <Graph> from <GraphProto>
|
// optional. This is called when loading a <Graph> from <GraphProto>
|
||||||
// normally.
|
// normally.
|
||||||
NodeArg(const std::string& name,
|
NodeArg(const std::string& name,
|
||||||
const TypeProto* p_arg_type);
|
const onnx::TypeProto* p_arg_type);
|
||||||
|
|
||||||
NodeArg(NodeArg&& other) = default;
|
NodeArg(NodeArg&& other) = default;
|
||||||
|
|
||||||
|
@ -78,17 +74,17 @@ class NodeArg {
|
||||||
const std::string& Name() const noexcept;
|
const std::string& Name() const noexcept;
|
||||||
|
|
||||||
// Get node arg type.
|
// Get node arg type.
|
||||||
DataType Type() const noexcept;
|
onnx::DataType Type() const noexcept;
|
||||||
const TypeProto* TypeAsProto() const noexcept;
|
const onnx::TypeProto* TypeAsProto() const noexcept;
|
||||||
|
|
||||||
// Get node arg shape.
|
// Get node arg shape.
|
||||||
// Return null pointer if there's no shape specified.
|
// Return null pointer if there's no shape specified.
|
||||||
const TensorShapeProto* Shape() const;
|
const onnx::TensorShapeProto* Shape() const;
|
||||||
|
|
||||||
// Set node arg shape.
|
// Set node arg shape.
|
||||||
// Shape could only be set after setting type since shape information
|
// Shape could only be set after setting type since shape information
|
||||||
// now is part of TypeProto.
|
// now is part of TypeProto.
|
||||||
void SetShape(const TensorShapeProto& shape);
|
void SetShape(const onnx::TensorShapeProto& shape);
|
||||||
|
|
||||||
// Get node arg info proto.
|
// Get node arg info proto.
|
||||||
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
|
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
|
||||||
|
@ -102,13 +98,13 @@ class NodeArg {
|
||||||
LOTUS_DISALLOW_COPY_AND_ASSIGN(NodeArg);
|
LOTUS_DISALLOW_COPY_AND_ASSIGN(NodeArg);
|
||||||
friend class Graph;
|
friend class Graph;
|
||||||
|
|
||||||
void SetType(DataType p_type);
|
void SetType(onnx::DataType p_type);
|
||||||
void SetType(const TypeProto& type_proto);
|
void SetType(const onnx::TypeProto& type_proto);
|
||||||
|
|
||||||
NodeArg& operator=(NodeArg&& other) = delete;
|
NodeArg& operator=(NodeArg&& other) = delete;
|
||||||
|
|
||||||
// Node arg PType.
|
// Node arg PType.
|
||||||
DataType type_;
|
onnx::DataType type_;
|
||||||
|
|
||||||
// Node arg name, type and shape.
|
// Node arg name, type and shape.
|
||||||
NodeArgInfo node_arg_info_;
|
NodeArgInfo node_arg_info_;
|
||||||
|
@ -159,7 +155,7 @@ class Node {
|
||||||
|
|
||||||
// Get the OperatorSchema this node refers to. ValidateOpType() must have been called previously.
|
// Get the OperatorSchema this node refers to. ValidateOpType() must have been called previously.
|
||||||
// May be null in the future.
|
// May be null in the future.
|
||||||
const OpSchema* Op() const noexcept;
|
const onnx::OpSchema* Op() const noexcept;
|
||||||
|
|
||||||
// Get node description.
|
// Get node description.
|
||||||
const std::string& Description() const noexcept;
|
const std::string& Description() const noexcept;
|
||||||
|
@ -167,8 +163,8 @@ class Node {
|
||||||
// Iterate through Input/OutputDefs() with index, note the loop early terminates with error
|
// Iterate through Input/OutputDefs() with index, note the loop early terminates with error
|
||||||
static Lotus::Common::Status ForEachWithIndex(
|
static Lotus::Common::Status ForEachWithIndex(
|
||||||
const ConstPointerContainer<std::vector<NodeArg*>>& nodeArgVec,
|
const ConstPointerContainer<std::vector<NodeArg*>>& nodeArgVec,
|
||||||
std::function<Lotus::Common::Status(const NodeArg& arg, int index)> func) {
|
std::function<Lotus::Common::Status(const NodeArg& arg, size_t index)> func) {
|
||||||
for (int index = 0; index < nodeArgVec.size(); ++index) {
|
for (size_t index = 0; index < nodeArgVec.size(); ++index) {
|
||||||
auto arg = nodeArgVec[index];
|
auto arg = nodeArgVec[index];
|
||||||
if (!arg->Exists())
|
if (!arg->Exists())
|
||||||
continue;
|
continue;
|
||||||
|
@ -207,7 +203,7 @@ class Node {
|
||||||
const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; }
|
const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; }
|
||||||
|
|
||||||
// Add a node attribute with specified attribute name and value.
|
// Add a node attribute with specified attribute name and value.
|
||||||
void AddAttribute(const std::string& attr_name, const AttributeProto& value);
|
void AddAttribute(const std::string& attr_name, const onnx::AttributeProto& value);
|
||||||
|
|
||||||
#define ADD_ATTR_INTERFACES(TypeName) \
|
#define ADD_ATTR_INTERFACES(TypeName) \
|
||||||
void AddAttribute(const std::string& attr_name, \
|
void AddAttribute(const std::string& attr_name, \
|
||||||
|
@ -218,8 +214,8 @@ class Node {
|
||||||
ADD_ATTR_INTERFACES(int64_t)
|
ADD_ATTR_INTERFACES(int64_t)
|
||||||
ADD_ATTR_INTERFACES(float)
|
ADD_ATTR_INTERFACES(float)
|
||||||
ADD_ATTR_INTERFACES(std::string)
|
ADD_ATTR_INTERFACES(std::string)
|
||||||
ADD_ATTR_INTERFACES(TensorProto)
|
ADD_ATTR_INTERFACES(onnx::TensorProto)
|
||||||
ADD_ATTR_INTERFACES(GraphProto)
|
ADD_ATTR_INTERFACES(onnx::GraphProto)
|
||||||
|
|
||||||
// Clear specified node attribute.
|
// Clear specified node attribute.
|
||||||
bool ClearAttribute(const std::string& attr_name);
|
bool ClearAttribute(const std::string& attr_name);
|
||||||
|
@ -235,7 +231,7 @@ class Node {
|
||||||
void SetExecutionProviderType(ProviderType execution_provider_type);
|
void SetExecutionProviderType(ProviderType execution_provider_type);
|
||||||
|
|
||||||
// Get the corresponding <NodeProto>.
|
// Get the corresponding <NodeProto>.
|
||||||
void ToProto(NodeProto& proto) const;
|
void ToProto(onnx::NodeProto& proto) const;
|
||||||
|
|
||||||
// iterate through all input/output defs
|
// iterate through all input/output defs
|
||||||
void ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const;
|
void ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const;
|
||||||
|
@ -355,7 +351,7 @@ class Node {
|
||||||
std::string domain_;
|
std::string domain_;
|
||||||
|
|
||||||
// OperatorSchema that <*this> node refers to.
|
// OperatorSchema that <*this> node refers to.
|
||||||
const OpSchema* op_ = nullptr;
|
const onnx::OpSchema* op_ = nullptr;
|
||||||
|
|
||||||
// Node doc string.
|
// Node doc string.
|
||||||
std::string description_;
|
std::string description_;
|
||||||
|
@ -428,7 +424,7 @@ class GraphBase {
|
||||||
int NumberOfNodes() const noexcept { return num_of_nodes_; }
|
int NumberOfNodes() const noexcept { return num_of_nodes_; }
|
||||||
|
|
||||||
// Get NodeArg by name, or create NodeArg owned by the graph if not found
|
// Get NodeArg by name, or create NodeArg owned by the graph if not found
|
||||||
NodeArg& GetOrCreateNodeArg(const std::string& name, const TypeProto* p_arg_type) {
|
NodeArg& GetOrCreateNodeArg(const std::string& name, const onnx::TypeProto* p_arg_type) {
|
||||||
auto iter = node_args_.find(name);
|
auto iter = node_args_.find(name);
|
||||||
if (iter != node_args_.end())
|
if (iter != node_args_.end())
|
||||||
return *(iter->second);
|
return *(iter->second);
|
||||||
|
@ -492,7 +488,7 @@ class GraphBase {
|
||||||
// TODO(Task:135) See if GraphBase::GetNodesInTopologicalOrder can be made more correctly const
|
// TODO(Task:135) See if GraphBase::GetNodesInTopologicalOrder can be made more correctly const
|
||||||
// by forcing Resolve to have been called directly previously. Simple change is to return error if
|
// by forcing Resolve to have been called directly previously. Simple change is to return error if
|
||||||
// GraphResolveNeeded is true.
|
// GraphResolveNeeded is true.
|
||||||
Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ const std::vector<NodeIndex>** pp_nodes) const;
|
Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ gsl::not_null<const std::vector<NodeIndex>**> pp_nodes) const;
|
||||||
|
|
||||||
// Mark Graph as needing Resolve() to be called
|
// Mark Graph as needing Resolve() to be called
|
||||||
GraphBase& SetGraphResolveNeeded() noexcept {
|
GraphBase& SetGraphResolveNeeded() noexcept {
|
||||||
|
@ -545,7 +541,7 @@ class GraphBase {
|
||||||
void AddSourceSinkNodes();
|
void AddSourceSinkNodes();
|
||||||
|
|
||||||
// Add node with specified <node_proto>.
|
// Add node with specified <node_proto>.
|
||||||
Node* AddNode(const NodeProto& node_proto,
|
Node* AddNode(const onnx::NodeProto& node_proto,
|
||||||
const ArgNameToTypeMap& name_to_type);
|
const ArgNameToTypeMap& name_to_type);
|
||||||
|
|
||||||
NodeIndex SourceNodeIndex() const noexcept { return source_node_index_; }
|
NodeIndex SourceNodeIndex() const noexcept { return source_node_index_; }
|
||||||
|
@ -613,13 +609,13 @@ class GraphBase {
|
||||||
// Returns the inferred shape+type for every output of the node in
|
// Returns the inferred shape+type for every output of the node in
|
||||||
// output parameter inferredShapes.
|
// output parameter inferredShapes.
|
||||||
Lotus::Common::Status InferOutputTypesAndShapes(LotusIR::Node& node,
|
Lotus::Common::Status InferOutputTypesAndShapes(LotusIR::Node& node,
|
||||||
/*out*/ std::vector<TypeProto>& inferred_shapes);
|
/*out*/ std::vector<onnx::TypeProto>& inferred_shapes);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// need custom versions to handle the unique_ptr's in nodes_
|
// need custom versions to handle the unique_ptr's in nodes_
|
||||||
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(GraphBase);
|
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(GraphBase);
|
||||||
|
|
||||||
Node* AllocateNode();
|
gsl::not_null<Node*> AllocateNode();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Release the node.
|
Release the node.
|
||||||
|
@ -709,9 +705,9 @@ class Graph : public GraphBase {
|
||||||
void SetDescription(const std::string& description) override;
|
void SetDescription(const std::string& description) override;
|
||||||
|
|
||||||
// Add/Remove/Get initial tensors for some graph inputs.
|
// Add/Remove/Get initial tensors for some graph inputs.
|
||||||
void AddInitializedTensor(const TensorProto& tensor_proto);
|
void AddInitializedTensor(const onnx::TensorProto& tensor_proto);
|
||||||
void RemoveInitializedTensor(const std::string& tensor_name);
|
void RemoveInitializedTensor(const std::string& tensor_name);
|
||||||
bool GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const;
|
bool GetInitializedTensor(const std::string& tensor_name, gsl::not_null<const onnx::TensorProto**> value) const;
|
||||||
const InitializedTensorSet& GetAllInitializedTensors() const noexcept;
|
const InitializedTensorSet& GetAllInitializedTensors() const noexcept;
|
||||||
void CleanAllInitializedTensors() noexcept;
|
void CleanAllInitializedTensors() noexcept;
|
||||||
|
|
||||||
|
@ -719,7 +715,7 @@ class Graph : public GraphBase {
|
||||||
const std::vector<const NodeArg*>& GetValueInfo() const noexcept;
|
const std::vector<const NodeArg*>& GetValueInfo() const noexcept;
|
||||||
|
|
||||||
// Serialize the <Graph> into <GraphProto>.
|
// Serialize the <Graph> into <GraphProto>.
|
||||||
const GraphProto& ToGraphProto();
|
const onnx::GraphProto& ToGraphProto();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(Graph);
|
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(Graph);
|
||||||
|
@ -730,7 +726,7 @@ class Graph : public GraphBase {
|
||||||
|
|
||||||
// Constructor: Given a <GraphProto> loaded from model file, construct
|
// Constructor: Given a <GraphProto> loaded from model file, construct
|
||||||
// a <Graph> object.
|
// a <Graph> object.
|
||||||
Graph(GraphProto* graph_proto,
|
Graph(onnx::GraphProto* graph_proto,
|
||||||
const std::unordered_map<std::string, int>& domain_to_version,
|
const std::unordered_map<std::string, int>& domain_to_version,
|
||||||
Version ir_version,
|
Version ir_version,
|
||||||
const ILotusOpSchemaCollection* local_registry = nullptr);
|
const ILotusOpSchemaCollection* local_registry = nullptr);
|
||||||
|
@ -754,7 +750,7 @@ class Graph : public GraphBase {
|
||||||
Lotus::Common::Status Resolve(bool no_proto_sync_required);
|
Lotus::Common::Status Resolve(bool no_proto_sync_required);
|
||||||
|
|
||||||
Lotus::Common::Status InferAndVerifyTypeMatch(Node& node,
|
Lotus::Common::Status InferAndVerifyTypeMatch(Node& node,
|
||||||
const OpSchema& op);
|
const onnx::OpSchema& op);
|
||||||
|
|
||||||
// Apply type-inference and type-checking to all inputs and initializers:
|
// Apply type-inference and type-checking to all inputs and initializers:
|
||||||
Lotus::Common::Status TypeCheckInputsAndInitializers();
|
Lotus::Common::Status TypeCheckInputsAndInitializers();
|
||||||
|
@ -783,7 +779,7 @@ class Graph : public GraphBase {
|
||||||
// functions in <Graph> will also be fed into <graph_proto_> so that
|
// functions in <Graph> will also be fed into <graph_proto_> so that
|
||||||
// it's consistent with <*this> graph.
|
// it's consistent with <*this> graph.
|
||||||
// This pointer is owned by parent model.
|
// This pointer is owned by parent model.
|
||||||
GraphProto* graph_proto_;
|
onnx::GraphProto* graph_proto_;
|
||||||
|
|
||||||
// The node which refers to <*this> graph (Function).
|
// The node which refers to <*this> graph (Function).
|
||||||
// Node* node_;
|
// Node* node_;
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include "core/graph/graph_transformer.h"
|
#include "core/graph/graph_transformer.h"
|
||||||
using namespace Lotus;
|
using namespace Lotus;
|
||||||
using namespace Lotus::Common;
|
using namespace Lotus::Common;
|
||||||
|
|
|
@ -53,7 +53,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
|
||||||
// should be stored globally. Otherwise, there will be multiple addresses/pointers
|
// should be stored globally. Otherwise, there will be multiple addresses/pointers
|
||||||
// for the same operator or function. To avoid this, we may use OpSignature ID
|
// for the same operator or function. To avoid this, we may use OpSignature ID
|
||||||
// as the key, which should be name_domain_version.
|
// as the key, which should be name_domain_version.
|
||||||
Lotus::Common::Status Register(const OpSchema* op, std::unique_ptr<RewriteRule> rule) {
|
Lotus::Common::Status Register(const onnx::OpSchema* op, std::unique_ptr<RewriteRule> rule) {
|
||||||
op_to_rules_[op].push_back(std::move(rule));
|
op_to_rules_[op].push_back(std::move(rule));
|
||||||
return Lotus::Common::Status::OK();
|
return Lotus::Common::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
typedef std::unordered_map<const OpSchema*, std::vector<std::unique_ptr<RewriteRule>>>
|
typedef std::unordered_map<const onnx::OpSchema*, std::vector<std::unique_ptr<RewriteRule>>>
|
||||||
RewriteRuleSet;
|
RewriteRuleSet;
|
||||||
|
|
||||||
RewriteRuleSet op_to_rules_;
|
RewriteRuleSet op_to_rules_;
|
||||||
|
|
|
@ -7,16 +7,15 @@
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
#pragma warning(pop)
|
#pragma warning(pop)
|
||||||
#endif
|
#endif
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "core/graph/model.h"
|
#include "core/graph/model.h"
|
||||||
#include "core/graph/utils.h"
|
#include "core/graph/utils.h"
|
||||||
#include "core/graph/schema_registry.h"
|
#include "core/graph/schema_registry.h"
|
||||||
// #include "gsl/pointers"
|
#include "gsl/pointers"
|
||||||
#include "gsl/gsl_util"
|
#include "gsl/gsl_util"
|
||||||
|
using namespace onnx;
|
||||||
using namespace Lotus;
|
using namespace Lotus;
|
||||||
using namespace Lotus::Common;
|
using namespace Lotus::Common;
|
||||||
|
|
||||||
|
@ -27,7 +26,7 @@ Model::Model(const std::string& graph_name, bool is_onnx_domain_only, const Mode
|
||||||
model_proto_->mutable_graph()->set_name(graph_name);
|
model_proto_->mutable_graph()->set_name(graph_name);
|
||||||
model_metadata_ = model_metadata;
|
model_metadata_ = model_metadata;
|
||||||
for (auto& metadata : model_metadata_) {
|
for (auto& metadata : model_metadata_) {
|
||||||
auto prop = model_proto_->add_metadata_props();
|
const gsl::not_null<StringStringEntryProto*> prop = model_proto_->add_metadata_props();
|
||||||
prop->set_key(metadata.first);
|
prop->set_key(metadata.first);
|
||||||
prop->set_value(metadata.second);
|
prop->set_value(metadata.second);
|
||||||
}
|
}
|
||||||
|
@ -138,7 +137,7 @@ ModelProto Model::ToProto() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Model::AddImportOpSets(bool is_onnx_domain_only,
|
void Model::AddImportOpSets(bool is_onnx_domain_only,
|
||||||
/*out*/ std::unordered_map<std::string, int>* domain_to_version,
|
/*out*/ gsl::not_null<std::unordered_map<std::string, int>*> domain_to_version,
|
||||||
const ILotusOpSchemaCollection* local_registry) {
|
const ILotusOpSchemaCollection* local_registry) {
|
||||||
auto& domain_to_version_range_map = OpSchemaRegistry::DomainToVersionRange::Instance().Map();
|
auto& domain_to_version_range_map = OpSchemaRegistry::DomainToVersionRange::Instance().Map();
|
||||||
Domain_To_Version_Map local_domain_to_version_map = local_registry ? local_registry->DomainToVersionMap() : Domain_To_Version_Map();
|
Domain_To_Version_Map local_domain_to_version_map = local_registry ? local_registry->DomainToVersionMap() : Domain_To_Version_Map();
|
||||||
|
@ -158,7 +157,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only,
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ignored = domain_to_version->insert({domainToVersionRange.first, max});
|
auto ignored = domain_to_version->insert({domainToVersionRange.first, max});
|
||||||
auto opset_id_proto = model_proto_->add_opset_import();
|
const gsl::not_null<OperatorSetIdProto*> opset_id_proto = model_proto_->add_opset_import();
|
||||||
opset_id_proto->set_domain(domainToVersionRange.first);
|
opset_id_proto->set_domain(domainToVersionRange.first);
|
||||||
opset_id_proto->set_version(domainToVersionRange.second.second);
|
opset_id_proto->set_version(domainToVersionRange.second.second);
|
||||||
}
|
}
|
||||||
|
@ -174,7 +173,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only,
|
||||||
|
|
||||||
if (domain_to_version_range_map.end() != domain_to_version_range_map.find(local_domain.first)) {
|
if (domain_to_version_range_map.end() != domain_to_version_range_map.find(local_domain.first)) {
|
||||||
auto ignored = domain_to_version->insert({local_domain.first, local_domain.second.second});
|
auto ignored = domain_to_version->insert({local_domain.first, local_domain.second.second});
|
||||||
auto opset_id_proto = model_proto_->add_opset_import();
|
const gsl::not_null<OperatorSetIdProto*> opset_id_proto = model_proto_->add_opset_import();
|
||||||
opset_id_proto->set_domain(local_domain.first);
|
opset_id_proto->set_domain(local_domain.first);
|
||||||
opset_id_proto->set_version(local_domain.second.second);
|
opset_id_proto->set_version(local_domain.second.second);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
#include "core/graph/graph.h"
|
#include "core/graph/graph.h"
|
||||||
|
|
||||||
// #include "gsl/pointers"
|
#include "gsl/pointers"
|
||||||
|
|
||||||
namespace LotusIR {
|
namespace LotusIR {
|
||||||
typedef std::unordered_map<std::string, std::string> ModelMetaData;
|
typedef std::unordered_map<std::string, std::string> ModelMetaData;
|
||||||
|
@ -22,11 +22,11 @@ class Model {
|
||||||
|
|
||||||
// NOTE: after calling this constructor, <*this> model will
|
// NOTE: after calling this constructor, <*this> model will
|
||||||
// hold a copy of <model_proto>.
|
// hold a copy of <model_proto>.
|
||||||
explicit Model(const ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
|
explicit Model(const onnx::ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
|
||||||
|
|
||||||
// NOTE: after calling this constructor, <*this> model will
|
// NOTE: after calling this constructor, <*this> model will
|
||||||
// own the <model_proto>.
|
// own the <model_proto>.
|
||||||
explicit Model(std::unique_ptr<ModelProto> model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
|
explicit Model(std::unique_ptr<onnx::ModelProto> model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
|
||||||
|
|
||||||
// Get model's IR version.
|
// Get model's IR version.
|
||||||
// Return <kNoVersion> if not specified.
|
// Return <kNoVersion> if not specified.
|
||||||
|
@ -71,7 +71,7 @@ class Model {
|
||||||
const Graph* MainGraph() const noexcept;
|
const Graph* MainGraph() const noexcept;
|
||||||
|
|
||||||
// Get model's serialization proto data.
|
// Get model's serialization proto data.
|
||||||
ModelProto ToProto();
|
onnx::ModelProto ToProto();
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
static Lotus::Common::Status Save(Model& model, const std::wstring& file_path);
|
static Lotus::Common::Status Save(Model& model, const std::wstring& file_path);
|
||||||
|
@ -84,7 +84,7 @@ class Model {
|
||||||
|
|
||||||
static Lotus::Common::Status Save(Model& model, int fd);
|
static Lotus::Common::Status Save(Model& model, int fd);
|
||||||
|
|
||||||
static Lotus::Common::Status Load(std::istream& model_istream, ModelProto* p_model_proto);
|
static Lotus::Common::Status Load(std::istream& model_istream, onnx::ModelProto* p_model_proto);
|
||||||
|
|
||||||
static Lotus::Common::Status Load(const std::string& file_path,
|
static Lotus::Common::Status Load(const std::string& file_path,
|
||||||
/*out*/ std::shared_ptr<Model>& p_model,
|
/*out*/ std::shared_ptr<Model>& p_model,
|
||||||
|
@ -97,7 +97,7 @@ class Model {
|
||||||
static Lotus::Common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
|
static Lotus::Common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
|
||||||
const ILotusOpSchemaCollection* local_registry = nullptr);
|
const ILotusOpSchemaCollection* local_registry = nullptr);
|
||||||
|
|
||||||
static Lotus::Common::Status Load(const ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model,
|
static Lotus::Common::Status Load(const onnx::ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model,
|
||||||
const ILotusOpSchemaCollection* local_registry = nullptr);
|
const ILotusOpSchemaCollection* local_registry = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -106,11 +106,11 @@ class Model {
|
||||||
// if <is_onnx_domain_only> is true, then only onnx domain will be contained.
|
// if <is_onnx_domain_only> is true, then only onnx domain will be contained.
|
||||||
// otherwise, ml domain will also be contained.
|
// otherwise, ml domain will also be contained.
|
||||||
void AddImportOpSets(bool is_onnx_domain_only,
|
void AddImportOpSets(bool is_onnx_domain_only,
|
||||||
/*out*/ std::unordered_map<std::string, int>* domain_to_version,
|
/*out*/ gsl::not_null<std::unordered_map<std::string, int>*> domain_to_version,
|
||||||
const ILotusOpSchemaCollection* local_registry);
|
const ILotusOpSchemaCollection* local_registry);
|
||||||
|
|
||||||
// Model data.
|
// Model data.
|
||||||
std::unique_ptr<ModelProto> model_proto_;
|
std::unique_ptr<onnx::ModelProto> model_proto_;
|
||||||
|
|
||||||
// This is a duplication of <model_proto_.metadata_props()>.
|
// This is a duplication of <model_proto_.metadata_props()>.
|
||||||
// It gives better accessibility.
|
// It gives better accessibility.
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include "core/graph/constants.h"
|
#include "core/graph/constants.h"
|
||||||
#include "core/graph/op.h"
|
#include "core/graph/op.h"
|
||||||
#include "core/graph/utils.h"
|
#include "core/graph/utils.h"
|
||||||
|
|
||||||
|
using namespace onnx;
|
||||||
namespace LotusIR {
|
namespace LotusIR {
|
||||||
|
|
||||||
bool TypeUtils::IsValidAttribute(const AttributeProto& attr) {
|
bool TypeUtils::IsValidAttribute(const AttributeProto& attr) {
|
||||||
|
|
|
@ -6,12 +6,11 @@
|
||||||
#include "core/common/status.h"
|
#include "core/common/status.h"
|
||||||
#include "core/graph/constants.h"
|
#include "core/graph/constants.h"
|
||||||
|
|
||||||
using namespace onnx;
|
|
||||||
using namespace Lotus::Common;
|
using namespace Lotus::Common;
|
||||||
|
|
||||||
namespace LotusIR {
|
namespace LotusIR {
|
||||||
using AttrType = AttributeProto_AttributeType;
|
using AttrType = onnx::AttributeProto_AttributeType;
|
||||||
using NodeAttributes = std::unordered_map<std::string, AttributeProto>;
|
using NodeAttributes = std::unordered_map<std::string, onnx::AttributeProto>;
|
||||||
|
|
||||||
// This string array should exactly match the AttrType defined above.
|
// This string array should exactly match the AttrType defined above.
|
||||||
/*
|
/*
|
||||||
|
@ -44,8 +43,8 @@ static constexpr const char* kAttrTypeStrings[] =
|
||||||
class TypeUtils {
|
class TypeUtils {
|
||||||
public:
|
public:
|
||||||
// Get attribute type given attribute proto data.
|
// Get attribute type given attribute proto data.
|
||||||
static Status GetType(const AttributeProto& attr, AttrType& type);
|
static Status GetType(const onnx::AttributeProto& attr, AttrType& type);
|
||||||
static bool IsValidAttribute(const AttributeProto& attribute);
|
static bool IsValidAttribute(const onnx::AttributeProto& attribute);
|
||||||
};
|
};
|
||||||
|
|
||||||
class MsOpRegistry {
|
class MsOpRegistry {
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
#include "core/common/CommonSTD.h"
|
|
||||||
#include "core/graph/tensorutils.h"
|
#include "core/graph/tensorutils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
// #include "gsl/span"
|
#include "gsl/span"
|
||||||
|
|
||||||
namespace Lotus {
|
namespace Lotus {
|
||||||
namespace Utils {
|
namespace Utils {
|
||||||
|
@ -23,9 +22,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
|
||||||
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
|
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
|
||||||
"UnpackTensor: the pre-allocate size does not match the size in proto");
|
"UnpackTensor: the pre-allocate size does not match the size in proto");
|
||||||
|
|
||||||
for (auto& elem : tensor.string_data()) {
|
const auto data = gsl::make_span(p_data, expected_size);
|
||||||
*p_data++ = elem;
|
|
||||||
}
|
auto& string_data = tensor.string_data();
|
||||||
|
std::copy(string_data.cbegin(), string_data.cend(), data.begin());
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -57,9 +57,8 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
|
||||||
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
|
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
|
||||||
"UnpackTensor: the pre-allocate size does not match the size in proto");
|
"UnpackTensor: the pre-allocate size does not match the size in proto");
|
||||||
|
|
||||||
for (auto& elem : tensor.int32_data()) {
|
const auto data = gsl::make_span(p_data, expected_size);
|
||||||
*p_data++ = elem;
|
std::copy(tensor.int32_data().cbegin(), tensor.int32_data().cend(), data.begin());
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -90,9 +89,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
|
||||||
if (tensor.int32_data_size() != expected_size)
|
if (tensor.int32_data_size() != expected_size)
|
||||||
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
|
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
|
||||||
"UnpackTensor: the pre-allocate size does not match the size in proto");
|
"UnpackTensor: the pre-allocate size does not match the size in proto");
|
||||||
for (auto& elem : tensor.int32_data()) {
|
|
||||||
*p_data++ = (uint16_t)elem;
|
const auto data = gsl::make_span(p_data, expected_size);
|
||||||
}
|
for (int i = 0; i < expected_size; i++)
|
||||||
|
data[i] = gsl::narrow_cast<uint16_t>(tensor.int32_data()[i]);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// #include "gsl/pointers"
|
#include "gsl/pointers"
|
||||||
// #include "gsl/span"
|
#include "gsl/span"
|
||||||
|
|
||||||
#include "core/common/common.h"
|
#include "core/common/common.h"
|
||||||
#include "core/common/status.h"
|
#include "core/common/status.h"
|
||||||
|
@ -36,9 +36,9 @@ class TensorUtils {
|
||||||
if (tensor.field_size() != expected_size) \
|
if (tensor.field_size() != expected_size) \
|
||||||
return Status(StatusCategory::LOTUS, StatusCode::FAIL, \
|
return Status(StatusCategory::LOTUS, StatusCode::FAIL, \
|
||||||
"UnpackTensor: the pre-allocated size does not match the size in proto"); \
|
"UnpackTensor: the pre-allocated size does not match the size in proto"); \
|
||||||
for (auto elem : tensor.field_name()) { \
|
const auto span = gsl::make_span(p_data, expected_size); \
|
||||||
*p_data++ = static_cast<T>(elem); \
|
auto& data = tensor.field_name(); \
|
||||||
} \
|
std::copy(data.cbegin(), data.cend(), span.begin()); \
|
||||||
return Status::OK(); \
|
return Status::OK(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
#include "core/common/status.h"
|
#include "core/common/status.h"
|
||||||
#include "onnx/onnx_pb.h"
|
#include "onnx/onnx_pb.h"
|
||||||
|
|
||||||
// #include "gsl/pointers"
|
#include "gsl/pointers"
|
||||||
|
|
||||||
namespace Lotus{
|
namespace Lotus{
|
||||||
using namespace ::Lotus::Common;
|
using namespace ::Lotus::Common;
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) {
|
inline Status FileOpenRd(const std::wstring& path, /*out*/ gsl::not_null<int*> p_fd) {
|
||||||
_wsopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
_wsopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||||
if (0 > *p_fd) {
|
if (0 > *p_fd) {
|
||||||
return Status(SYSTEM, errno);
|
return Status(SYSTEM, errno);
|
||||||
|
@ -29,7 +29,7 @@ inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) {
|
inline Status FileOpenWr(const std::wstring& path, /*out*/ gsl::not_null<int*> p_fd) {
|
||||||
_wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
_wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||||
if (0 > *p_fd) {
|
if (0 > *p_fd) {
|
||||||
return Status(SYSTEM, errno);
|
return Status(SYSTEM, errno);
|
||||||
|
@ -38,7 +38,7 @@ inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) {
|
inline Status FileOpenRd(const std::string& path, /*out*/ gsl::not_null<int*> p_fd) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
_sopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
_sopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||||
#else
|
#else
|
||||||
|
@ -50,7 +50,7 @@ inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Status FileOpenWr(const std::string& path, /*out*/ int* p_fd) {
|
inline Status FileOpenWr(const std::string& path, /*out*/ gsl::not_null<int*> p_fd) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
_sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
_sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -71,20 +71,20 @@ enum class MLTensorDataType : uint32_t {
|
||||||
};
|
};
|
||||||
|
|
||||||
union MLFloat16 {
|
union MLFloat16 {
|
||||||
uint16_t val;
|
uint16_t val;
|
||||||
|
|
||||||
MLFloat16(uint16_t x) : val(x) {}
|
MLFloat16(uint16_t x) : val(x) {}
|
||||||
MLFloat16() : val(0) {}
|
MLFloat16() : val(0) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool operator==(const MLFloat16& left, const MLFloat16& right)
|
inline bool operator==(const MLFloat16& left, const MLFloat16& right)
|
||||||
{
|
{
|
||||||
return left.val == right.val;
|
return left.val == right.val;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool operator!=(const MLFloat16& left, const MLFloat16& right)
|
inline bool operator!=(const MLFloat16& left, const MLFloat16& right)
|
||||||
{
|
{
|
||||||
return left.val != right.val;
|
return left.val != right.val;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MLMapType {
|
struct MLMapType {
|
||||||
|
|
Загрузка…
Ссылка в новой задаче