update with reviewers' comments
This commit is contained in:
Родитель
9a7dd4c82b
Коммит
80e9f79ae6
|
@ -5637,7 +5637,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
|
|||
}
|
||||
else if (src->OpName() == L"Pooling" && src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
|
||||
{
|
||||
// in case a Pooling op is created with bother batch and sequence axes, we need to reshape its input and output to match
|
||||
// in case a Pooling op is created with both batch and sequence axes, we need to reshape its input and output to match
|
||||
// ONNX spec of [N, C, H, W] shape requirement.
|
||||
return CreatePoolingNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
|
||||
}
|
||||
|
@ -7109,6 +7109,14 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
|
|||
}
|
||||
else
|
||||
{
|
||||
if (src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
|
||||
{
|
||||
if (!std::all_of(lowerPad.begin(), lowerPad.end(), [](int64_t pad) {return pad == 0; }) ||
|
||||
!std::all_of(upperPad.begin(), upperPad.end(), [](int64_t pad) {return pad == 0; }))
|
||||
{
|
||||
fprintf(stderr, "Warning: Cannot set upperPad and lowerPad with pooling ops. Padding values will be computed according to kernel and input shapes.");
|
||||
}
|
||||
}
|
||||
if (isPooling)
|
||||
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, /*dilation=*/std::vector<size_t>(kernelShape.Rank(), 1),
|
||||
ceilOutDim, /*transpose=*/!isPooling);
|
||||
|
@ -8661,7 +8669,8 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src,
|
|||
vector<int64_t> newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto());
|
||||
onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph);
|
||||
|
||||
return poolingNode;
|
||||
functionNodes.emplace(src, poolingNode);
|
||||
return postReshape;
|
||||
}
|
||||
|
||||
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src,
|
||||
|
|
Загрузка…
Ссылка в новой задаче