Support logPlus(log_add_exp) export to ONNX
* ONNX supports similar op ReduceLogSumExp. Conversions are added when exporting. * Refactored CNTKToONNXHelper::BroadcastInputsIfNeeded to support more generalized cases.
This commit is contained in:
Родитель
c2072cc4ab
Коммит
a36fae88bb
|
@ -116,7 +116,8 @@ private:
|
||||||
static onnxruntime::Node *AddTransposeNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, const std::vector<int64_t> &perm,
|
static onnxruntime::Node *AddTransposeNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, const std::vector<int64_t> &perm,
|
||||||
onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName);
|
onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName);
|
||||||
|
|
||||||
static void BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg *> &orderedInputs, const FunctionPtr& src, onnxruntime::Graph* graph);
|
static std::vector<int64_t> BroadcastInputs(std::vector<onnxruntime::NodeArg *> &orderedInputs, const std::set<int64_t>& ignoreAxes,
|
||||||
|
const FunctionPtr& src, onnxruntime::Graph* graph);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Insert a reshape node in front of a given node and its output node arg
|
// Insert a reshape node in front of a given node and its output node arg
|
||||||
|
@ -4104,17 +4105,15 @@ std::vector<int64_t> GetShapeFromNodeArg(onnxruntime::NodeArg *nodeArg)
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
// CNTK splice allows broadcast of inputs before applying concatination.
|
// CNTK splice allows broadcast of inputs before applying concatenation.
|
||||||
// ONNX Concat is limited to matching input shape cases
|
// ONNX Concat is limited to matching input shape cases
|
||||||
// i.e. inputs' dimensions shall be the equal except for the concatination axis.
|
// i.e. inputs' dimensions shall be the equal except for the concatenation axis.
|
||||||
// for an example, see test_Concat_With_Broadcast in onnx_op_test.py.
|
// for an example, see test_Concat_With_Broadcast in onnx_op_test.py.
|
||||||
void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg *> &orderedInputs, const FunctionPtr& src, onnxruntime::Graph* graph)
|
// This function broadcasts the inputs for axes excluding ignoreAxes.
|
||||||
|
// Returns the broadcasted shape.
|
||||||
|
std::vector<int64_t> CNTKToONNXHelper::BroadcastInputs(std::vector<onnxruntime::NodeArg *> &orderedInputs, const std::set<int64_t> &ignoreAxes,
|
||||||
|
const FunctionPtr& src, onnxruntime::Graph* graph)
|
||||||
{
|
{
|
||||||
if (src->OpName() != L"Splice")
|
|
||||||
return;
|
|
||||||
|
|
||||||
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
|
||||||
int64_t concatAxis = ConvertAxisToOnnxBroadcastOfOp(axis, src);
|
|
||||||
std::vector<std::vector<int64_t>> shapes;
|
std::vector<std::vector<int64_t>> shapes;
|
||||||
int max_rank = 0;
|
int max_rank = 0;
|
||||||
for (auto nodeArg : orderedInputs)
|
for (auto nodeArg : orderedInputs)
|
||||||
|
@ -4130,13 +4129,13 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
|
||||||
for (int index_to_shape_i = 0; index_to_shape_i < shape_i.size(); index_to_shape_i++)
|
for (int index_to_shape_i = 0; index_to_shape_i < shape_i.size(); index_to_shape_i++)
|
||||||
{
|
{
|
||||||
int onnx_axis = index_to_shape_i + (max_rank - shape_i.size());
|
int onnx_axis = index_to_shape_i + (max_rank - shape_i.size());
|
||||||
if (onnx_axis == concatAxis)
|
if (ignoreAxes.find(onnx_axis) != ignoreAxes.end())
|
||||||
// only check and update no-concat_axis dimensions
|
// only check and update non ignoreAxes dimensions
|
||||||
continue;
|
continue;
|
||||||
else if (broadcast_shape[onnx_axis] == 1)
|
else if (broadcast_shape[onnx_axis] == 1)
|
||||||
broadcast_shape[onnx_axis] = shape_i[index_to_shape_i];
|
broadcast_shape[onnx_axis] = shape_i[index_to_shape_i];
|
||||||
else if (broadcast_shape[onnx_axis] != shape_i[index_to_shape_i] && shape_i[index_to_shape_i] != 1)
|
else if (broadcast_shape[onnx_axis] != shape_i[index_to_shape_i] && shape_i[index_to_shape_i] != 1)
|
||||||
LogicError("Invalid splice inputs shape");
|
LogicError("Invalid broadcast inputs shape");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4152,7 +4151,7 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
|
||||||
|
|
||||||
for (int onnx_axis = 0; onnx_axis < shape_i.size(); onnx_axis++)
|
for (int onnx_axis = 0; onnx_axis < shape_i.size(); onnx_axis++)
|
||||||
{
|
{
|
||||||
if (onnx_axis != concatAxis && shape_i[onnx_axis] != broadcast_shape[onnx_axis])
|
if (ignoreAxes.find(onnx_axis) == ignoreAxes.end() && shape_i[onnx_axis] != broadcast_shape[onnx_axis])
|
||||||
{
|
{
|
||||||
shape_i[onnx_axis] = broadcast_shape[onnx_axis];
|
shape_i[onnx_axis] = broadcast_shape[onnx_axis];
|
||||||
need_broadcast = true;
|
need_broadcast = true;
|
||||||
|
@ -4171,6 +4170,8 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
|
||||||
onnxruntime::Node *node = AddAddNode(*nodeArg, nodeArg2, graph, out_arg_name);
|
onnxruntime::Node *node = AddAddNode(*nodeArg, nodeArg2, graph, out_arg_name);
|
||||||
orderedInputs[i] = const_cast<NodeArg*>(node->OutputDefs()[0]);
|
orderedInputs[i] = const_cast<NodeArg*>(node->OutputDefs()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return broadcast_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime::Graph* graph, const std::vector<onnxruntime::NodeArg *>& inputs, const std::vector<onnxruntime::NodeArg *>& outputs)
|
onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime::Graph* graph, const std::vector<onnxruntime::NodeArg *>& inputs, const std::vector<onnxruntime::NodeArg *>& outputs)
|
||||||
|
@ -4263,9 +4264,50 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
|
||||||
}
|
}
|
||||||
else if (src->OpName() == L"Splice")
|
else if (src->OpName() == L"Splice")
|
||||||
{
|
{
|
||||||
BroadcastInputsIfNeeded(orderedInputs, src, graph);
|
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||||
|
BroadcastInputs(orderedInputs, { ConvertAxisToOnnxBroadcastOfOp(axis, src) }, src, graph);
|
||||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||||
}
|
}
|
||||||
|
else if (src->OpName() == L"LogPlus")
|
||||||
|
{
|
||||||
|
// CNTK LogPlus is the equivalent to numpy.logaddexp
|
||||||
|
// ONNX has a different but similar op: ReduceLogSumExp
|
||||||
|
onnx::TensorProto_DataType tensorType = orderedInputs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||||
|
std::vector<int64_t> broadcastShape = BroadcastInputs(orderedInputs, /*ignoreAxes=*/{}, src, graph);
|
||||||
|
// Now both inputs should have the same shape.
|
||||||
|
// Add another axis in front. This will be the axis to be reduced over later.
|
||||||
|
std::vector<int64_t> unsqueezeOutputShape = broadcastShape;
|
||||||
|
unsqueezeOutputShape.insert(unsqueezeOutputShape.begin(), 1);
|
||||||
|
std::vector<int64_t> concatOutputShape = broadcastShape;
|
||||||
|
concatOutputShape.insert(concatOutputShape.begin(), 2);
|
||||||
|
|
||||||
|
auto unsqueezeInputFunc = [&](int inputIndex) -> onnxruntime::NodeArg& {
|
||||||
|
onnx::TypeProto outputArgType = ToTypeProto(unsqueezeOutputShape);
|
||||||
|
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
|
||||||
|
onnxruntime::NodeArg &unsqueezeTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_unsqueeze" + std::to_string(inputIndex) + "_output0"), &outputArgType);
|
||||||
|
onnxruntime::Node* unsqueezeNode = graph->AddNode(nodeName + string("_Unsqueeze") + std::to_string(inputIndex), "Unsqueeze", "", { orderedInputs[inputIndex] }, { &unsqueezeTensorOutputArg });
|
||||||
|
unsqueezeNode->AddAttribute("axes", std::vector<int64_t>(1, 0));
|
||||||
|
return unsqueezeTensorOutputArg;
|
||||||
|
};
|
||||||
|
|
||||||
|
onnxruntime::NodeArg &unsqueezeTensorOutputArg0 = unsqueezeInputFunc(0);
|
||||||
|
onnxruntime::NodeArg &unsqueezeTensorOutputArg1 = unsqueezeInputFunc(1);
|
||||||
|
|
||||||
|
onnx::TypeProto concatOutputArgType = ToTypeProto(concatOutputShape);
|
||||||
|
concatOutputArgType.mutable_tensor_type()->set_elem_type(tensorType);
|
||||||
|
onnxruntime::NodeArg &concatTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_concat_output0"), &concatOutputArgType);
|
||||||
|
onnxruntime::Node* concatNode = graph->AddNode(nodeName + string("_Concat"), "Concat", "", { &unsqueezeTensorOutputArg0, &unsqueezeTensorOutputArg1 },
|
||||||
|
{ &concatTensorOutputArg });
|
||||||
|
concatNode->AddAttribute("axis", static_cast<int64_t>(0));
|
||||||
|
|
||||||
|
onnx::TypeProto outputArgType = ToTypeProto(broadcastShape);
|
||||||
|
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
|
||||||
|
onnxruntime::NodeArg &reduceLogSumExpTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_Output_0"), &outputArgType);
|
||||||
|
node = graph->AddNode(nodeName + string("_reduce_log_sum_exp"), "ReduceLogSumExp", "", { &concatTensorOutputArg }, { &reduceLogSumExpTensorOutputArg });
|
||||||
|
// reduce over the first axis.
|
||||||
|
node->AddAttribute("axes", std::vector<int64_t>(1, 0));
|
||||||
|
node->AddAttribute("keepdims", static_cast<int64_t>(0));
|
||||||
|
}
|
||||||
else
|
else
|
||||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -443,6 +443,10 @@ namespace ONNX
|
||||||
{ L"StraightThrough",{ {
|
{ L"StraightThrough",{ {
|
||||||
{ L"StraightThrough", "StraightThrough" },
|
{ L"StraightThrough", "StraightThrough" },
|
||||||
} } },
|
} } },
|
||||||
|
{ L"LogPlus",{ {
|
||||||
|
{ L"LogPlus", "LogPlus" },
|
||||||
|
} } },
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,
|
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,
|
||||||
|
|
|
@ -904,6 +904,37 @@ def test_LogSoftmax(tmpdir, dtype):
|
||||||
model = C.log_softmax(x)
|
model = C.log_softmax(x)
|
||||||
verify_one_input(model, data, tmpdir, 'LogSoftmax_1')
|
verify_one_input(model, data, tmpdir, 'LogSoftmax_1')
|
||||||
|
|
||||||
|
#LogAddExp
|
||||||
|
@pytest.mark.parametrize("dtype", DType_Config)
|
||||||
|
def test_LogAddExp(tmpdir, dtype):
|
||||||
|
shape = (2,3,4)
|
||||||
|
|
||||||
|
data_x = np.random.rand(*shape).astype(np.float32)
|
||||||
|
data_y = np.random.rand(*shape).astype(np.float32)
|
||||||
|
|
||||||
|
x = C.input_variable(shape)
|
||||||
|
y = C.input_variable(shape)
|
||||||
|
|
||||||
|
model = C.log_add_exp(x, y)
|
||||||
|
|
||||||
|
verify_two_input(model, data_x, data_y, tmpdir, 'LogAddExp_0')
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", DType_Config)
|
||||||
|
def test_LogAddExp_Broadcast(tmpdir, dtype):
|
||||||
|
shape_x_arr = [(2,1,4), (2,1,4), (2,2,3,4)]
|
||||||
|
shape_y_arr = [(1,3,1), (3,1), (1,1)]
|
||||||
|
|
||||||
|
for i, (shape_x, shape_y) in enumerate(list(zip(shape_x_arr, shape_y_arr))):
|
||||||
|
data_x = np.random.rand(*shape_x).astype(np.float32)
|
||||||
|
data_y = np.random.rand(*shape_y).astype(np.float32)
|
||||||
|
|
||||||
|
x = C.input_variable(shape_x)
|
||||||
|
y = C.input_variable(shape_y)
|
||||||
|
|
||||||
|
model = C.log_add_exp(x, y)
|
||||||
|
|
||||||
|
verify_two_input(model, data_x, data_y, tmpdir, 'LogAddExp_Broadcast_' + str(i))
|
||||||
|
|
||||||
#LRN
|
#LRN
|
||||||
@pytest.mark.parametrize("dtype", DType_Config)
|
@pytest.mark.parametrize("dtype", DType_Config)
|
||||||
def test_LRN(tmpdir, dtype, device_id):
|
def test_LRN(tmpdir, dtype, device_id):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче