This commit is contained in:
Vincent Wang 2024-12-05 11:52:00 +08:00
Родитель dff068b804
Коммит 81d3fe90c7
6 изменённых файлов: 65 добавлений и 59 удалений

Просмотреть файл

@ -11,16 +11,62 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace {
// Attention subgraph has 4 MatMul-Add pairs, that we want to skip here because AttentionFusion will handle it.
// In such case, 3 of MatMul-Add pairs are following LN, the other one produces output which is added with LN's output.
// Use two sets to remember such patterns we already met during the graph iteration so that we can skip them directly
// if we go to other MatMul-Add pairs in the same pattern.
struct AttentionPatternCache {
bool IsAttentionPattern(const Graph& graph, const Node& matmul_node, const Node& add_node) {
const Node* parent_node = graph.GetProducerNode(matmul_node.InputDefs()[0]->Name());
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&add_node) > 0) {
return true;
}
if (parent_node && parent_node->OpType() == "LayerNormalization") {
unsigned int add_count = 0;
unsigned int matmul_count = 0;
unsigned int shape_count = 0;
const Node* ln_add_node = nullptr;
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
std::string op_type = (*it).OpType();
if (op_type == "Add") {
ln_add_node = &(*it);
add_count++;
} else if (op_type == "MatMul") {
matmul_count++;
} else if (op_type == "Shape") {
shape_count++;
}
}
if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
if (attn_add_node && attn_add_node->OpType() == "Add") {
attn_ln_nodes.insert(parent_node);
attn_add_nodes.insert(attn_add_node);
return true;
}
}
}
return false;
}
std::unordered_set<const Node*> attn_ln_nodes;
std::unordered_set<const Node*> attn_add_nodes;
};
} // namespace
Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
// These two sets are used to skip Attention pattern, which will be handled by AttentionFusion.
// There are 4 MatMul-Add pairs in Attention pattern, 3 of them are following LayerNormalization, the other one
// produces output which is added with LayerNormalization's output, we can skip them directly if we see same
// processed nodes again which are stored in these two sets.
std::unordered_set<const Node*> attn_ln_nodes;
std::unordered_set<const Node*> attn_add_nodes;
// Cache for skipping Attention subgraph pattern.
AttentionPatternCache attn_pattern_cache;
for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
@ -81,41 +127,11 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
InlinedVector<int64_t> shape_values;
int64_t m = 0, k = 0, n = 0;
if (need_reshape) {
// Skip Attention pattern, AttentionFusion will handle it. In such case, there are 4 MatMul-Add pairs,
// 3 of them are following LN, the other one produces output which is added with LN's output.
const Node* parent_node = graph.GetProducerNode(matmul_input_defs[0]->Name());
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&next_node) > 0) {
// Only check and skip Attention pattern here because normally input to Attention is 4D.
if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) {
continue;
}
if (parent_node && parent_node->OpType() == "LayerNormalization") {
unsigned int add_count = 0;
unsigned int matmul_count = 0;
unsigned int shape_count = 0;
const Node* ln_add_node = nullptr;
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
std::string op_type = (*it).OpType();
if (op_type == "Add") {
ln_add_node = &(*it);
add_count++;
} else if (op_type == "MatMul") {
matmul_count++;
} else if (op_type == "Shape") {
shape_count++;
}
}
if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
if (attn_add_node && attn_add_node->OpType() == "Add") {
attn_ln_nodes.insert(parent_node);
attn_add_nodes.insert(attn_add_node);
continue;
}
}
}
// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
// both inputs have concrete shape for now, we can add dynamic shape support in future.
auto a_shape = utils::GetTensorShapeFromTensorShapeProto(*matmul_a_shape);

Просмотреть файл

@ -48,7 +48,7 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
fused_count++;
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
modified = true;
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph, logger)) {
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph)) {
modified = true;
}
}
@ -454,8 +454,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
return true;
}
bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger) {
ORT_UNUSED_PARAMETER(logger);
bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) {
InlinedVector<std::reference_wrapper<Node>> contiguous_reshapes{reshape};
InlinedVector<int64_t> shape_value;
while (true) {
@ -474,19 +473,12 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const lo
break;
}
bool is_concrete_shape = true;
shape_value.clear();
for (const auto& dim : shape->dim()) {
if (dim.has_dim_value()) {
shape_value.emplace_back(dim.dim_value());
} else {
is_concrete_shape = false;
}
}
if (!is_concrete_shape) {
auto tensor_shape = utils::GetTensorShapeFromTensorShapeProto(*shape);
if (tensor_shape.Size() == -1) {
break;
}
shape_value = tensor_shape.AsShapeVector();
contiguous_reshapes.emplace_back(*next_node);
}

Просмотреть файл

@ -31,7 +31,7 @@ class ReshapeFusion : public GraphTransformer {
// Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete.
// For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output,
// so this fusion can help to reduce memory usage on such devices.
static bool FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger);
static bool FuseContiguousReshapes(Node& reshape, Graph& graph);
};
} // namespace onnxruntime

Просмотреть файл

@ -137,7 +137,7 @@ std::unique_ptr<IQnnNodeGroup> ReshapeGemmFusion::TryFusion(
QnnModelWrapper& qnn_model_wrapper, const NodeUnit& gemm_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
[[maybe_unused]] const logging::Logger& logger) {
const logging::Logger& /*logger*/) {
if (gemm_node_unit.OpType() != "Gemm" || gemm_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
return nullptr;
}
@ -173,13 +173,11 @@ ReshapeGemmFusion::ReshapeGemmFusion(const NodeUnit& reshape_node_unit, const No
node_units_[1] = &gemm_node_unit;
}
Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], true);
}
Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], false);
}

Просмотреть файл

@ -881,7 +881,7 @@ TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) {
"high"); // qnn_context_priority
}
// Create a model with Case + Add (quantized)
// Create a model with Cast + Add (quantized)
// cast_input -> Cast -> Q -> DQ ----
// |
// input2 -> Q -> DQ -> Add -> Q -> DQ -> output

Просмотреть файл

@ -162,7 +162,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) {
QnnContextBinaryMultiPartitionTestBody(single_ep_node);
}
// Create a model with Case + Add (quantized)
// Create a model with Cast + Add (quantized)
// cast_input -> Cast -> Q -> DQ ----
// |
// input2 -> Q -> DQ -> Add -> Q -> DQ -> output