diff --git a/Makefile b/Makefile index e0ee3e27d..ac4f74739 100644 --- a/Makefile +++ b/Makefile @@ -102,9 +102,9 @@ INCLUDEPATH+=$(GSL_PATH)/include INCLUDEPATH+=$(ONNX_PATH) INCLUDEPATH+=$(ONNX_REPO_PATH) # COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers. -COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++11 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__ +COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++14 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__ CPPFLAGS:= -CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -std=c++0x -fopenmp -fpermissive -fPIC -Werror -fcheck-new +CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -fopenmp -fpermissive -fPIC -Werror -fcheck-new LIBPATH:= LIBS_LIST:= LDFLAGS:= diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj index f4a702de8..43b5a8861 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj @@ -178,7 +178,6 @@ - diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters index 556c72a0a..39e9107fa 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters @@ -93,18 +93,6 @@ proto\onnx\onnx_repo\onnx\defs\traditionalml - - proto\onnx\onnx\defs\traditionalml - - - proto\onnx\onnx\defs - - - proto\onnx\onnx\defs - - - proto\onnx\core\graph - proto\onnx\core\graph @@ -117,26 +105,37 @@ proto\onnx\core\graph - - proto\onnx\onnx\defs\math + + + + proto\onnx\core\graph - - proto\onnx\onnx\defs\nn + + proto\onnx\onnx_repo\onnx\defs\logical - - proto\onnx\onnx\defs\rnn + + proto\onnx\onnx_repo\onnx\defs\tensor - - proto\onnx\onnx\defs\tensor + + proto\onnx\onnx_repo\onnx\defs\math - - proto\onnx\onnx + + proto\onnx\onnx_repo\onnx\defs\rnn - - proto\onnx\onnx\defs\logical + + proto\onnx\onnx_repo\onnx\defs\nn - - proto\onnx\onnx\common + + proto\onnx\onnx_repo\onnx\common + + + proto\onnx\onnx_repo\onnx + + + proto\onnx\onnx_repo\onnx\defs + + + proto\onnx\onnx_repo\onnx\defs @@ -295,9 +294,6 @@ proto\onnx\core\graph - - proto\onnx\core\common - proto\onnx\onnx_repo\onnx\common diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 6a2349503..4b13f206d 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -23,6 +23,7 @@ using namespace Microsoft::MSR::CNTK; using namespace CNTK::ONNX; using namespace CNTK; using namespace LotusIR; +using namespace onnx; const int FreeSequenceLen = 0; const std::string FreeSequenceDimParam = "None"; diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp index 815230e49..ac5865b25 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp @@ -16,6 +16,7 @@ using namespace LotusIR; using namespace CNTK; using namespace CNTK::ONNX; +using namespace onnx; using namespace Microsoft::MSR::CNTK; namespace CNTK diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/CommonSTD.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/CommonSTD.h deleted file mode 100644 index 24c692cda..000000000 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/CommonSTD.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once -#include -#include - -// to get make_unique definition -#include "Platform.h" - -// to add what is missing in gsl -#if defined(__GNUC__) -namespace std { - template - using enable_if_t = typename enable_if<_Test, _Ty>::type; - - template - using remove_cv_t = typename remove_cv<_Ty>::type; - - template - using conditional_t = typename conditional<_Test, _Ty1, _Ty2>::type; - - template - using add_pointer_t = typename add_pointer<_Ty>::type; - - template - using remove_const_t = typename remove_const<_Ty>::type; -} -#else -using std::make_unique; -#endif - diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/common.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/common.h index 763a8c14b..665ebf20f 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/common.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/common.h @@ -33,7 +33,6 @@ #include #include -#include "core/common/CommonSTD.h" #include "core/common/code_location.h" #include "core/common/exceptions.h" #include "core/common/status.h" @@ -62,6 +61,8 @@ using std::vector; #define UNUSED_PARAMETER(x) #endif +// std::vector GetStackTrace(); + // __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER // so we only define it as one for MSVC #if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) @@ -73,7 +74,7 @@ using std::vector; Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__) #define WHERE_WITH_STACK \ - Lotus::CodeLocation(__FILE__, __LINE__, __FUNCTION__) // Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__, Lotus::GetStackTrace()) + Lotus::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__) // , Lotus::GetStackTrace()) // Throw an exception with optional message. // NOTE: The arguments get streamed into a string via ostringstream::operator<< diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc index a24405b25..10750cd0f 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc @@ -1,7 +1,6 @@ -#include "core/common/CommonSTD.h" #include "core/common/logging/capture.h" #include "core/common/logging/logging.h" -// #include "gsl/span" +#include "gsl/span" #include "gsl/gsl_util" namespace Lotus { @@ -21,27 +20,22 @@ void Capture::CapturePrintf(msvc_printf_check const char *format, ...) { void Capture::ProcessPrintf(msvc_printf_check const char *format, va_list args) { static constexpr auto kTruncatedWarningText = "[...truncated...]"; static const int kMaxMessageSize = 2048; - char finished_message[kMaxMessageSize]; + char message_buffer[kMaxMessageSize]; + const auto message = gsl::make_span(message_buffer); #if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__)) - const auto finished_message_len = _countof(finished_message); + const int nbrcharacters = vsnprintf_s(message.data(), message.size(), _TRUNCATE, format, args); #else - int finished_message_len = sizeof(finished_message); -#endif - -#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__)) - const int nbrcharacters = vsnprintf_s(finished_message, finished_message_len, _TRUNCATE, format, args); -#else - const int nbrcharacters = vsnprintf(finished_message, finished_message_len, format, args); + const int nbrcharacters = vsnprintf(message.data(), message.size(), format, args); #endif if (nbrcharacters <= 0) { stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message"; stream_ << '"' << format << '"' << std::endl; - } else if (static_cast(nbrcharacters) > finished_message_len) { - stream_ << finished_message << kTruncatedWarningText; + } else if (nbrcharacters > message.size()) { + stream_ << message.data() << kTruncatedWarningText; } else { - stream_ << finished_message; + stream_ << message.data(); } } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc index ef7cbeedc..af57bf908 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc @@ -1,6 +1,6 @@ #include #include -#include "core/common/CommonSTD.h" + #include "core/common/exceptions.h" #include "core/common/logging/isink.h" #include "core/common/logging/logging.h" diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc index 4871cbea4..7f8ac82e3 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc @@ -1,4 +1,3 @@ -#include "core/common/CommonSTD.h" #include "core/common/status.h" #include "core/common/common.h" diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc index 729610227..55a360a49 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc @@ -8,15 +8,14 @@ #include #include -#include "core/common/CommonSTD.h" -// #include "gsl/pointers" +#include "gsl/pointers" #include "core/graph/graph.h" #include "core/graph/op.h" #include "core/graph/utils.h" #include "core/common/logging/logging.h" #include "onnx/checker.h" #include "core/graph/schema_registry.h" - +using namespace onnx; using namespace onnx::Utils; using namespace onnx::checker; @@ -172,7 +171,7 @@ void Node::ToProto(NodeProto& proto) const { // Set attributes. proto.clear_attribute(); for (auto attribute : attributes_) { - auto attr = proto.add_attribute(); + const gsl::not_null attr = proto.add_attribute(); *attr = attribute.second; } @@ -346,12 +345,12 @@ const NodeAttributes& Node::GetAttributes() const noexcept { } void Node::ForEachDef(std::function func) const { - for (const LotusIR::NodeArg* arg : InputDefs()) { + for (const gsl::not_null arg : InputDefs()) { if (!arg->Exists()) continue; func(&*arg, true); } - for (const LotusIR::NodeArg* arg : OutputDefs()) { + for (const gsl::not_null arg : OutputDefs()) { if (!arg->Exists()) continue; func(&*arg, false); @@ -359,7 +358,7 @@ void Node::ForEachDef(std::function func) const { - for (const LotusIR::NodeArg* arg : InputDefs()) { + for (const gsl::not_null arg : InputDefs()) { if (!arg->Exists()) continue; func(&*arg); @@ -367,7 +366,7 @@ void Node::ForEachInputDef(std::function func) co }; void Node::ForEachOutputDef(std::function func) const { - for (const LotusIR::NodeArg* arg : OutputDefs()) { + for (const gsl::not_null arg : OutputDefs()) { if (!arg->Exists()) continue; func(&*arg); @@ -378,7 +377,7 @@ void Node::ReplaceDefs(const std::map*> all_defs = {&definitions_.input_defs, &definitions_.output_defs}; for (auto pair : replacements) - for (auto defs : all_defs) + for (const gsl::not_null*> defs : all_defs) for (auto& def : *defs) if (def == pair.first) def = pair.second; @@ -418,14 +417,14 @@ Graph::Graph(GraphProto* graph_proto, // Copy constant nodes _value to name_to_initial_tensor_ for (auto& node : graph_proto_->node()) { if (node.op_type() == kConstant) { - auto tensor = graph_proto_->add_initializer(); + const gsl::not_null tensor = graph_proto_->add_initializer(); *tensor = node.attribute(0).t(); *(tensor->mutable_name()) = node.output(0); } } // remove constant nodes - auto graph_mutable_nodes = graph_proto_->mutable_node(); + const gsl::not_null*> graph_mutable_nodes = graph_proto_->mutable_node(); graph_mutable_nodes->erase( std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(), [](NodeProto& p) { @@ -493,7 +492,7 @@ Status GraphBase::VerifyNoDuplicateName(/*in*/ const std::unordered_set output_def : node.OutputDefs()) { if (output_def->Exists()) { auto& output_arg_name = output_def->Name(); if (inputs_and_initializers.count(output_arg_name)) { @@ -546,7 +545,7 @@ Status GraphBase::BuildConnections(const std::unordered_map& if (input_args.size() > 0) { // This node needs inputs. - for (const NodeArg* input_arg : input_args) { + for (const gsl::not_null input_arg : input_args) { if (!input_arg->Exists()) { // This input could be optional and it does not exist in this case. continue; @@ -655,7 +654,7 @@ void GraphBase::ReverseDFSFrom(const std::vector& from, sorted_nodes.push_back((*iter)); } std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp); - for (const LotusIR::Node* in : sorted_nodes) { + for (gsl::not_null in : sorted_nodes) { const NodeIndex idx = in->Index(); if (!visited[idx]) { stack.emplace_back(in, false); @@ -1131,7 +1130,7 @@ Status Graph::VerifyNodeAndOpMatch(const std::vector& nodes_in_topolo // currently an Op is required by ValidateVersion, so we use gsl::not_null. // This may change in the future to allow a null Op - const OpSchema* p_op = node.Op(); + const gsl::not_null p_op = node.Op(); // Attribute verification and fill node attribute with // default value defined in operator definition if needed. @@ -1218,7 +1217,7 @@ Status Graph::Resolve(bool no_proto_sync_required) { return Status::OK(); } -Status GraphBase::GetNodesInTopologicalOrder(const std::vector** pp_nodes) const { +Status GraphBase::GetNodesInTopologicalOrder(gsl::not_null**> pp_nodes) const { if (graph_resolve_needed_) { return Status{StatusCategory::LOTUS, StatusCode::FAIL, "Resolve() must be called before using the graph as modifications have been made to it."}; @@ -1263,7 +1262,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { return; } - auto tensorAdded = graph_proto_->add_initializer(); + const gsl::not_null tensorAdded = graph_proto_->add_initializer(); *(tensorAdded) = tensor; name_to_initial_tensorIndex_[tensor.name()] = graph_proto_->initializer_size() - 1; name_to_initial_tensor_[tensor.name()] = tensorAdded; @@ -1283,7 +1282,7 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { } } -bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const { +bool Graph::GetInitializedTensor(const std::string& tensor_name, gsl::not_null value) const { auto iter = name_to_initial_tensor_.find(tensor_name); if (name_to_initial_tensor_.end() == iter) { return false; @@ -1318,7 +1317,7 @@ const std::vector& Graph::GetValueInfo() const noexcept { // Ensure the NodeArgs in the input are created and in this Graph's node arg map static void AddNodeArgs(const std::vector& input_args, std::unordered_map& node_arg_map) { - for (auto input_arg : input_args) { + for (const gsl::not_null input_arg : input_args) { if (!input_arg->Exists()) continue; auto& key = input_arg->Name(); auto existing_entry = node_arg_map.find(key); @@ -1383,7 +1382,7 @@ Node* GraphBase::AddNode(const Node& other) { Node* GraphBase::AddNode(const NodeProto& node_proto, const ArgNameToTypeMap& name_to_type_map) { - auto node = AllocateNode(); + const gsl::not_null node = AllocateNode(); auto input_defs = CreateNodeArgs(node_proto.input(), name_to_type_map, node_args_, owned_node_args_); auto output_defs = CreateNodeArgs(node_proto.output(), name_to_type_map, node_args_, owned_node_args_); @@ -1460,7 +1459,7 @@ Node* GraphBase::AddNode(const std::string& name, AddNodeArgs(input_args, node_args_); AddNodeArgs(output_args, node_args_); - auto node = AllocateNode(); + const gsl::not_null node = AllocateNode(); node->Init(name, op_type, description, input_args, output_args, attributes, domain); if (0 != op_type.compare(kNoOp)) { graph_proto_sync_needed_ = true; @@ -1505,8 +1504,8 @@ const GraphProto& Graph::ToGraphProto() { continue; } - auto node_proto = graph_proto_->add_node(); - auto p_node = GetNode(node_idx); + const gsl::not_null node_proto = graph_proto_->add_node(); + const gsl::not_null p_node = GetNode(node_idx); p_node->ToProto(*node_proto); } @@ -1551,15 +1550,15 @@ void Graph::SyncGraphInputsOutputs() { graph_proto_->clear_output(); graph_proto_->clear_value_info(); - for (const LotusIR::NodeArg* input_arg : GetInputs()) { + for (const gsl::not_null input_arg : GetInputs()) { *(graph_proto_->mutable_input()->Add()) = input_arg->ToProto(); } - for (const LotusIR::NodeArg* output_arg : GetOutputs()) { + for (const gsl::not_null output_arg : GetOutputs()) { *(graph_proto_->mutable_output()->Add()) = output_arg->ToProto(); } - for (const LotusIR::NodeArg* value_info : value_info_) { + for (const gsl::not_null value_info : value_info_) { *(graph_proto_->mutable_value_info()->Add()) = value_info->ToProto(); } } @@ -1573,11 +1572,11 @@ void Graph::CleanUnusedInitializers() { for (const auto& pv : name_to_initial_tensor_) { const std::string& s = pv.first; const bool used_as_input = std::any_of(input_args.begin(), input_args.end(), - [&s](const NodeArg* input) noexcept { + [&s](const gsl::not_null input) noexcept { return s == input->Name(); }); const bool used_as_output = std::any_of(GetOutputs().begin(), GetOutputs().end(), - [&s](const NodeArg* output) noexcept { + [&s](const gsl::not_null output) noexcept { return s == output->Name(); }); @@ -1639,7 +1638,7 @@ Status Graph::SetGraphInputsOutputs() { std::unordered_map output_name_to_node_arg; for (const auto& node : Nodes()) { - for (const NodeArg* output_def : node.OutputDefs()) { + for (gsl::not_null output_def : node.OutputDefs()) { if (specified_graph_outputs.erase(output_def->Name()) >= 1) { graph_outputs.push_back(output_def); } @@ -1664,7 +1663,7 @@ Status Graph::SetGraphInputsOutputs() { for (const auto& node : Nodes()) { // Go thru all node's inputs. - for (const NodeArg* input_arg : node.InputDefs()) { + for (const gsl::not_null input_arg : node.InputDefs()) { if (!input_arg->Exists()) { // It's an optional input and does not exist in this case. continue; @@ -1694,7 +1693,7 @@ Status Graph::SetGraphInputsOutputs() { } else { std::unordered_map output_name_to_node_arg; for (const auto& node : Nodes()) { - for (const NodeArg* output_def : node.OutputDefs()) { + for (gsl::not_null output_def : node.OutputDefs()) { if (output_def->Exists()) output_name_to_node_arg.insert({output_def->Name(), output_def}); } @@ -1706,7 +1705,7 @@ Status Graph::SetGraphInputsOutputs() { std::unordered_set inner_nodes; for (const auto& node : Nodes()) { // Go thru all node's inputs. - for (const NodeArg* input_arg : node.InputDefs()) { + for (const gsl::not_null input_arg : node.InputDefs()) { if (!input_arg->Exists()) { // It's an optional input and does not exist in this case. continue; @@ -1719,8 +1718,8 @@ Status Graph::SetGraphInputsOutputs() { const std::string& name = input_arg->Name(); if (added_input_names.end() == added_input_names.find(name)) { // This graph input has not been added into . - // if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) - graph_inputs.push_back(input_arg); + //// if (name_to_initial_tensor_.find(name) == name_to_initial_tensor_.end()) + graph_inputs.push_back(input_arg); added_input_names.insert(input_arg->Name()); } } else if (graph_output_args.erase(output_arg_iter->first) >= 1) { @@ -1759,7 +1758,7 @@ const Node* GraphBase::SinkNode() const { // calling private ctor GSL_SUPPRESS(r .11) -Node* GraphBase::AllocateNode() { +gsl::not_null GraphBase::AllocateNode() { std::unique_ptr new_node(new Node(nodes_.size(), *this)); Node* node{new_node.get()}; diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.h index b14103211..c23e62acd 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.h @@ -7,9 +7,7 @@ #include #include -#include "core/common/CommonSTD.h" - -// #include "gsl/pointers" +#include "gsl/pointers" #include "gsl/gsl_util" #include "core/common/common.h" @@ -21,21 +19,19 @@ #include "core/graph/utils.h" #include "onnx/onnx_pb.h" -using namespace onnx; - // TODO - Evaluate switching the types below to support transparent comparators and enable // lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations // converting to std::string, but requires conversion to std::map> // instead of std::unordered_map]>. -using NodeAttributes = std::unordered_map; +using NodeAttributes = std::unordered_map; namespace LotusIR { using NodeIndex = size_t; using Version = int64_t; -using NodeArgInfo = ValueInfoProto; -using InitializedTensorSet = std::unordered_map; -using ArgNameToTypeMap = std::unordered_map; +using NodeArgInfo = onnx::ValueInfoProto; +using InitializedTensorSet = std::unordered_map; +using ArgNameToTypeMap = std::unordered_map; using ProviderType = const std::string&; class Graph; @@ -70,7 +66,7 @@ class NodeArg { // optional. This is called when loading a from // normally. NodeArg(const std::string& name, - const TypeProto* p_arg_type); + const onnx::TypeProto* p_arg_type); NodeArg(NodeArg&& other) = default; @@ -78,17 +74,17 @@ class NodeArg { const std::string& Name() const noexcept; // Get node arg type. - DataType Type() const noexcept; - const TypeProto* TypeAsProto() const noexcept; + onnx::DataType Type() const noexcept; + const onnx::TypeProto* TypeAsProto() const noexcept; // Get node arg shape. // Return null pointer if there's no shape specified. - const TensorShapeProto* Shape() const; + const onnx::TensorShapeProto* Shape() const; // Set node arg shape. // Shape could only be set after setting type since shape information // now is part of TypeProto. - void SetShape(const TensorShapeProto& shape); + void SetShape(const onnx::TensorShapeProto& shape); // Get node arg info proto. const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } @@ -102,13 +98,13 @@ class NodeArg { LOTUS_DISALLOW_COPY_AND_ASSIGN(NodeArg); friend class Graph; - void SetType(DataType p_type); - void SetType(const TypeProto& type_proto); + void SetType(onnx::DataType p_type); + void SetType(const onnx::TypeProto& type_proto); NodeArg& operator=(NodeArg&& other) = delete; // Node arg PType. - DataType type_; + onnx::DataType type_; // Node arg name, type and shape. NodeArgInfo node_arg_info_; @@ -159,7 +155,7 @@ class Node { // Get the OperatorSchema this node refers to. ValidateOpType() must have been called previously. // May be null in the future. - const OpSchema* Op() const noexcept; + const onnx::OpSchema* Op() const noexcept; // Get node description. const std::string& Description() const noexcept; @@ -167,8 +163,8 @@ class Node { // Iterate through Input/OutputDefs() with index, note the loop early terminates with error static Lotus::Common::Status ForEachWithIndex( const ConstPointerContainer>& nodeArgVec, - std::function func) { - for (int index = 0; index < nodeArgVec.size(); ++index) { + std::function func) { + for (size_t index = 0; index < nodeArgVec.size(); ++index) { auto arg = nodeArgVec[index]; if (!arg->Exists()) continue; @@ -207,7 +203,7 @@ class Node { const std::set& ControlInputs() const noexcept { return relationships_.control_inputs; } // Add a node attribute with specified attribute name and value. - void AddAttribute(const std::string& attr_name, const AttributeProto& value); + void AddAttribute(const std::string& attr_name, const onnx::AttributeProto& value); #define ADD_ATTR_INTERFACES(TypeName) \ void AddAttribute(const std::string& attr_name, \ @@ -218,8 +214,8 @@ class Node { ADD_ATTR_INTERFACES(int64_t) ADD_ATTR_INTERFACES(float) ADD_ATTR_INTERFACES(std::string) - ADD_ATTR_INTERFACES(TensorProto) - ADD_ATTR_INTERFACES(GraphProto) + ADD_ATTR_INTERFACES(onnx::TensorProto) + ADD_ATTR_INTERFACES(onnx::GraphProto) // Clear specified node attribute. bool ClearAttribute(const std::string& attr_name); @@ -235,7 +231,7 @@ class Node { void SetExecutionProviderType(ProviderType execution_provider_type); // Get the corresponding . - void ToProto(NodeProto& proto) const; + void ToProto(onnx::NodeProto& proto) const; // iterate through all input/output defs void ForEachDef(std::function func) const; @@ -355,7 +351,7 @@ class Node { std::string domain_; // OperatorSchema that <*this> node refers to. - const OpSchema* op_ = nullptr; + const onnx::OpSchema* op_ = nullptr; // Node doc string. std::string description_; @@ -428,7 +424,7 @@ class GraphBase { int NumberOfNodes() const noexcept { return num_of_nodes_; } // Get NodeArg by name, or create NodeArg owned by the graph if not found - NodeArg& GetOrCreateNodeArg(const std::string& name, const TypeProto* p_arg_type) { + NodeArg& GetOrCreateNodeArg(const std::string& name, const onnx::TypeProto* p_arg_type) { auto iter = node_args_.find(name); if (iter != node_args_.end()) return *(iter->second); @@ -492,7 +488,7 @@ class GraphBase { // TODO(Task:135) See if GraphBase::GetNodesInTopologicalOrder can be made more correctly const // by forcing Resolve to have been called directly previously. Simple change is to return error if // GraphResolveNeeded is true. - Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ const std::vector** pp_nodes) const; + Lotus::Common::Status GetNodesInTopologicalOrder(/*out*/ gsl::not_null**> pp_nodes) const; // Mark Graph as needing Resolve() to be called GraphBase& SetGraphResolveNeeded() noexcept { @@ -545,7 +541,7 @@ class GraphBase { void AddSourceSinkNodes(); // Add node with specified . - Node* AddNode(const NodeProto& node_proto, + Node* AddNode(const onnx::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); NodeIndex SourceNodeIndex() const noexcept { return source_node_index_; } @@ -613,13 +609,13 @@ class GraphBase { // Returns the inferred shape+type for every output of the node in // output parameter inferredShapes. Lotus::Common::Status InferOutputTypesAndShapes(LotusIR::Node& node, - /*out*/ std::vector& inferred_shapes); + /*out*/ std::vector& inferred_shapes); private: // need custom versions to handle the unique_ptr's in nodes_ LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(GraphBase); - Node* AllocateNode(); + gsl::not_null AllocateNode(); /** Release the node. @@ -709,9 +705,9 @@ class Graph : public GraphBase { void SetDescription(const std::string& description) override; // Add/Remove/Get initial tensors for some graph inputs. - void AddInitializedTensor(const TensorProto& tensor_proto); + void AddInitializedTensor(const onnx::TensorProto& tensor_proto); void RemoveInitializedTensor(const std::string& tensor_name); - bool GetInitializedTensor(const std::string& tensor_name, const TensorProto** value) const; + bool GetInitializedTensor(const std::string& tensor_name, gsl::not_null value) const; const InitializedTensorSet& GetAllInitializedTensors() const noexcept; void CleanAllInitializedTensors() noexcept; @@ -719,7 +715,7 @@ class Graph : public GraphBase { const std::vector& GetValueInfo() const noexcept; // Serialize the into . - const GraphProto& ToGraphProto(); + const onnx::GraphProto& ToGraphProto(); private: LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(Graph); @@ -730,7 +726,7 @@ class Graph : public GraphBase { // Constructor: Given a loaded from model file, construct // a object. - Graph(GraphProto* graph_proto, + Graph(onnx::GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, const ILotusOpSchemaCollection* local_registry = nullptr); @@ -754,7 +750,7 @@ class Graph : public GraphBase { Lotus::Common::Status Resolve(bool no_proto_sync_required); Lotus::Common::Status InferAndVerifyTypeMatch(Node& node, - const OpSchema& op); + const onnx::OpSchema& op); // Apply type-inference and type-checking to all inputs and initializers: Lotus::Common::Status TypeCheckInputsAndInitializers(); @@ -783,7 +779,7 @@ class Graph : public GraphBase { // functions in will also be fed into so that // it's consistent with <*this> graph. // This pointer is owned by parent model. - GraphProto* graph_proto_; + onnx::GraphProto* graph_proto_; // The node which refers to <*this> graph (Function). // Node* node_; diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.cc index a843df9e6..ed67fed98 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.cc @@ -1,4 +1,3 @@ -#include "core/common/CommonSTD.h" #include "core/graph/graph_transformer.h" using namespace Lotus; using namespace Lotus::Common; diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.h index 750b1dcf7..d4e870df0 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer.h @@ -53,7 +53,7 @@ class RuleBasedGraphTransformer : public GraphTransformer { // should be stored globally. Otherwise, there will be multiple addresses/pointers // for the same operator or function. To avoid this, we may use OpSignature ID // as the key, which should be name_domain_version. - Lotus::Common::Status Register(const OpSchema* op, std::unique_ptr rule) { + Lotus::Common::Status Register(const onnx::OpSchema* op, std::unique_ptr rule) { op_to_rules_[op].push_back(std::move(rule)); return Lotus::Common::Status::OK(); } @@ -66,7 +66,7 @@ class RuleBasedGraphTransformer : public GraphTransformer { } private: - typedef std::unordered_map>> + typedef std::unordered_map>> RewriteRuleSet; RewriteRuleSet op_to_rules_; diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.cc index e786bfc6c..bd9bb96fd 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.cc @@ -7,16 +7,15 @@ #ifdef _MSC_VER #pragma warning(pop) #endif -#include "core/common/CommonSTD.h" #include #include #include "core/graph/model.h" #include "core/graph/utils.h" #include "core/graph/schema_registry.h" -// #include "gsl/pointers" +#include "gsl/pointers" #include "gsl/gsl_util" - +using namespace onnx; using namespace Lotus; using namespace Lotus::Common; @@ -27,7 +26,7 @@ Model::Model(const std::string& graph_name, bool is_onnx_domain_only, const Mode model_proto_->mutable_graph()->set_name(graph_name); model_metadata_ = model_metadata; for (auto& metadata : model_metadata_) { - auto prop = model_proto_->add_metadata_props(); + const gsl::not_null prop = model_proto_->add_metadata_props(); prop->set_key(metadata.first); prop->set_value(metadata.second); } @@ -138,7 +137,7 @@ ModelProto Model::ToProto() { } void Model::AddImportOpSets(bool is_onnx_domain_only, - /*out*/ std::unordered_map* domain_to_version, + /*out*/ gsl::not_null*> domain_to_version, const ILotusOpSchemaCollection* local_registry) { auto& domain_to_version_range_map = OpSchemaRegistry::DomainToVersionRange::Instance().Map(); Domain_To_Version_Map local_domain_to_version_map = local_registry ? local_registry->DomainToVersionMap() : Domain_To_Version_Map(); @@ -158,7 +157,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only, } auto ignored = domain_to_version->insert({domainToVersionRange.first, max}); - auto opset_id_proto = model_proto_->add_opset_import(); + const gsl::not_null opset_id_proto = model_proto_->add_opset_import(); opset_id_proto->set_domain(domainToVersionRange.first); opset_id_proto->set_version(domainToVersionRange.second.second); } @@ -174,7 +173,7 @@ void Model::AddImportOpSets(bool is_onnx_domain_only, if (domain_to_version_range_map.end() != domain_to_version_range_map.find(local_domain.first)) { auto ignored = domain_to_version->insert({local_domain.first, local_domain.second.second}); - auto opset_id_proto = model_proto_->add_opset_import(); + const gsl::not_null opset_id_proto = model_proto_->add_opset_import(); opset_id_proto->set_domain(local_domain.first); opset_id_proto->set_version(local_domain.second.second); } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.h index 9e14eb127..15ab38633 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/model.h @@ -2,7 +2,7 @@ #include "core/graph/graph.h" -// #include "gsl/pointers" +#include "gsl/pointers" namespace LotusIR { typedef std::unordered_map ModelMetaData; @@ -22,11 +22,11 @@ class Model { // NOTE: after calling this constructor, <*this> model will // hold a copy of . - explicit Model(const ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr); + explicit Model(const onnx::ModelProto& model_proto, const ILotusOpSchemaCollection* local_registry = nullptr); // NOTE: after calling this constructor, <*this> model will // own the . - explicit Model(std::unique_ptr model_proto, const ILotusOpSchemaCollection* local_registry = nullptr); + explicit Model(std::unique_ptr model_proto, const ILotusOpSchemaCollection* local_registry = nullptr); // Get model's IR version. // Return if not specified. @@ -71,7 +71,7 @@ class Model { const Graph* MainGraph() const noexcept; // Get model's serialization proto data. - ModelProto ToProto(); + onnx::ModelProto ToProto(); #ifdef _WIN32 static Lotus::Common::Status Save(Model& model, const std::wstring& file_path); @@ -84,7 +84,7 @@ class Model { static Lotus::Common::Status Save(Model& model, int fd); - static Lotus::Common::Status Load(std::istream& model_istream, ModelProto* p_model_proto); + static Lotus::Common::Status Load(std::istream& model_istream, onnx::ModelProto* p_model_proto); static Lotus::Common::Status Load(const std::string& file_path, /*out*/ std::shared_ptr& p_model, @@ -97,7 +97,7 @@ class Model { static Lotus::Common::Status LoadFromBytes(int count, void* pBytes, /*out*/ std::shared_ptr& p_model, const ILotusOpSchemaCollection* local_registry = nullptr); - static Lotus::Common::Status Load(const ModelProto& model_proto, /*out*/ std::shared_ptr& p_model, + static Lotus::Common::Status Load(const onnx::ModelProto& model_proto, /*out*/ std::shared_ptr& p_model, const ILotusOpSchemaCollection* local_registry = nullptr); private: @@ -106,11 +106,11 @@ class Model { // if is true, then only onnx domain will be contained. // otherwise, ml domain will also be contained. void AddImportOpSets(bool is_onnx_domain_only, - /*out*/ std::unordered_map* domain_to_version, + /*out*/ gsl::not_null*> domain_to_version, const ILotusOpSchemaCollection* local_registry); // Model data. - std::unique_ptr model_proto_; + std::unique_ptr model_proto_; // This is a duplication of . // It gives better accessibility. diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.cc index 6779a513b..940c0b7f3 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.cc @@ -1,9 +1,9 @@ #include -#include "core/common/CommonSTD.h" #include "core/graph/constants.h" #include "core/graph/op.h" #include "core/graph/utils.h" +using namespace onnx; namespace LotusIR { bool TypeUtils::IsValidAttribute(const AttributeProto& attr) { diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.h index d720f17d4..451f2afde 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/op.h @@ -6,12 +6,11 @@ #include "core/common/status.h" #include "core/graph/constants.h" -using namespace onnx; using namespace Lotus::Common; namespace LotusIR { -using AttrType = AttributeProto_AttributeType; -using NodeAttributes = std::unordered_map; +using AttrType = onnx::AttributeProto_AttributeType; +using NodeAttributes = std::unordered_map; // This string array should exactly match the AttrType defined above. /* @@ -44,8 +43,8 @@ static constexpr const char* kAttrTypeStrings[] = class TypeUtils { public: // Get attribute type given attribute proto data. - static Status GetType(const AttributeProto& attr, AttrType& type); - static bool IsValidAttribute(const AttributeProto& attribute); + static Status GetType(const onnx::AttributeProto& attr, AttrType& type); + static bool IsValidAttribute(const onnx::AttributeProto& attribute); }; class MsOpRegistry { diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.cc index e1326ee05..70c2ea8b2 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.cc +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.cc @@ -1,8 +1,7 @@ -#include "core/common/CommonSTD.h" #include "core/graph/tensorutils.h" #include -// #include "gsl/span" +#include "gsl/span" namespace Lotus { namespace Utils { @@ -23,9 +22,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor, return Status(StatusCategory::LOTUS, StatusCode::FAIL, "UnpackTensor: the pre-allocate size does not match the size in proto"); - for (auto& elem : tensor.string_data()) { - *p_data++ = elem; - } + const auto data = gsl::make_span(p_data, expected_size); + + auto& string_data = tensor.string_data(); + std::copy(string_data.cbegin(), string_data.cend(), data.begin()); return Status::OK(); } @@ -57,9 +57,8 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor, return Status(StatusCategory::LOTUS, StatusCode::FAIL, "UnpackTensor: the pre-allocate size does not match the size in proto"); - for (auto& elem : tensor.int32_data()) { - *p_data++ = elem; - } + const auto data = gsl::make_span(p_data, expected_size); + std::copy(tensor.int32_data().cbegin(), tensor.int32_data().cend(), data.begin()); return Status::OK(); } @@ -90,9 +89,10 @@ Status TensorUtils::UnpackTensor(const onnx::TensorProto& tensor, if (tensor.int32_data_size() != expected_size) return Status(StatusCategory::LOTUS, StatusCode::FAIL, "UnpackTensor: the pre-allocate size does not match the size in proto"); - for (auto& elem : tensor.int32_data()) { - *p_data++ = (uint16_t)elem; - } + + const auto data = gsl::make_span(p_data, expected_size); + for (int i = 0; i < expected_size; i++) + data[i] = gsl::narrow_cast(tensor.int32_data()[i]); return Status::OK(); } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.h index b3296384b..39a2e272b 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/tensorutils.h @@ -3,8 +3,8 @@ #include #include -// #include "gsl/pointers" -// #include "gsl/span" +#include "gsl/pointers" +#include "gsl/span" #include "core/common/common.h" #include "core/common/status.h" @@ -36,9 +36,9 @@ class TensorUtils { if (tensor.field_size() != expected_size) \ return Status(StatusCategory::LOTUS, StatusCode::FAIL, \ "UnpackTensor: the pre-allocated size does not match the size in proto"); \ - for (auto elem : tensor.field_name()) { \ - *p_data++ = static_cast(elem); \ - } \ + const auto span = gsl::make_span(p_data, expected_size); \ + auto& data = tensor.field_name(); \ + std::copy(data.cbegin(), data.cend(), span.begin()); \ return Status::OK(); \ } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/utils.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/utils.h index 81c08f4f4..ab86933ea 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/utils.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/utils.h @@ -16,12 +16,12 @@ #include "core/common/status.h" #include "onnx/onnx_pb.h" -// #include "gsl/pointers" +#include "gsl/pointers" namespace Lotus{ using namespace ::Lotus::Common; #ifdef _WIN32 -inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) { +inline Status FileOpenRd(const std::wstring& path, /*out*/ gsl::not_null p_fd) { _wsopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE); if (0 > *p_fd) { return Status(SYSTEM, errno); @@ -29,7 +29,7 @@ inline Status FileOpenRd(const std::wstring& path, /*out*/ int* p_fd) { return Status::OK(); } -inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) { +inline Status FileOpenWr(const std::wstring& path, /*out*/ gsl::not_null p_fd) { _wsopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE); if (0 > *p_fd) { return Status(SYSTEM, errno); @@ -38,7 +38,7 @@ inline Status FileOpenWr(const std::wstring& path, /*out*/ int* p_fd) { } #endif -inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) { +inline Status FileOpenRd(const std::string& path, /*out*/ gsl::not_null p_fd) { #ifdef _WIN32 _sopen_s(p_fd, path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE); #else @@ -50,7 +50,7 @@ inline Status FileOpenRd(const std::string& path, /*out*/ int* p_fd) { return Status::OK(); } -inline Status FileOpenWr(const std::string& path, /*out*/ int* p_fd) { +inline Status FileOpenWr(const std::string& path, /*out*/ gsl::not_null p_fd) { #ifdef _WIN32 _sopen_s(p_fd, path.c_str(), _O_CREAT | O_TRUNC | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE); #else diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/inc/op_kernel_author.h b/Source/CNTKv2LibraryDll/proto/onnx/core/inc/op_kernel_author.h index 91e8829ea..f30ec1bcc 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/core/inc/op_kernel_author.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/core/inc/op_kernel_author.h @@ -71,20 +71,20 @@ enum class MLTensorDataType : uint32_t { }; union MLFloat16 { - uint16_t val; + uint16_t val; - MLFloat16(uint16_t x) : val(x) {} - MLFloat16() : val(0) {} + MLFloat16(uint16_t x) : val(x) {} + MLFloat16() : val(0) {} }; inline bool operator==(const MLFloat16& left, const MLFloat16& right) { - return left.val == right.val; + return left.val == right.val; } inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { - return left.val != right.val; + return left.val != right.val; } struct MLMapType {