resolve comments
This commit is contained in:
Родитель
dff068b804
Коммит
81d3fe90c7
|
@ -11,16 +11,62 @@ using namespace ONNX_NAMESPACE;
|
|||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
|
||||
// Attention subgraph has 4 MatMul-Add pairs, that we want to skip here because AttentionFusion will handle it.
|
||||
// In such case, 3 of MatMul-Add pairs are following LN, the other one produces output which is added with LN's output.
|
||||
// Use two sets to remember such patterns we already met during the graph iteration so that we can skip them directly
|
||||
// if we go to other MatMul-Add pairs in the same pattern.
|
||||
struct AttentionPatternCache {
|
||||
bool IsAttentionPattern(const Graph& graph, const Node& matmul_node, const Node& add_node) {
|
||||
const Node* parent_node = graph.GetProducerNode(matmul_node.InputDefs()[0]->Name());
|
||||
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&add_node) > 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (parent_node && parent_node->OpType() == "LayerNormalization") {
|
||||
unsigned int add_count = 0;
|
||||
unsigned int matmul_count = 0;
|
||||
unsigned int shape_count = 0;
|
||||
const Node* ln_add_node = nullptr;
|
||||
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
|
||||
std::string op_type = (*it).OpType();
|
||||
if (op_type == "Add") {
|
||||
ln_add_node = &(*it);
|
||||
add_count++;
|
||||
} else if (op_type == "MatMul") {
|
||||
matmul_count++;
|
||||
} else if (op_type == "Shape") {
|
||||
shape_count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
|
||||
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
|
||||
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
|
||||
if (attn_add_node && attn_add_node->OpType() == "Add") {
|
||||
attn_ln_nodes.insert(parent_node);
|
||||
attn_add_nodes.insert(attn_add_node);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<const Node*> attn_ln_nodes;
|
||||
std::unordered_set<const Node*> attn_add_nodes;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
// These two sets are used to skip Attention pattern, which will be handled by AttentionFusion.
|
||||
// There are 4 MatMul-Add pairs in Attention pattern, 3 of them are following LayerNormalization, the other one
|
||||
// produces output which is added with LayerNormalization's output, we can skip them directly if we see same
|
||||
// processed nodes again which are stored in these two sets.
|
||||
std::unordered_set<const Node*> attn_ln_nodes;
|
||||
std::unordered_set<const Node*> attn_add_nodes;
|
||||
// Cache for skipping Attention subgraph pattern.
|
||||
AttentionPatternCache attn_pattern_cache;
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto* node_ptr = graph.GetNode(node_index);
|
||||
|
@ -81,41 +127,11 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
InlinedVector<int64_t> shape_values;
|
||||
int64_t m = 0, k = 0, n = 0;
|
||||
if (need_reshape) {
|
||||
// Skip Attention pattern, AttentionFusion will handle it. In such case, there are 4 MatMul-Add pairs,
|
||||
// 3 of them are following LN, the other one produces output which is added with LN's output.
|
||||
const Node* parent_node = graph.GetProducerNode(matmul_input_defs[0]->Name());
|
||||
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&next_node) > 0) {
|
||||
// Only check and skip Attention pattern here because normally input to Attention is 4D.
|
||||
if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (parent_node && parent_node->OpType() == "LayerNormalization") {
|
||||
unsigned int add_count = 0;
|
||||
unsigned int matmul_count = 0;
|
||||
unsigned int shape_count = 0;
|
||||
const Node* ln_add_node = nullptr;
|
||||
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
|
||||
std::string op_type = (*it).OpType();
|
||||
if (op_type == "Add") {
|
||||
ln_add_node = &(*it);
|
||||
add_count++;
|
||||
} else if (op_type == "MatMul") {
|
||||
matmul_count++;
|
||||
} else if (op_type == "Shape") {
|
||||
shape_count++;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
|
||||
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
|
||||
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
|
||||
if (attn_add_node && attn_add_node->OpType() == "Add") {
|
||||
attn_ln_nodes.insert(parent_node);
|
||||
attn_add_nodes.insert(attn_add_node);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
|
||||
// both inputs have concrete shape for now, we can add dynamic shape support in future.
|
||||
auto a_shape = utils::GetTensorShapeFromTensorShapeProto(*matmul_a_shape);
|
||||
|
|
|
@ -48,7 +48,7 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
|
|||
fused_count++;
|
||||
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
|
||||
modified = true;
|
||||
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph, logger)) {
|
||||
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph)) {
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
@ -454,8 +454,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger) {
|
||||
ORT_UNUSED_PARAMETER(logger);
|
||||
bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) {
|
||||
InlinedVector<std::reference_wrapper<Node>> contiguous_reshapes{reshape};
|
||||
InlinedVector<int64_t> shape_value;
|
||||
while (true) {
|
||||
|
@ -474,19 +473,12 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const lo
|
|||
break;
|
||||
}
|
||||
|
||||
bool is_concrete_shape = true;
|
||||
shape_value.clear();
|
||||
for (const auto& dim : shape->dim()) {
|
||||
if (dim.has_dim_value()) {
|
||||
shape_value.emplace_back(dim.dim_value());
|
||||
} else {
|
||||
is_concrete_shape = false;
|
||||
}
|
||||
}
|
||||
if (!is_concrete_shape) {
|
||||
auto tensor_shape = utils::GetTensorShapeFromTensorShapeProto(*shape);
|
||||
if (tensor_shape.Size() == -1) {
|
||||
break;
|
||||
}
|
||||
|
||||
shape_value = tensor_shape.AsShapeVector();
|
||||
contiguous_reshapes.emplace_back(*next_node);
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class ReshapeFusion : public GraphTransformer {
|
|||
// Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete.
|
||||
// For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output,
|
||||
// so this fusion can help to reduce memory usage on such devices.
|
||||
static bool FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger);
|
||||
static bool FuseContiguousReshapes(Node& reshape, Graph& graph);
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -137,7 +137,7 @@ std::unique_ptr<IQnnNodeGroup> ReshapeGemmFusion::TryFusion(
|
|||
QnnModelWrapper& qnn_model_wrapper, const NodeUnit& gemm_node_unit,
|
||||
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
|
||||
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
|
||||
[[maybe_unused]] const logging::Logger& logger) {
|
||||
const logging::Logger& /*logger*/) {
|
||||
if (gemm_node_unit.OpType() != "Gemm" || gemm_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -173,13 +173,11 @@ ReshapeGemmFusion::ReshapeGemmFusion(const NodeUnit& reshape_node_unit, const No
|
|||
node_units_[1] = &gemm_node_unit;
|
||||
}
|
||||
|
||||
Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
|
||||
ORT_UNUSED_PARAMETER(logger);
|
||||
Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
|
||||
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], true);
|
||||
}
|
||||
|
||||
Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
|
||||
ORT_UNUSED_PARAMETER(logger);
|
||||
Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const {
|
||||
return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], false);
|
||||
}
|
||||
|
||||
|
|
|
@ -881,7 +881,7 @@ TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) {
|
|||
"high"); // qnn_context_priority
|
||||
}
|
||||
|
||||
// Create a model with Case + Add (quantized)
|
||||
// Create a model with Cast + Add (quantized)
|
||||
// cast_input -> Cast -> Q -> DQ ----
|
||||
// |
|
||||
// input2 -> Q -> DQ -> Add -> Q -> DQ -> output
|
||||
|
|
|
@ -162,7 +162,7 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) {
|
|||
QnnContextBinaryMultiPartitionTestBody(single_ep_node);
|
||||
}
|
||||
|
||||
// Create a model with Case + Add (quantized)
|
||||
// Create a model with Cast + Add (quantized)
|
||||
// cast_input -> Cast -> Q -> DQ ----
|
||||
// |
|
||||
// input2 -> Q -> DQ -> Add -> Q -> DQ -> output
|
||||
|
|
Загрузка…
Ссылка в новой задаче