update with ONNXIR
This commit is contained in:
Родитель
18723c7d40
Коммит
3a1b8e1347
|
@ -20,16 +20,16 @@ message Axis {
|
|||
message NDArrayView {
|
||||
|
||||
enum DataType {
|
||||
Unknown = 0;
|
||||
Float = 1;
|
||||
Double = 2;
|
||||
Unknown = 0;
|
||||
Float = 1;
|
||||
Double = 2;
|
||||
Float16 = 4;
|
||||
}
|
||||
|
||||
enum StorageFormat {
|
||||
Dense = 0;
|
||||
SparseCSC = 1;
|
||||
SparseBlockCol = 2;
|
||||
Dense = 0;
|
||||
SparseCSC = 1;
|
||||
SparseBlockCol = 2;
|
||||
}
|
||||
|
||||
DataType data_type = 1;
|
||||
|
@ -37,16 +37,16 @@ message NDArrayView {
|
|||
NDShape shape = 3;
|
||||
|
||||
message FloatValues {
|
||||
repeated float value = 1 [packed = true];
|
||||
repeated float value = 1 [packed = true];
|
||||
}
|
||||
|
||||
message DoubleValues {
|
||||
repeated double value = 1 [packed = true];
|
||||
repeated double value = 1 [packed = true];
|
||||
}
|
||||
|
||||
oneof values {
|
||||
FloatValues float_values = 4;
|
||||
DoubleValues double_values = 5;
|
||||
FloatValues float_values = 4;
|
||||
DoubleValues double_values = 5;
|
||||
}
|
||||
|
||||
// TODO: bool read_only = 6;
|
||||
|
@ -66,32 +66,32 @@ message DictionaryValue {
|
|||
uint64 version = 1;
|
||||
|
||||
enum Type {
|
||||
None = 0;
|
||||
Bool = 1;
|
||||
Int = 2;
|
||||
SizeT = 3;
|
||||
Float = 4;
|
||||
Double = 5;
|
||||
String = 6;
|
||||
NDShape = 7;
|
||||
Axis = 8;
|
||||
Vector = 9;
|
||||
Dictionary = 10;
|
||||
NDArrayView = 11;
|
||||
None = 0;
|
||||
Bool = 1;
|
||||
Int = 2;
|
||||
SizeT = 3;
|
||||
Float = 4;
|
||||
Double = 5;
|
||||
String = 6;
|
||||
NDShape = 7;
|
||||
Axis = 8;
|
||||
Vector = 9;
|
||||
Dictionary = 10;
|
||||
NDArrayView = 11;
|
||||
}
|
||||
|
||||
Type value_type = 2;
|
||||
oneof value {
|
||||
bool bool_value = 3;
|
||||
int32 int_value = 4;
|
||||
uint64 size_t_value = 5;
|
||||
float float_value = 6;
|
||||
double double_value = 7;
|
||||
string string_value = 8;
|
||||
NDShape nd_shape_value = 9;
|
||||
Axis axis_value = 10;
|
||||
Vector vector_value = 11;
|
||||
Dictionary dictionary_value = 12;
|
||||
NDArrayView nd_array_view_value = 13;
|
||||
bool bool_value = 3;
|
||||
int32 int_value = 4;
|
||||
uint64 size_t_value = 5;
|
||||
float float_value = 6;
|
||||
double double_value = 7;
|
||||
string string_value = 8;
|
||||
NDShape nd_shape_value = 9;
|
||||
Axis axis_value = 10;
|
||||
Vector vector_value = 11;
|
||||
Dictionary dictionary_value = 12;
|
||||
NDArrayView nd_array_view_value = 13;
|
||||
}
|
||||
}
|
|
@ -2140,7 +2140,7 @@ ONNXIR::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const FunctionPt
|
|||
inputArg.SetShape(inputShape);
|
||||
|
||||
// this is the output NodeArg of the reshaped node. It has to be named
|
||||
// with the original node's output NodeArg so that LotusIR can make a the connection.
|
||||
// with the original node's output NodeArg so that ONNXIR can make a the connection.
|
||||
onnx::TypeProto typeProto = ToTypeProto(shape, false);
|
||||
ONNXIR::NodeArg outputArg(outputNodeArgName, &typeProto);
|
||||
|
||||
|
@ -2597,7 +2597,7 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node
|
|||
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
int64_t axisIndex = ConvertAxisToOnnx(axis, src->Inputs()[0]);
|
||||
bool workaroundONNXRT = false;
|
||||
// this code is to workarund a LotusRT bug that fails
|
||||
// this code is to workaround a ONNXRT bug that fails
|
||||
// to take axes attribute into consideration.
|
||||
// we need to convert op attribute to a default ONNX case
|
||||
// where axes is not set (or set to ordered indices).
|
||||
|
|
|
@ -609,7 +609,7 @@ Constant CreateConstantWithRawData(DType *data,const NDShape &shape, const std:
|
|||
}
|
||||
|
||||
std::vector<Variable> CreateRNNConstant(
|
||||
const Node *parentNode, int index, const std::string &name, onnx::TensorProto &valueProto, const DeviceDescriptor& computeDevice)
|
||||
const Node *parentNode, int index, const std::string &name, const onnx::TensorProto &valueProto, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
std::vector<Variable> inputs;
|
||||
auto dataType = valueProto.data_type();
|
||||
|
@ -962,18 +962,18 @@ std::vector<Variable> CreateRNNConstant(
|
|||
std::vector<FunctionPtr> CreateRNNConstantOp(const Graph* graph, const Node *node, const Node *parentNode, int index,
|
||||
const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
onnx::TensorProto valueProto;
|
||||
if (!graph->GetInitialTensor(node->Name(), valueProto))
|
||||
const onnx::TensorProto *valueProto;
|
||||
if (!graph->GetInitialTensor(node->Name(), &valueProto))
|
||||
{
|
||||
NodeAttributes::const_iterator itValue = node->GetAttributes().find("value");
|
||||
if (itValue == node->GetAttributes().cend())
|
||||
{
|
||||
return std::vector<FunctionPtr>();
|
||||
}
|
||||
valueProto = itValue->second.t();
|
||||
valueProto = &itValue->second.t();
|
||||
}
|
||||
|
||||
std::vector<Variable> constantNodes = CreateRNNConstant(parentNode, index, node->Name(), valueProto, computeDevice);
|
||||
std::vector<Variable> constantNodes = CreateRNNConstant(parentNode, index, node->Name(), *valueProto, computeDevice);
|
||||
std::vector<FunctionPtr> returns;
|
||||
for (auto c : constantNodes)
|
||||
returns.push_back(c);
|
||||
|
@ -986,11 +986,11 @@ std::vector<Variable> ONNXToCNTKHelper::CreateRNNLeafVariableOrConstant(const No
|
|||
{
|
||||
string parentONNXOpName = parentNode->OpType();
|
||||
std::string nodeName = nodeArg->Name();
|
||||
onnx::TensorProto valueProto;
|
||||
if (graph->GetInitialTensor(nodeName, valueProto))
|
||||
const onnx::TensorProto *valueProto;
|
||||
if (graph->GetInitialTensor(nodeName, &valueProto))
|
||||
{
|
||||
int index = CalculateNodeArgInputIndex(nodeArg, parentNode);
|
||||
return CreateRNNConstant(parentNode, index, nodeName, valueProto, computeDevice);
|
||||
return CreateRNNConstant(parentNode, index, nodeName, *valueProto, computeDevice);
|
||||
}
|
||||
|
||||
const TensorShapeProto *shapeProto = nodeArg->Shape();
|
||||
|
@ -1080,10 +1080,10 @@ Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
|
|||
string parentONNXOpName = parentNode->OpType();
|
||||
|
||||
std::string nodeName = nodeArg->Name();
|
||||
onnx::TensorProto valueProto;
|
||||
if (graph->GetInitialTensor(nodeName, valueProto))
|
||||
const onnx::TensorProto *valueProto;
|
||||
if (graph->GetInitialTensor(nodeName, &valueProto))
|
||||
{
|
||||
return CreateConstant(valueProto, nodeName, computeDevice);
|
||||
return CreateConstant(*valueProto, nodeName, computeDevice);
|
||||
}
|
||||
|
||||
auto shapeProto = nodeArg->Shape();
|
||||
|
@ -2475,8 +2475,8 @@ std::vector<FunctionPtr> ONNXToCNTKHelper::FromONNXNode(const Node *node, ONNXTo
|
|||
else
|
||||
{
|
||||
FunctionPtr cntkFunction = CreateCNTKNode(node, inputs, computeDevice);
|
||||
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, std::vector<FunctionPtr>({ cntkFunction })));
|
||||
return std::vector<FunctionPtr>({ cntkFunction });
|
||||
constructedNodeMap.insert(ONNXToCNTKMap::value_type(node, std::vector<FunctionPtr>({ cntkFunction })));
|
||||
return std::vector<FunctionPtr>({ cntkFunction });
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2756,7 +2756,15 @@ FunctionPtr ONNXToCNTK::CreateGraph(ONNXIR::Graph* src, const DeviceDescriptor&
|
|||
std::vector<FunctionPtr> functions;
|
||||
for (Node::NodeConstIterator it = itNodeFn->first->InputNodes_begin(); it != itNodeFn->first->InputNodes_end(); ++it)
|
||||
{
|
||||
functions.insert(functions.end(), constructedFunctions[*it].begin(), constructedFunctions[*it].end());
|
||||
// TODO: consulting ONNXIR to see how to do this solidly.
|
||||
// https://msasg.visualstudio.com/DefaultCollection/Shared%20Data/AIToolkits-CNTK/_queries?id=1134732&_a=edit&triage=true
|
||||
std::vector<FunctionPtr> &constructedFuncts = constructedFunctions[*it];
|
||||
for (int index = 0; index < constructedFuncts.size(); index++)
|
||||
{
|
||||
FunctionPtr &constructedFunct = constructedFuncts[index];
|
||||
if (constructedFunct->RootFunction()->OpName() != L"Combine")
|
||||
functions.insert(functions.end(), constructedFunct);
|
||||
}
|
||||
}
|
||||
|
||||
if (functions.empty())
|
||||
|
|
|
@ -42,4 +42,4 @@ namespace ONNXIR
|
|||
// Note: due to non-deterministic static initialization order, some of the type strings
|
||||
// may have already been added via Op Registrations which use those type strings.
|
||||
static TypeStringsInitializer& _typeStrings = TypeStringsInitializer::InitializeTypeStrings();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,9 @@ namespace ONNXIR
|
|||
static const std::string c_noOp = "NoOp";
|
||||
static const std::string c_constantOp = "Constant";
|
||||
static const std::string c_constantValue = "value";
|
||||
static const std::string c_onnxDomain = "";
|
||||
static const std::string c_mlDomain = "ai.onnx.ml";
|
||||
static const std::string c_msDomain = "com.microsoft";
|
||||
|
||||
// Singleton wrapper around allowed data types.
|
||||
// This implements construct on first use which is needed to ensure
|
||||
|
@ -32,6 +35,7 @@ namespace ONNXIR
|
|||
const std::string c_complex128 = "complex128";
|
||||
const std::string c_string = "string";
|
||||
const std::string c_bool = "bool";
|
||||
const std::string c_undefined = "undefined";
|
||||
|
||||
std::unordered_set<std::string>& GetAllowedDataTypes();
|
||||
~TypesWrapper() = default;
|
||||
|
|
|
@ -30,6 +30,10 @@ namespace ONNXIR
|
|||
(*m_nodeArgInfo.mutable_type()) = *p_nodeArgType;
|
||||
m_type = OpUtils::ToType(m_nodeArgInfo.type());
|
||||
}
|
||||
else
|
||||
{
|
||||
m_type = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& NodeArg::Name() const
|
||||
|
@ -429,7 +433,7 @@ namespace ONNXIR
|
|||
}
|
||||
}
|
||||
|
||||
#define ADD_BASIC_ATTR_IMPL(type, field) \
|
||||
#define ADD_BASIC_ATTR_IMPL(type, enumType, field) \
|
||||
bool Node::AddAttribute(const std::string& p_attrName, const type& p_value) \
|
||||
{ \
|
||||
auto it = m_attributes.find(p_attrName); \
|
||||
|
@ -439,6 +443,7 @@ namespace ONNXIR
|
|||
m_graph->m_graphProtoSyncNeeded = true; \
|
||||
AttributeProto a; \
|
||||
a.set_name(p_attrName); \
|
||||
a.set_type(enumType); \
|
||||
a.set_##field(p_value); \
|
||||
m_attributes.emplace(p_attrName, a); \
|
||||
return true; \
|
||||
|
@ -449,7 +454,7 @@ namespace ONNXIR
|
|||
} \
|
||||
}; \
|
||||
|
||||
#define ADD_ATTR_IMPL(type, field) \
|
||||
#define ADD_ATTR_IMPL(type, enumType, field) \
|
||||
bool Node::AddAttribute(const std::string& p_attrName, const type& p_value) \
|
||||
{ \
|
||||
auto it = m_attributes.find(p_attrName); \
|
||||
|
@ -459,6 +464,7 @@ namespace ONNXIR
|
|||
m_graph->m_graphProtoSyncNeeded = true; \
|
||||
AttributeProto a; \
|
||||
a.set_name(p_attrName); \
|
||||
a.set_type(enumType); \
|
||||
*(a.mutable_##field()) = p_value; \
|
||||
m_attributes.emplace(p_attrName, a); \
|
||||
return true; \
|
||||
|
@ -469,7 +475,7 @@ namespace ONNXIR
|
|||
} \
|
||||
}; \
|
||||
|
||||
#define ADD_LIST_ATTR_IMPL(type, field) \
|
||||
#define ADD_LIST_ATTR_IMPL(type, enumType, field) \
|
||||
bool Node::AddAttribute(const std::string& p_attrName, \
|
||||
const std::vector<type>& p_values) \
|
||||
{ \
|
||||
|
@ -480,6 +486,7 @@ namespace ONNXIR
|
|||
m_graph->m_graphProtoSyncNeeded = true; \
|
||||
AttributeProto a; \
|
||||
a.set_name(p_attrName); \
|
||||
a.set_type(enumType); \
|
||||
for (const auto& val : p_values) \
|
||||
{ \
|
||||
*(a.mutable_##field()->Add()) = val; \
|
||||
|
@ -493,16 +500,16 @@ namespace ONNXIR
|
|||
} \
|
||||
}; \
|
||||
|
||||
ADD_BASIC_ATTR_IMPL(float, f)
|
||||
ADD_BASIC_ATTR_IMPL(int64_t, i)
|
||||
ADD_BASIC_ATTR_IMPL(std::string, s)
|
||||
ADD_ATTR_IMPL(TensorProto, t)
|
||||
ADD_ATTR_IMPL(GraphProto, g)
|
||||
ADD_LIST_ATTR_IMPL(float, floats)
|
||||
ADD_LIST_ATTR_IMPL(int64_t, ints)
|
||||
ADD_LIST_ATTR_IMPL(std::string, strings)
|
||||
ADD_LIST_ATTR_IMPL(TensorProto, tensors)
|
||||
ADD_LIST_ATTR_IMPL(GraphProto, graphs)
|
||||
ADD_BASIC_ATTR_IMPL(float, AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT, f)
|
||||
ADD_BASIC_ATTR_IMPL(int64_t, AttributeProto_AttributeType::AttributeProto_AttributeType_INT, i)
|
||||
ADD_BASIC_ATTR_IMPL(std::string, AttributeProto_AttributeType::AttributeProto_AttributeType_STRING, s)
|
||||
ADD_ATTR_IMPL(TensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR, t)
|
||||
ADD_ATTR_IMPL(GraphProto, AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH, g)
|
||||
ADD_LIST_ATTR_IMPL(float, AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS, floats)
|
||||
ADD_LIST_ATTR_IMPL(int64_t, AttributeProto_AttributeType::AttributeProto_AttributeType_INTS, ints)
|
||||
ADD_LIST_ATTR_IMPL(std::string, AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS, strings)
|
||||
ADD_LIST_ATTR_IMPL(TensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_TENSORS, tensors)
|
||||
ADD_LIST_ATTR_IMPL(GraphProto, AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPHS, graphs)
|
||||
|
||||
bool Node::ClearAttribute(const std::string& p_attrName)
|
||||
{
|
||||
|
@ -547,42 +554,48 @@ namespace ONNXIR
|
|||
return m_graph->GetNode(m_currentNodeIndex);
|
||||
}
|
||||
|
||||
Graph::Graph(const GraphProto& p_graphProto)
|
||||
Graph::Graph(GraphProto* p_graphProto,
|
||||
const std::unordered_map<std::string, int>& p_domainToVersion, bool p_isONNX)
|
||||
: m_graphProto(p_graphProto),
|
||||
m_graphProtoSyncNeeded(false),
|
||||
m_graphResolveNeeded(true),
|
||||
m_numOfNodes(0)
|
||||
{
|
||||
// This is a main graph, and strict type checking needed..
|
||||
m_removedInitializerIndexes.clear();
|
||||
m_domainToVersion = &p_domainToVersion;
|
||||
// This is a main graph.
|
||||
m_graphType |= Type::Main;
|
||||
|
||||
// TODO: add Type::Strict back.
|
||||
|
||||
// Copy initial tensors to a map.
|
||||
for (auto tensor : p_graphProto.initializer())
|
||||
if (!p_isONNX)
|
||||
{
|
||||
m_nameToInitialTensor[tensor.name()] = tensor;
|
||||
m_graphType |= Type::Strict;
|
||||
}
|
||||
|
||||
// Copy initial tensor indexes to a map.
|
||||
for (int i = 0; i < m_graphProto->initializer_size(); ++i)
|
||||
{
|
||||
m_nameToInitialTensorIndex[m_graphProto->initializer()[i].name()] = i;
|
||||
m_nameToInitialTensorPtr[m_graphProto->initializer()[i].name()] = m_graphProto->mutable_initializer(i);
|
||||
}
|
||||
|
||||
// Collect all node arg name, type, shape information in the graph.
|
||||
// type/shape information will be assigned to each node arg when going
|
||||
// thru all nodes later.
|
||||
ArgNameToTypeMap nameToTypeMap;
|
||||
for (auto& graphInput : m_graphProto.input())
|
||||
for (auto& graphInput : m_graphProto->input())
|
||||
{
|
||||
if (graphInput.has_name() && graphInput.has_type())
|
||||
{
|
||||
nameToTypeMap[graphInput.name()] = graphInput.type();
|
||||
}
|
||||
}
|
||||
for (auto& graphOutput : m_graphProto.output())
|
||||
for (auto& graphOutput : m_graphProto->output())
|
||||
{
|
||||
if (graphOutput.has_name() && graphOutput.has_type())
|
||||
{
|
||||
nameToTypeMap[graphOutput.name()] = graphOutput.type();
|
||||
}
|
||||
}
|
||||
for (auto& nodeArg : m_graphProto.value_info())
|
||||
for (auto& nodeArg : m_graphProto->value_info())
|
||||
{
|
||||
if (nodeArg.has_name() && nodeArg.has_type())
|
||||
{
|
||||
|
@ -592,40 +605,12 @@ namespace ONNXIR
|
|||
|
||||
// Add nodes.
|
||||
AddSourceSinkNodes();
|
||||
for (auto nodeProto : p_graphProto.node())
|
||||
for (auto& nodeProto : p_graphProto->node())
|
||||
{
|
||||
AddNode(nodeProto, nameToTypeMap);
|
||||
}
|
||||
}
|
||||
|
||||
Graph::Graph(const std::string& p_name, bool p_isONNX)
|
||||
: m_graphProtoSyncNeeded(false),
|
||||
m_graphResolveNeeded(true),
|
||||
m_numOfNodes(0)
|
||||
{
|
||||
m_graphProto.set_name(p_name);
|
||||
m_graphType |= Type::Main;
|
||||
if (!p_isONNX)
|
||||
{
|
||||
m_graphType |= Type::Strict;
|
||||
}
|
||||
|
||||
AddSourceSinkNodes();
|
||||
}
|
||||
|
||||
Graph::Graph(const std::string& p_name,
|
||||
const std::string& p_docString)
|
||||
: m_graphProtoSyncNeeded(false),
|
||||
m_graphResolveNeeded(true),
|
||||
m_numOfNodes(0)
|
||||
{
|
||||
m_graphProto.set_name(p_name);
|
||||
m_graphProto.set_doc_string(p_docString);
|
||||
m_graphType |= (Type::Main | Type::Strict);
|
||||
|
||||
AddSourceSinkNodes();
|
||||
}
|
||||
|
||||
Status Graph::VerifyNoDuplicateName(
|
||||
/*out*/ std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs,
|
||||
/*out*/ std::unordered_map<std::string, NODEINDEX>& p_nodeNameToIndex)
|
||||
|
@ -851,7 +836,15 @@ namespace ONNXIR
|
|||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
||||
if (this->NumberOfNodes() == p_nodesInTopologicalOrder.size())
|
||||
{
|
||||
return Status::OK();
|
||||
}
|
||||
else
|
||||
{
|
||||
return Status(ONNX, FAIL, "Error: the graph is not acyclic.");
|
||||
}
|
||||
}
|
||||
|
||||
Status Graph::InferAndVerifyTypeMatch(Node* p_node,
|
||||
|
@ -886,20 +879,20 @@ namespace ONNXIR
|
|||
// If it's fed by callers, it's needed to have type
|
||||
// information defined well.
|
||||
auto initialTensorIter
|
||||
= m_nameToInitialTensor.find(inputDef.Name());
|
||||
if (m_nameToInitialTensor.end()
|
||||
= m_nameToInitialTensorPtr.find(inputDef.Name());
|
||||
if (m_nameToInitialTensorPtr.end()
|
||||
!= initialTensorIter)
|
||||
{
|
||||
// This input is fed with default value by initializer.
|
||||
// Infer its type from initializer tensor.
|
||||
TypeProto initialTensorType;
|
||||
initialTensorType.mutable_tensor_type()->set_elem_type(
|
||||
initialTensorIter->second.data_type());
|
||||
initialTensorIter->second->data_type());
|
||||
inputDef.SetType(OpUtils::ToType(initialTensorType));
|
||||
|
||||
// Set shape accordingly.
|
||||
TensorShapeProto shape;
|
||||
for (auto dim : initialTensorIter->second.dims())
|
||||
for (auto dim : initialTensorIter->second->dims())
|
||||
{
|
||||
shape.add_dim()->set_dim_value(dim);
|
||||
}
|
||||
|
@ -1075,10 +1068,18 @@ namespace ONNXIR
|
|||
auto& nodeName = node->Name();
|
||||
auto& op_type = node->OpType();
|
||||
auto& domain = node->Domain();
|
||||
auto versionIter = m_domainToVersion->find(domain);
|
||||
if (m_domainToVersion->end() == versionIter)
|
||||
{
|
||||
// The domain referred by this node does not exist either
|
||||
// in <OpSetIdProto> in the <ModelProto> loaded (in the case of model loaded from file) or
|
||||
// in global DomainToVersionRange map (in the case of model constructed from scratch).
|
||||
return Status(ONNX, FAIL, "The op domain (" + domain + ") used by node ("
|
||||
+ nodeName + ") is not supported for this model.");
|
||||
}
|
||||
|
||||
// Get op schema with latest version given op name and domain.
|
||||
// TODO: version may be used when we want to support versioning in run time.
|
||||
node->m_op = OpSchemaRegistry::Schema(op_type, domain);
|
||||
// Get op schema given op name, max inclusive version and domain.
|
||||
node->m_op = OpSchemaRegistry::Schema(op_type, versionIter->second, domain);
|
||||
if (nullptr == node->m_op)
|
||||
{
|
||||
// A op_type refers to nothing.
|
||||
|
@ -1234,7 +1235,7 @@ namespace ONNXIR
|
|||
RETURN_IF_ERROR(BuildConnections(outputArgs, nodeNameToIndex));
|
||||
RETURN_IF_ERROR(CheckIsAcyclic(m_nodesInTopologicalOrder));
|
||||
RETURN_IF_ERROR(VerifyNodeAndOpMatch(m_nodesInTopologicalOrder, outputArgs));
|
||||
SetGraphInputsOutputs();
|
||||
RETURN_IF_ERROR(SetGraphInputsOutputs());
|
||||
|
||||
m_graphResolveNeeded = false;
|
||||
return Status::OK();
|
||||
|
@ -1261,47 +1262,72 @@ namespace ONNXIR
|
|||
"Sink node internally in a graph.",
|
||||
emptyArgs,
|
||||
emptyArgs)->Index();
|
||||
AddControlEdge(m_sourceNodeIndex, m_sinkNodeIndex);
|
||||
}
|
||||
|
||||
const std::string& Graph::Name() const
|
||||
{
|
||||
return m_graphProto.name();
|
||||
return m_graphProto->name();
|
||||
}
|
||||
|
||||
void Graph::SetName(const std::string& p_name)
|
||||
{
|
||||
m_graphProto.set_name(p_name);
|
||||
m_graphProto->set_name(p_name);
|
||||
}
|
||||
|
||||
const std::string& Graph::Description() const
|
||||
{
|
||||
return m_graphProto->doc_string();
|
||||
}
|
||||
|
||||
void Graph::SetDescription(const std::string& p_desription)
|
||||
{
|
||||
m_graphProto->set_doc_string(p_desription);
|
||||
}
|
||||
|
||||
void Graph::AddInitialTensor(const TensorProto& p_tensor)
|
||||
{
|
||||
m_nameToInitialTensor[p_tensor.name()] = p_tensor;
|
||||
if (m_nameToInitialTensorPtr.end() != m_nameToInitialTensorPtr.find(p_tensor.name()))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto tensorAdded = m_graphProto->add_initializer();
|
||||
*(tensorAdded) = p_tensor;
|
||||
m_nameToInitialTensorIndex[p_tensor.name()] = m_graphProto->initializer_size() - 1;
|
||||
m_nameToInitialTensorPtr[p_tensor.name()] = tensorAdded;
|
||||
m_graphProtoSyncNeeded = true;
|
||||
m_graphResolveNeeded = true;
|
||||
}
|
||||
|
||||
void Graph::RemoveInitialTensor(const std::string& p_tensorName)
|
||||
{
|
||||
m_nameToInitialTensor.erase(p_tensorName);
|
||||
m_graphProtoSyncNeeded = true;
|
||||
m_graphResolveNeeded = true;
|
||||
auto iter = m_nameToInitialTensorIndex.find(p_tensorName);
|
||||
if (m_nameToInitialTensorIndex.end() != iter)
|
||||
{
|
||||
m_removedInitializerIndexes.push_back(iter->second);
|
||||
m_nameToInitialTensorIndex.erase(p_tensorName);
|
||||
m_nameToInitialTensorPtr.erase(p_tensorName);
|
||||
m_graphProtoSyncNeeded = true;
|
||||
m_graphResolveNeeded = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool Graph::GetInitialTensor(const std::string& p_tensorName,
|
||||
TensorProto& p_value) const
|
||||
const TensorProto** p_value) const
|
||||
{
|
||||
auto iter = m_nameToInitialTensor.find(p_tensorName);
|
||||
if (m_nameToInitialTensor.end() == iter)
|
||||
auto iter = m_nameToInitialTensorPtr.find(p_tensorName);
|
||||
if (m_nameToInitialTensorPtr.end() == iter)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
p_value = iter->second;
|
||||
*p_value = iter->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
const InitialTensorSet& Graph::GetAllInitialTensors() const
|
||||
{
|
||||
return m_nameToInitialTensor;
|
||||
return m_nameToInitialTensorPtr;
|
||||
}
|
||||
|
||||
const std::vector<const NodeArg*>& Graph::GetInputs() const
|
||||
|
@ -1470,11 +1496,11 @@ namespace ONNXIR
|
|||
{
|
||||
if (!m_graphProtoSyncNeeded)
|
||||
{
|
||||
return m_graphProto;
|
||||
return *m_graphProto;
|
||||
}
|
||||
|
||||
// Nodes.
|
||||
m_graphProto.clear_node();
|
||||
m_graphProto->clear_node();
|
||||
|
||||
// Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
|
||||
for (auto& nodeIdx : m_nodesInTopologicalOrder)
|
||||
|
@ -1484,16 +1510,41 @@ namespace ONNXIR
|
|||
{
|
||||
continue;
|
||||
}
|
||||
auto nodeProto = m_graphProto.add_node();
|
||||
auto nodeProto = m_graphProto->add_node();
|
||||
m_nodes[nodeIdx]->ToProto(*nodeProto);
|
||||
}
|
||||
|
||||
// Initial tensors;
|
||||
m_graphProto.clear_initializer();
|
||||
for (auto item : m_nameToInitialTensor)
|
||||
if (m_removedInitializerIndexes.size() > 0)
|
||||
{
|
||||
auto tensor = m_graphProto.add_initializer();
|
||||
*tensor = item.second;
|
||||
// Move initializers.
|
||||
std::sort(m_removedInitializerIndexes.begin(), m_removedInitializerIndexes.end());
|
||||
int lastInUseInitializerIndex = m_graphProto->initializer_size() - 1;
|
||||
int start = 0, end = static_cast<int>(m_removedInitializerIndexes.size()) - 1;
|
||||
int lastRemovedInitializerIndex = m_removedInitializerIndexes[end];
|
||||
|
||||
for (; start <= end; start++)
|
||||
{
|
||||
// Find a lastInUseInitializer.
|
||||
while (start <= end && lastInUseInitializerIndex == lastRemovedInitializerIndex)
|
||||
{
|
||||
m_graphProto->mutable_initializer()->RemoveLast();
|
||||
lastInUseInitializerIndex--;
|
||||
end--;
|
||||
if (start <= end)
|
||||
{
|
||||
lastRemovedInitializerIndex = m_removedInitializerIndexes[end];
|
||||
}
|
||||
}
|
||||
|
||||
if (start <= end)
|
||||
{
|
||||
// Copy the <lastInUseInitializerIndex> initializer in use to the <start> slot which is removed.
|
||||
*m_graphProto->mutable_initializer(m_removedInitializerIndexes[start]) = m_graphProto->initializer(lastInUseInitializerIndex);
|
||||
m_graphProto->mutable_initializer()->RemoveLast();
|
||||
lastInUseInitializerIndex--;
|
||||
}
|
||||
}
|
||||
m_removedInitializerIndexes.clear();
|
||||
}
|
||||
|
||||
// Sync graph inputs/outputs/valueInfo.
|
||||
|
@ -1501,91 +1552,189 @@ namespace ONNXIR
|
|||
|
||||
m_graphProtoSyncNeeded = false;
|
||||
|
||||
return m_graphProto;
|
||||
return *m_graphProto;
|
||||
}
|
||||
|
||||
void Graph::SyncGraphInputsOutputs()
|
||||
{
|
||||
m_graphProto.clear_input();
|
||||
m_graphProto.clear_output();
|
||||
m_graphProto.clear_value_info();
|
||||
m_graphProto->clear_input();
|
||||
m_graphProto->clear_output();
|
||||
m_graphProto->clear_value_info();
|
||||
|
||||
for (auto inputArg : m_graphInputs)
|
||||
{
|
||||
*(m_graphProto.mutable_input()->Add()) = inputArg->ToProto();
|
||||
*(m_graphProto->mutable_input()->Add()) = inputArg->ToProto();
|
||||
}
|
||||
|
||||
for (auto outputArg : m_graphOutputs)
|
||||
{
|
||||
*(m_graphProto.mutable_output()->Add()) = outputArg->ToProto();
|
||||
*(m_graphProto->mutable_output()->Add()) = outputArg->ToProto();
|
||||
}
|
||||
|
||||
for (auto valueInfo : m_valueInfo)
|
||||
{
|
||||
*(m_graphProto.mutable_value_info()->Add()) = valueInfo->ToProto();
|
||||
*(m_graphProto->mutable_value_info()->Add()) = valueInfo->ToProto();
|
||||
}
|
||||
}
|
||||
|
||||
void Graph::SetGraphInputsOutputs()
|
||||
Status Graph::SetGraphInputsOutputs()
|
||||
{
|
||||
// Reset graphInputs/graphOutputs/valueInfo state.
|
||||
m_graphInputs.clear();
|
||||
m_graphOutputs.clear();
|
||||
m_valueInfo.clear();
|
||||
|
||||
std::unordered_map<std::string, const NodeArg*> outputNameToNodeArg;
|
||||
for (auto nodeIter = Nodes_begin();
|
||||
nodeIter != Nodes_end();
|
||||
++nodeIter)
|
||||
// Flag indicates that this graph is loaded from model file.
|
||||
// If it's true, then graph inputs and outputs will keep the same
|
||||
// as what are specified in the model, otherwise, graph inputs
|
||||
// and outputs will be inferred.
|
||||
bool loadedFromModelFile = m_graphProto->input_size() != 0
|
||||
|| m_graphProto->output_size() != 0
|
||||
|| m_graphProto->value_info_size() != 0;
|
||||
|
||||
std::unordered_set<std::string> addedInputNames{};
|
||||
if (loadedFromModelFile)
|
||||
{
|
||||
for (auto& outputDef : (*nodeIter)->OutputDefs())
|
||||
// Collect all graph inputs/outputs specified in original graph proto
|
||||
std::unordered_set<std::string> specifiedGraphInputs;
|
||||
std::unordered_set<std::string> specifiedGraphOutputs;
|
||||
std::unordered_set<std::string> specifiedGraphValueInfo;
|
||||
std::unordered_set<std::string> specifiedInitializers;
|
||||
for (auto& graphInput : m_graphProto->input())
|
||||
{
|
||||
outputNameToNodeArg.insert({ outputDef.Name(), &outputDef });
|
||||
specifiedGraphInputs.insert(graphInput.name());
|
||||
}
|
||||
}
|
||||
|
||||
// Init graph output args with all node output args.
|
||||
auto graphOutputArgs = outputNameToNodeArg;
|
||||
|
||||
std::unordered_set<Node*> innerNodes;
|
||||
for (auto nodeIter = Nodes_begin();
|
||||
nodeIter != Nodes_end();
|
||||
++nodeIter)
|
||||
{
|
||||
if (IsSourceNode((*nodeIter)->Index())
|
||||
|| IsSinkNode((*nodeIter)->Index()))
|
||||
for (auto& graphOutput : m_graphProto->output())
|
||||
{
|
||||
continue;
|
||||
specifiedGraphOutputs.insert(graphOutput.name());
|
||||
}
|
||||
for (auto& graphValueInfo : m_graphProto->value_info())
|
||||
{
|
||||
specifiedGraphValueInfo.insert(graphValueInfo.name());
|
||||
}
|
||||
for (auto& initializer : m_graphProto->initializer())
|
||||
{
|
||||
specifiedInitializers.insert(initializer.name());
|
||||
}
|
||||
|
||||
// Go thru all node's inputs.
|
||||
for (auto& inputArg : (*nodeIter)->InputDefs())
|
||||
std::unordered_map<std::string, const NodeArg*> outputNameToNodeArg;
|
||||
for (auto nodeIter = Nodes_begin();
|
||||
nodeIter != Nodes_end();
|
||||
++nodeIter)
|
||||
{
|
||||
auto outputArgIter = outputNameToNodeArg.find(inputArg.Name());
|
||||
if (outputNameToNodeArg.end()
|
||||
== outputArgIter)
|
||||
for (auto& outputDef : (*nodeIter)->OutputDefs())
|
||||
{
|
||||
// No such outputArg matching this inputArg.
|
||||
// This input arg should be fed when running evaluation.
|
||||
// it should be a graph input or initializer (say, weight).
|
||||
m_graphInputs.push_back(&inputArg);
|
||||
continue;
|
||||
if (specifiedGraphOutputs.erase(outputDef.Name()) >= 1)
|
||||
{
|
||||
m_graphOutputs.push_back(&outputDef);
|
||||
}
|
||||
outputNameToNodeArg.insert({ outputDef.Name(), &outputDef });
|
||||
}
|
||||
}
|
||||
if (specifiedGraphOutputs.size() != 0)
|
||||
{
|
||||
return Status(ONNX, FAIL, "Some graph outputs which don't exist in the graph.");
|
||||
}
|
||||
|
||||
// Remove the output arg name from graph outputs since it's
|
||||
// feeding another node as the node's input.
|
||||
if (graphOutputArgs.erase(outputArgIter->first) >= 1)
|
||||
for (auto nodeIter = Nodes_begin();
|
||||
nodeIter != Nodes_end();
|
||||
++nodeIter)
|
||||
{
|
||||
// Go thru all node's inputs.
|
||||
for (auto& inputArg : (*nodeIter)->InputDefs())
|
||||
{
|
||||
m_valueInfo.push_back(&inputArg);
|
||||
if (!inputArg.Exist())
|
||||
{
|
||||
// It's an optional input and does not exist in this case.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (specifiedGraphInputs.end() != specifiedGraphInputs.find(inputArg.Name()))
|
||||
{
|
||||
if (addedInputNames.end() == addedInputNames.find(inputArg.Name()))
|
||||
{
|
||||
// The node input is specified as graph input.
|
||||
m_graphInputs.push_back(&inputArg);
|
||||
addedInputNames.insert(inputArg.Name());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto outputArgIter = outputNameToNodeArg.find(inputArg.Name());
|
||||
if (outputNameToNodeArg.end() == outputArgIter
|
||||
&& specifiedInitializers.end() == specifiedInitializers.find(inputArg.Name()))
|
||||
{
|
||||
// The node input is not specified as graph input,
|
||||
// and it's not fed by another node neither.
|
||||
return Status(ONNX, FAIL, "Node input (" + inputArg.Name() + ") should be a graph input.");
|
||||
}
|
||||
|
||||
if (specifiedGraphValueInfo.erase(inputArg.Name()) >= 1)
|
||||
{
|
||||
m_valueInfo.push_back(&inputArg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set graph outputs.
|
||||
for (auto& outputArg : graphOutputArgs)
|
||||
else
|
||||
{
|
||||
m_graphOutputs.push_back(outputArg.second);
|
||||
std::unordered_map<std::string, const NodeArg*> outputNameToNodeArg;
|
||||
for (auto nodeIter = Nodes_begin();
|
||||
nodeIter != Nodes_end();
|
||||
++nodeIter)
|
||||
{
|
||||
for (auto& outputDef : (*nodeIter)->OutputDefs())
|
||||
{
|
||||
outputNameToNodeArg.insert({ outputDef.Name(), &outputDef });
|
||||
}
|
||||
}
|
||||
// Init graph output args with all node output args.
|
||||
auto graphOutputArgs = outputNameToNodeArg;
|
||||
|
||||
std::unordered_set<Node*> innerNodes;
|
||||
for (auto nodeIter = Nodes_begin();
|
||||
nodeIter != Nodes_end();
|
||||
++nodeIter)
|
||||
{
|
||||
// Go thru all node's inputs.
|
||||
for (auto& inputArg : (*nodeIter)->InputDefs())
|
||||
{
|
||||
if (!inputArg.Exist())
|
||||
{
|
||||
// It's an optional input and does not exist in this case.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto outputArgIter = outputNameToNodeArg.find(inputArg.Name());
|
||||
if (outputNameToNodeArg.end() == outputArgIter)
|
||||
{
|
||||
// This input arg should be fed when running evaluation.
|
||||
// it should be a graph input.
|
||||
if (addedInputNames.end() == addedInputNames.find(inputArg.Name()))
|
||||
{
|
||||
// This graph input has not been added into <m_graphInputs>.
|
||||
m_graphInputs.push_back(&inputArg);
|
||||
addedInputNames.insert(inputArg.Name());
|
||||
}
|
||||
}
|
||||
else if (graphOutputArgs.erase(outputArgIter->first) >= 1)
|
||||
{
|
||||
// Remove the output arg name from graph outputs since it's
|
||||
// the input of another node, which we call it intermediate result
|
||||
// and store it in <m_valueinfo>.
|
||||
m_valueInfo.push_back(&inputArg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set graph outputs.
|
||||
for (auto& outputArg : graphOutputArgs)
|
||||
{
|
||||
m_graphOutputs.push_back(outputArg.second);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Graph::IsSourceNode(NODEINDEX p_index) const
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace ONNXIR
|
|||
typedef size_t NODEINDEX;
|
||||
typedef int64_t VERSION;
|
||||
typedef ValueInfoProto NodeArgInfo;
|
||||
typedef std::unordered_map<std::string, TensorProto> InitialTensorSet;
|
||||
typedef std::unordered_map<std::string, const TensorProto*> InitialTensorSet;
|
||||
typedef std::unordered_map<std::string, TypeProto> ArgNameToTypeMap;
|
||||
|
||||
class Graph;
|
||||
|
@ -347,17 +347,6 @@ namespace ONNXIR
|
|||
NODEINDEX m_currentNodeIndex;
|
||||
};
|
||||
|
||||
// Constructor from scratch.
|
||||
// <p_isONNX> is a special flag to indicate whether it's
|
||||
// going to construct a ONNX graph. With ONNX graph, strict
|
||||
// type checking will be skiped.
|
||||
Graph(const std::string& p_name, bool p_isONNX = false);
|
||||
Graph(const std::string& p_name, const std::string& p_docString);
|
||||
|
||||
// Constructor: Given a <GraphProto> loaded from model file, construct
|
||||
// a <Graph> object.
|
||||
Graph(const GraphProto& p_graphProto);
|
||||
|
||||
// Resolve <*this> graph to ensure it's in a good shape with full
|
||||
// functionality.
|
||||
// 1. Run through all validation rules.
|
||||
|
@ -374,14 +363,17 @@ namespace ONNXIR
|
|||
const std::string& Name() const;
|
||||
void SetName(const std::string& p_name);
|
||||
|
||||
const std::string& Description() const;
|
||||
void SetDescription(const std::string& p_desription);
|
||||
|
||||
// Add/Remove/Get initial tensors for some graph inputs.
|
||||
void AddInitialTensor(const TensorProto& p_tensor);
|
||||
void RemoveInitialTensor(const std::string& p_tensorName);
|
||||
bool GetInitialTensor(const std::string& p_tensorName,
|
||||
TensorProto& p_value) const;
|
||||
const TensorProto** p_value) const;
|
||||
const InitialTensorSet& GetAllInitialTensors() const;
|
||||
|
||||
// Get graph inputs/outputs.
|
||||
// Get graph inputs/outputs/valueinfos.
|
||||
const std::vector<const NodeArg*>& GetInputs() const;
|
||||
const std::vector<const NodeArg*>& GetOutputs() const;
|
||||
const std::vector<const NodeArg*>& GetValueInfo() const;
|
||||
|
@ -447,6 +439,15 @@ namespace ONNXIR
|
|||
|
||||
private:
|
||||
|
||||
friend class Model;
|
||||
|
||||
Graph() = delete;
|
||||
|
||||
// Constructor: Given a <GraphProto> loaded from model file, construct
|
||||
// a <Graph> object.
|
||||
Graph(GraphProto* p_graphProto,
|
||||
const std::unordered_map<std::string, int>& p_domainToVersion, bool p_isONNX = true);
|
||||
|
||||
enum Type
|
||||
{
|
||||
// A main graph.
|
||||
|
@ -495,15 +496,11 @@ namespace ONNXIR
|
|||
const OpSignature* p_op,
|
||||
const std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs);
|
||||
|
||||
// Clean function definition map.
|
||||
// Remove function definitions not refered by any node.
|
||||
void CleanFunctionDefMap(const std::set<std::string>& p_funcDefNames);
|
||||
|
||||
// Add source/sink nodes to <*this> graph.
|
||||
void AddSourceSinkNodes();
|
||||
|
||||
// Set graph inputs/outputs when resolving a graph..
|
||||
void SetGraphInputsOutputs();
|
||||
Status SetGraphInputsOutputs();
|
||||
|
||||
// Sync graph inputs/outputs when serializing to proto.
|
||||
void SyncGraphInputsOutputs();
|
||||
|
@ -525,12 +522,15 @@ namespace ONNXIR
|
|||
// When serilizing <*this> Graph to a GraphProto, the nodes and
|
||||
// functions in <Graph> will also be fed into <m_graphProto> so that
|
||||
// it's consistent with <*this> graph.
|
||||
GraphProto m_graphProto;
|
||||
// This pointer is owned by parent model.
|
||||
GraphProto* m_graphProto;
|
||||
|
||||
// The node which refers to <*this> graph (Function).
|
||||
Node* m_node;
|
||||
|
||||
InitialTensorSet m_nameToInitialTensor;
|
||||
std::unordered_map<std::string, int> m_nameToInitialTensorIndex;
|
||||
InitialTensorSet m_nameToInitialTensorPtr;
|
||||
std::vector<int> m_removedInitializerIndexes;
|
||||
|
||||
// A flag indicates whether <*this> graph needs to be resolved.
|
||||
bool m_graphResolveNeeded;
|
||||
|
@ -550,6 +550,8 @@ namespace ONNXIR
|
|||
|
||||
// Graph value_info.
|
||||
std::vector<const NodeArg*> m_valueInfo;
|
||||
|
||||
const std::unordered_map<std::string, int>* m_domainToVersion;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
// 'type' : forcing value to bool 'true' or 'false' (performance warning)
|
||||
|
@ -17,74 +15,7 @@
|
|||
#include <unistd.h>
|
||||
#endif
|
||||
#include "model.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
#ifdef _WIN32
|
||||
inline Status FileOpenRd(const std::wstring& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
_wsopen_s(p_fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status FileOpenWr(const std::wstring& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
_wsopen_s(p_fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
inline Status FileOpenRd(const std::string& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
_sopen_s(p_fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
#else
|
||||
*p_fd = open(p_path.c_str(), O_RDONLY);
|
||||
#endif
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status FileOpenWr(const std::string& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
_sopen_s(p_fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
#else
|
||||
*p_fd = open(p_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
|
||||
#endif
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status FileClose(int fd)
|
||||
{
|
||||
int ret = 0;
|
||||
#ifdef _WIN32
|
||||
ret = _close(fd);
|
||||
#else
|
||||
ret = close(fd);
|
||||
#endif
|
||||
if (0 != ret)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
#include "utils.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
@ -92,118 +23,116 @@ namespace ONNXIR
|
|||
bool p_isONNX,
|
||||
const ModelMetaData& p_modelMetaData)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_isONNX));
|
||||
m_modelProto.set_ir_version(Version::IR_VERSION);
|
||||
m_modelProto.reset(new ModelProto);
|
||||
m_modelProto->set_ir_version(Version::IR_VERSION);
|
||||
m_modelProto->mutable_graph()->set_name(p_graphName);
|
||||
m_modelMetaData = p_modelMetaData;
|
||||
for (auto& metaData : m_modelMetaData)
|
||||
{
|
||||
auto prop = m_modelProto.add_metadata_props();
|
||||
auto prop = m_modelProto->add_metadata_props();
|
||||
prop->set_key(metaData.first);
|
||||
prop->set_value(metaData.second);
|
||||
}
|
||||
}
|
||||
|
||||
Model::Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString,
|
||||
const std::string& p_producerName,
|
||||
const std::string& p_producerVersion,
|
||||
const std::string& p_domain,
|
||||
VERSION p_modelVersion,
|
||||
const std::string& p_docString,
|
||||
const ModelMetaData& p_modelMetaData)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_graphDocString));
|
||||
m_modelProto.set_ir_version(Version::IR_VERSION);
|
||||
m_modelMetaData = p_modelMetaData;
|
||||
for (auto& metaData : m_modelMetaData)
|
||||
{
|
||||
auto prop = m_modelProto.add_metadata_props();
|
||||
prop->set_key(metaData.first);
|
||||
prop->set_value(metaData.second);
|
||||
}
|
||||
|
||||
m_modelProto.set_producer_name(p_producerName);
|
||||
m_modelProto.set_producer_version(p_producerVersion);
|
||||
m_modelProto.set_domain(p_domain);
|
||||
m_modelProto.set_model_version(p_modelVersion);
|
||||
m_modelProto.set_doc_string(p_docString);
|
||||
// Set m_domainToVersion to contain related domains with latest version.
|
||||
AddImportOpSets(p_isONNX);
|
||||
m_graph.reset(new Graph(m_modelProto->mutable_graph(), m_domainToVersion, p_isONNX));
|
||||
}
|
||||
|
||||
Model::Model(const ModelProto& p_modelProto)
|
||||
: Model(std::unique_ptr<ModelProto>(new ModelProto(p_modelProto)))
|
||||
{
|
||||
m_modelProto = p_modelProto;
|
||||
if (m_modelProto.has_graph())
|
||||
{
|
||||
m_graph.reset(new Graph(m_modelProto.graph()));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& prop : m_modelProto.metadata_props())
|
||||
Model::Model(std::unique_ptr<ModelProto> p_modelProto)
|
||||
{
|
||||
assert(nullptr != p_modelProto);
|
||||
m_modelProto.reset(p_modelProto.release());
|
||||
for (auto& prop : m_modelProto->metadata_props())
|
||||
{
|
||||
m_modelMetaData[prop.key()] = prop.value();
|
||||
}
|
||||
|
||||
if (0 == m_modelProto->opset_import_size())
|
||||
{
|
||||
// Operator sets are not specified in this model.
|
||||
// Will use global operator store instead.
|
||||
AddImportOpSets(false);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto& opSet : m_modelProto->opset_import())
|
||||
{
|
||||
m_domainToVersion[opSet.domain()] = static_cast<int>(opSet.version());
|
||||
}
|
||||
}
|
||||
|
||||
if (m_modelProto->has_graph())
|
||||
{
|
||||
m_graph.reset(new Graph(m_modelProto->mutable_graph(), m_domainToVersion));
|
||||
}
|
||||
}
|
||||
|
||||
VERSION Model::IrVersion() const
|
||||
{
|
||||
if (m_modelProto.has_ir_version())
|
||||
if (m_modelProto->has_ir_version())
|
||||
{
|
||||
return m_modelProto.ir_version();
|
||||
return m_modelProto->ir_version();
|
||||
}
|
||||
return c_noVersion;
|
||||
}
|
||||
|
||||
const std::string& Model::ProducerName() const
|
||||
{
|
||||
return m_modelProto.producer_name();
|
||||
return m_modelProto->producer_name();
|
||||
}
|
||||
|
||||
void Model::SetProducerName(const std::string& p_producerName)
|
||||
{
|
||||
m_modelProto.set_producer_name(p_producerName);
|
||||
m_modelProto->set_producer_name(p_producerName);
|
||||
}
|
||||
|
||||
const std::string& Model::ProducerVersion() const
|
||||
{
|
||||
return m_modelProto.producer_version();
|
||||
return m_modelProto->producer_version();
|
||||
}
|
||||
|
||||
void Model::SetProducerVersion(const std::string& p_producerVersion)
|
||||
{
|
||||
m_modelProto.set_producer_version(p_producerVersion);
|
||||
m_modelProto->set_producer_version(p_producerVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::Domain() const
|
||||
{
|
||||
return m_modelProto.domain();
|
||||
return m_modelProto->domain();
|
||||
}
|
||||
|
||||
void Model::SetDomain(const std::string& p_domain)
|
||||
{
|
||||
m_modelProto.set_domain(p_domain);
|
||||
m_modelProto->set_domain(p_domain);
|
||||
}
|
||||
|
||||
VERSION Model::ModelVersion() const
|
||||
{
|
||||
if (m_modelProto.has_model_version())
|
||||
if (m_modelProto->has_model_version())
|
||||
{
|
||||
return m_modelProto.model_version();
|
||||
return m_modelProto->model_version();
|
||||
}
|
||||
return c_noVersion;
|
||||
}
|
||||
|
||||
void Model::SetModelversion(VERSION p_modelVersion)
|
||||
{
|
||||
m_modelProto.set_model_version(p_modelVersion);
|
||||
m_modelProto->set_model_version(p_modelVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::DocString() const
|
||||
{
|
||||
return m_modelProto.doc_string();
|
||||
return m_modelProto->doc_string();
|
||||
}
|
||||
|
||||
void Model::SetDocString(const std::string& p_docString)
|
||||
{
|
||||
m_modelProto.set_doc_string(p_docString);
|
||||
m_modelProto->set_doc_string(p_docString);
|
||||
}
|
||||
|
||||
const ModelMetaData& Model::MetaData() const
|
||||
|
@ -221,10 +150,29 @@ namespace ONNXIR
|
|||
return m_graph.get();
|
||||
}
|
||||
|
||||
const ModelProto& Model::ToProto()
|
||||
ModelProto Model::ToProto()
|
||||
{
|
||||
*(m_modelProto.mutable_graph()) = m_graph->ToGraphProto();
|
||||
return m_modelProto;
|
||||
*(m_modelProto->mutable_graph()) = m_graph->ToGraphProto();
|
||||
return *m_modelProto;
|
||||
}
|
||||
|
||||
void Model::AddImportOpSets(bool p_isONNX)
|
||||
{
|
||||
auto& domainToVersionRangeMap = OpSchemaRegistry::DomainToVersionRange::Instance().Map();
|
||||
for (auto& domainToVersionRange : domainToVersionRangeMap)
|
||||
{
|
||||
if (p_isONNX && domainToVersionRange.first.compare(c_onnxDomain) != 0)
|
||||
{
|
||||
// Constructing a pure ONNX model.
|
||||
// Only ops in ONNX domain should be used.
|
||||
continue;
|
||||
}
|
||||
|
||||
m_domainToVersion[domainToVersionRange.first] = domainToVersionRange.second.second;
|
||||
auto opSetIdProto = m_modelProto->add_opset_import();
|
||||
opSetIdProto->set_domain(domainToVersionRange.first);
|
||||
opSetIdProto->set_version(domainToVersionRange.second.second);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
|
@ -259,16 +207,18 @@ namespace ONNXIR
|
|||
|
||||
Status Model::LoadFromBytes(int count, void *pBytes, /*out*/ std::shared_ptr<Model>* p_model)
|
||||
{
|
||||
ModelProto modelProto;
|
||||
bool result = modelProto.ParseFromArray(pBytes, count);
|
||||
std::unique_ptr<ModelProto> modelProto(new ModelProto);
|
||||
bool result = modelProto->ParseFromArray(pBytes, count);
|
||||
if (!result)
|
||||
{
|
||||
return Status(ONNX, INVALID_PROTOBUF, "Protobuf parsing failed.");
|
||||
}
|
||||
|
||||
(*p_model).reset(new Model(modelProto));
|
||||
RETURN_IF_ERROR((*p_model)->MainGraph()->Resolve());
|
||||
|
||||
(*p_model).reset(new Model(std::move(modelProto)));
|
||||
if ((*p_model)->MainGraph() != nullptr)
|
||||
{
|
||||
RETURN_IF_ERROR((*p_model)->MainGraph()->Resolve());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -297,17 +247,20 @@ namespace ONNXIR
|
|||
new CodedInputStream(raw_input.get()));
|
||||
// Allows protobuf library versions < 3.2.0 to parse messages greater than 64MB.
|
||||
coded_input->SetTotalBytesLimit(INT_MAX, INT_MAX);
|
||||
ModelProto modelProto;
|
||||
bool result = modelProto.ParseFromCodedStream(coded_input.get());
|
||||
std::unique_ptr<ModelProto> modelProto(new ModelProto);
|
||||
bool result = modelProto->ParseFromCodedStream(coded_input.get());
|
||||
coded_input.reset();
|
||||
raw_input.reset();
|
||||
if (!result)
|
||||
{
|
||||
return Status(ONNX, INVALID_PROTOBUF, "Protobuf parsing failed.");
|
||||
}
|
||||
(*p_model).reset(new Model(modelProto));
|
||||
RETURN_IF_ERROR((*p_model)->MainGraph()->Resolve());
|
||||
|
||||
(*p_model).reset(new Model(std::move(modelProto)));
|
||||
if ((*p_model)->MainGraph() != nullptr)
|
||||
{
|
||||
RETURN_IF_ERROR((*p_model)->MainGraph()->Resolve());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -319,7 +272,7 @@ namespace ONNXIR
|
|||
}
|
||||
|
||||
RETURN_IF_ERROR(p_model.MainGraph()->Resolve());
|
||||
auto& modelProto = p_model.ToProto();
|
||||
auto modelProto = p_model.ToProto();
|
||||
bool result = modelProto.SerializeToFileDescriptor(p_fd);
|
||||
if (result)
|
||||
{
|
||||
|
|
|
@ -20,17 +20,14 @@ namespace ONNXIR
|
|||
bool p_isONNX = false,
|
||||
const ModelMetaData& p_modelMetaData = ModelMetaData());
|
||||
|
||||
Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString,
|
||||
const std::string& p_producerName,
|
||||
const std::string& p_producerVersion,
|
||||
const std::string& p_domain,
|
||||
VERSION p_modelVersion,
|
||||
const std::string& p_modelDocString,
|
||||
const ModelMetaData& p_modelMetaData = ModelMetaData());
|
||||
|
||||
// NOTE: after calling this contructor, <*this> model will
|
||||
// hold a copy of <p_modelProto>.
|
||||
Model(const ModelProto& p_modelProto);
|
||||
|
||||
// NOTE: after calling this constructor, <*this> model will
|
||||
// own the <p_modelProto>.
|
||||
Model(std::unique_ptr<ModelProto> p_modelProto);
|
||||
|
||||
// Get model's IR version.
|
||||
// Return <c_noVersion> if not specified.
|
||||
VERSION IrVersion() const;
|
||||
|
@ -73,7 +70,7 @@ namespace ONNXIR
|
|||
const Graph* MainGraph() const;
|
||||
|
||||
// Get model's serlization proto data.
|
||||
const ModelProto& ToProto();
|
||||
ModelProto ToProto();
|
||||
|
||||
#ifdef _WIN32
|
||||
static Status Save(Model& p_model, const std::wstring& p_filePath);
|
||||
|
@ -93,12 +90,23 @@ namespace ONNXIR
|
|||
|
||||
private:
|
||||
|
||||
// Set <m_domainToVersion> and <m_modelProto> to contain related domains
|
||||
// with latest version in OpSchemaRegistry.
|
||||
// if <p_isONNX> is true, then only onnx domain will be contained.
|
||||
// otherwise, ml domain will also be contained.
|
||||
void AddImportOpSets(bool p_isONNX);
|
||||
|
||||
// Model data.
|
||||
ModelProto m_modelProto;
|
||||
std::unique_ptr<ModelProto> m_modelProto;
|
||||
|
||||
// This is a duplication of <m_modelProto.metadata_props()>.
|
||||
// It gives better accessibility.
|
||||
ModelMetaData m_modelMetaData;
|
||||
|
||||
// Operator set used by this model.
|
||||
// It contains <domain, version> pairs.
|
||||
std::unordered_map<std::string, int> m_domainToVersion;
|
||||
|
||||
// Main graph of the model.
|
||||
std::unique_ptr<Graph> m_graph;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma warning(disable : 4503)
|
||||
|
||||
#include "constants.h"
|
||||
#include "op.h"
|
||||
#include "opsignature.h"
|
||||
#include "utils.h"
|
||||
|
@ -241,8 +242,9 @@ namespace ONNXIR
|
|||
// Increase the highest version when you make BC-breaking changes to the
|
||||
// operator schema on specific domain. Update the lowest version when it's
|
||||
// determined to remove too old version history.
|
||||
m_map[""] = std::make_pair(1, 2);
|
||||
m_map["ai.onnx.ml"] = std::make_pair(1, 1);
|
||||
m_map[c_onnxDomain] = std::make_pair(1, 2);
|
||||
m_map[c_mlDomain] = std::make_pair(1, 1);
|
||||
m_map[c_msDomain] = std::make_pair(1, 1);
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::pair<int, int>>&
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "utils.h"
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456 4189 4996 4503)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
#include "status.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
namespace Common
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456 4189 4996 4503)
|
||||
#include "proto/onnx/protobuf/onnx-ml.pb.h"
|
||||
#pragma warning(pop)
|
||||
|
||||
#include "status.h"
|
||||
|
||||
namespace ONNXIR
|
||||
|
|
|
@ -119,6 +119,8 @@ namespace ONNXIR
|
|||
return t.c_complex64;
|
||||
case TensorProto::DataType::TensorProto_DataType_COMPLEX128:
|
||||
return t.c_complex128;
|
||||
case TensorProto::DataType::TensorProto_DataType_UNDEFINED:
|
||||
return t.c_undefined;
|
||||
}
|
||||
|
||||
assert(false);
|
||||
|
|
|
@ -1,15 +1,94 @@
|
|||
#ifndef ONNXIR_UTILS_H
|
||||
#define ONNXIR_UTILS_H
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#ifdef _WIN32
|
||||
#include <io.h>
|
||||
#else
|
||||
#include <sys/io.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "status.h"
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456 4189 4996 4503)
|
||||
#include "proto/onnx/protobuf/onnx-ml.pb.h"
|
||||
#pragma warning(pop)
|
||||
|
||||
using namespace onnx;
|
||||
using namespace ONNXIR::Common;
|
||||
|
||||
namespace
|
||||
{
|
||||
#ifdef _WIN32
|
||||
inline Status FileOpenRd(const std::wstring& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
_wsopen_s(p_fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status FileOpenWr(const std::wstring& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
_wsopen_s(p_fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
inline Status FileOpenRd(const std::string& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
_sopen_s(p_fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
#else
|
||||
*p_fd = open(p_path.c_str(), O_RDONLY);
|
||||
#endif
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status FileOpenWr(const std::string& p_path, /*out*/ int* p_fd)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
_sopen_s(p_fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
#else
|
||||
*p_fd = open(p_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
|
||||
#endif
|
||||
if (0 > *p_fd)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status FileClose(int fd)
|
||||
{
|
||||
int ret = 0;
|
||||
#ifdef _WIN32
|
||||
ret = _close(fd);
|
||||
#else
|
||||
ret = close(fd);
|
||||
#endif
|
||||
if (0 != ret)
|
||||
{
|
||||
return Status(SYSTEM, errno);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
|
|
@ -111,9 +111,9 @@ namespace ONNXIR {
|
|||
"the batch_size", AttrType::AttributeProto_AttributeType_INT);
|
||||
|
||||
// Taken from RS4
|
||||
REGISTER_OPERATOR_SCHEMA(Linear)
|
||||
.Description("Linear takes one input data (Tensor<T>) and produces one output "
|
||||
"data (Tensor<T>) where the linear function, f(x)= alpha * x + beta is "
|
||||
REGISTER_OPERATOR_SCHEMA(Affine)
|
||||
.Description("Affine takes one input data (Tensor<T>) and produces one output "
|
||||
"data (Tensor<T>) where the affine function, f(x)= alpha * x + beta is "
|
||||
"applied to the tensor elementwise.")
|
||||
.Input("X", "Input tensor of any shape", "T")
|
||||
.Output("Y", "Output tensor of same shape and type as input X.", "T")
|
||||
|
|
|
@ -9,8 +9,7 @@ namespace ONNXIR
|
|||
"The value for the elements of the output tensor.",
|
||||
AttrType::AttributeProto_AttributeType_TENSOR)
|
||||
.Output("output",
|
||||
"Output tensor containing the same value of the provided tensor.",
|
||||
"T")
|
||||
"Output tensor containing the same value of the provided tensor.", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.");
|
||||
}
|
||||
|
|
|
@ -2,50 +2,55 @@
|
|||
|
||||
namespace ONNXIR {
|
||||
|
||||
#define REGISTER_BINARY_COMPARISON_OPERATOR_SCHEMA(OpName) \
|
||||
#define REGISTER_BINARY_COMPARISON_OPERATOR_SCHEMA(OpName) \
|
||||
REGISTER_OPERATOR_SCHEMA(OpName) \
|
||||
.Description("Computes the elementwise comparison `"#OpName"` between " \
|
||||
"`A` and `B` input tensor. The result is a tensor of type integer " \
|
||||
"in which `0` mean false and `1` mean true.") \
|
||||
.Input("A", "Left input tensor for the operator.", "T1") \
|
||||
.Input("B", "Right input tensor for the operator.", "T1") \
|
||||
.Output("C", "Result tensor of type `int`, 0 mean False and 1 mean True.", "T2") \
|
||||
.Description("Returns the tensor resulted from performing the '"#OpName"' logical operation" \
|
||||
"elementwise on the input tensors A and B. If broadcasting is enabled, the right-hand-side" \
|
||||
"argument will be broadcasted to match the shape of left-hand-side argument. Refer to Add for" \
|
||||
"a detailed description of the broadcasting rules.") \
|
||||
.Input("A", "First operand, should share the type with the second operand.", "T1") \
|
||||
.Input("B", "Second operand. With broadcasting can be of smaller size than A." \
|
||||
"If broadcasting is disabled, it should be of the same size.", "T1") \
|
||||
.Output("C", "Result, has same dimensions as A and type bool.", "T2") \
|
||||
.TypeConstraint("T1", { "tensor(float16)", "tensor(float)", "tensor(double)" }, \
|
||||
"Constrain input to float tensors.") \
|
||||
.TypeConstraint("T2", { "tensor(int32)" }, "Constrain output types to int tensor.") \
|
||||
.TypeConstraint("T2", { "tensor(bool)" }, "Constrain output types to bool tensor.") \
|
||||
.Attr("axis", "If set, defines the broadcast dimensions.", \
|
||||
AttrType::AttributeProto_AttributeType_INT) \
|
||||
.Attr("broadcast", "Enable broadcasting.", \
|
||||
.Attr("broadcast", "Pass 1 to enable broadcasting.", \
|
||||
AttrType::AttributeProto_AttributeType_INT);
|
||||
|
||||
//‘GREATER’, ‘LESS’, ‘EQUALS,
|
||||
REGISTER_BINARY_COMPARISON_OPERATOR_SCHEMA(Greater)
|
||||
REGISTER_BINARY_COMPARISON_OPERATOR_SCHEMA(Less)
|
||||
REGISTER_BINARY_COMPARISON_OPERATOR_SCHEMA(Equal)
|
||||
REGISTER_BINARY_COMPARISON_OPERATOR_SCHEMA(Less)
|
||||
REGISTER_BINARY_COMPARISON_OPERATOR_SCHEMA(Equal)
|
||||
|
||||
#define REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(OpName) \
|
||||
#define REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(OpName) \
|
||||
REGISTER_OPERATOR_SCHEMA(OpName) \
|
||||
.Description("Computes the elementwise logical operation '"#OpName"' between " \
|
||||
"`A` and `B` input tensor. The result is a tensor of type integer " \
|
||||
"in which `0` mean false and `1` mean true.") \
|
||||
.Input("A", "Left input tensor for the logical operator.", "T") \
|
||||
.Input("B", "Right input tensor for the logical operator.", "T") \
|
||||
.Output("output", "Result tensor of type `int`, 0 mean False and 1 mean True.", "T") \
|
||||
.TypeConstraint("T", { "tensor(int32)" }, "Constrain input and output types to int tensor.") \
|
||||
.Description("Returns the tensor resulted from performing the '"#OpName"' logical operation" \
|
||||
"elementwise on the input tensors A and B. If broadcasting is enabled, the right-hand-side" \
|
||||
"argument will be broadcasted to match the shape of left-hand-side argument. Refer to Add" \
|
||||
" for a detailed description of the broadcasting rules.") \
|
||||
.Input("A", "First operand.", "T") \
|
||||
.Input("B", "Second operand. With broadcasting can be of smaller size than A. If broadcasting" \
|
||||
"is disabled, it should be of the same size.", "T") \
|
||||
.Output("C", "Result, has same dimensions and A and type bool.", "T") \
|
||||
.TypeConstraint("T", { "tensor(bool)" }, "Constrain input and output types to bool tensor.") \
|
||||
.Attr("axis", "If set, defines the broadcast dimensions.", \
|
||||
AttrType::AttributeProto_AttributeType_INT) \
|
||||
.Attr("broadcast", "Enable broadcasting.", \
|
||||
.Attr("broadcast", "Pass 1 to enable broadcasting.", \
|
||||
AttrType::AttributeProto_AttributeType_INT);
|
||||
|
||||
// ‘AND, ‘OR’, ‘XOR’
|
||||
REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(And)
|
||||
REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(Or)
|
||||
REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(Xor)
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Not)
|
||||
// ‘AND, ‘OR’, ‘XOR’
|
||||
REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(And)
|
||||
REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(Or)
|
||||
REGISTER_BINARY_LOGIC_OPERATOR_SCHEMA(Xor)
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Not)
|
||||
.Description("Performs element-wise negation.")
|
||||
.Input("X", "Input tensor of type bool.", "T")
|
||||
.Output("Y", " Output tensor of type bool.", "T")
|
||||
.TypeConstraint("T", { "tensor(int32)" }, "Constrain input and output types to int tensor.");
|
||||
.TypeConstraint("T", { "tensor(bool)" }, "Constrain input and output types to bool tensor.");
|
||||
|
||||
}
|
||||
|
|
|
@ -2,11 +2,25 @@
|
|||
|
||||
namespace ONNXIR {
|
||||
|
||||
#define REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(OpName) \
|
||||
#define REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(OpName) \
|
||||
REGISTER_OPERATOR_SCHEMA(OpName) \
|
||||
.Description("Elementwise "#OpName" takes one or more input data (Tensor<T>) and produces one " \
|
||||
"output data (Tensor<T>) where the declared function is applied to the input " \
|
||||
"tensors elementwise.") \
|
||||
.Description( \
|
||||
"Performs element-wise binary "#OpName" (with limited broadcast support)." \
|
||||
\
|
||||
"If necessary, the right-hand-side argument will be broadcasted to match the shape of" \
|
||||
"left-handside argument. When broadcasting is specified, the second tensor can either be of" \
|
||||
"size 1 (a scalar value) or having its shape as a contiguous subset of the first tensor's" \
|
||||
"shape. The starting of the mutually equal shape is specified by the argument \"axis\" and if" \
|
||||
"it is not set, suffix matching is assumed. 1-dim expansion doesn't work yet. " \
|
||||
\
|
||||
"For example, the following tensor shapes are supported (with broadcast=1): " \
|
||||
"shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar" \
|
||||
"shape(A) = (2, 3, 4, 5), shape(B) = (5,)" \
|
||||
"shape(A) = (2, 3, 4, 5), shape(B) = (4, 5)" \
|
||||
"shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1" \
|
||||
"shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0" \
|
||||
\
|
||||
"Attribute broadcast=1 needs to be passed to enable broadcasting") \
|
||||
.Input("A", "First operand, should share the type with the second operand.", "T") \
|
||||
.Input("B", "Second operand. With broadcasting can be of smaller size than A. " \
|
||||
"If broadcasting is disabled it should be of the same size..", "T") \
|
||||
|
@ -19,11 +33,11 @@ namespace ONNXIR {
|
|||
AttrType::AttributeProto_AttributeType_INT);
|
||||
|
||||
REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(Add)
|
||||
REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(Sub)
|
||||
REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(Mul)
|
||||
REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(Div)
|
||||
REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(Sub)
|
||||
REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(Mul)
|
||||
REGISTER_ELEMENTWISE_BROADCAST_OPERATOR_SCHEMA(Div)
|
||||
|
||||
#define REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(OpName, output) \
|
||||
#define REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(OpName, output) \
|
||||
REGISTER_OPERATOR_SCHEMA(OpName) \
|
||||
.Description("Element-wise "#OpName" of each of the input tensors. The first input tensor can be " \
|
||||
"used in-place as the output tensor, in which case the "#OpName" will be done in " \
|
||||
|
@ -34,13 +48,13 @@ namespace ONNXIR {
|
|||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, \
|
||||
"Constrain input and output types to float tensors.");
|
||||
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Max, "max")
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Min, "min")
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Sum, "sum")
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Mean, "mean")
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Max, "max")
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Min, "min")
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Sum, "sum")
|
||||
REGISTER_ELEMENTWISE_OPERATOR_SCHEMA(Mean, "mean")
|
||||
|
||||
// Taken from ONNX
|
||||
REGISTER_OPERATOR_SCHEMA(Neg)
|
||||
// Taken from ONNX
|
||||
REGISTER_OPERATOR_SCHEMA(Neg)
|
||||
.Description("Neg takes one input data (Tensor<T>) and produces one output data "
|
||||
"(Tensor<T>) where each element flipped sign, y = -x, is applied to "
|
||||
"the tensor elementwise.")
|
||||
|
@ -57,7 +71,7 @@ namespace ONNXIR {
|
|||
.Input("X", "Input tensor of any shape", "T")
|
||||
.Output("Y", "Output tensor of same shape and type as input X.", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.");
|
||||
"Constrain input and output types to float tensors.");
|
||||
|
||||
// Take from ONNX
|
||||
REGISTER_OPERATOR_SCHEMA(Reciprocal)
|
||||
|
@ -173,18 +187,18 @@ namespace ONNXIR {
|
|||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.")
|
||||
.Attr("transA",
|
||||
"Whether A should be transposed",
|
||||
AttrType::AttributeProto_AttributeType_INT)
|
||||
"Whether A should be transposed",
|
||||
AttrType::AttributeProto_AttributeType_INT)
|
||||
.Attr("transB",
|
||||
"Whether B should be transposed",
|
||||
AttrType::AttributeProto_AttributeType_INT)
|
||||
"Whether B should be transposed",
|
||||
AttrType::AttributeProto_AttributeType_INT)
|
||||
.Attr("broadcast",
|
||||
"Whether C should be broadcasted",
|
||||
AttrType::AttributeProto_AttributeType_INT)
|
||||
"Whether C should be broadcasted",
|
||||
AttrType::AttributeProto_AttributeType_INT)
|
||||
.Attr("alpha",
|
||||
"Scalar multiplier for the product of input tensors A * B",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT)
|
||||
"Scalar multiplier for the product of input tensors A * B",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT)
|
||||
.Attr("beta",
|
||||
"Scalar multiplier for input tensor C",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT);
|
||||
"Scalar multiplier for input tensor C",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include "proto/onnx/core/op.h"
|
||||
|
||||
#include "proto/onnx/core/constants.h"
|
||||
|
||||
namespace ONNXIR {
|
||||
REGISTER_OPERATOR_SCHEMA(FC)
|
||||
|
@ -402,33 +402,34 @@ namespace ONNXIR {
|
|||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, "Constrain input and output types to float tensors.")
|
||||
.Attr("p", "Value of p, default 2.", AttrType::AttributeProto_AttributeType_INT, int64_t(2));
|
||||
|
||||
|
||||
std::function<void(OperatorSchemaSetter&)> LRNDocGenerator() {
|
||||
return [=](OperatorSchemaSetter& schema) {
|
||||
schema.Description("Perform local response normalization. "
|
||||
"NOTE: Only supports Caffe across channel mode. ");
|
||||
schema.Input("X", "Input tensor of any shape", "T");
|
||||
schema.Output("Y", "Output tensor of same shape and type as input X.", "T");
|
||||
schema.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.");
|
||||
schema.Attr("size", "[default 5]: the number of channels to sum over (for cross "
|
||||
"channel LRN) or the side length of the square region to sum over (for within "
|
||||
"channel LRN)", AttrType::AttributeProto_AttributeType_INT, int64_t(5));
|
||||
schema.Attr("alpha", "Scalar scaling factor. Default is 0.0001",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT, float(0.0001));
|
||||
schema.Attr("beta", "Scalar exponent in the LRN. Default is 0.5.",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT, float(0.5));
|
||||
schema.Attr("bias", "An offset (must be positive to avoid dividing by 0). Defaults to 1.0.",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT, float(1.0));
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LocalResponseNormalization)
|
||||
.FillUsing(LRNDocGenerator());
|
||||
|
||||
// TODO: to be duplicated.
|
||||
REGISTER_OPERATOR_SCHEMA(LRN)
|
||||
.Description("Perform local response normalization. "
|
||||
"NOTE: Only supports Caffe across channel mode. ")
|
||||
.Input("input", "Input tensor of any shape", "T")
|
||||
.Output("output", "Output tensor of same shape and type as input X.", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, "Constrain input and output "
|
||||
" types to float tensors.")
|
||||
.Attr("size", "[default 5]: the number of channels to sum over (for cross "
|
||||
"channel LRN) or the side length of the square region to sum over (for within "
|
||||
"channel LRN)", AttrType::AttributeProto_AttributeType_INT, int64_t(5))
|
||||
.Attr("alpha", "Scalar scaling factor. Default is 0.0001", AttrType::AttributeProto_AttributeType_FLOAT, float(0.0001))
|
||||
.Attr("beta", "Scalar exponent in the LRN. Default is 0.5.", AttrType::AttributeProto_AttributeType_FLOAT, float(0.5))
|
||||
.Attr("bias", "An offset (must be positive to avoid dividing by 0). Defaults to 1.0.",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT, float(1.0));
|
||||
.FillUsing(LRNDocGenerator());
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(MVN)
|
||||
.Description("Perform mean variance normalization.")
|
||||
.Input("input", "Input tensor of any shape", "T")
|
||||
.Output("output", "Output tensor of same shape and type as input X.", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, "Constrain input and output "
|
||||
"types to float tensors.")
|
||||
.Attr("across_channels", "If true, mean and variance are computed across channels. "
|
||||
"Default is false.", AttrType::AttributeProto_AttributeType_INT, int64_t(0))
|
||||
.Attr("normalize_variance", "If false, normalize the mean only. Default is true.",
|
||||
AttrType::AttributeProto_AttributeType_INT, int64_t(1));
|
||||
|
||||
// Manually added on 2/14/2018.
|
||||
REGISTER_OPERATOR_SCHEMA(MeanVarianceNormalization)
|
||||
.Description("Perform mean variance normalization.")
|
||||
.Input("input", "Input tensor of any shape", "T")
|
||||
|
@ -488,10 +489,10 @@ namespace ONNXIR {
|
|||
.Output("output", "Result, has same shape and type as X", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, "Constrain input and "
|
||||
"output types to float tensors.")
|
||||
.Attr("mode", "enum {'NN', 'BILINEAR' }, Nearest neighbor or bilinear upsampling.",
|
||||
.Attr("mode", "enum {'NEAREST', 'BILINEAR' }, Nearest neighbor or bilinear upsampling.",
|
||||
AttrType::AttributeProto_AttributeType_STRING)
|
||||
.Attr("width_scale", "Scale along width dimension", AttrType::AttributeProto_AttributeType_INT)
|
||||
.Attr("height_scale", "Scale along height dimension", AttrType::AttributeProto_AttributeType_INT);
|
||||
.Attr("width_scale", "Scale along width dimension", AttrType::AttributeProto_AttributeType_FLOAT)
|
||||
.Attr("height_scale", "Scale along height dimension", AttrType::AttributeProto_AttributeType_FLOAT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Crop)
|
||||
.Description("Crop and image to the specified spatial dimensions. If scale is given,"
|
||||
|
@ -502,8 +503,8 @@ namespace ONNXIR {
|
|||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, "Constrain input and "
|
||||
"output types to float tensors.")
|
||||
.Attr("border", "A 1-D tensor of values (leftBorder, topBorder, rightBorder, bottomBorder)",
|
||||
AttrType::AttributeProto_AttributeType_INT)
|
||||
.Attr("scale", "A 1-D tensor of values (height, width)", AttrType::AttributeProto_AttributeType_INT);
|
||||
AttrType::AttributeProto_AttributeType_INTS)
|
||||
.Attr("scale", "A 1-D tensor of values (height, width)", AttrType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
// Taken from ONNX
|
||||
REGISTER_OPERATOR_SCHEMA(Pad)
|
||||
|
@ -537,5 +538,28 @@ namespace ONNXIR {
|
|||
"Constrain input and output types to float tensors.")
|
||||
.Attr("image", "Image tensor stored as a sequence of floats [C,H,W].", AttrType::AttributeProto_AttributeType_TENSOR);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LegacyPadding)
|
||||
.SetDomain(c_msDomain)
|
||||
.Description("his operator is designed to support CoreML's pooling operator under IncludeLastPixel padding mode.. "
|
||||
"To simulate CoreML's pooling operator, First, copy kernel shape, strides, padding "
|
||||
"amounts from the original pooling operator to this LegacyPadding operator. "
|
||||
"Second, create a pooling operator under auto_pad=VALID with the kernel and strides used in the original pooling. "
|
||||
"Third, connect the output of LegacyPadding operator with the pooling operator we just create. ")
|
||||
.Input("data", "Input tensor.", "T")
|
||||
.Output("output", "Tensor after padding.", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.")
|
||||
.Attr("pads",
|
||||
"Padding amounts along H- and W-axes, [pad_h, pad_w]. ",
|
||||
AttrType::AttributeProto_AttributeType_INTS, int64_t(1))
|
||||
.Attr("kernel_shape",
|
||||
"The size of the kernel along H- and W-axes, [k_h, k_w]. Notice that the kernel is a 2-D tensor. ",
|
||||
AttrType::AttributeProto_AttributeType_INTS, int64_t(1))
|
||||
.Attr("strides",
|
||||
"Stride along H- and W-axes, [stride_h, stride_w].",
|
||||
AttrType::AttributeProto_AttributeType_INTS, int64_t(1))
|
||||
.Attr("value",
|
||||
"One float, indicates the value to be filled, default is 0",
|
||||
AttrType::AttributeProto_AttributeType_FLOAT, float(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -2,43 +2,61 @@
|
|||
|
||||
namespace ONNXIR {
|
||||
|
||||
std::function<void(OperatorSchemaSetter&)> RNNDocGenerator() {
|
||||
std::function<void(OperatorSchemaSetter&)> RNNDocGeneratorInputX() {
|
||||
return [=](OperatorSchemaSetter& schema) {
|
||||
schema.Input("X",
|
||||
"The input sequences packed (and potentially padded) into one 3-D "
|
||||
"tensor with the shape of `[seq_length, batch_size, input_size]`.", "T");
|
||||
schema.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.");
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(OperatorSchemaSetter&)> RNNDocGeneratorInputSeqLen() {
|
||||
return [=](OperatorSchemaSetter& schema) {
|
||||
schema.Input("sequence_lens",
|
||||
"Optional tensor specifying lengths of the sequences in a batch. "
|
||||
"If not specified - assumed all sequences in the batch to have "
|
||||
"length `seq_length`. It has shape `[batch_size]`.", "T1", true /*optional*/);
|
||||
schema.TypeConstraint("T1", { "tensor(int32)" }, "Constrain seq_lens to integer tensor.");
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(OperatorSchemaSetter&)> RNNDocGeneratorInputInitialH() {
|
||||
return [=](OperatorSchemaSetter& schema) {
|
||||
schema.Input("initial_h",
|
||||
"Optional initial value of the hidden. If not specified - assumed "
|
||||
"to be 0. It has shape `[num_directions, batch_size, hidden_size]`.", "T", true /*optional*/);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(OperatorSchemaSetter&)> RNNDocGeneratorAttrOutput() {
|
||||
return [=](OperatorSchemaSetter& schema) {
|
||||
schema.Attr("direction", "Specify if the RNN is forward, reverse, or bidirectional. "
|
||||
"Must be one of forward (default), reverse, or bidirectional.",
|
||||
AttrType::AttributeProto_AttributeType_STRING);
|
||||
schema.Attr("hidden_size", "Number of neurons in the hidden layer",
|
||||
AttrType::AttributeProto_AttributeType_INT);
|
||||
schema.Output("Y",
|
||||
"A tensor that concats all the intermediate output values of the hidden."
|
||||
"It has shape `[seq_length, num_directions, batch_size, hidden_size]`.", "T");
|
||||
schema.Output("Y_h",
|
||||
"The last output value of the hidden. It has shape "
|
||||
"`[num_directions, batch_size, hidden_size]`.", "T");
|
||||
schema.Attr("direction", "Specify if the RNN is forward, reverse, or bidirectional. "
|
||||
"Must be one of forward (default), reverse, or bidirectional.",
|
||||
AttrType::AttributeProto_AttributeType_STRING);
|
||||
schema.Attr("hidden_size", "Number of neurons in the hidden layer",
|
||||
AttrType::AttributeProto_AttributeType_INT);
|
||||
schema.Attr("alpha",
|
||||
"Optional scaling values used by some activation functions.",
|
||||
AttrType::AttributeProto_AttributeType_FLOATS);
|
||||
schema.Attr("beta",
|
||||
"Optional scaling values used by some activation functions.",
|
||||
AttrType::AttributeProto_AttributeType_FLOATS);
|
||||
schema.TypeConstraint("T1", { "tensor(int32)" }, "Constrain seq_lens to integer tensor.");
|
||||
schema.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.");
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: An attribute "output_sequence" missing here per op specification doc.
|
||||
// Check with Radu/Sherlock on this later.
|
||||
std::function<void(OperatorSchemaSetter&)> RNNDocGeneratorActivationArgs() {
|
||||
return [=](OperatorSchemaSetter& schema) {
|
||||
schema.Attr("activation_alpha",
|
||||
"Optional scaling values used by some activation functions.",
|
||||
AttrType::AttributeProto_AttributeType_FLOATS);
|
||||
schema.Attr("activation_beta",
|
||||
"Optional scaling values used by some activation functions.",
|
||||
AttrType::AttributeProto_AttributeType_FLOATS);
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RNN)
|
||||
.Description(R"DOC(
|
||||
Computes an one-layer simple RNN. This operator is usually supported
|
||||
|
@ -64,6 +82,7 @@ namespace ONNXIR {
|
|||
Equations:
|
||||
- Ht = Activation(Wi*Xt + Ri*Ht-1 + Wbi + Rbi)
|
||||
)DOC")
|
||||
.FillUsing(RNNDocGeneratorInputX())
|
||||
.Input("W",
|
||||
"The weight tensor for input gate. Concatenation of `Wi` and `WBi` "
|
||||
"(if bidirectional). The tensor has shape "
|
||||
|
@ -78,10 +97,13 @@ namespace ONNXIR {
|
|||
"`[num_directions, 2*hidden_size]`, Optional: If not specified - assumed "
|
||||
"to be 0.", "T",
|
||||
true)
|
||||
.FillUsing(RNNDocGeneratorInputSeqLen())
|
||||
.FillUsing(RNNDocGeneratorInputInitialH())
|
||||
.Attr("activations", "One (or two if bidirectional) activation function for "
|
||||
"input gate. It must be one of tanh and ReLU. Default `tanh`.",
|
||||
AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.FillUsing(RNNDocGenerator());
|
||||
.FillUsing(RNNDocGeneratorActivationArgs())
|
||||
.FillUsing(RNNDocGeneratorAttrOutput());
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(GRU)
|
||||
.Description(R"DOC(
|
||||
|
@ -113,6 +135,7 @@ namespace ONNXIR {
|
|||
- ht = tanh(Wh*Xt + rt*(Rh*Ht-1 + Rbh) + Wbh)
|
||||
- H = (1 - zt) (.) ht + it (.) Ht-1
|
||||
)DOC")
|
||||
.FillUsing(RNNDocGeneratorInputX())
|
||||
.Input("W",
|
||||
"The weight tensor for the gates. Concatenation of `W[zrh]` and `WB[zrh]` "
|
||||
"(if bidirectional) along dimension 0. This tensor has shape "
|
||||
|
@ -127,11 +150,14 @@ namespace ONNXIR {
|
|||
"has shape `[num_directions, 6*hidden_size]`. Optional: If not specified "
|
||||
"- assumed to be 0", "T",
|
||||
true /*optional*/)
|
||||
.FillUsing(RNNDocGeneratorInputSeqLen())
|
||||
.FillUsing(RNNDocGeneratorInputInitialH())
|
||||
.Attr("activations", "A list of 3 (or 6 if bidirectional) activation functions "
|
||||
"for update, reset, and hidden gates. The activation functions must be "
|
||||
"one of sigmoid and tanh. See the equations for default.",
|
||||
AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.FillUsing(RNNDocGenerator());
|
||||
.FillUsing(RNNDocGeneratorActivationArgs())
|
||||
.FillUsing(RNNDocGeneratorAttrOutput());
|
||||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LSTM)
|
||||
|
@ -169,6 +195,7 @@ namespace ONNXIR {
|
|||
- ot = sigmoid(Wo*Xt + Ro*Ht-1 + Po (.) Ct + Wbo + Rbo)
|
||||
- H = ot (.) tanh(Ct)
|
||||
)DOC")
|
||||
.FillUsing(RNNDocGeneratorInputX())
|
||||
.Input("W",
|
||||
"The weight tensor for the gates. Concatenation of `W[zrh]` and `WB[zrh]` "
|
||||
"(if bidirectional) along dimension 0. This tensor has shape "
|
||||
|
@ -183,10 +210,13 @@ namespace ONNXIR {
|
|||
"has shape `[num_directions, 6*hidden_size]`. Optional: If not specified "
|
||||
"- assumed to be 0", "T",
|
||||
true /*optional*/)
|
||||
.FillUsing(RNNDocGeneratorInputSeqLen())
|
||||
.FillUsing(RNNDocGeneratorInputInitialH())
|
||||
.Attr("activations", "A list of 3 (or 6 if bidirectional) activation functions "
|
||||
"for update, reset, and hidden gates. The activation functions must be "
|
||||
"one of sigmoid and tanh. See the equations for default.",
|
||||
AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.FillUsing(RNNDocGeneratorActivationArgs())
|
||||
.Attr("clip", "Cell clip threshold. Clipping bounds the elements of a tensor "
|
||||
"in the range of [-threshold, +threshold] and is applied to the input "
|
||||
"of activations. No clip if not specified.",
|
||||
|
@ -203,5 +233,8 @@ namespace ONNXIR {
|
|||
"`[num_directions, 3*hidde_size]`. Optional: If not specified - "
|
||||
"assumed to be 0.", "T",
|
||||
true /*optional*/)
|
||||
.FillUsing(RNNDocGenerator());
|
||||
.FillUsing(RNNDocGeneratorAttrOutput())
|
||||
.Output("Y_c",
|
||||
"The last output value of the cell. It has shape "
|
||||
"`[num_directions, batch_size, hidden_size]`.", "T");
|
||||
}
|
|
@ -67,7 +67,7 @@ namespace ONNXIR {
|
|||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.")
|
||||
.Attr("axis", "Which axis to split on", AttrType::AttributeProto_AttributeType_INT)
|
||||
.Attr("split", "Number of tensors to output.", AttrType::AttributeProto_AttributeType_INT);
|
||||
.Attr("split", "Number of tensors to output.", AttrType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
// Taken from ONNX
|
||||
REGISTER_OPERATOR_SCHEMA(Transpose)
|
||||
|
@ -89,11 +89,7 @@ namespace ONNXIR {
|
|||
.Input("axis", "Axis along which to repeat.", "T")
|
||||
.Output("output", "Repeated output.", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.")
|
||||
.Attr("axis", "Axis along which to repeat. Default is 0.",
|
||||
AttrType::AttributeProto_AttributeType_INT, int64_t(0))
|
||||
.Attr("tiles", "Number of repeated copies to make of the input tensor.",
|
||||
AttrType::AttributeProto_AttributeType_INT);
|
||||
"Constrain input and output types to float tensors.");
|
||||
|
||||
// Taken from ONNX
|
||||
REGISTER_OPERATOR_SCHEMA(Concat)
|
||||
|
@ -130,14 +126,23 @@ namespace ONNXIR {
|
|||
.Description("Given data tensor of rank r >= 1, and indices tensor of rank q, gather "
|
||||
"entries of the outer-most dimension of data indexed by indices, and concatenate "
|
||||
"them in an output tensor of rank q + (r - 1). "
|
||||
"Example: data = [ [1.0, 1.2], [2.3, 3.4], [4.5, 5.7] ] "
|
||||
"indices = [ [0, 1], [1, 2] ] "
|
||||
"ouput = [ [ [1.0, 1.2], [2.3, 3.4], ], [ [2.3, 3.4], [4.5, 5.7] ] ] ")
|
||||
"Example 1: data = [ [1.0, 1.2], [2.3, 3.4], [4.5, 5.7], ] "
|
||||
" indices = [ [0, 1], [1, 2], ] "
|
||||
" output = [ [ [1.0, 1.2], [2.3, 3.4], ], [ [2.3, 3.4], [4.5, 5.7], ], ]"
|
||||
"Example 2: data = [ [1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9], ] "
|
||||
" indices = [0, 2], ] axis = 1, "
|
||||
" output = [ [ [1.0, 1.9], [2.3, 3.9], [4.5, 5.9], ], ]")
|
||||
.Input("data", "Tensor of rank r >= 1.", "T")
|
||||
.Input("indices", "Tensor of int32/int64 indices, of any rank q.", "T")
|
||||
.Output("ouput", "Tensor of rank q + (r - 1).", "T")
|
||||
.Input("indices", "Tensor of int32/int64 indices, of any rank q.", "Tind")
|
||||
.Output("output", "Tensor of rank q + (r - 1).", "T")
|
||||
.TypeConstraint("T", { "tensor(float16)", "tensor(float)", "tensor(double)" },
|
||||
"Constrain input and output types to float tensors.");
|
||||
"Constrain input types to float tensors.")
|
||||
.TypeConstraint("Tind", { "tensor(int32)", "tensor(int64)" },
|
||||
"Constrain indices types to float tensors.")
|
||||
.Attr("axis",
|
||||
"Which axis to gather on, defaults to 0. Negative value means counting dimensions "
|
||||
"from the back. Accepted range in [-r, r-1]",
|
||||
AttrType::AttributeProto_AttributeType_INT, int64_t(0));
|
||||
|
||||
// Taken from ONNX
|
||||
REGISTER_OPERATOR_SCHEMA(Squeeze)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
#include "proto/onnx/core/constants.h"
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR {
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ArrayFeatureExtractor)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be selected from", "T1")
|
||||
.Input("Y", "Data to be selected from", "T2")
|
||||
.Output("Z", "Selected data as an array", "T1")
|
||||
|
@ -14,6 +16,7 @@ namespace ONNXIR {
|
|||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Binarizer)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be binarized", "T")
|
||||
.Output("Y", "Binarized output data", "T")
|
||||
.Description(R"DOC(
|
||||
|
@ -23,6 +26,7 @@ namespace ONNXIR {
|
|||
.Attr("threshold", "Values greater than this are set to 1, else set to 0", AttrType::AttributeProto_AttributeType_FLOAT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(CastMap)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "The input values", "T1")
|
||||
.Output("Y", "The output values", "T2")
|
||||
.Description(R"DOC(
|
||||
|
@ -35,6 +39,7 @@ namespace ONNXIR {
|
|||
.Attr("max_map", "if casting from a sparse map, what is the max key in the map", AttrType::AttributeProto_AttributeType_INT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(CategoryMapper)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Input data", "T1")
|
||||
.Output("Y", "Output data, if strings are input, then output is INTS, and vice versa.", "T2")
|
||||
.Description(R"DOC(
|
||||
|
@ -56,8 +61,9 @@ namespace ONNXIR {
|
|||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(DictVectorizer)
|
||||
.Input("X", "The input dictionary", "T")
|
||||
.Output("Y", "The tensor", "tensor(int64)")
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "The input dictionary", "T1")
|
||||
.Output("Y", "The tensor", "T2")
|
||||
.Description(R"DOC(
|
||||
Uses an index mapping to convert a dictionary to an array.
|
||||
The output array will be equal in length to the index mapping vector parameter.
|
||||
|
@ -69,12 +75,14 @@ namespace ONNXIR {
|
|||
For example: if the ``string_vocabulary`` parameter is set to ``["a", "c", "b", "z"]``,
|
||||
then an input of ``{"a": 4, "c": 8}`` will produce an output of ``[4, 8, 0, 0]``.
|
||||
)DOC")
|
||||
.TypeConstraint("T", { "map(string, int64)", "map(int64, string)" }, " allowed types.")
|
||||
.TypeConstraint("T1", { "map(string, int64)", "map(int64, string)", "map(int64, float)", "map(int64, double)", "map(string, float)", "map(string, double)"}, " allowed types.")
|
||||
.TypeConstraint("T2", { "tensor(int64)", "tensor(float)", "tensor(double)", "tensor(string)" }, " allowed types.")
|
||||
.Attr("string_vocabulary", "The vocabulary vector of strings", AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.Attr("int64_vocabulary", "The vocabulary vector of int64s", AttrType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Imputer)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be imputed", "T")
|
||||
.Output("Y", "Imputed output data", "T")
|
||||
.Description(R"DOC(
|
||||
|
@ -90,18 +98,19 @@ namespace ONNXIR {
|
|||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(FeatureVectorizer)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "ordered input tensors", "T")
|
||||
.Output("Y", "flattened feature vectors.", "T")
|
||||
.Output("Y", "flattened output vector.", "T")
|
||||
.Description(R"DOC(
|
||||
Concatenates a list of input tensors of floats into one tensor.
|
||||
Input order in inputs must match inputlist and inputdimensions order.
|
||||
The size of each input in the input list is expressed in inputdimensions.
|
||||
)DOC")
|
||||
.TypeConstraint("T", { "tensor(float)" }, " allowed types.")
|
||||
.Attr("inputlist", "list of string names of the input features, output features will appear in this order", AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.Attr("inputdimensions", "the size of the inputs in the input list", AttrType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LabelEncoder)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be encoded", "T1")
|
||||
.Output("Y", "Encoded output data", "T2")
|
||||
.Description(R"DOC(
|
||||
|
@ -117,6 +126,7 @@ namespace ONNXIR {
|
|||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LinearClassifier)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be classified", "T1")
|
||||
.Output("Y", "Classification outputs (one class per example", "T2")
|
||||
.Output("Z", "Classification outputs (All classes scores per example,N,E", "tensor(float)")
|
||||
|
@ -130,10 +140,11 @@ namespace ONNXIR {
|
|||
.Attr("post_transform", "post eval transform for score, enum 'NONE', 'SOFTMAX', 'LOGISTIC', 'SOFTMAX_ZERO', 'PROBIT'", AttrType::AttributeProto_AttributeType_STRING)
|
||||
.Attr("multi_class", "whether to do OvR or multinomial (0=OvR and is default)", AttrType::AttributeProto_AttributeType_INT)
|
||||
.Attr("classlabels_strings", "class labels if using string labels, size E", AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.Attr("classlabels_int64s", "class labels if using int labels, size E", AttrType::AttributeProto_AttributeType_INTS);
|
||||
.Attr("classlabels_ints", "class labels if using int labels, size E", AttrType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LinearRegressor)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be regressed", "T")
|
||||
.Output("Y", "Regression outputs (one per target, per example", "tensor(float)")
|
||||
.Description(R"DOC(
|
||||
|
@ -152,6 +163,7 @@ namespace ONNXIR {
|
|||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Normalizer)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be encoded", "T")
|
||||
.Output("Y", "encoded output data", "tensor(float)")
|
||||
.Description(R"DOC(
|
||||
|
@ -166,18 +178,22 @@ namespace ONNXIR {
|
|||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(OneHotEncoder)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be encoded", "T")
|
||||
.Output("Y", "encoded output data", "tensor(float)")
|
||||
.Description(R"DOC(
|
||||
Replace the inputs with an array of ones and zeros, where the only
|
||||
one is the zero-based category that was passed in. The total category count
|
||||
will determine the length of the vector. For example if we pass a
|
||||
tensor with a single value of 4, and a category count of 8, the
|
||||
one is the zero-based category that was passed in. The total category count
|
||||
will determine the length of the vector. For example if we pass a
|
||||
tensor with a single value of 4, and a category count of 8, the
|
||||
output will be a tensor with 0,0,0,0,1,0,0,0 .
|
||||
This operator assumes every input in X is of the same category set
|
||||
This operator assumes every input in X is of the same category set
|
||||
(meaning there is only one category count).
|
||||
|
||||
If the input is a tensor of float, int32, or double, the data will be cast
|
||||
to int64s and the cats_int64s category list will be used for the lookups.
|
||||
)DOC")
|
||||
.TypeConstraint("T", { "tensor(string)", "tensor(int64)" }, " allowed types.")
|
||||
.TypeConstraint("T", { "tensor(string)", "tensor(int64)","tensor(int32)", "tensor(float)","tensor(double)" }, "allowed types.")
|
||||
.Attr("cats_int64s", "list of cateogries, ints", AttrType::AttributeProto_AttributeType_INTS)
|
||||
.Attr("cats_strings", "list of cateogries, strings", AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.Attr("zeros", "if true and category is not present, will return all zeros, if false and missing category, operator will return false", AttrType::AttributeProto_AttributeType_INT);
|
||||
|
@ -185,6 +201,7 @@ namespace ONNXIR {
|
|||
|
||||
// Input: X, output: Y
|
||||
REGISTER_OPERATOR_SCHEMA(Scaler)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be scaled", "T")
|
||||
.Output("Y", "Scaled output data", "tensor(float)")
|
||||
.Description(R"DOC(
|
||||
|
@ -195,6 +212,7 @@ namespace ONNXIR {
|
|||
.Attr("offset", "first, offset by thisfirst, offset by this, can be one value or a separate value for each feature", AttrType::AttributeProto_AttributeType_FLOATS);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(SVMClassifier)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be classified", "T1")
|
||||
.Output("Y", "Classification outputs, one class per example", "T2")
|
||||
.Output("Z", "Classification outputs, All classes scores per example,N,E*(E-1)/2 if dual scores, or E if probabilities are used.", "tensor(float)")
|
||||
|
@ -214,14 +232,15 @@ namespace ONNXIR {
|
|||
.Attr("rho", "", AttrType::AttributeProto_AttributeType_FLOATS)
|
||||
.Attr("post_transform", "post eval transform for score, enum 'NONE', 'SOFTMAX', 'LOGISTIC', 'SOFTMAX_ZERO', 'PROBIT'", AttrType::AttributeProto_AttributeType_STRING)
|
||||
.Attr("classlabels_strings", "class labels if using string labels", AttrType::AttributeProto_AttributeType_STRINGS)
|
||||
.Attr("classlabels_int64s", "class labels if using int labels", AttrType::AttributeProto_AttributeType_INTS);
|
||||
.Attr("classlabels_ints", "class labels if using int labels", AttrType::AttributeProto_AttributeType_INTS);
|
||||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(SVMRegressor)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Input N,F", "T")
|
||||
.Output("Y", "All target scores, N,E", "tensor(float)")
|
||||
.Description(R"DOC(
|
||||
SVM regressor. Also supports oneclass svm.
|
||||
SVM regressor. Also supports oneclass svm.
|
||||
)DOC")
|
||||
.TypeConstraint("T", { "tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)" }, " allowed types.")
|
||||
.Attr("kernel_type", "enum 'LINEAR', 'POLY', 'RBF', 'SIGMOID', defaults to linear", AttrType::AttributeProto_AttributeType_STRING)
|
||||
|
@ -235,18 +254,19 @@ namespace ONNXIR {
|
|||
.Attr("one_class", "If this regressor is a oneclass svm set this param to 1, otherwise use 0 (default is zero)", AttrType::AttributeProto_AttributeType_INT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(TreeEnsembleClassifier)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Data to be classified", "T1")
|
||||
.Output("Y", "Classification outputs (one class per example", "T2")
|
||||
.Output("Z", "Classification outputs (All classes scores per example,N,E", "tensor(float)")
|
||||
.Description(R"DOC(
|
||||
Tree Ensemble classifier. Returns the top class for each input in N.
|
||||
All args with nodes_ are fields of a tuple of tree nodes, and
|
||||
All args with nodes_ are fields of a tuple of tree nodes, and
|
||||
it is assumed they are the same length, and an index i will decode the
|
||||
tuple across these inputs. Each node id can appear only once
|
||||
tuple across these inputs. Each node id can appear only once
|
||||
for each tree id."
|
||||
All fields prefixed with class_ are tuples of votes at the leaves.
|
||||
A leaf may have multiple votes, where each vote is weighted by
|
||||
the associated class_weights index.
|
||||
the associated class_weights index.
|
||||
It is expected that either classlabels_strings or classlabels_INTS
|
||||
will be passed and the class_ids are an index into this list.
|
||||
Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF.
|
||||
|
@ -273,17 +293,18 @@ namespace ONNXIR {
|
|||
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(TreeEnsembleRegressor)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "Input N,F", "T")
|
||||
.Output("Y", "NxE floats", "tensor(float)")
|
||||
.Description(R"DOC(
|
||||
Tree Ensemble regressor. Returns the regressed values for each input in N.
|
||||
All args with nodes_ are fields of a tuple of tree nodes, and
|
||||
All args with nodes_ are fields of a tuple of tree nodes, and
|
||||
it is assumed they are the same length, and an index i will decode the
|
||||
tuple across these inputs. Each node id can appear only once
|
||||
tuple across these inputs. Each node id can appear only once
|
||||
for each tree id.
|
||||
All fields prefixed with target_ are tuples of votes at the leaves.
|
||||
A leaf may have multiple votes, where each vote is weighted by
|
||||
the associated target_weights index.
|
||||
the associated target_weights index.
|
||||
All trees must have their node ids start at 0 and increment by 1.
|
||||
Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF
|
||||
)DOC")
|
||||
|
@ -306,14 +327,15 @@ namespace ONNXIR {
|
|||
.Attr("aggregate_function", "post eval transform for score, enum 'AVERAGE', 'SUM', 'MIN', 'MAX'", AttrType::AttributeProto_AttributeType_STRING)
|
||||
.Attr("base_values", "base values for regression, added to final score, size must be the same as n_outputs or can be left unassigned (assumed 0)", AttrType::AttributeProto_AttributeType_FLOATS);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(VecDictionizer)
|
||||
REGISTER_OPERATOR_SCHEMA(ZipMap)
|
||||
.SetDomain(c_mlDomain)
|
||||
.Input("X", "The input values", "tensor(float)")
|
||||
.Output("Y", "The output map", "T")
|
||||
.Description(R"DOC(
|
||||
Makes a map from the input and the attributes.
|
||||
Makes a map from the input and the attributes.
|
||||
Assumes input 0 are the values, and the keys are specified by the attributes.
|
||||
Must provide keys in either classlabels_strings or classlabels_int64s (but not both).
|
||||
Input 0 may have a batch size larger than 1,
|
||||
Input 0 may have a batch size larger than 1,
|
||||
but each input in the batch must be the size of the keys specified by the attributes.
|
||||
The order of the input and attributes determines the key-value mapping.
|
||||
)DOC")
|
||||
|
|
Загрузка…
Ссылка в новой задаче