update with reviewers' comments
This commit is contained in:
Родитель
9a7dd4c82b
Коммит
80e9f79ae6
|
@ -5637,7 +5637,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
|
||||||
}
|
}
|
||||||
else if (src->OpName() == L"Pooling" && src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
|
else if (src->OpName() == L"Pooling" && src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
|
||||||
{
|
{
|
||||||
// in case a Pooling op is created with bother batch and sequence axes, we need to reshape its input and output to match
|
// in case a Pooling op is created with both batch and sequence axes, we need to reshape its input and output to match
|
||||||
// ONNX spec of [N, C, H, W] shape requirement.
|
// ONNX spec of [N, C, H, W] shape requirement.
|
||||||
return CreatePoolingNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
|
return CreatePoolingNode(src, graph, functionNodes, variableNodes, scanLoops, createLoopIndex);
|
||||||
}
|
}
|
||||||
|
@ -7109,6 +7109,14 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
if (src->Inputs()[0].HasBatchAxis() && src->Inputs()[0].HasSequenceAxis())
|
||||||
|
{
|
||||||
|
if (!std::all_of(lowerPad.begin(), lowerPad.end(), [](int64_t pad) {return pad == 0; }) ||
|
||||||
|
!std::all_of(upperPad.begin(), upperPad.end(), [](int64_t pad) {return pad == 0; }))
|
||||||
|
{
|
||||||
|
fprintf(stderr, "Warning: Cannot set upperPad and lowerPad with pooling ops. Padding values will be computed according to kernel and input shapes.");
|
||||||
|
}
|
||||||
|
}
|
||||||
if (isPooling)
|
if (isPooling)
|
||||||
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, /*dilation=*/std::vector<size_t>(kernelShape.Rank(), 1),
|
PutPadAttrInNode(node, autoPadding, kernelShape, inputShape, strides, /*dilation=*/std::vector<size_t>(kernelShape.Rank(), 1),
|
||||||
ceilOutDim, /*transpose=*/!isPooling);
|
ceilOutDim, /*transpose=*/!isPooling);
|
||||||
|
@ -8661,7 +8669,8 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePoolingNode(const FunctionPtr& src,
|
||||||
vector<int64_t> newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto());
|
vector<int64_t> newDimOutputFromPooling = ToINTS(*outputs[0]->TypeAsProto());
|
||||||
onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph);
|
onnxruntime::Node* postReshape = AddReshapeNode(*poolingOutputArg, newDimOutputFromPooling, outputs[0]->Name(), graph);
|
||||||
|
|
||||||
return poolingNode;
|
functionNodes.emplace(src, poolingNode);
|
||||||
|
return postReshape;
|
||||||
}
|
}
|
||||||
|
|
||||||
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src,
|
onnxruntime::Node* CNTKToONNXHelper::CreateConvolutionNode(const FunctionPtr& src,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче