Fix ONNX compatibility issues when exporting Select.

This commit is contained in:
Sergii Dymchenko 2018-05-23 11:37:43 -07:00
Родитель cb99bd6325
Коммит 0642d734c3
1 изменённых файлов: 30 добавлений и 19 удалений

Просмотреть файл

@ -2555,16 +2555,16 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
std::vector<ONNXIR::NodeArg> varInputs;
std::vector<ONNXIR::NodeArg> varOutputs;
varOutputs.push_back({ inputArg });
if (input.IsParameter() || input.IsConstant())
{
auto srcTensor = input.IsParameter() ? Parameter(input).Value() : Constant(input).Value();
varOutputs.push_back({ inputArg });
if (input.IsParameter() || input.IsConstant())
{
auto srcTensor = input.IsParameter() ? Parameter(input).Value() : Constant(input).Value();
onnx::TensorProto dstTensor;
dstTensor.set_name(inputName);
CopyTensor(srcTensor, dstTensor, &inputArgType);
onnx::TensorProto dstTensor;
dstTensor.set_name(inputName);
CopyTensor(srcTensor, dstTensor, &inputArgType);
graph->AddInitialTensor(dstTensor);
graph->AddInitialTensor(dstTensor);
}
}
}
@ -2575,6 +2575,7 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
else if (input.IsOutput())
CreateNode(input.Owner(), graph, functionNodes, variableNodes, compositeOutputsMap);
}
}
void CNTKToONNXHelper::ProcessOutputs(const FunctionPtr& src,
std::vector<ONNXIR::NodeArg>& outputs)
@ -3875,6 +3876,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr &src,
{
std::vector<ONNXIR::NodeArg> inputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs);
assert(inputs.size() == 3);
std::vector<ONNXIR::NodeArg> outputs;
ProcessOutputs(src, outputs);
@ -3889,6 +3891,18 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr &src,
ONNXIR::NodeArg absOutputArg(outputName + "_abs_out", nullptr);
graph->AddNode(outputName + "_abs", "Abs", "", { inputs[0] }, { absOutputArg });
// Add a Clip node equivalent to min(abs(flag), 1).
ONNXIR::NodeArg clipOutputArg(outputName + "_clip_out", nullptr);
ONNXIR::Node* clipNode = graph->AddNode(outputName + "_clip", "Clip", "", { absOutputArg }, { clipOutputArg });
clipNode->AddAttribute("min", 0.0f); // Should be unnecesary for ONNX, but currently required by CNTK.
clipNode->AddAttribute("max", 1.0f);
ONNXIR::NodeArg ceilOutputArg(outputName + "_ceil_out", nullptr);
graph->AddNode(outputName + "_ceil", "Ceil", "", { clipOutputArg }, { ceilOutputArg });
ONNXIR::NodeArg mulTrueOutputArg(outputName + "_mul_true_out", nullptr);
graph->AddNode(outputName + "_mul_true", "Mul", "", { ceilOutputArg, inputs[1] }, { mulTrueOutputArg });
ONNXIR::NodeArg oneOutputArg(outputName + "_one_out", nullptr);
ONNXIR::Node* oneNode = graph->AddNode(outputName + "_one", "Constant", "", {}, { oneOutputArg });
onnx::TensorProto oneTensor;
@ -3896,20 +3910,17 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr &src,
oneTensor.add_float_data(1.0f);
oneNode->AddAttribute("value", oneTensor);
ONNXIR::NodeArg minOutputArg(outputName + "_min_out", nullptr);
graph->AddNode(outputName + "_min", "Min", "", { absOutputArg, oneOutputArg }, { minOutputArg });
// ONNX can broadcast only the second input for element-wise operations,
// so instead of 1 - flag01 use -(flag01 - 1).
ONNXIR::NodeArg subOneOutputArg(outputName + "_sub_one_out", nullptr);
ONNXIR::Node* subOneNode = graph->AddNode(outputName + "_sub_one", "Sub", "", { ceilOutputArg, oneOutputArg }, { subOneOutputArg });
subOneNode->AddAttribute("broadcast", static_cast<int64_t>(1));
ONNXIR::NodeArg ceilOutputArg(outputName + "_ceil_out", nullptr);
graph->AddNode(outputName + "_ceil", "Ceil", "", { minOutputArg }, { ceilOutputArg });
ONNXIR::NodeArg mulTrueOutputArg(outputName + "_mul_true_out", nullptr);
graph->AddNode(outputName + "_mul_true", "Mul", "", { ceilOutputArg, inputs[1] }, { mulTrueOutputArg });
ONNXIR::NodeArg oneSubOutputArg(outputName + "_one_sub_out", nullptr);
graph->AddNode(outputName + "_one_sub", "Sub", "", { oneOutputArg, ceilOutputArg }, { oneSubOutputArg });
ONNXIR::NodeArg negOutputArg(outputName + "_neg_out", nullptr);
graph->AddNode(outputName + "_neg", "Neg", "", { subOneOutputArg }, { negOutputArg });
ONNXIR::NodeArg mulFalseOutputArg(outputName + "_mul_false_out", nullptr);
graph->AddNode(outputName + "_mul_false", "Mul", "", { oneSubOutputArg, inputs[2] }, { mulFalseOutputArg });
graph->AddNode(outputName + "_mul_false", "Mul", "", { negOutputArg, inputs[2] }, { mulFalseOutputArg });
ONNXIR::Node* sumNode = graph->AddNode(outputName + "_sum", "Sum", "", { mulTrueOutputArg, mulFalseOutputArg }, { outputs[0] });