upgrade linux build from c++11 to c++14, enable gsl, update with latest LotusIR

This commit is contained in:
liqfu 2018-07-09 13:29:47 -07:00
Родитель 55b4606b23
Коммит 3c87d2012c
22 изменённых файлов: 154 добавлений и 203 удалений

Просмотреть файл

@ -102,9 +102,9 @@ INCLUDEPATH+=$(GSL_PATH)/include
INCLUDEPATH+=$(ONNX_PATH)
INCLUDEPATH+=$(ONNX_REPO_PATH)
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__
COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++14 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__
CPPFLAGS:=
CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new
CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -fopenmp -fpermissive -fPIC -Werror -fcheck-new
LIBPATH:=
LIBS_LIST:=
LDFLAGS:=

Просмотреть файл

@ -178,7 +178,6 @@
<ClInclude Include="proto\onnx\CNTKToONNX.h" />
<ClInclude Include="proto\onnx\core\common\code_location.h" />
<ClInclude Include="proto\onnx\core\common\common.h" />
<ClInclude Include="proto\onnx\core\common\CommonSTD.h" />
<ClInclude Include="proto\onnx\core\common\const_pointer_container.h" />
<ClInclude Include="proto\onnx\core\common\exceptions.h" />
<ClInclude Include="proto\onnx\core\common\logging\capture.h" />

Просмотреть файл

@ -93,18 +93,6 @@
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\traditionalml\defs.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs\traditionalml</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\traditionalml\defs.cpp">
<Filter>proto\onnx\onnx\defs\traditionalml</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\data_type_utils.cpp">
<Filter>proto\onnx\onnx\defs</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\schema.cpp">
<Filter>proto\onnx\onnx\defs</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\core\graph\graph.cpp">
<Filter>proto\onnx\core\graph</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\core\graph\model.cc">
<Filter>proto\onnx\core\graph</Filter>
</ClCompile>
@ -117,26 +105,37 @@
<ClCompile Include="proto\onnx\core\graph\graph_transformer.cc">
<Filter>proto\onnx\core\graph</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\math\old.cpp">
<Filter>proto\onnx\onnx\defs\math</Filter>
<ClCompile Include="proto\onnx\onnx-ml.pb.cc.VS_wrapper.cpp" />
<ClCompile Include="proto\onnx\onnx-operators-ml.pb.cc.VS_wrapper.cpp" />
<ClCompile Include="proto\onnx\core\graph\graph.cc">
<Filter>proto\onnx\core\graph</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\nn\old.cpp">
<Filter>proto\onnx\onnx\defs\nn</Filter>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\logical\old.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs\logical</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\rnn\old.cpp">
<Filter>proto\onnx\onnx\defs\rnn</Filter>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\tensor\old.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs\tensor</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\tensor\old.cpp">
<Filter>proto\onnx\onnx\defs\tensor</Filter>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\math\old.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs\math</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\checker.cpp">
<Filter>proto\onnx\onnx</Filter>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\rnn\old.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs\rnn</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\defs\logical\old.cpp">
<Filter>proto\onnx\onnx\defs\logical</Filter>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\nn\old.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs\nn</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx\common\assertions.cpp">
<Filter>proto\onnx\onnx\common</Filter>
<ClCompile Include="proto\onnx\onnx_repo\onnx\common\assertions.cc">
<Filter>proto\onnx\onnx_repo\onnx\common</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx_repo\onnx\checker.cc">
<Filter>proto\onnx\onnx_repo\onnx</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\data_type_utils.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs</Filter>
</ClCompile>
<ClCompile Include="proto\onnx\onnx_repo\onnx\defs\schema.cc">
<Filter>proto\onnx\onnx_repo\onnx\defs</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
@ -295,9 +294,6 @@
<ClInclude Include="proto\onnx\core\graph\utils.h">
<Filter>proto\onnx\core\graph</Filter>
</ClInclude>
<ClInclude Include="proto\onnx\core\common\CommonSTD.h">
<Filter>proto\onnx\core\common</Filter>
</ClInclude>
<ClInclude Include="proto\onnx\onnx_repo\onnx\common\assertions.h">
<Filter>proto\onnx\onnx_repo\onnx\common</Filter>
</ClInclude>

Просмотреть файл

@ -23,6 +23,7 @@ using namespace Microsoft::MSR::CNTK;
using namespace CNTK::ONNX;
using namespace CNTK;
using namespace LotusIR;
using namespace onnx;
const int FreeSequenceLen = 0;
const std::string FreeSequenceDimParam = "None";

Просмотреть файл

@ -16,6 +16,7 @@
using namespace LotusIR;
using namespace CNTK;
using namespace CNTK::ONNX;
using namespace onnx;
using namespace Microsoft::MSR::CNTK;
namespace CNTK

Просмотреть файл

@ -1,32 +0,0 @@
#pragma once
#include <memory>
#include <type_traits>
// to get make_unique definition
#include "Platform.h"
// to add what is missing in gsl
#if defined(__GNUC__)
namespace std {
template<bool _Test,
class _Ty = void>
using enable_if_t = typename enable_if<_Test, _Ty>::type;
template<class _Ty>
using remove_cv_t = typename remove_cv<_Ty>::type;
template<bool _Test,
class _Ty1,
class _Ty2>
using conditional_t = typename conditional<_Test, _Ty1, _Ty2>::type;
template<class _Ty>
using add_pointer_t = typename add_pointer<_Ty>::type;
template<class _Ty>
using remove_const_t = typename remove_const<_Ty>::type;
}
#else
using std::make_unique;
#endif

