Export custom attributes on Times node to ONNX file (#3709)
This commit is contained in:
Родитель
13f3b5fd3c
Коммит
a4621a8a5d
|
@ -7312,6 +7312,80 @@ std::vector<int64_t> CNTKToONNXHelper::BroadcastInputs(std::vector<onnxruntime::
|
|||
return broadcast_shape;
|
||||
}
|
||||
|
||||
// Forward declaration
|
||||
static std::string SerializeDictionaryValueToString(const DictionaryValue& val);
|
||||
|
||||
static std::string SerializeVectorToString(const std::vector<DictionaryValue>& vals)
|
||||
{
|
||||
bool first = true;
|
||||
std::string str = "[";
|
||||
for (const auto& v : vals)
|
||||
{
|
||||
if (first)
|
||||
{
|
||||
first = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
str += ",";
|
||||
}
|
||||
str += SerializeDictionaryValueToString(v);
|
||||
}
|
||||
str += "]";
|
||||
return str;
|
||||
}
|
||||
|
||||
static std::string SerializeDictionaryToString(const Dictionary& dict)
|
||||
{
|
||||
bool first = true;
|
||||
std::string str = "{";
|
||||
for (const auto& kv : dict)
|
||||
{
|
||||
auto key = Microsoft::MSR::CNTK::ToLegacyString(Microsoft::MSR::CNTK::ToUTF8(kv.first));
|
||||
if (first)
|
||||
{
|
||||
first = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
str += ",";
|
||||
}
|
||||
str += "\"" + key + "\":" + SerializeDictionaryValueToString(kv.second);
|
||||
}
|
||||
str += "}";
|
||||
return str;
|
||||
}
|
||||
|
||||
// Valid DictionaryValues that can be converted to string:
|
||||
// Bool, Int, SzieT, Float, Double, String, and Vector or Dictionary of aforementioned types.
|
||||
static std::string SerializeDictionaryValueToString(const DictionaryValue& val)
|
||||
{
|
||||
auto value_type = val.ValueType();
|
||||
switch (value_type)
|
||||
{
|
||||
case DictionaryValue::Type::Bool:
|
||||
return val.Value<bool>() ? "true" : "false";
|
||||
case DictionaryValue::Type::Int:
|
||||
return std::to_string(val.Value<int>());
|
||||
case DictionaryValue::Type::SizeT:
|
||||
return std::to_string(val.Value<size_t>());
|
||||
case DictionaryValue::Type::Float:
|
||||
return std::to_string(val.Value<float>());
|
||||
case DictionaryValue::Type::Double:
|
||||
return std::to_string(val.Value<double>());
|
||||
case DictionaryValue::Type::String:
|
||||
return "\"" + Microsoft::MSR::CNTK::ToLegacyString(Microsoft::MSR::CNTK::ToUTF8(val.Value<std::wstring>())) + "\"";
|
||||
case DictionaryValue::Type::Dictionary:
|
||||
return SerializeDictionaryToString(val.Value<Dictionary>());
|
||||
case DictionaryValue::Type::Vector:
|
||||
return SerializeVectorToString(val.Value<std::vector<DictionaryValue>>());
|
||||
default:
|
||||
// skip all other cases;
|
||||
return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime::Graph* graph, const std::vector<onnxruntime::NodeArg *>& inputs, const std::vector<onnxruntime::NodeArg *>& outputs)
|
||||
{
|
||||
onnxruntime::Node* node = nullptr;
|
||||
|
@ -7348,6 +7422,12 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
|
|||
bool input1HasBatchAxis = (input1Rank - reductionRank) == 2;
|
||||
bool input2HasBatchAxis = (input2Rank - reductionRank) == 2;
|
||||
|
||||
// some Times nodes may carry custom attributes that need to be saved in ONNX as NodeProto.doc_string
|
||||
std::string customAttrsStr =
|
||||
src->GetCustomAttributes().Size() == 0 ?
|
||||
"" :
|
||||
"{\"custom_attributes\":" + SerializeDictionaryToString(src->GetCustomAttributes()) + "}";
|
||||
|
||||
if (reductionRank > 1 || py_api_output_rank_argument > 1) // We need to insert reshape.
|
||||
{
|
||||
onnx::TypeProto matMulInput1Reshape, matMulInput2Reshape, matMulOutputShape;
|
||||
|
@ -7372,17 +7452,17 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
|
|||
// these dimensions were reduced in ReduceRank
|
||||
onnxruntime::NodeArg &matMulOutputNodeArg =
|
||||
graph->GetOrCreateNodeArg(nodeName + string("_reshape"), &matMulOutputShape);
|
||||
graph->AddNode(nodeName, "MatMul", "", {&inputOutput1Arg, &inputOutput2Arg}, {&matMulOutputNodeArg});
|
||||
graph->AddNode(nodeName, "MatMul", customAttrsStr, {&inputOutput1Arg, &inputOutput2Arg}, {&matMulOutputNodeArg});
|
||||
// node = graph->AddNode(nodeName + "_reshape", "Reshape", "", { input, &shapeInputArg }, { output });
|
||||
std::vector<int64_t> finalOutputShape = ToINTS(*outputs[0]->TypeAsProto());
|
||||
node = AddReshapeNodeImpl(graph, nodeName + "_output_reshape", &matMulOutputNodeArg, outputs[0], finalOutputShape);
|
||||
}
|
||||
else
|
||||
node = &graph->AddNode(nodeName, ToOPName(src), "", {&inputOutput1Arg, &inputOutput2Arg}, outputs);
|
||||
node = &graph->AddNode(nodeName, ToOPName(src), customAttrsStr, {&inputOutput1Arg, &inputOutput2Arg}, outputs);
|
||||
}
|
||||
else
|
||||
{
|
||||
node = &graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||
node = &graph->AddNode(nodeName, ToOPName(src), customAttrsStr, orderedInputs, outputs);
|
||||
}
|
||||
}
|
||||
else if (src->OpName() == L"LayerNormalization")
|
||||
|
|
Загрузка…
Ссылка в новой задаче