Support logPlus(log_add_exp) export to ONNX

* ONNX supports similar op ReduceLogSumExp. Conversions are added when
exporting.
* Refactored CNTKToONNXHelper::BroadcastInputsIfNeeded to support more
generalized cases.
This commit is contained in:
Bowen Bao 2018-09-21 16:04:14 -07:00
Родитель c2072cc4ab
Коммит a36fae88bb
3 изменённых файлов: 91 добавлений и 14 удалений

Просмотреть файл

@ -116,7 +116,8 @@ private:
static onnxruntime::Node *AddTransposeNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, const std::vector<int64_t> &perm, static onnxruntime::Node *AddTransposeNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, const std::vector<int64_t> &perm,
onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName); onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName);
static void BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg *> &orderedInputs, const FunctionPtr& src, onnxruntime::Graph* graph); static std::vector<int64_t> BroadcastInputs(std::vector<onnxruntime::NodeArg *> &orderedInputs, const std::set<int64_t>& ignoreAxes,
const FunctionPtr& src, onnxruntime::Graph* graph);
// //
// Insert a reshape node in front of a given node and its output node arg // Insert a reshape node in front of a given node and its output node arg
@ -4104,17 +4105,15 @@ std::vector<int64_t> GetShapeFromNodeArg(onnxruntime::NodeArg *nodeArg)
return shape; return shape;
} }
// CNTK splice allows broadcast of inputs before applying concatination. // CNTK splice allows broadcast of inputs before applying concatenation.
// ONNX Concat is limited to matching input shape cases // ONNX Concat is limited to matching input shape cases
// i.e. inputs' dimensions shall be the equal except for the concatination axis. // i.e. inputs' dimensions shall be the equal except for the concatenation axis.
// for an example, see test_Concat_With_Broadcast in onnx_op_test.py. // for an example, see test_Concat_With_Broadcast in onnx_op_test.py.
void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg *> &orderedInputs, const FunctionPtr& src, onnxruntime::Graph* graph) // This function broadcasts the inputs for axes excluding ignoreAxes.
// Returns the broadcasted shape.
std::vector<int64_t> CNTKToONNXHelper::BroadcastInputs(std::vector<onnxruntime::NodeArg *> &orderedInputs, const std::set<int64_t> &ignoreAxes,
const FunctionPtr& src, onnxruntime::Graph* graph)
{ {
if (src->OpName() != L"Splice")
return;
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
int64_t concatAxis = ConvertAxisToOnnxBroadcastOfOp(axis, src);
std::vector<std::vector<int64_t>> shapes; std::vector<std::vector<int64_t>> shapes;
int max_rank = 0; int max_rank = 0;
for (auto nodeArg : orderedInputs) for (auto nodeArg : orderedInputs)
@ -4130,13 +4129,13 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
for (int index_to_shape_i = 0; index_to_shape_i < shape_i.size(); index_to_shape_i++) for (int index_to_shape_i = 0; index_to_shape_i < shape_i.size(); index_to_shape_i++)
{ {
int onnx_axis = index_to_shape_i + (max_rank - shape_i.size()); int onnx_axis = index_to_shape_i + (max_rank - shape_i.size());
if (onnx_axis == concatAxis) if (ignoreAxes.find(onnx_axis) != ignoreAxes.end())
// only check and update no-concat_axis dimensions // only check and update non ignoreAxes dimensions
continue; continue;
else if (broadcast_shape[onnx_axis] == 1) else if (broadcast_shape[onnx_axis] == 1)
broadcast_shape[onnx_axis] = shape_i[index_to_shape_i]; broadcast_shape[onnx_axis] = shape_i[index_to_shape_i];
else if (broadcast_shape[onnx_axis] != shape_i[index_to_shape_i] && shape_i[index_to_shape_i] != 1) else if (broadcast_shape[onnx_axis] != shape_i[index_to_shape_i] && shape_i[index_to_shape_i] != 1)
LogicError("Invalid splice inputs shape"); LogicError("Invalid broadcast inputs shape");
} }
} }
@ -4152,7 +4151,7 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
for (int onnx_axis = 0; onnx_axis < shape_i.size(); onnx_axis++) for (int onnx_axis = 0; onnx_axis < shape_i.size(); onnx_axis++)
{ {
if (onnx_axis != concatAxis && shape_i[onnx_axis] != broadcast_shape[onnx_axis]) if (ignoreAxes.find(onnx_axis) == ignoreAxes.end() && shape_i[onnx_axis] != broadcast_shape[onnx_axis])
{ {
shape_i[onnx_axis] = broadcast_shape[onnx_axis]; shape_i[onnx_axis] = broadcast_shape[onnx_axis];
need_broadcast = true; need_broadcast = true;
@ -4171,6 +4170,8 @@ void CNTKToONNXHelper::BroadcastInputsIfNeeded(std::vector<onnxruntime::NodeArg
onnxruntime::Node *node = AddAddNode(*nodeArg, nodeArg2, graph, out_arg_name); onnxruntime::Node *node = AddAddNode(*nodeArg, nodeArg2, graph, out_arg_name);
orderedInputs[i] = const_cast<NodeArg*>(node->OutputDefs()[0]); orderedInputs[i] = const_cast<NodeArg*>(node->OutputDefs()[0]);
} }
return broadcast_shape;
} }
onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime::Graph* graph, const std::vector<onnxruntime::NodeArg *>& inputs, const std::vector<onnxruntime::NodeArg *>& outputs) onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime::Graph* graph, const std::vector<onnxruntime::NodeArg *>& inputs, const std::vector<onnxruntime::NodeArg *>& outputs)
@ -4263,9 +4264,50 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
} }
else if (src->OpName() == L"Splice") else if (src->OpName() == L"Splice")
{ {
BroadcastInputsIfNeeded(orderedInputs, src, graph); Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
BroadcastInputs(orderedInputs, { ConvertAxisToOnnxBroadcastOfOp(axis, src) }, src, graph);
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs); node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
} }
else if (src->OpName() == L"LogPlus")
{
// CNTK LogPlus is the equivalent to numpy.logaddexp
// ONNX has a different but similar op: ReduceLogSumExp
onnx::TensorProto_DataType tensorType = orderedInputs[0]->TypeAsProto()->tensor_type().elem_type();
std::vector<int64_t> broadcastShape = BroadcastInputs(orderedInputs, /*ignoreAxes=*/{}, src, graph);
// Now both inputs should have the same shape.
// Add another axis in front. This will be the axis to be reduced over later.
std::vector<int64_t> unsqueezeOutputShape = broadcastShape;
unsqueezeOutputShape.insert(unsqueezeOutputShape.begin(), 1);
std::vector<int64_t> concatOutputShape = broadcastShape;
concatOutputShape.insert(concatOutputShape.begin(), 2);
auto unsqueezeInputFunc = [&](int inputIndex) -> onnxruntime::NodeArg& {
onnx::TypeProto outputArgType = ToTypeProto(unsqueezeOutputShape);
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
onnxruntime::NodeArg &unsqueezeTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_unsqueeze" + std::to_string(inputIndex) + "_output0"), &outputArgType);
onnxruntime::Node* unsqueezeNode = graph->AddNode(nodeName + string("_Unsqueeze") + std::to_string(inputIndex), "Unsqueeze", "", { orderedInputs[inputIndex] }, { &unsqueezeTensorOutputArg });
unsqueezeNode->AddAttribute("axes", std::vector<int64_t>(1, 0));
return unsqueezeTensorOutputArg;
};
onnxruntime::NodeArg &unsqueezeTensorOutputArg0 = unsqueezeInputFunc(0);
onnxruntime::NodeArg &unsqueezeTensorOutputArg1 = unsqueezeInputFunc(1);
onnx::TypeProto concatOutputArgType = ToTypeProto(concatOutputShape);
concatOutputArgType.mutable_tensor_type()->set_elem_type(tensorType);
onnxruntime::NodeArg &concatTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_concat_output0"), &concatOutputArgType);
onnxruntime::Node* concatNode = graph->AddNode(nodeName + string("_Concat"), "Concat", "", { &unsqueezeTensorOutputArg0, &unsqueezeTensorOutputArg1 },
{ &concatTensorOutputArg });
concatNode->AddAttribute("axis", static_cast<int64_t>(0));
onnx::TypeProto outputArgType = ToTypeProto(broadcastShape);
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
onnxruntime::NodeArg &reduceLogSumExpTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_Output_0"), &outputArgType);
node = graph->AddNode(nodeName + string("_reduce_log_sum_exp"), "ReduceLogSumExp", "", { &concatTensorOutputArg }, { &reduceLogSumExpTensorOutputArg });
// reduce over the first axis.
node->AddAttribute("axes", std::vector<int64_t>(1, 0));
node->AddAttribute("keepdims", static_cast<int64_t>(0));
}
else else
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs); node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
} }