Просмотреть файл

@ -33,7 +33,6 @@
#include <vector>
#include <chrono>
#include "core/common/CommonSTD.h"
#include "core/common/code_location.h"
#include "core/common/exceptions.h"
#include "core/common/status.h"
@ -62,6 +61,8 @@ using std::vector;
#define UNUSED_PARAMETER(x)
#endif
// std::vector<std::string> GetStackTrace();
// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER
// so we only define it as one for MSVC
#if (_MSC_VER && !defined(__PRETTY_FUNCTION__))
@ -73,7 +74,7 @@ using std::vector;
Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__)
#define WHERE_WITH_STACK \
Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__) // Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__, Lotus::GetStackTrace())
Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__) // , Lotus::GetStackTrace())
// Throw an exception with optional message.
// NOTE: The arguments get streamed into a string via ostringstream::operator<<

Просмотреть файл

@ -1,7 +1,6 @@
#include "core/common/CommonSTD.h"
#include "core/common/logging/capture.h"
#include "core/common/logging/logging.h"
// #include "gsl/span"
#include "gsl/span"
#include "gsl/gsl_util"
namespace Lotus {
@ -21,27 +20,22 @@ void Capture::CapturePrintf(msvc_printf_check const char *format, ...) {
void Capture::ProcessPrintf(msvc_printf_check const char *format, va_list args) {
static constexpr auto kTruncatedWarningText = "[...truncated...]";
static const int kMaxMessageSize = 2048;
char finished_message[kMaxMessageSize];
char message_buffer[kMaxMessageSize];
const auto message = gsl::make_span(message_buffer);
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
const auto finished_message_len = _countof(finished_message);
const int nbrcharacters = vsnprintf_s(message.data(), message.size(), _TRUNCATE, format, args);
#else
int finished_message_len = sizeof(finished_message);
#endif
#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
const int nbrcharacters = vsnprintf_s(finished_message, finished_message_len, _TRUNCATE, format, args);
#else
const int nbrcharacters = vsnprintf(finished_message, finished_message_len, format, args);
const int nbrcharacters = vsnprintf(message.data(), message.size(), format, args);
#endif
if (nbrcharacters <= 0) {
stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message";
stream_ << '"' << format << '"' << std::endl;
} else if (static_cast<uint32_t>(nbrcharacters) > finished_message_len) {
stream_ << finished_message << kTruncatedWarningText;
} else if (nbrcharacters > message.size()) {
stream_ << message.data() << kTruncatedWarningText;
} else {
stream_ << finished_message;
stream_ << message.data();
}
}

Просмотреть файл

@ -1,6 +1,6 @@
#include <exception>
#include <ctime>
#include "core/common/CommonSTD.h"
#include "core/common/exceptions.h"
#include "core/common/logging/isink.h"
#include "core/common/logging/logging.h"

Просмотреть файл

@ -1,4 +1,3 @@
#include "core/common/CommonSTD.h"
#include "core/common/status.h"
#include "core/common/common.h"

Просмотреть файл

@ -8,15 +8,14 @@
#include <numeric>
#include <stack>
#include "core/common/CommonSTD.h"
// #include "gsl/pointers"
#include "gsl/pointers"
#include "core/graph/graph.h"
#include "core/graph/op.h"
#include "core/graph/utils.h"
#include "core/common/logging/logging.h"
#include "onnx/checker.h"
#include "core/graph/schema_registry.h"
using namespace onnx;
using namespace onnx::Utils;
using namespace onnx::checker;
@ -172,7 +171,7 @@ void Node::ToProto(NodeProto& proto) const {
// Set attributes.
proto.clear_attribute();
for (auto attribute : attributes_) {
auto attr = proto.add_attribute();
const gsl::not_null<AttributeProto*> attr = proto.add_attribute();
*attr = attribute.second;
}
@ -346,12 +345,12 @@ const NodeAttributes& Node::GetAttributes() const noexcept {
}
void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const {
for (const LotusIR::NodeArg* arg : InputDefs()) {
for (const gsl::not_null<const LotusIR::NodeArg*> arg : InputDefs()) {
if (!arg->Exists())
continue;
func(&*arg, true);
}
for (const LotusIR::NodeArg* arg : OutputDefs()) {
for (const gsl::not_null<const LotusIR::NodeArg*> arg : OutputDefs()) {
if (!arg->Exists())
continue;
func(&*arg, false);
@ -359,7 +358,7 @@ void Node::ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)
};
void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
for (const LotusIR::NodeArg* arg : InputDefs()) {
for (const gsl::not_null<const LotusIR::NodeArg*> arg : InputDefs()) {
if (!arg->Exists())
continue;
func(&*arg);
@ -367,7 +366,7 @@ void Node::ForEachInputDef(std::function<void(const LotusIR::NodeArg*)> func) co
};
void Node::ForEachOutputDef(std::function<void(const LotusIR::NodeArg*)> func) const {
for (const LotusIR::NodeArg* arg : OutputDefs()) {
for (const gsl::not_null<const LotusIR::NodeArg*> arg : OutputDefs()) {
if (!arg->Exists())
continue;
func(&*arg);
@ -378,7 +377,7 @@ void Node::ReplaceDefs(const std::map<const LotusIR::NodeArg*, LotusIR::NodeArg*
std::vector<std::vector<NodeArg*>*> all_defs = {&definitions_.input_defs, &definitions_.output_defs};
for (auto pair : replacements)
for (auto defs : all_defs)
for (const gsl::not_null<std::vector<LotusIR::NodeArg*>*> defs : all_defs)
for (auto& def : *defs)
if (def == pair.first)
def = pair.second;
@ -418,14 +417,14 @@ Graph::Graph(GraphProto* graph_proto,
// Copy constant nodes _value to name_to_initial_tensor_
for (auto& node : graph_proto_->node()) {
if (node.op_type() == kConstant) {
auto tensor = graph_proto_->add_initializer();
const gsl::not_null<TensorProto*> tensor = graph_proto_->add_initializer();
*tensor = node.attribute(0).t();
*(tensor->mutable_name()) = node.output(0);
}
}
// remove constant nodes
auto graph_mutable_nodes = graph_proto_->mutable_node();
const gsl::not_null<RepeatedPtrField<NodeProto>*> graph_mutable_nodes = graph_proto_->mutable_node();
graph_mutable_nodes->erase(
std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(),
[](NodeProto& p) {
@ -493,7 +492,7 @@ Status GraphBase::VerifyNoDuplicateName(/*in*/ const std::unordered_set<std::str
node_name_to_index[node_name] = node.Index();
// Verify node outputs' name should be unique.
for (const NodeArg* output_def : node.OutputDefs()) {
for (const gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
if (output_def->Exists()) {
auto& output_arg_name = output_def->Name();
if (inputs_and_initializers.count(output_arg_name)) {
@ -546,7 +545,7 @@ Status GraphBase::BuildConnections(const std::unordered_map<std::string, Node*>&
if (input_args.size() > 0) {
// This node needs inputs.
for (const NodeArg* input_arg : input_args) {
for (const gsl::not_null<const NodeArg*> input_arg : input_args) {
if (!input_arg->Exists()) {
// This input could be optional and it does not exist in this case.
continue;
@ -655,7 +654,7 @@ void GraphBase::ReverseDFSFrom(const std::vector<const Node*>& from,
sorted_nodes.push_back((*iter));
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
for (const LotusIR::Node* in : sorted_nodes) {
for (gsl::not_null<const LotusIR::Node*> in : sorted_nodes) {
const NodeIndex idx = in->Index();
if (!visited[idx]) {
stack.emplace_back(in, false);
@ -1131,7 +1130,7 @@ Status Graph::VerifyNodeAndOpMatch(const std::vector<NodeIndex>& nodes_in_topolo
// currently an Op is required by ValidateVersion, so we use gsl::not_null.
// This may change in the future to allow a null Op
const OpSchema* p_op = node.Op();
const gsl::not_null<const OpSchema*> p_op = node.Op();
// Attribute verification and fill node attribute with
// default value defined in operator definition if needed.
@ -1218,7 +1217,7 @@ Status Graph::Resolve(bool no_proto_sync_required) {
return Status::OK();
}
Status GraphBase::GetNodesInTopologicalOrder(const std::vector<NodeIndex>** pp_nodes) const {
Status GraphBase::GetNodesInTopologicalOrder(gsl::not_null<const std::vector<NodeIndex>**> pp_nodes) const {
if (graph_resolve_needed_) {
return Status{StatusCategory::LOTUS, StatusCode::FAIL,
"Resolve() must be called before using the graph as modifications have been made to it."};
@ -1263,7 +1262,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) {
return;
}
auto tensorAdded = graph_proto_->add_initializer();
const gsl::not_null<TensorProto*> tensorAdded = graph_proto_->add_initializer();
*(tensorAdded) = tensor;
name_to_initial_tensorIndex_[tensor.name()] = graph_proto_->initializer_size() - 1;
name_to_initial_tensor_[tensor.name()] = tensorAdded;
@ -1283,7 +1282,7 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) {
}
}
bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const {
bool Graph::GetInitializedTensor(const std::string& tensor_name, gsl::not_null<const TensorProto**> value) const {
auto iter = name_to_initial_tensor_.find(tensor_name);
if (name_to_initial_tensor_.end() == iter) {
return false;
@ -1318,7 +1317,7 @@ const std::vector<const NodeArg*>& Graph::GetValueInfo() const noexcept {
// Ensure the NodeArgs in the input are created and in this Graph's node arg map
static void AddNodeArgs(const std::vector<NodeArg*>& input_args,
std::unordered_map<std::string, NodeArg*>& node_arg_map) {
for (auto input_arg : input_args) {
for (const gsl::not_null<NodeArg*> input_arg : input_args) {
if (!input_arg->Exists()) continue;
auto& key = input_arg->Name();
auto existing_entry = node_arg_map.find(key);
@ -1383,7 +1382,7 @@ Node* GraphBase::AddNode(const Node& other) {
Node* GraphBase::AddNode(const NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type_map) {
auto node = AllocateNode();
const gsl::not_null<Node*> node = AllocateNode();
auto input_defs = CreateNodeArgs(node_proto.input(), name_to_type_map, node_args_, owned_node_args_);
auto output_defs = CreateNodeArgs(node_proto.output(), name_to_type_map, node_args_, owned_node_args_);
@ -1460,7 +1459,7 @@ Node* GraphBase::AddNode(const std::string& name,
AddNodeArgs(input_args, node_args_);
AddNodeArgs(output_args, node_args_);
auto node = AllocateNode();
const gsl::not_null<Node*> node = AllocateNode();
node->Init(name, op_type, description, input_args, output_args, attributes, domain);
if (0 != op_type.compare(kNoOp)) {
graph_proto_sync_needed_ = true;
@ -1505,8 +1504,8 @@ const GraphProto& Graph::ToGraphProto() {
continue;
}
auto node_proto = graph_proto_->add_node();
auto p_node = GetNode(node_idx);
const gsl::not_null<NodeProto*> node_proto = graph_proto_->add_node();
const gsl::not_null<Node*> p_node = GetNode(node_idx);
p_node->ToProto(*node_proto);
}
@ -1551,15 +1550,15 @@ void Graph::SyncGraphInputsOutputs() {
graph_proto_->clear_output();
graph_proto_->clear_value_info();
for (const LotusIR::NodeArg* input_arg : GetInputs()) {
for (const gsl::not_null<const LotusIR::NodeArg*> input_arg : GetInputs()) {
*(graph_proto_->mutable_input()->Add()) = input_arg->ToProto();
}
for (const LotusIR::NodeArg* output_arg : GetOutputs()) {
for (const gsl::not_null<const LotusIR::NodeArg*> output_arg : GetOutputs()) {
*(graph_proto_->mutable_output()->Add()) = output_arg->ToProto();
}
for (const LotusIR::NodeArg* value_info : value_info_) {
for (const gsl::not_null<const LotusIR::NodeArg*> value_info : value_info_) {
*(graph_proto_->mutable_value_info()->Add()) = value_info->ToProto();
}
}
@ -1573,11 +1572,11 @@ void Graph::CleanUnusedInitializers() {
for (const auto& pv : name_to_initial_tensor_) {
const std::string& s = pv.first;
const bool used_as_input = std::any_of(input_args.begin(), input_args.end(),
[&s](const NodeArg* input) noexcept {
[&s](const gsl::not_null<const NodeArg*> input) noexcept {
return s == input->Name();
});
const bool used_as_output = std::any_of(GetOutputs().begin(), GetOutputs().end(),
[&s](const NodeArg* output) noexcept {
[&s](const gsl::not_null<const NodeArg*> output) noexcept {
return s == output->Name();
});
@ -1639,7 +1638,7 @@ Status Graph::SetGraphInputsOutputs() {
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
for (const auto& node : Nodes()) {
for (const NodeArg* output_def : node.OutputDefs()) {
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
if (specified_graph_outputs.erase(output_def->Name()) >= 1) {
graph_outputs.push_back(output_def);
}
@ -1664,7 +1663,7 @@ Status Graph::SetGraphInputsOutputs() {
for (const auto& node : Nodes()) {
// Go thru all node's inputs.
for (const NodeArg* input_arg : node.InputDefs()) {
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
if (!input_arg->Exists()) {
// It's an optional input and does not exist in this case.
continue;
@ -1694,7 +1693,7 @@ Status Graph::SetGraphInputsOutputs() {
} else {
std::unordered_map<std::string, const NodeArg*> output_name_to_node_arg;
for (const auto& node : Nodes()) {
for (const NodeArg* output_def : node.OutputDefs()) {
for (gsl::not_null<const NodeArg*> output_def : node.OutputDefs()) {
if (output_def->Exists())
output_name_to_node_arg.insert({output_def->Name(), output_def});
}
@ -1706,7 +1705,7 @@ Status Graph::SetGraphInputsOutputs() {
std::unordered_set<Node*> inner_nodes;
for (const auto& node : Nodes()) {
// Go thru all node's inputs.
for (const NodeArg* input_arg : node.InputDefs()) {
for (const gsl::not_null<const NodeArg*> input_arg : node.InputDefs()) {
if (!input_arg->Exists()) {
// It's an optional input and does not exist in this case.
continue;
@ -1719,7 +1718,7 @@ Status Graph::SetGraphInputsOutputs() {
const std::string& name = input_arg->Name();
if (added_input_names.end() == added_input_names.find(name)) {
// This graph input has not been added into <graph_inputs_>.
// if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end())
//// if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end())
graph_inputs.push_back(input_arg);
added_input_names.insert(input_arg->Name());
}
@ -1759,7 +1758,7 @@ const Node* GraphBase::SinkNode() const {
// calling private ctor
GSL_SUPPRESS(r .11)
Node* GraphBase::AllocateNode() {
gsl::not_null<Node*> GraphBase::AllocateNode() {
std::unique_ptr<Node> new_node(new Node(nodes_.size(), *this));
Node* node{new_node.get()};

Просмотреть файл

@ -7,9 +7,7 @@
#include <unordered_map>
#include <unordered_set>
#include "core/common/CommonSTD.h"
// #include "gsl/pointers"
#include "gsl/pointers"
#include "gsl/gsl_util"
#include "core/common/common.h"
@ -21,21 +19,19 @@
#include "core/graph/utils.h"
#include "onnx/onnx_pb.h"
using namespace onnx;
// TODO - Evaluate switching the types below to support transparent comparators and enable
// lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations
// converting to std::string, but requires conversion to std::map<std::string, foo, std::less<>>
// instead of std::unordered_map<std::string, foo, [std::less<foo>]>.
using NodeAttributes = std::unordered_map<std::string, AttributeProto>;
using NodeAttributes = std::unordered_map<std::string, onnx::AttributeProto>;
namespace LotusIR {
using NodeIndex = size_t;
using Version = int64_t;
using NodeArgInfo = ValueInfoProto;
using InitializedTensorSet = std::unordered_map<std::string, const TensorProto*>;
using ArgNameToTypeMap = std::unordered_map<std::string, TypeProto>;
using NodeArgInfo = onnx::ValueInfoProto;
using InitializedTensorSet = std::unordered_map<std::string, const onnx::TensorProto*>;
using ArgNameToTypeMap = std::unordered_map<std::string, onnx::TypeProto>;
using ProviderType = const std::string&;
class Graph;
@ -70,7 +66,7 @@ class NodeArg {
// optional. This is called when loading a <Graph> from <GraphProto>
// normally.
NodeArg(const std::string& name,
const TypeProto* p_arg_type);
const onnx::TypeProto* p_arg_type);
NodeArg(NodeArg&& other) = default;
@ -78,17 +74,17 @@ class NodeArg {
const std::string& Name() const noexcept;
// Get node arg type.
DataType Type() const noexcept;
const TypeProto* TypeAsProto() const noexcept;
onnx::DataType Type() const noexcept;
const onnx::TypeProto* TypeAsProto() const noexcept;
// Get node arg shape.
// Return null pointer if there's no shape specified.
const TensorShapeProto* Shape() const;
const onnx::TensorShapeProto* Shape() const;
// Set node arg shape.
// Shape could only be set after setting type since shape information
// now is part of TypeProto.
void SetShape(const TensorShapeProto& shape);
void SetShape(const onnx::TensorShapeProto& shape);
// Get node arg info proto.
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
@ -102,13 +98,13 @@ class NodeArg {
LOTUS_DISALLOW_COPY_AND_ASSIGN(NodeArg);
friend class Graph;
void SetType(DataType p_type);
void SetType(const TypeProto& type_proto);
void SetType(onnx::DataType p_type);
void SetType(const onnx::TypeProto& type_proto);
NodeArg& operator=(NodeArg&& other) = delete;
// Node arg PType.
DataType type_;
onnx::DataType type_;
// Node arg name, type and shape.
NodeArgInfo node_arg_info_;
@ -159,7 +155,7 @@ class Node {
// Get the OperatorSchema this node refers to. ValidateOpType() must have been called previously.
// May be null in the future.
const OpSchema* Op() const noexcept;
const onnx::OpSchema* Op() const noexcept;
// Get node description.
const std::string& Description() const noexcept;
@ -167,8 +163,8 @@ class Node {
// Iterate through Input/OutputDefs() with index, note the loop early terminates with error
static Lotus::Common::Status ForEachWithIndex(
const ConstPointerContainer<std::vector<NodeArg*>>& nodeArgVec,
std::function<Lotus::Common::Status(const NodeArg& arg, int index)> func) {
for (int index = 0; index < nodeArgVec.size(); ++index) {
std::function<Lotus::Common::Status(const NodeArg& arg, size_t index)> func) {
for (size_t index = 0; index < nodeArgVec.size(); ++index) {
auto arg = nodeArgVec[index];
if (!arg->Exists())
continue;
@ -207,7 +203,7 @@ class Node {
const std::set<std::string>& ControlInputs() const noexcept { return relationships_.control_inputs; }
// Add a node attribute with specified attribute name and value.
void AddAttribute(const std::string& attr_name, const AttributeProto& value);
void AddAttribute(const std::string& attr_name, const onnx::AttributeProto& value);
#define ADD_ATTR_INTERFACES(TypeName) \
void AddAttribute(const std::string& attr_name, \
@ -218,8 +214,8 @@ class Node {
ADD_ATTR_INTERFACES(int64_t)
ADD_ATTR_INTERFACES(float)
ADD_ATTR_INTERFACES(std::string)
ADD_ATTR_INTERFACES(TensorProto)
ADD_ATTR_INTERFACES(GraphProto)
ADD_ATTR_INTERFACES(onnx::TensorProto)
ADD_ATTR_INTERFACES(onnx::GraphProto)
// Clear specified node attribute.
bool ClearAttribute(const std::string& attr_name);
@ -235,7 +231,7 @@ class Node {
void SetExecutionProviderType(ProviderType execution_provider_type);
// Get the corresponding <NodeProto>.
void ToProto(NodeProto& proto) const;
void ToProto(onnx::NodeProto& proto) const;
// iterate through all input/output defs
void ForEachDef(std::function<void(const LotusIR::NodeArg*, bool is_input)> func) const;
@ -355,7 +351,7 @@ class Node {
std::string domain_;
// OperatorSchema that <*this> node refers to.
const OpSchema* op_ = nullptr;
const onnx::OpSchema* op_ = nullptr;
// Node doc string.
std::string description_;
@ -428,7 +424,7 @@ class GraphBase {
int NumberOfNodes() const noexcept { return num_of_nodes_; }
// Get NodeArg by name, or create NodeArg owned by the graph if not found
NodeArg& GetOrCreateNodeArg(const std::string& name, const TypeProto* p_arg_type) {
NodeArg& GetOrCreateNodeArg(const std::string& name, const onnx::TypeProto* p_arg_type) {
auto iter = node_args_.find(name);
if (iter != node_args_.end())
return *(iter->second);
@ -492,7 +488,7 @@ class GraphBase {
// TODO(Task:135) See if GraphBase::GetNodesInTopologicalOrder can be made more correctly const
// by forcing Resolve to have been called directly previously. Simple change is to return error if
// GraphResolveNeeded is true.
Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ const std::vector<NodeIndex>** pp_nodes) const;
Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ gsl::not_null<const std::vector<NodeIndex>**> pp_nodes) const;
// Mark Graph as needing Resolve() to be called
GraphBase& SetGraphResolveNeeded() noexcept {
@ -545,7 +541,7 @@ class GraphBase {
void AddSourceSinkNodes();
// Add node with specified <node_proto>.
Node* AddNode(const NodeProto& node_proto,
Node* AddNode(const onnx::NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type);
NodeIndex SourceNodeIndex() const noexcept { return source_node_index_; }
@ -613,13 +609,13 @@ class GraphBase {
// Returns the inferred shape+type for every output of the node in
// output parameter inferredShapes.
Lotus::Common::Status InferOutputTypesAndShapes(LotusIR::Node& node,
/*out*/ std::vector<TypeProto>& inferred_shapes);
/*out*/ std::vector<onnx::TypeProto>& inferred_shapes);
private:
// need custom versions to handle the unique_ptr's in nodes_
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(GraphBase);
Node* AllocateNode();
gsl::not_null<Node*> AllocateNode();
/**
Release the node.
@ -709,9 +705,9 @@ class Graph : public GraphBase {
void SetDescription(const std::string& description) override;
// Add/Remove/Get initial tensors for some graph inputs.
void AddInitializedTensor(const TensorProto& tensor_proto);
void AddInitializedTensor(const onnx::TensorProto& tensor_proto);
void RemoveInitializedTensor(const std::string& tensor_name);
bool GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const;
bool GetInitializedTensor(const std::string& tensor_name, gsl::not_null<const onnx::TensorProto**> value) const;
const InitializedTensorSet& GetAllInitializedTensors() const noexcept;
void CleanAllInitializedTensors() noexcept;
@ -719,7 +715,7 @@ class Graph : public GraphBase {
const std::vector<const NodeArg*>& GetValueInfo() const noexcept;
// Serialize the <Graph> into <GraphProto>.
const GraphProto& ToGraphProto();
const onnx::GraphProto& ToGraphProto();
private:
LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(Graph);
@ -730,7 +726,7 @@ class Graph : public GraphBase {
// Constructor: Given a <GraphProto> loaded from model file, construct
// a <Graph> object.
Graph(GraphProto* graph_proto,
Graph(onnx::GraphProto* graph_proto,
const std::unordered_map<std::string, int>& domain_to_version,
Version ir_version,
const ILotusOpSchemaCollection* local_registry = nullptr);
@ -754,7 +750,7 @@ class Graph : public GraphBase {
Lotus::Common::Status Resolve(bool no_proto_sync_required);
Lotus::Common::Status InferAndVerifyTypeMatch(Node& node,
const OpSchema& op);
const onnx::OpSchema& op);
// Apply type-inference and type-checking to all inputs and initializers:
Lotus::Common::Status TypeCheckInputsAndInitializers();
@ -783,7 +779,7 @@ class Graph : public GraphBase {
// functions in <Graph> will also be fed into <graph_proto_> so that
// it's consistent with <*this> graph.
// This pointer is owned by parent model.
GraphProto* graph_proto_;
onnx::GraphProto* graph_proto_;
// The node which refers to <*this> graph (Function).
// Node* node_;

Просмотреть файл

@ -1,4 +1,3 @@
#include "core/common/CommonSTD.h"
#include "core/graph/graph_transformer.h"
using namespace Lotus;
using namespace Lotus::Common;

Просмотреть файл

@ -53,7 +53,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
// should be stored globally. Otherwise, there will be multiple addresses/pointers
// for the same operator or function. To avoid this, we may use OpSignature ID
// as the key, which should be name_domain_version.
Lotus::Common::Status Register(const OpSchema* op, std::unique_ptr<RewriteRule> rule) {
Lotus::Common::Status Register(const onnx::OpSchema* op, std::unique_ptr<RewriteRule> rule) {
op_to_rules_[op].push_back(std::move(rule));
return Lotus::Common::Status::OK();
}
@ -66,7 +66,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
}
private:
typedef std::unordered_map<const OpSchema*, std::vector<std::unique_ptr<RewriteRule>>>
typedef std::unordered_map<const onnx::OpSchema*, std::vector<std::unique_ptr<RewriteRule>>>
RewriteRuleSet;
RewriteRuleSet op_to_rules_;

Просмотреть файл

@ -7,16 +7,15 @@
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#include "core/common/CommonSTD.h"
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <memory>
#include "core/graph/model.h"
#include "core/graph/utils.h"
#include "core/graph/schema_registry.h"
// #include "gsl/pointers"
#include "gsl/pointers"
#include "gsl/gsl_util"
using namespace onnx;
using namespace Lotus;
using namespace Lotus::Common;
@ -27,7 +26,7 @@ Model::Model(const std::string& graph_name, bool is_onnx_domain_only, const Mode
model_proto_->mutable_graph()->set_name(graph_name);
model_metadata_ = model_metadata;
for (auto& metadata : model_metadata_) {
auto prop = model_proto_->add_metadata_props();
const gsl::not_null<StringStringEntryProto*> prop = model_proto_->add_metadata_props();
prop->set_key(metadata.first);
prop->set_value(metadata.second);
}
@ -138,7 +137,7 @@ ModelProto Model::ToProto() {
}
void Model::AddImportOpSets(bool is_onnx_domain_only,
/*out*/ std::unordered_map<std::string, int>* domain_to_version,
/*out*/ gsl::not_null<std::unordered_map<std::string, int>*> domain_to_version,
const ILotusOpSchemaCollection* local_registry) {
auto& domain_to_version_range_map = OpSchemaRegistry::DomainToVersionRange::Instance().Map();
Domain_To_Version_Map local_domain_to_version_map = local_registry ? local_registry->DomainToVersionMap() : Domain_To_Version_Map();
@ -158,7 +157,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only,
}
auto ignored = domain_to_version->insert({domainToVersionRange.first, max});
auto opset_id_proto = model_proto_->add_opset_import();
const gsl::not_null<OperatorSetIdProto*> opset_id_proto = model_proto_->add_opset_import();
opset_id_proto->set_domain(domainToVersionRange.first);
opset_id_proto->set_version(domainToVersionRange.second.second);
}
@ -174,7 +173,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only,
if (domain_to_version_range_map.end() != domain_to_version_range_map.find(local_domain.first)) {
auto ignored = domain_to_version->insert({local_domain.first, local_domain.second.second});
auto opset_id_proto = model_proto_->add_opset_import();
const gsl::not_null<OperatorSetIdProto*> opset_id_proto = model_proto_->add_opset_import();
opset_id_proto->set_domain(local_domain.first);
opset_id_proto->set_version(local_domain.second.second);
}

Просмотреть файл

@ -2,7 +2,7 @@
#include "core/graph/graph.h"
// #include "gsl/pointers"
#include "gsl/pointers"
namespace LotusIR {
typedef std::unordered_map<std::string, std::string> ModelMetaData;
@ -22,11 +22,11 @@ class Model {
// NOTE: after calling this constructor, <*this> model will
// hold a copy of <model_proto>.
explicit Model(const ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
explicit Model(const onnx::ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
// NOTE: after calling this constructor, <*this> model will
// own the <model_proto>.
explicit Model(std::unique_ptr<ModelProto> model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
explicit Model(std::unique_ptr<onnx::ModelProto> model_proto, const ILotusOpSchemaCollection* local_registry = nullptr);
// Get model's IR version.
// Return <kNoVersion> if not specified.
@ -71,7 +71,7 @@ class Model {
const Graph* MainGraph() const noexcept;
// Get model's serialization proto data.
ModelProto ToProto();
onnx::ModelProto ToProto();
#ifdef _WIN32
static Lotus::Common::Status Save(Model& model, const std::wstring& file_path);
@ -84,7 +84,7 @@ class Model {
static Lotus::Common::Status Save(Model& model, int fd);
static Lotus::Common::Status Load(std::istream& model_istream, ModelProto* p_model_proto);
static Lotus::Common::Status Load(std::istream& model_istream, onnx::ModelProto* p_model_proto);
static Lotus::Common::Status Load(const std::string& file_path,
/*out*/ std::shared_ptr<Model>& p_model,
@ -97,7 +97,7 @@ class Model {
static Lotus::Common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr<Model>& p_model,
const ILotusOpSchemaCollection* local_registry = nullptr);
static Lotus::Common::Status Load(const ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model,
static Lotus::Common::Status Load(const onnx::ModelProto& model_proto, /*out*/ std::shared_ptr<Model>& p_model,
const ILotusOpSchemaCollection* local_registry = nullptr);
private:
@ -106,11 +106,11 @@ class Model {
// if <is_onnx_domain_only> is true, then only onnx domain will be contained.
// otherwise, ml domain will also be contained.
void AddImportOpSets(bool is_onnx_domain_only,
/*out*/ std::unordered_map<std::string, int>* domain_to_version,
/*out*/ gsl::not_null<std::unordered_map<std::string, int>*> domain_to_version,
const ILotusOpSchemaCollection* local_registry);
// Model data.
std::unique_ptr<ModelProto> model_proto_;
std::unique_ptr<onnx::ModelProto> model_proto_;
// This is a duplication of <model_proto_.metadata_props()>.
// It gives better accessibility.

Просмотреть файл

@ -1,9 +1,9 @@
#include <cstring>
#include "core/common/CommonSTD.h"
#include "core/graph/constants.h"
#include "core/graph/op.h"
#include "core/graph/utils.h"
using namespace onnx;
namespace LotusIR {
bool TypeUtils::IsValidAttribute(const AttributeProto& attr) {

Просмотреть файл

@ -6,12 +6,11 @@
#include "core/common/status.h"
#include "core/graph/constants.h"
using namespace onnx;
using namespace Lotus::Common;
namespace LotusIR {
using AttrType = AttributeProto_AttributeType;
using NodeAttributes = std::unordered_map<std::string, AttributeProto>;
using AttrType = onnx::AttributeProto_AttributeType;
using NodeAttributes = std::unordered_map<std::string, onnx::AttributeProto>;
// This string array should exactly match the AttrType defined above.
/*
@ -44,8 +43,8 @@ static constexpr const char* kAttrTypeStrings[] =
class TypeUtils {
public:
// Get attribute type given attribute proto data.
static Status GetType(const AttributeProto& attr, AttrType& type);
static bool IsValidAttribute(const AttributeProto& attribute);
static Status GetType(const onnx::AttributeProto& attr, AttrType& type);
static bool IsValidAttribute(const onnx::AttributeProto& attribute);
};
class MsOpRegistry {

Просмотреть файл

@ -1,8 +1,7 @@
#include "core/common/CommonSTD.h"
#include "core/graph/tensorutils.h"
#include <algorithm>
// #include "gsl/span"
#include "gsl/span"
namespace Lotus {
namespace Utils {
@ -23,9 +22,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto");
for (auto& elem : tensor.string_data()) {
*p_data++ = elem;
}
const auto data = gsl::make_span(p_data, expected_size);
auto& string_data = tensor.string_data();
std::copy(string_data.cbegin(), string_data.cend(), data.begin());
return Status::OK();
}
@ -57,9 +57,8 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto");
for (auto& elem : tensor.int32_data()) {
*p_data++ = elem;
}
const auto data = gsl::make_span(p_data, expected_size);
std::copy(tensor.int32_data().cbegin(), tensor.int32_data().cend(), data.begin());
return Status::OK();
}
@ -90,9 +89,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor,
if (tensor.int32_data_size() != expected_size)
return Status(StatusCategory::LOTUS, StatusCode::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto");
for (auto& elem : tensor.int32_data()) {
*p_data++ = (uint16_t)elem;
}
const auto data = gsl::make_span(p_data, expected_size);
for (int i = 0; i < expected_size; i++)
data[i] = gsl::narrow_cast<uint16_t>(tensor.int32_data()[i]);
return Status::OK();
}

Просмотреть файл

@ -3,8 +3,8 @@
#include <type_traits>
#include <vector>
// #include "gsl/pointers"
// #include "gsl/span"
#include "gsl/pointers"
#include "gsl/span"
#include "core/common/common.h"
#include "core/common/status.h"
@ -36,9 +36,9 @@ class TensorUtils {
if (tensor.field_size() != expected_size) \
return Status(StatusCategory::LOTUS, StatusCode::FAIL, \
"UnpackTensor: the pre-allocated size does not match the size in proto"); \
for (auto elem : tensor.field_name()) { \
*p_data++ = static_cast<T>(elem); \
} \
const auto span = gsl::make_span(p_data, expected_size); \
auto& data = tensor.field_name(); \
std::copy(data.cbegin(), data.cend(), span.begin()); \
return Status::OK(); \
}

Просмотреть файл

@ -16,12 +16,12 @@
#include "core/common/status.h"
#include "onnx/onnx_pb.h"
// #include "gsl/pointers"
#include "gsl/pointers"
namespace Lotus{
using namespace ::Lotus::Common;
#ifdef _WIN32
inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) {
inline Status FileOpenRd(const std::wstring& path, /*out*/ gsl::not_null<int*> p_fd) {
_wsopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
if (0 > *p_fd) {
return Status(SYSTEM, errno);
@ -29,7 +29,7 @@ inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) {
return Status::OK();
}
inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) {
inline Status FileOpenWr(const std::wstring& path, /*out*/ gsl::not_null<int*> p_fd) {
_wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
if (0 > *p_fd) {
return Status(SYSTEM, errno);
@ -38,7 +38,7 @@ inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) {
}
#endif
inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) {
inline Status FileOpenRd(const std::string& path, /*out*/ gsl::not_null<int*> p_fd) {
#ifdef _WIN32
_sopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
#else
@ -50,7 +50,7 @@ inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) {
return Status::OK();
}
inline Status FileOpenWr(const std::string& path, /*out*/ int* p_fd) {
inline Status FileOpenWr(const std::string& path, /*out*/ gsl::not_null<int*> p_fd) {
#ifdef _WIN32
_sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
#else