update with reviewers' comments

This commit is contained in:
liqun fu 2020-03-27 14:35:31 -07:00
Родитель 9a7dd4c82b
Коммит 80e9f79ae6
1 изменённых файлов: 11 добавлений и 2 удалений

Просмотреть файл

@ -5637,7 +5637,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
} }
else if (src->OpName() == L"Pooling" && src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis()) else if (src->OpName() == L"Pooling" && src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
{ {
// in case a Pooling op is created with bother batch and sequence axes, we need to reshape its input and output to match // in case a Pooling op is created with both batch and sequence axes, we need to reshape its input and output to match
// ONNX spec of [N, C, H, W] shape requirement. // ONNX spec of [N, C, H, W] shape requirement.
return CreatePoolingNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex); return CreatePoolingNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
} }
@ -7109,6 +7109,14 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
} }
else else
{ {
if (src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
{
if (!std::all_of(lowerPad.begin(), lowerPad.end(), [](int64_t pad) {return pad == 0; }) ||
!std::all_of(upperPad.begin(), upperPad.end(), [](int64_t pad) {return pad == 0; }))
{
fprintf(stderr, "Warning: Cannot set upperPad and lowerPad with pooling ops. Padding values will be computed according to kernel and input shapes.");
}
}
if (isPooling) if (isPooling)
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, /*dilation=*/std::vector<size_t>(kernelShape.Rank(), 1), PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, /*dilation=*/std::vector<size_t>(kernelShape.Rank(), 1),
ceilOutDim, /*transpose=*/!isPooling); ceilOutDim, /*transpose=*/!isPooling);
@ -8661,7 +8669,8 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src,
vector<int64_t> newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto()); vector<int64_t> newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto());
onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph); onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph);
return poolingNode; functionNodes.emplace(src, poolingNode);
return postReshape;
} }
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src, onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src,