Fix ONNX compatibility issues when exporting Select.
This commit is contained in:
Родитель
cb99bd6325
Коммит
0642d734c3
|
@ -2555,16 +2555,16 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
std::vector<ONNXIR::NodeArg> varInputs;
|
||||
std::vector<ONNXIR::NodeArg> varOutputs;
|
||||
|
||||
varOutputs.push_back({ inputArg });
|
||||
if (input.IsParameter() || input.IsConstant())
|
||||
{
|
||||
auto srcTensor = input.IsParameter() ? Parameter(input).Value() : Constant(input).Value();
|
||||
varOutputs.push_back({ inputArg });
|
||||
if (input.IsParameter() || input.IsConstant())
|
||||
{
|
||||
auto srcTensor = input.IsParameter() ? Parameter(input).Value() : Constant(input).Value();
|
||||
|
||||
onnx::TensorProto dstTensor;
|
||||
dstTensor.set_name(inputName);
|
||||
CopyTensor(srcTensor, dstTensor, &inputArgType);
|
||||
onnx::TensorProto dstTensor;
|
||||
dstTensor.set_name(inputName);
|
||||
CopyTensor(srcTensor, dstTensor, &inputArgType);
|
||||
|
||||
graph->AddInitialTensor(dstTensor);
|
||||
graph->AddInitialTensor(dstTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2575,6 +2575,7 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
else if (input.IsOutput())
|
||||
CreateNode(input.Owner(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::ProcessOutputs(const FunctionPtr& src,
|
||||
std::vector<ONNXIR::NodeArg>& outputs)
|
||||
|
@ -3875,6 +3876,7 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr &src,
|
|||
{
|
||||
std::vector<ONNXIR::NodeArg> inputs;
|
||||
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs);
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
std::vector<ONNXIR::NodeArg> outputs;
|
||||
ProcessOutputs(src, outputs);
|
||||
|
@ -3889,6 +3891,18 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr &src,
|
|||
ONNXIR::NodeArg absOutputArg(outputName + "_abs_out", nullptr);
|
||||
graph->AddNode(outputName + "_abs", "Abs", "", { inputs[0] }, { absOutputArg });
|
||||
|
||||
// Add a Clip node equivalent to min(abs(flag), 1).
|
||||
ONNXIR::NodeArg clipOutputArg(outputName + "_clip_out", nullptr);
|
||||
ONNXIR::Node* clipNode = graph->AddNode(outputName + "_clip", "Clip", "", { absOutputArg }, { clipOutputArg });
|
||||
clipNode->AddAttribute("min", 0.0f); // Should be unnecesary for ONNX, but currently required by CNTK.
|
||||
clipNode->AddAttribute("max", 1.0f);
|
||||
|
||||
ONNXIR::NodeArg ceilOutputArg(outputName + "_ceil_out", nullptr);
|
||||
graph->AddNode(outputName + "_ceil", "Ceil", "", { clipOutputArg }, { ceilOutputArg });
|
||||
|
||||
ONNXIR::NodeArg mulTrueOutputArg(outputName + "_mul_true_out", nullptr);
|
||||
graph->AddNode(outputName + "_mul_true", "Mul", "", { ceilOutputArg, inputs[1] }, { mulTrueOutputArg });
|
||||
|
||||
ONNXIR::NodeArg oneOutputArg(outputName + "_one_out", nullptr);
|
||||
ONNXIR::Node* oneNode = graph->AddNode(outputName + "_one", "Constant", "", {}, { oneOutputArg });
|
||||
onnx::TensorProto oneTensor;
|
||||
|
@ -3896,20 +3910,17 @@ ONNXIR::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr &src,
|
|||
oneTensor.add_float_data(1.0f);
|
||||
oneNode->AddAttribute("value", oneTensor);
|
||||
|
||||
ONNXIR::NodeArg minOutputArg(outputName + "_min_out", nullptr);
|
||||
graph->AddNode(outputName + "_min", "Min", "", { absOutputArg, oneOutputArg }, { minOutputArg });
|
||||
// ONNX can broadcast only the second input for element-wise operations,
|
||||
// so instead of 1 - flag01 use -(flag01 - 1).
|
||||
ONNXIR::NodeArg subOneOutputArg(outputName + "_sub_one_out", nullptr);
|
||||
ONNXIR::Node* subOneNode = graph->AddNode(outputName + "_sub_one", "Sub", "", { ceilOutputArg, oneOutputArg }, { subOneOutputArg });
|
||||
subOneNode->AddAttribute("broadcast", static_cast<int64_t>(1));
|
||||
|
||||
ONNXIR::NodeArg ceilOutputArg(outputName + "_ceil_out", nullptr);
|
||||
graph->AddNode(outputName + "_ceil", "Ceil", "", { minOutputArg }, { ceilOutputArg });
|
||||
|
||||
ONNXIR::NodeArg mulTrueOutputArg(outputName + "_mul_true_out", nullptr);
|
||||
graph->AddNode(outputName + "_mul_true", "Mul", "", { ceilOutputArg, inputs[1] }, { mulTrueOutputArg });
|
||||
|
||||
ONNXIR::NodeArg oneSubOutputArg(outputName + "_one_sub_out", nullptr);
|
||||
graph->AddNode(outputName + "_one_sub", "Sub", "", { oneOutputArg, ceilOutputArg }, { oneSubOutputArg });
|
||||
ONNXIR::NodeArg negOutputArg(outputName + "_neg_out", nullptr);
|
||||
graph->AddNode(outputName + "_neg", "Neg", "", { subOneOutputArg }, { negOutputArg });
|
||||
|
||||
ONNXIR::NodeArg mulFalseOutputArg(outputName + "_mul_false_out", nullptr);
|
||||
graph->AddNode(outputName + "_mul_false", "Mul", "", { oneSubOutputArg, inputs[2] }, { mulFalseOutputArg });
|
||||
graph->AddNode(outputName + "_mul_false", "Mul", "", { negOutputArg, inputs[2] }, { mulFalseOutputArg });
|
||||
|
||||
ONNXIR::Node* sumNode = graph->AddNode(outputName + "_sum", "Sum", "", { mulTrueOutputArg, mulFalseOutputArg }, { outputs[0] });
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче