Export custom attributes on Times node to ONNX file (#3709)

This commit is contained in:
KeDengMS 2019-07-05 13:33:45 -07:00 коммит произвёл GitHub
Родитель 13f3b5fd3c
Коммит a4621a8a5d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 83 добавлений и 3 удалений

Просмотреть файл

@ -7312,6 +7312,80 @@ std::vector<int64_t> CNTKToONNXHelper::BroadcastInputs(std::vector<onnxruntime::
return broadcast_shape;
}
// Forward declaration
static std::string SerializeDictionaryValueToString(const DictionaryValue& val);
static std::string SerializeVectorToString(const std::vector<DictionaryValue>& vals)
{
bool first = true;
std::string str = "[";
for (const auto& v : vals)
{
if (first)
{
first = false;
}
else
{
str += ",";
}
str += SerializeDictionaryValueToString(v);
}
str += "]";
return str;
}
static std::string SerializeDictionaryToString(const Dictionary& dict)
{
bool first = true;
std::string str = "{";
for (const auto& kv : dict)
{
auto key = Microsoft::MSR::CNTK::ToLegacyString(Microsoft::MSR::CNTK::ToUTF8(kv.first));
if (first)
{
first = false;
}
else
{
str += ",";
}
str += "\"" + key + "\":" + SerializeDictionaryValueToString(kv.second);
}
str += "}";
return str;
}
// Valid DictionaryValues that can be converted to string:
// Bool, Int, SzieT, Float, Double, String, and Vector or Dictionary of aforementioned types.
static std::string SerializeDictionaryValueToString(const DictionaryValue& val)
{
auto value_type = val.ValueType();
switch (value_type)
{
case DictionaryValue::Type::Bool:
return val.Value<bool>() ? "true" : "false";
case DictionaryValue::Type::Int:
return std::to_string(val.Value<int>());
case DictionaryValue::Type::SizeT:
return std::to_string(val.Value<size_t>());
case DictionaryValue::Type::Float:
return std::to_string(val.Value<float>());
case DictionaryValue::Type::Double:
return std::to_string(val.Value<double>());
case DictionaryValue::Type::String:
return "\"" + Microsoft::MSR::CNTK::ToLegacyString(Microsoft::MSR::CNTK::ToUTF8(val.Value<std::wstring>())) + "\"";
case DictionaryValue::Type::Dictionary:
return SerializeDictionaryToString(val.Value<Dictionary>());
case DictionaryValue::Type::Vector:
return SerializeVectorToString(val.Value<std::vector<DictionaryValue>>());
default:
// skip all other cases;
return "";
}
return "";
}
onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime::Graph* graph, const std::vector<onnxruntime::NodeArg *>& inputs, const std::vector<onnxruntime::NodeArg *>& outputs)
{
onnxruntime::Node* node = nullptr;
@ -7348,6 +7422,12 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
bool input1HasBatchAxis = (input1Rank - reductionRank) == 2;
bool input2HasBatchAxis = (input2Rank - reductionRank) == 2;
// some Times nodes may carry custom attributes that need to be saved in ONNX as NodeProto.doc_string
std::string customAttrsStr =
src->GetCustomAttributes().Size() == 0 ?
"" :
"{\"custom_attributes\":" + SerializeDictionaryToString(src->GetCustomAttributes()) + "}";
if (reductionRank > 1 || py_api_output_rank_argument > 1) // We need to insert reshape.
{
onnx::TypeProto matMulInput1Reshape, matMulInput2Reshape, matMulOutputShape;
@ -7372,17 +7452,17 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
// these dimensions were reduced in ReduceRank
onnxruntime::NodeArg &matMulOutputNodeArg =
graph->GetOrCreateNodeArg(nodeName + string("_reshape"), &matMulOutputShape);
graph->AddNode(nodeName, "MatMul", "", {&inputOutput1Arg, &inputOutput2Arg}, {&matMulOutputNodeArg});
graph->AddNode(nodeName, "MatMul", customAttrsStr, {&inputOutput1Arg, &inputOutput2Arg}, {&matMulOutputNodeArg});
// node = graph->AddNode(nodeName + "_reshape", "Reshape", "", { input, &shapeInputArg }, { output });
std::vector<int64_t> finalOutputShape = ToINTS(*outputs[0]->TypeAsProto());
node = AddReshapeNodeImpl(graph, nodeName + "_output_reshape", &matMulOutputNodeArg, outputs[0], finalOutputShape);
}
else
node = &graph->AddNode(nodeName, ToOPName(src), "", {&inputOutput1Arg, &inputOutput2Arg}, outputs);
node = &graph->AddNode(nodeName, ToOPName(src), customAttrsStr, {&inputOutput1Arg, &inputOutput2Arg}, outputs);
}
else
{
node = &graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
node = &graph->AddNode(nodeName, ToOPName(src), customAttrsStr, orderedInputs, outputs);
}
}
else if (src->OpName() == L"LayerNormalization")