Просмотреть файл

@ -443,6 +443,10 @@ namespace ONNX
{ L"StraightThrough",{ { { L"StraightThrough",{ {
{ L"StraightThrough", "StraightThrough" }, { L"StraightThrough", "StraightThrough" },
} } }, } } },
{ L"LogPlus",{ {
{ L"LogPlus", "LogPlus" },
} } },
}; };
// given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute, // given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute,

Просмотреть файл

@ -904,6 +904,37 @@ def test_LogSoftmax(tmpdir, dtype):
model = C.log_softmax(x) model = C.log_softmax(x)
verify_one_input(model, data, tmpdir, 'LogSoftmax_1') verify_one_input(model, data, tmpdir, 'LogSoftmax_1')
#LogAddExp
@pytest.mark.parametrize("dtype", DType_Config)
def test_LogAddExp(tmpdir, dtype):
shape = (2,3,4)
data_x = np.random.rand(*shape).astype(np.float32)
data_y = np.random.rand(*shape).astype(np.float32)
x = C.input_variable(shape)
y = C.input_variable(shape)
model = C.log_add_exp(x, y)
verify_two_input(model, data_x, data_y, tmpdir, 'LogAddExp_0')
@pytest.mark.parametrize("dtype", DType_Config)
def test_LogAddExp_Broadcast(tmpdir, dtype):
shape_x_arr = [(2,1,4), (2,1,4), (2,2,3,4)]
shape_y_arr = [(1,3,1), (3,1), (1,1)]
for i, (shape_x, shape_y) in enumerate(list(zip(shape_x_arr, shape_y_arr))):
data_x = np.random.rand(*shape_x).astype(np.float32)
data_y = np.random.rand(*shape_y).astype(np.float32)
x = C.input_variable(shape_x)
y = C.input_variable(shape_y)
model = C.log_add_exp(x, y)
verify_two_input(model, data_x, data_y, tmpdir, 'LogAddExp_Broadcast_' + str(i))
#LRN #LRN
@pytest.mark.parametrize("dtype", DType_Config) @pytest.mark.parametrize("dtype", DType_Config)
def test_LRN(tmpdir, dtype, device_id): def test_LRN(tmpdir, dtype, device_id):