[TransposeOptimizer] Support Unsqueeze/Transpose of input consumed by per-axis DQ (#21821)

### Description
Follow-up to: https://github.com/microsoft/onnxruntime/pull/21793

- Support looking past a per-axis DQ to do in-place Unsqueeze/Transpose
of initializers
- Support looking past a per-axis DQ to cancel a Transpose or Squeeze.

### Test models
For all test models, the transpose optimizer pushes a Transpose through
a Mul's input[0]. The Mul's input[1] is optionally unsqueezed and then
transposed.

### I. Test in-place unsqueeze and transpose of per-axis quantized
weight
Original model has input[1] with shape (3,)
<details><summary>click to expand model image</summary>
<img
src="https://github.com/user-attachments/assets/37b6f60c-77d2-4bd3-8ca2-58dc7c88a304"
/>
</details>

Optimized model has input[1] with shape (1, 3, 1, 1). The initializer
was unsqueezed and transposed in-place.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/adb72757-a164-400c-bfef-2a05f0e35825"
/>
</details>

### II. Test canceling existing Squeeze before per-axis DQ
Original model has input[1] that is squeezed.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/f27e6742-b563-42a9-ad06-bb3178b0ceb8"
/>
</details>

Optimized model unsqueezed and transposed input[1]. The original squeeze
was removed due to the unsqueeze, leaving only the Transpose.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/e56261d4-eba6-4a9f-847b-dcd33548dd07"
/>
</details>

### III. Test canceling existing Transpose before per-axis DQ
Original model has input[1] that is transposed.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/f157e04a-572a-479d-8e3b-cf57954df5c0"
/>
</details>

Optimized model transposed input[1], thus canceling the existing
transpose.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/63d742ce-3762-4ab2-bdb0-1b507886da9d"
/>
</details>

### IV. Test QDQ fix-up of Transpose/Unsqueeze for per-axis quantization
Original model has input[1] that can be broadcasted.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/96c0092c-22ec-486d-882e-e2cb59ffe324"
/>
</details>

The main transpose optimization loop inserts float32 Unsqueeze and
Transpose after the DQ. The qdq fix-up pass inserts new per-axis Q/DQ
ops after the inserted nodes.
<details><summary>click expand model image</summary>
<img
src="https://github.com/user-attachments/assets/b6f89c11-974d-4b35-922f-11effdf06883"
/>
</details>


### Motivation and Context
Enables the TransposeOptimizer to support more models with per-axis QDQ
nodes. Per-axis quantization can improve model accuracy and is used by
EPs like QNN.

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Adrian Lizarraga 2024-09-05 17:26:17 -07:00 коммит произвёл GitHub
Родитель 23f6604c39
Коммит b011f6fbf6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 767 добавлений и 257 удалений

Просмотреть файл

@ -245,16 +245,174 @@ static std::unique_ptr<api::NodeRef> GetDQIfProducingValue(const api::GraphRef&
: std::unique_ptr<api::NodeRef>();
}
// Forward declarations.
static bool NormalizeAndValidateAxes(std::vector<int64_t>& axes, size_t rank);
static std::optional<std::vector<int64_t>> ReadFromAttrOrInput(const api::GraphRef& graph, api::NodeRef& node,
std::string_view attr_name, size_t inp_index,
int64_t opset);
static int64_t UnsqueezeAxis(gsl::span<const int64_t> positive_unsqueeze_axes, int64_t axis);
// Quantization modes for QuantizeLinear/DequantizeLinear.
enum class QuantizationMode : uint8_t {
kUnknown,
kPerTensor,
kPerAxis,
kBlocked,
};
// Returns true if this optimizer supports the given quantization mode.
constexpr bool IsSupportedQuantizationMode(QuantizationMode mode) {
return (mode == QuantizationMode::kPerTensor) || (mode == QuantizationMode::kPerAxis);
}
// Stores quantization info for a validated Q or DQ node.
struct QuantizationInfo {
QuantizationMode mode;
int64_t norm_axis; // 'normalized' axis is in the range [0, Input0Rank - 1]
};
// Returns quantization information (quantization mode, normalized axis).
// Returns std::nullopt if unable to get the quantization info or if the axis attribute is invalid.
static std::optional<QuantizationInfo> GetQuantizationInfo(const api::GraphRef& graph,
const api::NodeRef& q_or_dq_node) {
const std::vector<std::string_view> inputs = q_or_dq_node.Inputs();
// Need to use the scale input's shape to determine the quantization mode. Can't just use the presence of the axis
// attribute because even per-tensor Q/DQ have a default axis of 1.
std::string_view scale_input = inputs[1];
const std::unique_ptr<api::ValueInfoRef> scale_value_info = graph.GetValueInfo(scale_input);
std::optional<std::vector<int64_t>> scale_shape = scale_value_info->Shape();
if (!scale_shape) {
return std::nullopt;
}
QuantizationInfo quant_info = {};
if (IsScalarOr1Element1DTensor(*scale_shape)) {
// A scalar or tensor scale with shape (1,) indicates per-tensor quantization.
quant_info.mode = QuantizationMode::kPerTensor;
quant_info.norm_axis = 1; // 1 is the default 'axis' even for per-tensor quantization (axis is ignored).
} else {
// This is either per-axis or blocked quantization. A non-zero block_size attribute indicates blocked quantization.
int64_t axis = q_or_dq_node.GetAttributeIntDefault("axis", 1);
const auto input0_info = graph.GetValueInfo(inputs[0]);
auto input0_rank = input0_info->ShapeRank();
if (!input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *input0_rank)) {
// Unable to normalize the DQ's axis.
// TODO(adrianlizarraga): Should look into a logging facility to make it easier to inspect issues.
return std::nullopt;
}
int64_t block_size = q_or_dq_node.GetAttributeIntDefault("block_size", 0);
quant_info.norm_axis = axis;
quant_info.mode = block_size != 0 ? QuantizationMode::kBlocked : QuantizationMode::kPerAxis;
}
return quant_info;
}
/// <summary>
/// Represents a DequantizeLinear node that TransposeInputImpl() or UnsqueezeInput() can look past to transpose or
/// unsqueeze its input.
/// </summary>
class DQToLookPast {
public:
DQToLookPast(std::unique_ptr<api::NodeRef>&& dq_node, QuantizationInfo quant_info)
: dq_node_(std::move(dq_node)), quant_info_(quant_info) {
assert(dq_node_ != nullptr); // Expect dq_node to be valid.
}
DQToLookPast(DQToLookPast&& other) = default;
DQToLookPast& operator=(DQToLookPast&& other) = default;
inline void DisconnectInput0() {
dq_node_->SetInput(0, "");
}
inline void ReconnectInput0(std::string_view input_name) {
dq_node_->SetInput(0, input_name);
}
inline std::string_view GetInput0() const {
return dq_node_->Inputs()[0];
}
inline std::string_view GetOutput() const {
return dq_node_->Outputs()[0];
}
/// <summary>
/// Sets the DQ's new transposed input[0]. The DQ's axis and output shape are updated.
/// </summary>
/// <param name="graph">graph</param>
/// <param name="new_input">name of new transposed input[0]</param>
/// <param name="perm_inv">inverse transpose permutation used to update the DQ's axis</param>
void SetTransposedInput(api::GraphRef& graph, std::string_view new_input, gsl::span<const int64_t> perm_inv) {
if (quant_info_.mode == QuantizationMode::kPerAxis) {
quant_info_.norm_axis = perm_inv[gsl::narrow_cast<size_t>(quant_info_.norm_axis)];
}
this->SetUpdatedInput(graph, new_input);
}
/// <summary>
/// Sets the DQ's new unsqueezed input[0]. The DQ's axis and output shape are updated.
/// The provided unsqueeze axes must all be positive (i.e., normalized).
/// </summary>
/// <param name="graph">graph</param>
/// <param name="new_input">name of new unsqueezed input[0]</param>
/// <param name="perm_inv">positive unsqueeze axes used to update the DQ's axis</param>
void SetUnsqueezedInput(api::GraphRef& graph, std::string_view new_input,
gsl::span<const int64_t> positive_unsqueeze_axes) {
if (quant_info_.mode == QuantizationMode::kPerAxis) {
quant_info_.norm_axis = UnsqueezeAxis(positive_unsqueeze_axes, quant_info_.norm_axis);
}
this->SetUpdatedInput(graph, new_input);
}
/// <summary>
/// Static function that returns/moves the DQ node from a std::optional DQToLookPast.
/// The std::optional DQToLookPast is set to std::nullopt to prevent accidental reuse.
/// </summary>
/// <param name="dq_to_look_past">The DQToLookPast to steal the node from</param>
/// <returns>The dq_node unique_ptr moved from dq_to_look_past</returns>
static std::unique_ptr<api::NodeRef> TakeDQNode(std::optional<DQToLookPast>& dq_to_look_past) {
std::unique_ptr<api::NodeRef> node;
if (dq_to_look_past) {
node = std::move(dq_to_look_past->dq_node_);
dq_to_look_past = std::nullopt;
}
return node;
}
private:
// Called by SetTransposedInput() and SetUnsqueezedInput() to update the DQ's input,
// axis, and output shape.
void SetUpdatedInput(api::GraphRef& graph, std::string_view new_input) {
dq_node_->SetInput(0, new_input);
dq_node_->SetAttributeInt("axis", quant_info_.norm_axis);
auto new_shape = *graph.GetValueInfo(new_input)->Shape();
graph.GetValueInfo(dq_node_->Outputs()[0])->SetShape(&new_shape);
}
std::unique_ptr<api::NodeRef> dq_node_;
QuantizationInfo quant_info_;
};
/// <summary>
/// Return a DequantizeLinear node if it's input is a constant initializer and it has a single consumer.
/// In this case the initializer can be updated in-place by UnsqueezeInput or TransposeInput.
/// </summary>
/// <param name="graph">Current graph</param>
/// <param name="value_name">Value to check if produced by a DQ node who's input is a constant initializer</param>
/// <returns>NodeRef for DQ node if it meets the requirements.</returns>
static std::unique_ptr<api::NodeRef> GetDQWithConstInitializerInputAndSingleConsumer(const api::GraphRef& graph,
std::string_view value_name) {
std::unique_ptr<api::NodeRef> result;
/// <returns>DQToLookPast for DQ node if it meets the requirements, or std::nullopt otherwise.</returns>
static std::optional<DQToLookPast> GetDQWithConstInitializerInputAndSingleConsumer(const api::GraphRef& graph,
std::string_view value_name) {
std::optional<DQToLookPast> result;
std::optional<QuantizationInfo> quant_info;
auto dq_node = GetDQIfProducingValue(graph, value_name);
if (dq_node) {
@ -267,10 +425,10 @@ static std::unique_ptr<api::NodeRef> GetDQWithConstInitializerInputAndSingleCons
break;
}
// For now keep it simple and don't support per-axis quantization as that would require updating the axis of
// the DQ node during TransposeInputImpl and UnsqueezeInput.
auto dq_scale = graph.GetConstant(dq_node->Inputs()[1]);
if (!dq_scale || dq_scale->NumElements() != 1) {
// Get the quantization mode (per-tensor, per-channel) and the normalized quantization axis.
// To keep things simple, do not support blocked quantization for now (added in opset 21).
quant_info = GetQuantizationInfo(graph, *dq_node);
if (!quant_info || !IsSupportedQuantizationMode(quant_info->mode)) {
break;
}
@ -285,20 +443,13 @@ static std::unique_ptr<api::NodeRef> GetDQWithConstInitializerInputAndSingleCons
break;
}
result = std::move(dq_node);
result = DQToLookPast(std::move(dq_node), *quant_info);
} while (false);
}
return result;
}
// Forward declarations for utils used by MakeQDQNodeUnit
static bool NormalizeAndValidateAxes(std::vector<int64_t>& axes, size_t rank);
static std::optional<std::vector<int64_t>> ReadFromAttrOrInput(const api::GraphRef& graph, api::NodeRef& node,
std::string_view attr_name, size_t inp_index,
int64_t opset);
static int64_t UnsqueezeAxis(gsl::span<const int64_t> sorted_positive_unsqueeze_axes, int64_t axis);
/// <summary>
/// Insert a Q -> DQ pair after the node following the DQ by using scale and zp info from the preceding DQ node.
/// DQ -> next node => DQ -> next node -> Q -> DQ.
@ -324,34 +475,21 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) {
const bool is_unsqueeze = next_node.OpType() == "Unsqueeze";
const auto scale_input = dq_inputs[1];
const auto scale_value_info = graph.GetValueInfo(scale_input);
std::optional<std::string_view> zp_input;
std::optional<std::unique_ptr<api::ValueInfoRef>> zp_value_info;
auto scale_shape = scale_value_info->Shape();
if (!scale_shape) {
// axis potentially needs updating due to the transpose or unsqueeze but we don't have the required info to do it.
return false;
}
if (dq_inputs.size() > 2) {
zp_input = dq_inputs[2];
zp_value_info = graph.GetValueInfo(zp_input.value());
}
// DQ uses per-axis quantization if its scale input is not a scalar and not a tensor with shape (1,).
// Note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization,
// so we have to check the scale input's shape.
const bool update_dq_axis = !IsScalarOr1Element1DTensor(*scale_shape);
int64_t axis = dq_node.GetAttributeIntDefault("axis", 1);
std::optional<QuantizationInfo> dq_quant_info = GetQuantizationInfo(graph, dq_node);
if (!dq_quant_info || !IsSupportedQuantizationMode(dq_quant_info->mode)) {
return false; // Can't get the quantization mode/axis or is a quantization mode that is not supported.
}
if (update_dq_axis) {
const auto dq_input0_info = graph.GetValueInfo(dq_inputs[0]);
auto dq_input0_rank = dq_input0_info->ShapeRank();
if (!dq_input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *dq_input0_rank)) {
return false; // Unable to normalize the DQ's axis.
}
int64_t axis = dq_quant_info->norm_axis;
// Have to update the axis for newly inserted Q/DQ after a Transpose or Unsqueeze if using per-axis quantization.
if (dq_quant_info->mode == QuantizationMode::kPerAxis) {
if (is_transpose) {
auto perm = GetPermAttrIfValid(next_node);
assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid
@ -360,15 +498,19 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) {
auto axes = ReadFromAttrOrInput(graph, next_node, "axes", /*inp_index*/ 1, /*opset*/ 13);
assert(axes.has_value()); // 'axes' are required for Unsqueeze
const auto dq_output_info = graph.GetValueInfo(dq_node.Outputs()[0]);
std::optional<size_t> dq_output_rank = dq_output_info->ShapeRank();
if (!dq_output_rank.has_value()) {
return false; // Need to know the rank of the input to the Unsqueeze to normalize unsqueeze axes
}
// Normalize negative unsqueeze axes by adding output rank.
// Unsqueeze output rank = input_rank + axes.size()
// Unsqueeze's input rank is the same as the DQ's input[0] rank.
if (!NormalizeAndValidateAxes(*axes, *dq_input0_rank + axes->size())) {
// Unsqueeze output rank = unsqueeze input rank + axes.size()
if (!NormalizeAndValidateAxes(*axes, *dq_output_rank + axes->size())) {
return false;
}
// Need to update axis if Unsqueeze inserts a 1 before the axis dim.
std::sort(axes->begin(), axes->end());
axis = UnsqueezeAxis(*axes, axis);
}
}
@ -567,7 +709,7 @@ static std::optional<std::vector<int64_t>> ReadFromAttrOrInput(const api::GraphR
}
// Computes inverse permutation. Unsafe if perm is not a valid permutation.
std::vector<int64_t> InvertPerm(const std::vector<int64_t>& perm) {
std::vector<int64_t> InvertPerm(gsl::span<const int64_t> perm) {
size_t rank = perm.size();
std::vector<int64_t> perm_inv(rank);
for (size_t i = 0; i < rank; ++i) {
@ -754,15 +896,18 @@ static std::vector<int64_t> SqueezePerm(const std::vector<int64_t>& axes, const
}
// Computes a new axis value for an unsqueezed version of a tensor. Incorrect if any axes
// values are negative, duplicated, or are not sorted in increasing order.
// values are negative or duplicated.
//
// Ex: axes = [0, 1, 2], axis = 0, new_axis = 3
// axes = [0, 1, 3], axis = 1, new_axis = 4
static int64_t UnsqueezeAxis(gsl::span<const int64_t> sorted_positive_unsqueeze_axes, int64_t axis) {
static int64_t UnsqueezeAxis(gsl::span<const int64_t> positive_unsqueeze_axes, int64_t axis) {
assert(axis >= 0);
int64_t new_axis = axis;
for (int64_t unsqueeze_axis : sorted_positive_unsqueeze_axes) {
std::vector<int64_t> sorted_axes(positive_unsqueeze_axes.begin(), positive_unsqueeze_axes.end());
std::sort(sorted_axes.begin(), sorted_axes.end());
for (int64_t unsqueeze_axis : sorted_axes) {
if (unsqueeze_axis <= new_axis) {
new_axis += 1;
}
@ -812,12 +957,6 @@ static std::vector<int64_t> SortedAxesForTransposedInput(const std::vector<int64
return new_axes;
}
static void UpdateDQNodeInputAndShape(api::GraphRef& graph, api::NodeRef& dq, std::string_view new_input) {
dq.SetInput(0, new_input);
auto new_shape = *graph.GetValueInfo(new_input)->Shape();
graph.GetValueInfo(dq.Outputs()[0])->SetShape(&new_shape);
}
/////// </Helper Utils> ///////
/////// <Core Helpers> ///////
@ -834,27 +973,27 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
std::unique_ptr<api::TensorRef> constant = ctx.graph.GetLocalConstant(input);
// allow a constant initializer coming via a DQ node with a single consumer
std::unique_ptr<api::NodeRef> dq_node;
std::optional<DQToLookPast> dq_to_look_past;
std::string_view constant_dq_input;
if (!constant) {
// look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist
// to enable directly making changes to the initializer. any nodes added for other consumers of the initializer
// in 'Case 1' are prior to the DQ so we don't break up any QDQ node units.
dq_node = GetDQWithConstInitializerInputAndSingleConsumer(ctx.graph, input);
if (dq_node) {
dq_to_look_past = GetDQWithConstInitializerInputAndSingleConsumer(ctx.graph, input);
if (dq_to_look_past) {
// underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input
constant_dq_input = dq_node->Inputs()[0];
constant_dq_input = dq_to_look_past->GetInput0();
constant = ctx.graph.GetLocalConstant(constant_dq_input);
// remove the DQ node as a consumer of the initializer while we modify things
dq_node->SetInput(0, "");
dq_to_look_past->DisconnectInput0();
}
}
// Clear the input, which also removes this node's input as a consumer of the value.
// NOTE: the node may have multiple inputs consuming the value.
node.SetInput(i, "");
auto value_to_modify = dq_node ? constant_dq_input : input;
auto value_to_modify = dq_to_look_past ? constant_dq_input : input;
auto consumers = ctx.graph.GetValueConsumers(value_to_modify);
// Case 1: input is a constant with a known list of consumer nodes
@ -873,8 +1012,8 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
auto new_shape = UnsqueezeShape(constant->Shape(), axes);
ctx.graph.ReshapeInitializer(value_to_modify, new_shape);
if (dq_node) {
UpdateDQNodeInputAndShape(ctx.graph, *dq_node, constant_dq_input);
if (dq_to_look_past) {
dq_to_look_past->SetUnsqueezedInput(ctx.graph, constant_dq_input, axes);
}
node.SetInput(i, input); // restore the original connection
@ -886,10 +1025,14 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
// look past a DQ node for a Squeeze to cancel
if (inp_node && inp_node->OpType() == "DequantizeLinear") {
dq_node = std::move(inp_node);
auto dq_input = dq_node->Inputs()[0];
inp_node = ctx.graph.GetNodeProducingOutput(dq_input);
consumers = ctx.graph.GetValueConsumers(dq_input);
std::optional<QuantizationInfo> dq_quant_info = GetQuantizationInfo(ctx.graph, *inp_node);
if (dq_quant_info && IsSupportedQuantizationMode(dq_quant_info->mode)) {
dq_to_look_past = std::make_optional<DQToLookPast>(std::move(inp_node), *dq_quant_info);
auto dq_input = dq_to_look_past->GetInput0();
inp_node = ctx.graph.GetNodeProducingOutput(dq_input);
consumers = ctx.graph.GetValueConsumers(dq_input);
}
}
if (inp_node != nullptr && inp_node->IsOp("Squeeze")) {
@ -897,9 +1040,9 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
std::optional<std::vector<int64_t>> squeeze_axes = std::nullopt;
squeeze_axes = ReadFromAttrOrInput(ctx, *inp_node, "axes", /*inp_index*/ 1, /*opset*/ 13);
if (squeeze_axes != std::nullopt && *squeeze_axes == axes) {
if (dq_node) {
UpdateDQNodeInputAndShape(ctx.graph, *dq_node, inp_node_inputs[0]);
node.SetInput(i, dq_node->Outputs()[0]);
if (dq_to_look_past) {
dq_to_look_past->SetUnsqueezedInput(ctx.graph, inp_node_inputs[0], axes);
node.SetInput(i, dq_to_look_past->GetOutput());
} else {
node.SetInput(i, inp_node_inputs[0]);
}
@ -907,7 +1050,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
// Remove the Squeeze node if possible
// if there's a DQ node the `consumers` list still includes it so allow for that.
// in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it.
if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) {
if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_to_look_past ? 1 : 0)) {
ctx.graph.RemoveNode(*inp_node);
if (ctx.opset >= 13 && !ctx.graph.HasValueConsumers(inp_node_inputs[1])) {
@ -922,8 +1065,8 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons
}
// any DQ node special casing doesn't apply anymore, so go back to the original inp_node
if (dq_node) {
inp_node = std::move(dq_node);
if (dq_to_look_past) {
inp_node = DQToLookPast::TakeDQNode(dq_to_look_past);
}
// Case 3: Add an Unsqueeze node.
@ -989,20 +1132,20 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
std::unique_ptr<api::TensorRef> constant = graph.GetLocalConstant(input);
// allow a constant initializer coming via a DQ node with a single consumer
std::unique_ptr<api::NodeRef> dq_node;
std::optional<DQToLookPast> dq_to_look_past;
std::string_view constant_dq_input;
if (!constant) {
// look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist
// to enable directly making changes to the initializer. any nodes added for other consumers of the initializer
// in 'Case 1' are prior to the DQ so we don't break up any QDQ node units.
dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, input);
if (dq_node) {
dq_to_look_past = GetDQWithConstInitializerInputAndSingleConsumer(graph, input);
if (dq_to_look_past) {
// underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input
constant_dq_input = dq_node->Inputs()[0];
constant_dq_input = dq_to_look_past->GetInput0();
constant = graph.GetLocalConstant(constant_dq_input);
// remove the DQ node as a consumer of the initializer while we modify things
dq_node->SetInput(0, "");
dq_to_look_past->DisconnectInput0();
}
}
@ -1010,7 +1153,7 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
// NOTE: the node may have multiple inputs consuming the value.
node.SetInput(i, "");
auto constant_to_modify = dq_node ? constant_dq_input : input;
auto constant_to_modify = dq_to_look_past ? constant_dq_input : input;
auto consumers = graph.GetValueConsumers(constant_to_modify);
// Case 1: input is a constant with a known list of consumer nodes
@ -1018,13 +1161,14 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
// we modify the initializer in-place and need to reconnect things up when we're done. this helper will
// do that when it goes out of scope. if we have manually reconnected, input or constant_dq_input is
// set to an empty string.
auto reconnect_nodes = gsl::finally([i, &node, &dq_node, &input, &constant_dq_input] {
auto reconnect_nodes = gsl::finally([i, &node, &dq_to_look_past, &input, &constant_dq_input] {
if (!input.empty()) {
node.SetInput(i, input);
}
if (!constant_dq_input.empty()) {
dq_node->SetInput(0, constant_dq_input);
assert(dq_to_look_past);
dq_to_look_past->ReconnectInput0(constant_dq_input);
}
});
@ -1038,16 +1182,19 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
// Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if
// there are no other consumers.
if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast<int64_t>(perm.size())) {
auto& node_to_update = dq_node ? *dq_node : node;
Permute1DConstant(graph, node_to_update, *constant, i, constant_to_modify, perm);
// A quantized (roi/scales/sizes) input for Resize or a quantized pads input for Pad would be unlikely.
// Even if it occurred, HandleResize()/HandlePad() permute these kinds of inputs directly and do not try to
// call TransposeInput() on them. Also, WrapTransposesAroundNode() does not call TransposeInput() on non-const
// inputs of this kind. So, we should not have a DQ to look past at this point.
//
// In the event that we decide to handle a DQ in the future, note that the DQ axis should not be
// changed (remains 0), but all DQ inputs should be permuted.
assert(!dq_to_look_past);
Permute1DConstant(graph, node, *constant, i, constant_to_modify, perm);
// unset updated input so reconnect_nodes doesn't change it back
if (dq_node) {
constant_dq_input = "";
} else {
input = "";
}
input = "";
return;
}
@ -1063,8 +1210,8 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
graph.TransposeInitializer(constant_to_modify, perm);
if (dq_node) {
UpdateDQNodeInputAndShape(graph, *dq_node, constant_to_modify);
if (dq_to_look_past) {
dq_to_look_past->SetTransposedInput(graph, constant_to_modify, perm_inv);
constant_dq_input = ""; // DQ input was already updated so we don't need reconnect_nodes to handle it
}
@ -1076,10 +1223,14 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
// Look past a DQ for the Transpose
if (inp_node && inp_node->OpType() == "DequantizeLinear") {
dq_node = std::move(inp_node);
auto dq_input = dq_node->Inputs()[0];
inp_node = graph.GetNodeProducingOutput(dq_input);
consumers = graph.GetValueConsumers(dq_input);
std::optional<QuantizationInfo> dq_quant_info = GetQuantizationInfo(graph, *inp_node);
if (dq_quant_info && IsSupportedQuantizationMode(dq_quant_info->mode)) {
dq_to_look_past = std::make_optional<DQToLookPast>(std::move(inp_node), *dq_quant_info);
std::string_view dq_input = dq_to_look_past->GetInput0();
inp_node = graph.GetNodeProducingOutput(dq_input);
consumers = graph.GetValueConsumers(dq_input);
}
}
if (inp_node != nullptr && inp_node->IsOp("Transpose")) {
@ -1089,9 +1240,9 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
if (*perm2 == perm_inv) {
std::string_view pre_transpose_value = inp_node->Inputs()[0];
if (dq_node) {
UpdateDQNodeInputAndShape(graph, *dq_node, pre_transpose_value);
node.SetInput(i, dq_node->Outputs()[0]);
if (dq_to_look_past) {
dq_to_look_past->SetTransposedInput(graph, pre_transpose_value, perm_inv);
node.SetInput(i, dq_to_look_past->GetOutput());
} else {
node.SetInput(i, pre_transpose_value);
}
@ -1099,14 +1250,14 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
// Remove the Transpose node if possible
// if there's a DQ node the `consumers` list still includes it so allow for that.
// in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it.
if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) {
if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_to_look_past ? 1 : 0)) {
graph.RemoveNode(*inp_node);
}
return;
}
if (!dq_node) {
if (!dq_to_look_past) {
// Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove
// the other Transpose.
const std::vector<int64_t>& perm_combined = ComposePerm(*perm2, perm);
@ -1130,8 +1281,8 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t
}
// any DQ node special casing doesn't apply anymore, so go back to the original inp_node
if (dq_node) {
inp_node = std::move(dq_node);
if (dq_to_look_past) {
inp_node = DQToLookPast::TakeDQNode(dq_to_look_past);
consumers = graph.GetValueConsumers(input);
}
@ -1301,8 +1452,8 @@ static bool CanLikelyRemoveTranspose(const api::GraphRef& graph, api::NodeRef& t
// - the value is a constant initializer
// - the value is the output of a DQ node who's input is a constant initializer
// - UnsqueezeInput/TransposeInput can look past the DQ to update the constant initializer directly
// - DQ node is currently ignored if it uses per-channel quantization
// - supporting per-channel quantization requires modifying the scales and zero point data, which can be done
// - DQ node is currently ignored if it uses blocked quantization (per-tensor and per-axis are supported).
// - supporting blocked quantization requires modifying the scales and zero point data, which can be done
// if/when there's a use-case to justify the development cost.
// - the input was originally connected to a shared constant initializer that was updated in place by UnsqueezeInput
// or TransposeInput, and usage by this node had Squeeze/Transpose nodes inserted to counteract the effect of the
@ -1321,8 +1472,8 @@ static bool IsConstant(const api::GraphRef& graph, std::string_view value_name)
// look past a DQ node
if (producer_node->OpType() == "DequantizeLinear") {
std::unique_ptr<api::NodeRef> dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, value_name);
if (dq_node != nullptr) {
std::optional<DQToLookPast> dq_to_look_past = GetDQWithConstInitializerInputAndSingleConsumer(graph, value_name);
if (dq_to_look_past) {
// DQ node pointing to an constant initializer
return true;
}
@ -1363,7 +1514,7 @@ static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_vi
if (dq_input_node != nullptr) {
if (dq_input_node->OpType() == "Squeeze") {
auto squeeze_input_node = graph.GetNodeProducingOutput(dq_input_node->Inputs()[0]);
if (squeeze_input_node->OpType() == "Transpose") {
if (squeeze_input_node != nullptr && squeeze_input_node->OpType() == "Transpose") {
// we only want to set this if it is a Transpose as otherwise we're invalidating the cost given it is
// rank based and the Squeeze will change that.
producer_node = std::move(squeeze_input_node);
@ -2733,37 +2884,23 @@ static bool TryFixTransposeMissingDQ(OptimizerCtx& ctx, api::NodeRef& transpose_
const auto q_domain = q_node.Domain();
const auto scale_input = q_node_inputs[1];
const auto scale_value_info = ctx.graph.GetValueInfo(scale_input);
std::optional<std::string_view> zp_input;
std::optional<std::unique_ptr<api::ValueInfoRef>> zp_value_info;
auto scale_shape = scale_value_info->Shape();
if (!scale_shape) {
// Axis potentially needs updating due to the transpose but we don't have the required info to do it.
return false;
}
if (q_node_inputs.size() > 2) {
zp_input = q_node_inputs[2];
zp_value_info = ctx.graph.GetValueInfo(zp_input.value());
}
// Q uses per-axis quantization if its scale input is not a scalar and not a tensor with shape (1,).
// Note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization,
// so we have to check the scale input's shape.
const bool update_axis = !IsScalarOr1Element1DTensor(*scale_shape);
int64_t axis = q_node.GetAttributeIntDefault("axis", 1);
std::optional<QuantizationInfo> q_quant_info = GetQuantizationInfo(ctx.graph, q_node);
if (!q_quant_info || !IsSupportedQuantizationMode(q_quant_info->mode)) {
return false; // Can't get quantization mode/axis or is a quantization mode that is not supported.
}
if (update_axis) {
int64_t axis = q_quant_info->norm_axis;
if (q_quant_info->mode == QuantizationMode::kPerAxis) {
// Have to update the axis for newly inserted Q/DQ before a Transpose if using per-channel quantization.
auto perm = GetPermAttrIfValid(transpose_node);
assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid
const auto q_input0_info = ctx.graph.GetValueInfo(q_node_inputs[0]);
std::optional<size_t> q_input0_rank = q_input0_info->ShapeRank();
if (!q_input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *q_input0_rank)) {
return false; // Unable to normalize the Q's axis.
}
assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid
axis = (*perm)[gsl::narrow_cast<size_t>(axis)]; // Note: do not invert permutation.
}

Просмотреть файл

@ -3,6 +3,7 @@
#pragma once
#include <gsl/gsl>
#include <unordered_map>
#include <vector>
@ -59,7 +60,7 @@ struct OptimizerCtx {
/// <returns>{0}</returns>
inline std::vector<size_t> FirstInput(OptimizerCtx&, api::NodeRef&) { return {0}; }
std::vector<int64_t> InvertPerm(const std::vector<int64_t>& perm);
std::vector<int64_t> InvertPerm(gsl::span<const int64_t> perm);
// Transpose all inputs and all outputs
bool HandleSimpleNode(HandlerArgs& args);

Просмотреть файл

@ -4848,20 +4848,135 @@ TEST(TransposeOptimizerTests, ConstantFoldTransposeAndSqueezeOutputCorrectness)
testing::ContainerEq(fetches[1].Get<Tensor>().DataAsSpan<float>()));
}
// Tests the fix-up of a QDQ NodeUnit containing a per-channel DQ followed by an Unsqueeze.
// Before: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Op
// After: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Q (axis = 3) -> DQ (axis = 3) -> Op
TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) {
// Test model contains a Mul with a broadcastable/constant/per-channel DQ input. When a transpose is pushed through
// the Mul, the contant DQ input is Unsqueezed.
auto model_uri = ORT_TSTR("testdata/transpose_optimization_unsqueeze_dq_axis.qdq.onnx");
// Utility to get the axis attribute for a Q or DQ node.
static void GetQOrDQAxis(const Node& q_or_dq_node, /*out*/ int64_t& axis) {
const NodeAttributes& attrs = q_or_dq_node.GetAttributes();
auto axis_attr_it = attrs.find("axis");
axis = 1;
if (axis_attr_it != attrs.end()) {
auto axis_attr = axis_attr_it->second;
ASSERT_TRUE(axis_attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
axis = axis_attr.i();
}
}
// Tests the fix-up of a QDQ NodeUnit containing a per-axis DQ followed by an Unsqueeze and Transpose.
// Before: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Transpose (perm = [0, 3, 1, 2]) -> Op
// After: DQ (axis = 0) -> Unsqueeze -> Q(axis = 3) -> DQ(axis = 3) -> Transpose -> Q(axis = 1) -> DQ(axis = 1) -> Op
TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerAxisDQUnsqueezeTranspose) {
// Model contains a Mul with a broadcastable/per-axis DQ input[1]. When a transpose is pushed through
// the Mul's input[0], input[1]'s input is unsqueezed and transposed.
auto model_uri = ORT_TSTR("testdata/transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx");
RandomValueGenerator random{123};
std::vector<int64_t> input_dims{1, 3, 4, 4};
std::vector<float> input0_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f);
std::vector<int8_t> input1_data = {0, 1, 2};
auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators();
OrtValue input0;
OrtValue input1;
CreateMLValue<float>(allocators[0], input_dims, input0_data, &input0);
CreateMLValue<int8_t>(allocators[0], {3}, input1_data, &input1);
NameMLValMap feeds{{"input0", input0}, {"input1", input1}};
std::vector<std::string> output_names{"output0"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;
SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"));
so.graph_optimization_level = TransformerLevel::Default; // off
// get results with no modifications to the model
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
Graph& graph = session.GetMutableGraph();
CPUAllocator allocator;
namespace alias_oto = onnx_transpose_optimization;
auto api_graph = MakeApiGraph(graph,
TestCPUExecutionProvider()->CreatePreferredAllocators()[0],
/*new_node_ep*/ nullptr);
alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph);
ASSERT_EQ(result.error_msg, std::nullopt);
ASSERT_TRUE(result.graph_modified);
ASSERT_TRUE(graph.GraphResolveNeeded());
ASSERT_STATUS_OK(graph.Resolve());
// Use this hack to save model for viewing if needed
// ASSERT_STATUS_OK(Model::Save(const_cast<Model&>(session.GetModel()),
// ToPathString("transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.debug.onnx")));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Unsqueeze"], 1) << "1 Unsqueeze node added to broadcastable Mul weight.";
EXPECT_EQ(op_to_count["Transpose"], 1) << "2 Transposes at the I/O cancel. 1 Transpose inserted above Mul weight.";
// Get the Unsqueeze and Transpose nodes.
Node* unsqueeze_node = nullptr;
Node* transpose_node = nullptr;
for (auto& node : graph.Nodes()) {
const std::string& op_type = node.OpType();
if (op_type == "Unsqueeze") {
unsqueeze_node = &node;
} else if (op_type == "Transpose") {
transpose_node = &node;
}
}
// DQ axis starts as 0
ASSERT_TRUE(unsqueeze_node != nullptr);
const auto& dq_before_unsqueeze = *(unsqueeze_node->InputNodesBegin());
int64_t dq_before_unsqueeze_axis = 1;
GetQOrDQAxis(dq_before_unsqueeze, dq_before_unsqueeze_axis);
EXPECT_EQ(dq_before_unsqueeze_axis, 0);
// Axis changes to 3 after Unsqueeze
const auto& q_after_unsqueeze = *(unsqueeze_node->OutputNodesBegin());
int64_t q_after_unsqueeze_axis = 1;
GetQOrDQAxis(q_after_unsqueeze, q_after_unsqueeze_axis);
EXPECT_EQ(q_after_unsqueeze_axis, 3);
// Axis changes to 1 after Transpose
ASSERT_TRUE(transpose_node != nullptr);
const auto& q_after_transpose = *(transpose_node->OutputNodesBegin());
int64_t q_after_transpose_axis = 1;
GetQOrDQAxis(q_after_transpose, q_after_transpose_axis);
EXPECT_EQ(q_after_transpose_axis, 1);
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}
ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}
// Tests the in-place unsqueeze and transpose of a constant consumed by a per-axis DQ.
TEST(TransposeOptimizerTests, InPlaceUnsqueezeTransposePerAxisDQ) {
// Model contains a Mul with a constant/broadcastable/per-axis DQ input[1].
// When a transpose is pushed through the Mul's input[0], input[1]'s input is unsqueezed and transposed in-place.
auto model_uri = ORT_TSTR("testdata/transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx");
RandomValueGenerator random{123};
std::vector<int64_t> input_dims{1, 3, 4, 4};
std::vector<float> input0_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f);
auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators();
OrtValue input0;
CreateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input0_data, &input0);
CreateMLValue<float>(allocators[0], input_dims, input0_data, &input0);
NameMLValMap feeds{{"input0", input0}};
@ -4885,7 +5000,6 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) {
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
// We call the ONNX transpose optimizer directly to use a custom cost check function.
Graph& graph = session.GetMutableGraph();
CPUAllocator allocator;
@ -4894,22 +5008,7 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) {
TestCPUExecutionProvider()->CreatePreferredAllocators()[0],
/*new_node_ep*/ nullptr);
// Use a custom optimization cost check that aggressively pushes channel-last or channel-first transposes.
auto custom_cost_fn =
[](const alias_oto::api::GraphRef& /* graph */,
const alias_oto::api::NodeRef& /* node */,
const std::vector<int64_t>& perm,
const std::unordered_set<std::string>& /* outputs_leading_to_transpose */) -> alias_oto::CostCheckResult {
if (perm == alias_oto::ChannelFirstToLastPerm(perm.size()) ||
perm == alias_oto::ChannelLastToFirstPerm(perm.size())) {
return alias_oto::CostCheckResult::kPushTranspose;
}
return alias_oto::CostCheckResult::kFallThrough;
};
alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph, /*provider_type*/ "", custom_cost_fn);
alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph);
ASSERT_EQ(result.error_msg, std::nullopt);
ASSERT_TRUE(result.graph_modified);
ASSERT_TRUE(graph.GraphResolveNeeded());
@ -4917,11 +5016,180 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) {
// Use this hack to save model for viewing if needed
// ASSERT_STATUS_OK(Model::Save(const_cast<Model&>(session.GetModel()),
// ToPathString("transpose_optimization_unsqueeze_dq_axis.qdq.updated.onnx")));
// ToPathString("updated_model_inplace_peraxis.onnx")));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Unsqueeze"], 1) << "1 Unsqueeze node added to broadcastable Mul weight.";
EXPECT_EQ(op_to_count["Transpose"], 1) << "2 Transposes at the I/O cancel. 1 Transpose inserted above Mul weight.";
EXPECT_EQ(op_to_count["Unsqueeze"], 0) << "per-axis DQ constant was unsqueezed in-place.";
EXPECT_EQ(op_to_count["Transpose"], 0) << "2 pre-existing Transposes at the I/O cancel. "
<< "per-axis DQ constant was transposed in-place.";
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}
ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}
// Tests the canceling of a pre-existing Transpose before a per-axis DQ during call to TransposeInputImpl.
// Before: input1 -> Transpose(perm = [0, 2, 3, 1]) -> DQ (axis = -1) -> Mul
// After : input1 -> DQ (axis = 1) -> Mul
TEST(TransposeOptimizerTests, CancelTransposeBeforePerAxisDQ) {
auto model_uri = ORT_TSTR("testdata/transpose_optimizer_cancel_transpose_per_axis_dq.onnx");
RandomValueGenerator random{123};
std::vector<int64_t> input_dims{1, 3, 4, 4};
std::vector<float> input0_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f);
std::vector<int8_t> input1_data = {0, 1, 2};
auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators();
OrtValue input0;
OrtValue input1;
CreateMLValue<float>(allocators[0], input_dims, input0_data, &input0);
CreateMLValue<int8_t>(allocators[0], {1, 3, 1, 1}, input1_data, &input1);
NameMLValMap feeds{{"input0", input0}, {"input1", input1}};
std::vector<std::string> output_names{"output0"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;
SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"));
so.graph_optimization_level = TransformerLevel::Default; // off
// get results with no modifications to the model
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
Graph& graph = session.GetMutableGraph();
CPUAllocator allocator;
namespace alias_oto = onnx_transpose_optimization;
auto api_graph = MakeApiGraph(graph,
TestCPUExecutionProvider()->CreatePreferredAllocators()[0],
/*new_node_ep*/ nullptr);
alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph);
ASSERT_EQ(result.error_msg, std::nullopt);
ASSERT_TRUE(result.graph_modified);
ASSERT_TRUE(graph.GraphResolveNeeded());
ASSERT_STATUS_OK(graph.Resolve());
// Use this hack to save model for viewing if needed
// ASSERT_STATUS_OK(Model::Save(const_cast<Model&>(session.GetModel()),
// ToPathString("updated_model_peraxis_transpose_cancel.onnx")));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Transpose"], 0) << "2 Transposes at the I/O cancel. "
<< "Transpose inserted above Mul weight cancels.";
// Get the DQ above Mul's input[1]
Node* dq_node = nullptr;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "DequantizeLinear" && node.Name() == "dq_mul_input_1") {
dq_node = &node;
break;
}
}
// DQ axis changed from -1 (3) to 1 due to Tranpose above DQ being canceled.
ASSERT_TRUE(dq_node != nullptr);
int64_t dq_axis = 1;
GetQOrDQAxis(*dq_node, dq_axis);
EXPECT_EQ(dq_axis, 1);
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}
ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}
// Tests the canceling of a pre-existing Squeeze before a per-axis DQ during call to UnsqueezeInput.
// Before: input1 (shape = [1, 1, 1, 3]) -> Squeeze(axes = [0, 1, 2]) -> DQ (axis = 0) -> Mul
// After : input1 -> DQ (axis = 3) -> Mul
TEST(TransposeOptimizerTests, CancelSqueezeBeforePerAxisDQ) {
auto model_uri = ORT_TSTR("testdata/transpose_optimizer_cancel_squeeze_per_axis_dq.onnx");
RandomValueGenerator random{123};
std::vector<int64_t> input_dims{1, 3, 4, 4};
std::vector<float> input0_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f);
std::vector<int8_t> input1_data = {0, 1, 2};
auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators();
OrtValue input0;
OrtValue input1;
CreateMLValue<float>(allocators[0], input_dims, input0_data, &input0);
CreateMLValue<int8_t>(allocators[0], {1, 1, 1, 3}, input1_data, &input1);
NameMLValMap feeds{{"input0", input0}, {"input1", input1}};
std::vector<std::string> output_names{"output0"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;
SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"));
so.graph_optimization_level = TransformerLevel::Default; // off
// get results with no modifications to the model
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
Graph& graph = session.GetMutableGraph();
CPUAllocator allocator;
namespace alias_oto = onnx_transpose_optimization;
auto api_graph = MakeApiGraph(graph,
TestCPUExecutionProvider()->CreatePreferredAllocators()[0],
/*new_node_ep*/ nullptr);
alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph);
ASSERT_EQ(result.error_msg, std::nullopt);
ASSERT_TRUE(result.graph_modified);
ASSERT_TRUE(graph.GraphResolveNeeded());
ASSERT_STATUS_OK(graph.Resolve());
// Use this hack to save model for viewing if needed
// ASSERT_STATUS_OK(Model::Save(const_cast<Model&>(session.GetModel()),
// ToPathString("updated_model_peraxis_squeeze_cancel.onnx")));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Squeeze"], 0) << "Canceled by unsqueezed input consumed by per-axis DQ";
EXPECT_EQ(op_to_count["Unsqueeze"], 0) << "No Unsqueeze inserted because it cancels with pre-existing Squeeze.";
// Get the DQ above Mul's input[1]
Node* dq_node = nullptr;
for (auto& node : graph.Nodes()) {
if (node.OpType() == "DequantizeLinear" && node.Name() == "dq_mul_input_1") {
dq_node = &node;
break;
}
}
// DQ axis changed from 0 to 3 due to Squeeze above DQ being canceled.
ASSERT_TRUE(dq_node != nullptr);
int64_t dq_axis = 1;
GetQOrDQAxis(*dq_node, dq_axis);
EXPECT_EQ(dq_axis, 3);
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));

Просмотреть файл

@ -1,103 +0,0 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import numpy as np
import onnx
if __name__ == "__main__":
"""
Creates a QDQ model with a per-channel DQ weight that is Unsqueezed and Transposed by the Transpose optimizer.
"""
input0_shape = (1, 3, 4, 4)
input0 = onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)
output0 = onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None)
scale_1 = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1")
zp_128 = onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128")
scale_inv_255 = onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255")
zp_0 = onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0")
mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8)
mul_weight_scales_data = np.array([1.0, 1.0, 1.0], dtype=np.float32)
mul_weight_zps_data = np.array([0, 0, 0], dtype=np.int8)
mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight")
mul_weight_scales = onnx.numpy_helper.from_array(mul_weight_scales_data, "mul_weight_scales")
mul_weight_zps = onnx.numpy_helper.from_array(mul_weight_zps_data, "mul_weight_zps")
# Transpose to channel-last
tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1))
# Q_0
q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node")
# DQ_0
dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node")
# Sigmoid
sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node")
# Q_1
q1_node = onnx.helper.make_node(
"QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node"
)
# DQ_1
dq1_node = onnx.helper.make_node(
"DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node"
)
# DQ_weight
dq_weight_node = onnx.helper.make_node(
"DequantizeLinear",
["mul_weight", "mul_weight_scales", "mul_weight_zps"],
["dq_weight_out"],
name="dq_weight_node",
axis=0,
)
# Mul
mul_node = onnx.helper.make_node("Mul", ["dq1_out", "dq_weight_out"], ["mul_out"], name="mul_node")
# Q_2
q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node")
# DQ_2
dq2_node = onnx.helper.make_node(
"DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node"
)
# Transpose to channel-first
tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["output0"], name="tp1_node", perm=(0, 3, 1, 2))
graph = onnx.helper.make_graph(
[
tp0_node,
q0_node,
dq0_node,
sigmoid_node,
q1_node,
dq1_node,
dq_weight_node,
mul_node,
q2_node,
dq2_node,
tp1_node,
],
"transpose_opt_unsqueeze_dq_axis",
[input0],
[output0],
initializer=[scale_1, zp_128, scale_inv_255, zp_0, mul_weight, mul_weight_scales, mul_weight_zps],
)
opset_imports = [
onnx.helper.make_opsetid("", 19),
]
qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports)
print("[INFO]: Running onnx.checker on qdq model")
qdq_model = onnx.shape_inference.infer_shapes(qdq_model)
onnx.checker.check_model(qdq_model, True)
qdq_model_path = "transpose_optimization_unsqueeze_dq_axis.qdq.onnx"
print(f"[INFO]: Saving {qdq_model_path}")
onnx.save_model(qdq_model, qdq_model_path)

Просмотреть файл

@ -0,0 +1,207 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import numpy as np
import onnx
def subgraph_1d_const_input_dq(inputs, initializers, nodes) -> str:
"""
Creates mul_weight -> DQ. mul_weight is a constant of rank 1.
"""
mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8)
mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight")
initializers.append(mul_weight)
dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["mul_weight", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=0,
)
)
return dq_output_name
def subgraph_1d_input_dq(inputs, initializers, nodes) -> str:
"""
Creates input1 -> DQ. input1 is a graph input of rank 1.
"""
input1_shape = (3,)
inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape))
dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["input1", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=0,
)
)
return dq_output_name
def subgraph_4d_input_squeeze_dq(inputs, initializers, nodes) -> str:
"""
Creates input1 -> Squeeze -> DQ. input1 is a graph input of rank 4.
"""
input1_shape = (1, 1, 1, 3)
inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape))
axes_data = np.array([0, 1, 2], dtype=np.int64)
initializers.append(onnx.numpy_helper.from_array(axes_data, "axes_const"))
nodes.append(onnx.helper.make_node("Squeeze", ["input1", "axes_const"], ["squeeze_out"], name="squeeze_node"))
dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["squeeze_out", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=0,
)
)
return dq_output_name
def subgraph_4d_input_transpose_dq(inputs, initializers, nodes) -> str:
"""
Creates input1 -> Transpose -> DQ. input1 is a graph input of rank 4.
"""
input1_shape = (1, 3, 1, 1)
inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape))
perm = [0, 2, 3, 1] # To channel-last
nodes.append(onnx.helper.make_node("Transpose", ["input1"], ["tp_out_"], perm=perm, name="transpose_"))
dq_output_name = "mul_input_1"
nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
["tp_out_", "mul_weight_scales", "mul_weight_zps"],
[dq_output_name],
name="dq_mul_input_1",
axis=-1,
)
)
return dq_output_name
def make_model(model_path: str, build_mul_input_1_subgraph):
"""
Creates a QDQ model with a per-axis DQ input that is Unsqueezed and Transposed by the Transpose optimizer.
"""
input0_shape = (1, 3, 4, 4)
inputs = [onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)]
outputs = [onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None)]
mul_weight_scales_data = np.array([1.0, 1.0, 1.0], dtype=np.float32)
mul_weight_zps_data = np.array([0, 0, 0], dtype=np.int8)
initializers = [
onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1"),
onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128"),
onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255"),
onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0"),
onnx.numpy_helper.from_array(mul_weight_scales_data, "mul_weight_scales"),
onnx.numpy_helper.from_array(mul_weight_zps_data, "mul_weight_zps"),
]
nodes = []
# Transpose to channel-last
tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1))
nodes.append(tp0_node)
# Q_0
q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node")
nodes.append(q0_node)
# DQ_0
dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node")
nodes.append(dq0_node)
# Sigmoid
sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node")
nodes.append(sigmoid_node)
# Q_1
q1_node = onnx.helper.make_node(
"QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node"
)
nodes.append(q1_node)
# DQ_1
dq1_node = onnx.helper.make_node(
"DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node"
)
nodes.append(dq1_node)
# DQ for mul input[1]
mul_input_1_name = build_mul_input_1_subgraph(inputs, initializers, nodes)
# Mul
mul_node = onnx.helper.make_node("Mul", ["dq1_out", mul_input_1_name], ["mul_out"], name="mul_node")
nodes.append(mul_node)
# Q_2
q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node")
nodes.append(q2_node)
# DQ_2
dq2_node = onnx.helper.make_node(
"DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node"
)
nodes.append(dq2_node)
# Transpose to channel-first
tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["output0"], name="tp1_node", perm=(0, 3, 1, 2))
nodes.append(tp1_node)
graph = onnx.helper.make_graph(
nodes,
"transpose_opt_unsqueeze_dq_axis",
inputs,
outputs,
initializer=initializers,
)
opset_imports = [
onnx.helper.make_opsetid("", 19),
]
qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports)
print("[INFO]: Running onnx.checker on qdq model")
qdq_model = onnx.shape_inference.infer_shapes(qdq_model)
onnx.checker.check_model(qdq_model, True)
print(f"[INFO]: Saving {model_path}")
onnx.save_model(qdq_model, model_path)
if __name__ == "__main__":
make_model(
"transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx",
subgraph_1d_input_dq,
)
make_model(
"transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx",
subgraph_1d_const_input_dq,
)
make_model(
"transpose_optimizer_cancel_squeeze_per_axis_dq.onnx",
subgraph_4d_input_squeeze_dq,
)
make_model(
"transpose_optimizer_cancel_transpose_per_axis_dq.onnx",
subgraph_4d_input_transpose_dq,
)

Двоичные данные
onnxruntime/test/testdata/transpose_optimizer_cancel_squeeze_per_axis_dq.onnx поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
onnxruntime/test/testdata/transpose_optimizer_cancel_transpose_per_axis_dq.onnx поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
onnxruntime/test/testdata/transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx поставляемый Normal file

Двоичный файл не отображается.