Extend Pad Fusion for AveragePool (#21556)
### Description This extends the existing pad_fusion for AveragePool operator i.e. fuse Pad if it is followed by AveragePool operator. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Родитель
530a2d7b41
Коммит
1637f22d39
|
@ -12,7 +12,7 @@ namespace onnxruntime {
|
|||
* It matches following pattern:
|
||||
* Pad
|
||||
* |
|
||||
* Conv/MaxPool
|
||||
* Conv/MaxPool/AveragePool
|
||||
*/
|
||||
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
// if Pad has input axis, don't fuse it.
|
||||
|
@ -28,6 +28,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
|
|||
|
||||
const Node& child_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
namespace onnxruntime {
|
||||
/*
|
||||
* This fusion submerges a Pad operator to it's child
|
||||
* Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition()
|
||||
* Conv or MaxPool or AveragePool operator, if and only if PadFusion::SatisfyCondition()
|
||||
* is true.
|
||||
*/
|
||||
class PadFusion : public RewriteRule {
|
||||
|
|
Загрузка…
Ссылка в новой задаче