Support logPlus(log_add_exp) export to ONNX
* ONNX supports similar op ReduceLogSumExp. Conversions are added when exporting. * Refactored CNTKToONNXHelper::BroadcastInputsIfNeeded to support more generalized cases.
This commit is contained in:
Родитель
c2072cc4ab
Коммит
a36fae88bb
|
@ -116,7 +116,8 @@ private:
|
|||
static onnxruntime::Node *AddTransposeNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, const std::vector<int64_t> &perm,
|
||||
onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName);
|
||||
|
||||
static void BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg *> &orderedInputs, const FunctionPtr& src, onnxruntime::Graph* graph);
|
||||
static std::vector<int64_t> BroadcastInputs(std::vector<onnxruntime::NodeArg *> &orderedInputs, const std::set<int64_t>& ignoreAxes,
|
||||
const FunctionPtr& src, onnxruntime::Graph* graph);
|
||||
|
||||
//
|
||||
// Insert a reshape node in front of a given node and its output node arg
|
||||
|
@ -4104,17 +4105,15 @@ std::vector<int64_t> GetShapeFromNodeArg(onnxruntime::NodeArg *nodeArg)
|
|||
return shape;
|
||||
}
|
||||
|
||||
// CNTK splice allows broadcast of inputs before applying concatination.
|
||||
// CNTK splice allows broadcast of inputs before applying concatenation.
|
||||
// ONNX Concat is limited to matching input shape cases
|
||||
// i.e. inputs' dimensions shall be the equal except for the concatination axis.
|
||||
// i.e. inputs' dimensions shall be the equal except for the concatenation axis.
|
||||
// for an example, see test_Concat_With_Broadcast in onnx_op_test.py.
|
||||
void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg *> &orderedInputs, const FunctionPtr& src, onnxruntime::Graph* graph)
|
||||
// This function broadcasts the inputs for axes excluding ignoreAxes.
|
||||
// Returns the broadcasted shape.
|
||||
std::vector<int64_t> CNTKToONNXHelper::BroadcastInputs(std::vector<onnxruntime::NodeArg *> &orderedInputs, const std::set<int64_t> &ignoreAxes,
|
||||
const FunctionPtr& src, onnxruntime::Graph* graph)
|
||||
{
|
||||
if (src->OpName() != L"Splice")
|
||||
return;
|
||||
|
||||
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
int64_t concatAxis = ConvertAxisToOnnxBroadcastOfOp(axis, src);
|
||||
std::vector<std::vector<int64_t>> shapes;
|
||||
int max_rank = 0;
|
||||
for (auto nodeArg : orderedInputs)
|
||||
|
@ -4130,13 +4129,13 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
|
|||
for (int index_to_shape_i = 0; index_to_shape_i < shape_i.size(); index_to_shape_i++)
|
||||
{
|
||||
int onnx_axis = index_to_shape_i + (max_rank - shape_i.size());
|
||||
if (onnx_axis == concatAxis)
|
||||
// only check and update no-concat_axis dimensions
|
||||
if (ignoreAxes.find(onnx_axis) != ignoreAxes.end())
|
||||
// only check and update non ignoreAxes dimensions
|
||||
continue;
|
||||
else if (broadcast_shape[onnx_axis] == 1)
|
||||
broadcast_shape[onnx_axis] = shape_i[index_to_shape_i];
|
||||
else if (broadcast_shape[onnx_axis] != shape_i[index_to_shape_i] && shape_i[index_to_shape_i] != 1)
|
||||
LogicError("Invalid splice inputs shape");
|
||||
LogicError("Invalid broadcast inputs shape");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4152,7 +4151,7 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
|
|||
|
||||
for (int onnx_axis = 0; onnx_axis < shape_i.size(); onnx_axis++)
|
||||
{
|
||||
if (onnx_axis != concatAxis && shape_i[onnx_axis] != broadcast_shape[onnx_axis])
|
||||
if (ignoreAxes.find(onnx_axis) == ignoreAxes.end() && shape_i[onnx_axis] != broadcast_shape[onnx_axis])
|
||||
{
|
||||
shape_i[onnx_axis] = broadcast_shape[onnx_axis];
|
||||
need_broadcast = true;
|
||||
|
@ -4171,6 +4170,8 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
|
|||
onnxruntime::Node *node = AddAddNode(*nodeArg, nodeArg2, graph, out_arg_name);
|
||||
orderedInputs[i] = const_cast<NodeArg*>(node->OutputDefs()[0]);
|
||||
}
|
||||
|
||||
return broadcast_shape;
|
||||
}
|
||||
|
||||
onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime::Graph* graph, const std::vector<onnxruntime::NodeArg *>& inputs, const std::vector<onnxruntime::NodeArg *>& outputs)
|
||||
|
@ -4263,9 +4264,50 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
|
|||
}
|
||||
else if (src->OpName() == L"Splice")
|
||||
{
|
||||
BroadcastInputsIfNeeded(orderedInputs, src, graph);
|
||||
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
BroadcastInputs(orderedInputs, { ConvertAxisToOnnxBroadcastOfOp(axis, src) }, src, graph);
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||
}
|
||||
else if (src->OpName() == L"LogPlus")
|
||||
{
|
||||
// CNTK LogPlus is the equivalent to numpy.logaddexp
|
||||
// ONNX has a different but similar op: ReduceLogSumExp
|
||||
onnx::TensorProto_DataType tensorType = orderedInputs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
std::vector<int64_t> broadcastShape = BroadcastInputs(orderedInputs, /*ignoreAxes=*/{}, src, graph);
|
||||
// Now both inputs should have the same shape.
|
||||
// Add another axis in front. This will be the axis to be reduced over later.
|
||||
std::vector<int64_t> unsqueezeOutputShape = broadcastShape;
|
||||
unsqueezeOutputShape.insert(unsqueezeOutputShape.begin(), 1);
|
||||
std::vector<int64_t> concatOutputShape = broadcastShape;
|
||||
concatOutputShape.insert(concatOutputShape.begin(), 2);
|
||||
|
||||
auto unsqueezeInputFunc = [&](int inputIndex) -> onnxruntime::NodeArg& {
|
||||
onnx::TypeProto outputArgType = ToTypeProto(unsqueezeOutputShape);
|
||||
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
|
||||
onnxruntime::NodeArg &unsqueezeTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_unsqueeze" + std::to_string(inputIndex) + "_output0"), &outputArgType);
|
||||
onnxruntime::Node* unsqueezeNode = graph->AddNode(nodeName + string("_Unsqueeze") + std::to_string(inputIndex), "Unsqueeze", "", { orderedInputs[inputIndex] }, { &unsqueezeTensorOutputArg });
|
||||
unsqueezeNode->AddAttribute("axes", std::vector<int64_t>(1, 0));
|
||||
return unsqueezeTensorOutputArg;
|
||||
};
|
||||
|
||||
onnxruntime::NodeArg &unsqueezeTensorOutputArg0 = unsqueezeInputFunc(0);
|
||||
onnxruntime::NodeArg &unsqueezeTensorOutputArg1 = unsqueezeInputFunc(1);
|
||||
|
||||
onnx::TypeProto concatOutputArgType = ToTypeProto(concatOutputShape);
|
||||
concatOutputArgType.mutable_tensor_type()->set_elem_type(tensorType);
|
||||
onnxruntime::NodeArg &concatTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_concat_output0"), &concatOutputArgType);
|
||||
onnxruntime::Node* concatNode = graph->AddNode(nodeName + string("_Concat"), "Concat", "", { &unsqueezeTensorOutputArg0, &unsqueezeTensorOutputArg1 },
|
||||
{ &concatTensorOutputArg });
|
||||
concatNode->AddAttribute("axis", static_cast<int64_t>(0));
|
||||
|
||||
onnx::TypeProto outputArgType = ToTypeProto(broadcastShape);
|
||||
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
|
||||
onnxruntime::NodeArg &reduceLogSumExpTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_Output_0"), &outputArgType);
|
||||
node = graph->AddNode(nodeName + string("_reduce_log_sum_exp"), "ReduceLogSumExp", "", { &concatTensorOutputArg }, { &reduceLogSumExpTensorOutputArg });
|
||||
// reduce over the first axis.
|
||||
node->AddAttribute("axes", std::vector<int64_t>(1, 0));
|
||||
node->AddAttribute("keepdims", static_cast<int64_t>(0));
|
||||
}
|
||||
else
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||
}
|
||||
|
|
|
@ -443,6 +443,10 @@ namespace ONNX
|
|||
{ L"StraightThrough",{ {
|
||||
{ L"StraightThrough", "StraightThrough" },
|
||||
} } },
|
||||
{ L"LogPlus",{ {
|
||||
{ L"LogPlus", "LogPlus" },
|
||||
} } },
|
||||
|
||||
};
|
||||
|
||||
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,
|
||||
|
|
|
@ -904,6 +904,37 @@ def test_LogSoftmax(tmpdir, dtype):
|
|||
model = C.log_softmax(x)
|
||||
verify_one_input(model, data, tmpdir, 'LogSoftmax_1')
|
||||
|
||||
#LogAddExp
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_LogAddExp(tmpdir, dtype):
|
||||
shape = (2,3,4)
|
||||
|
||||
data_x = np.random.rand(*shape).astype(np.float32)
|
||||
data_y = np.random.rand(*shape).astype(np.float32)
|
||||
|
||||
x = C.input_variable(shape)
|
||||
y = C.input_variable(shape)
|
||||
|
||||
model = C.log_add_exp(x, y)
|
||||
|
||||
verify_two_input(model, data_x, data_y, tmpdir, 'LogAddExp_0')
|
||||
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_LogAddExp_Broadcast(tmpdir, dtype):
|
||||
shape_x_arr = [(2,1,4), (2,1,4), (2,2,3,4)]
|
||||
shape_y_arr = [(1,3,1), (3,1), (1,1)]
|
||||
|
||||
for i, (shape_x, shape_y) in enumerate(list(zip(shape_x_arr, shape_y_arr))):
|
||||
data_x = np.random.rand(*shape_x).astype(np.float32)
|
||||
data_y = np.random.rand(*shape_y).astype(np.float32)
|
||||
|
||||
x = C.input_variable(shape_x)
|
||||
y = C.input_variable(shape_y)
|
||||
|
||||
model = C.log_add_exp(x, y)
|
||||
|
||||
verify_two_input(model, data_x, data_y, tmpdir, 'LogAddExp_Broadcast_' + str(i))
|
||||
|
||||
#LRN
|
||||
@pytest.mark.parametrize("dtype", DType_Config)
|
||||
def test_LRN(tmpdir, dtype, device_id):
|
||||
|
|
Загрузка…
Ссылка в новой задаче