[spirv] Optimize floating point matrix scaling codegen (#525)
SPIR-V has a specific OpMatrixTimesScalar for scaling floating point matrices.
This commit is contained in:
Родитель
50517e7f34
Коммит
76796801b8
|
@ -122,6 +122,18 @@ bool isCompoundAssignment(BinaryOperatorKind opcode) {
|
|||
}
|
||||
}
|
||||
|
||||
bool isSpirvMatrixOp(spv::Op opcode) {
|
||||
switch (opcode) {
|
||||
case spv::Op::OpMatrixTimesMatrix:
|
||||
case spv::Op::OpMatrixTimesVector:
|
||||
case spv::Op::OpMatrixTimesScalar:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
|
||||
|
@ -731,8 +743,10 @@ uint32_t SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
|
|||
if (opcode == BO_Assign)
|
||||
return processAssignment(expr->getLHS(), doExpr(expr->getRHS()), false);
|
||||
|
||||
// Try to optimize floatN * float case
|
||||
// Try to optimize floatMxN * float and floatN * float case
|
||||
if (opcode == BO_Mul) {
|
||||
if (const uint32_t result = tryToGenFloatMatrixScale(expr))
|
||||
return result;
|
||||
if (const uint32_t result = tryToGenFloatVectorScale(expr))
|
||||
return result;
|
||||
}
|
||||
|
@ -940,8 +954,10 @@ uint32_t
|
|||
SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
|
||||
const auto opcode = expr->getOpcode();
|
||||
|
||||
// Try to optimize floatN *= float case
|
||||
// Try to optimize floatMxN *= float and floatN *= float case
|
||||
if (opcode == BO_MulAssign) {
|
||||
if (const uint32_t result = tryToGenFloatMatrixScale(expr))
|
||||
return result;
|
||||
if (const uint32_t result = tryToGenFloatVectorScale(expr))
|
||||
return result;
|
||||
}
|
||||
|
@ -1370,7 +1386,11 @@ uint32_t SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
|
|||
const uint32_t resultType,
|
||||
uint32_t *lhsResultId,
|
||||
const spv::Op mandateGenOpcode) {
|
||||
if (TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
|
||||
// If the operands are of matrix type, we need to dispatch the operation
|
||||
// onto each element vector iff the operands are not degenerated matrices
|
||||
// and we don't have a matrix specific SPIR-V instruction for the operation.
|
||||
if (!isSpirvMatrixOp(mandateGenOpcode) &&
|
||||
TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
|
||||
return processMatrixBinaryOp(lhs, rhs, opcode);
|
||||
}
|
||||
|
||||
|
@ -1647,6 +1667,68 @@ uint32_t SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
uint32_t SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
|
||||
const QualType type = expr->getType();
|
||||
// We can only translate floatMxN * float into OpMatrixTimesScalar.
|
||||
// So the result type must be floatMxN.
|
||||
if (!hlsl::IsHLSLMatType(type) ||
|
||||
!hlsl::GetHLSLMatElementType(type)->isFloatingType())
|
||||
return 0;
|
||||
|
||||
const Expr *lhs = expr->getLHS();
|
||||
const Expr *rhs = expr->getRHS();
|
||||
const QualType lhsType = lhs->getType();
|
||||
const QualType rhsType = rhs->getType();
|
||||
|
||||
const auto selectOpcode = [](const QualType ty) {
|
||||
return TypeTranslator::isMx1MatrixType(ty) ||
|
||||
TypeTranslator::is1xNMatrixType(ty)
|
||||
? spv::Op::OpVectorTimesScalar
|
||||
: spv::Op::OpMatrixTimesScalar;
|
||||
};
|
||||
|
||||
// Multiplying a float matrix with a float scalar will be represented in
|
||||
// AST via a binary operation with two float matrices as operands; one of
|
||||
// the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
|
||||
|
||||
// matrix * scalar
|
||||
if (hlsl::IsHLSLMatType(lhsType)) {
|
||||
if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
|
||||
if (cast->getCastKind() == CK_HLSLMatrixSplat) {
|
||||
const uint32_t matType = typeTranslator.translateType(expr->getType());
|
||||
const spv::Op opcode = selectOpcode(lhsType);
|
||||
if (isa<CompoundAssignOperator>(expr)) {
|
||||
uint32_t lhsPtr = 0;
|
||||
const uint32_t result =
|
||||
processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
|
||||
matType, &lhsPtr, opcode);
|
||||
return processAssignment(lhs, result, true, lhsPtr);
|
||||
} else {
|
||||
return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
|
||||
matType, nullptr, opcode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scalar * matrix
|
||||
if (hlsl::IsHLSLMatType(rhsType)) {
|
||||
if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
|
||||
if (cast->getCastKind() == CK_HLSLMatrixSplat) {
|
||||
const uint32_t matType = typeTranslator.translateType(expr->getType());
|
||||
const spv::Op opcode = selectOpcode(rhsType);
|
||||
// We need to switch the positions of lhs and rhs here because
|
||||
// OpMatrixTimesScalar requires the first operand to be a matrix and
|
||||
// the second to be a scalar.
|
||||
return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
|
||||
matType, nullptr, opcode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
|
||||
const uint32_t rhs) {
|
||||
// Assigning to a vector swizzling lhs is tricky if we are neither
|
||||
|
|
|
@ -158,6 +158,11 @@ private:
|
|||
/// floatN * float.
|
||||
uint32_t tryToGenFloatVectorScale(const BinaryOperator *expr);
|
||||
|
||||
/// Translates a floatMxN * float multiplication into SPIR-V instructions and
|
||||
/// returns the <result-id>. Returns 0 if the given binary operation is not
|
||||
/// floatMxN * float.
|
||||
uint32_t tryToGenFloatMatrixScale(const BinaryOperator *expr);
|
||||
|
||||
/// Tries to emit instructions for assigning to the given vector element
|
||||
/// accessing expression. Returns 0 if the trial fails and no instructions
|
||||
/// are generated.
|
||||
|
|
|
@ -155,6 +155,16 @@ bool TypeTranslator::is1xNMatrixType(QualType type) {
|
|||
return rowCount == 1 && colCount > 1;
|
||||
}
|
||||
|
||||
bool TypeTranslator::isMx1MatrixType(QualType type) {
|
||||
if (!hlsl::IsHLSLMatType(type))
|
||||
return false;
|
||||
|
||||
uint32_t rowCount = 0, colCount = 0;
|
||||
hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
|
||||
|
||||
return rowCount > 1 && colCount == 1;
|
||||
}
|
||||
|
||||
/// Returns true if the given type is a SPIR-V acceptable matrix type, i.e.,
|
||||
/// with floating point elements and greater than 1 row and column counts.
|
||||
bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {
|
||||
|
|
|
@ -48,6 +48,9 @@ public:
|
|||
/// \brief Returns true if the givne type is a 1xN (N > 1) matrix type.
|
||||
static bool is1xNMatrixType(QualType type);
|
||||
|
||||
/// \brief Returns true if the givne type is a Mx1 (M > 1) matrix type.
|
||||
static bool isMx1MatrixType(QualType type);
|
||||
|
||||
/// \brief Returns true if the given type is a SPIR-V acceptable matrix type,
|
||||
/// i.e., with floating point elements and greater than 1 row and column
|
||||
/// counts.
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
// Run: %dxc -T vs_6_0 -E main
|
||||
|
||||
// TODO: matrix *= scalar
|
||||
|
||||
void main() {
|
||||
// CHECK-LABEL: %bb_entry = OpLabel
|
||||
|
||||
|
@ -11,18 +9,70 @@ void main() {
|
|||
int3 c;
|
||||
int t;
|
||||
|
||||
float1 e;
|
||||
int1 g;
|
||||
|
||||
float2x3 i;
|
||||
float1x3 k;
|
||||
float2x1 m;
|
||||
float1x1 o;
|
||||
|
||||
// Use OpVectorTimesScalar for floatN * float
|
||||
// CHECK: [[s4:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[a4:%\d+]] = OpLoad %v4float %a
|
||||
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s4]]
|
||||
// CHECK: [[s0:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[a0:%\d+]] = OpLoad %v4float %a
|
||||
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a0]] [[s0]]
|
||||
// CHECK-NEXT: OpStore %a [[mul0]]
|
||||
a *= s;
|
||||
|
||||
// Use normal OpCompositeConstruct and OpIMul for intN * int
|
||||
// CHECK-NEXT: [[t0:%\d+]] = OpLoad %int %t
|
||||
// CHECK-NEXT: [[cc10:%\d+]] = OpCompositeConstruct %v3int [[t0]] [[t0]] [[t0]]
|
||||
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3int [[t0]] [[t0]] [[t0]]
|
||||
// CHECK-NEXT: [[c0:%\d+]] = OpLoad %v3int %c
|
||||
// CHECK-NEXT: [[mul2:%\d+]] = OpIMul %v3int [[c0]] [[cc10]]
|
||||
// CHECK-NEXT: [[mul2:%\d+]] = OpIMul %v3int [[c0]] [[cc0]]
|
||||
// CHECK-NEXT: OpStore %c [[mul2]]
|
||||
c *= t;
|
||||
|
||||
// Vector of size 1
|
||||
// CHECK-NEXT: [[s2:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[e0:%\d+]] = OpLoad %float %e
|
||||
// CHECK-NEXT: [[mul4:%\d+]] = OpFMul %float [[e0]] [[s2]]
|
||||
// CHECK-NEXT: OpStore %e [[mul4]]
|
||||
e *= s;
|
||||
// CHECK-NEXT: [[t2:%\d+]] = OpLoad %int %t
|
||||
// CHECK-NEXT: [[g0:%\d+]] = OpLoad %int %g
|
||||
// CHECK-NEXT: [[mul6:%\d+]] = OpIMul %int [[g0]] [[t2]]
|
||||
// CHECK-NEXT: OpStore %g [[mul6]]
|
||||
g *= t;
|
||||
|
||||
// Use OpMatrixTimesScalar for floatMxN * float
|
||||
// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[i0:%\d+]] = OpLoad %mat2v3float %i
|
||||
// CHECK-NEXT: [[mul8:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i0]] [[s4]]
|
||||
// CHECK-NEXT: OpStore %i [[mul8]]
|
||||
i *= s;
|
||||
|
||||
// Use OpVectorTimesScalar for float1xN * float
|
||||
// Sadly, the AST is constructed differently for 'float1xN *= float' cases.
|
||||
// So we are not able generate an OpVectorTimesScalar here.
|
||||
// TODO: Minor issue. Fix this later maybe.
|
||||
// CHECK-NEXT: [[s6:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v3float [[s6]] [[s6]] [[s6]]
|
||||
// CHECK-NEXT: [[k0:%\d+]] = OpLoad %v3float %k
|
||||
// CHECK-NEXT: [[mul10:%\d+]] = OpFMul %v3float [[k0]] [[cc1]]
|
||||
// CHECK-NEXT: OpStore %k [[mul10]]
|
||||
k *= s;
|
||||
|
||||
// Use OpVectorTimesScalar for floatMx1 * float
|
||||
// CHECK-NEXT: [[s8:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[m0:%\d+]] = OpLoad %v2float %m
|
||||
// CHECK-NEXT: [[mul12:%\d+]] = OpVectorTimesScalar %v2float [[m0]] [[s8]]
|
||||
// CHECK-NEXT: OpStore %m [[mul12]]
|
||||
m *= s;
|
||||
|
||||
// Matrix of size 1x1
|
||||
// CHECK-NEXT: [[s10:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
|
||||
// CHECK-NEXT: [[mul14:%\d+]] = OpFMul %float [[o0]] [[s10]]
|
||||
// CHECK-NEXT: OpStore %o [[mul14]]
|
||||
o *= s;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
// Run: %dxc -T vs_6_0 -E main
|
||||
|
||||
// TODO: matrix * scalar
|
||||
|
||||
void main() {
|
||||
// CHECK-LABEL: %bb_entry = OpLabel
|
||||
|
||||
|
@ -11,15 +9,23 @@ void main() {
|
|||
int3 c, d;
|
||||
int t;
|
||||
|
||||
float1 e, f;
|
||||
int1 g, h;
|
||||
|
||||
float2x3 i, j;
|
||||
float1x3 k, l;
|
||||
float2x1 m, n;
|
||||
float1x1 o, p;
|
||||
|
||||
// Use OpVectorTimesScalar for floatN * float
|
||||
// CHECK: [[a4:%\d+]] = OpLoad %v4float %a
|
||||
// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s4]]
|
||||
// CHECK-NEXT: [[s0:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s0]]
|
||||
// CHECK-NEXT: OpStore %b [[mul0]]
|
||||
b = a * s;
|
||||
// CHECK-NEXT: [[a5:%\d+]] = OpLoad %v4float %a
|
||||
// CHECK-NEXT: [[s5:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul1:%\d+]] = OpVectorTimesScalar %v4float [[a5]] [[s5]]
|
||||
// CHECK-NEXT: [[s1:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul1:%\d+]] = OpVectorTimesScalar %v4float [[a5]] [[s1]]
|
||||
// CHECK-NEXT: OpStore %b [[mul1]]
|
||||
b = s * a;
|
||||
|
||||
|
@ -36,4 +42,74 @@ void main() {
|
|||
// CHECK-NEXT: [[mul3:%\d+]] = OpIMul %v3int [[cc11]] [[c1]]
|
||||
// CHECK-NEXT: OpStore %d [[mul3]]
|
||||
d = t * c;
|
||||
|
||||
// Vector of size 1
|
||||
// CHECK-NEXT: [[e0:%\d+]] = OpLoad %float %e
|
||||
// CHECK-NEXT: [[s2:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul4:%\d+]] = OpFMul %float [[e0]] [[s2]]
|
||||
// CHECK-NEXT: OpStore %f [[mul4]]
|
||||
f = e * s;
|
||||
// CHECK-NEXT: [[s3:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[e1:%\d+]] = OpLoad %float %e
|
||||
// CHECK-NEXT: [[mul5:%\d+]] = OpFMul %float [[s3]] [[e1]]
|
||||
// CHECK-NEXT: OpStore %f [[mul5]]
|
||||
f = s * e;
|
||||
// CHECK-NEXT: [[g0:%\d+]] = OpLoad %int %g
|
||||
// CHECK-NEXT: [[t2:%\d+]] = OpLoad %int %t
|
||||
// CHECK-NEXT: [[mul6:%\d+]] = OpIMul %int [[g0]] [[t2]]
|
||||
// CHECK-NEXT: OpStore %h [[mul6]]
|
||||
h = g * t;
|
||||
// CHECK-NEXT: [[t3:%\d+]] = OpLoad %int %t
|
||||
// CHECK-NEXT: [[g1:%\d+]] = OpLoad %int %g
|
||||
// CHECK-NEXT: [[mul7:%\d+]] = OpIMul %int [[t3]] [[g1]]
|
||||
// CHECK-NEXT: OpStore %h [[mul7]]
|
||||
h = t * g;
|
||||
|
||||
// Use OpMatrixTimesScalar for floatMxN * float
|
||||
// CHECK-NEXT: [[i0:%\d+]] = OpLoad %mat2v3float %i
|
||||
// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul8:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i0]] [[s4]]
|
||||
// CHECK-NEXT: OpStore %j [[mul8]]
|
||||
j = i * s;
|
||||
// CHECK-NEXT: [[i1:%\d+]] = OpLoad %mat2v3float %i
|
||||
// CHECK-NEXT: [[s5:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul9:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i1]] [[s5]]
|
||||
// CHECK-NEXT: OpStore %j [[mul9]]
|
||||
j = s * i;
|
||||
|
||||
// Use OpVectorTimesScalar for float1xN * float
|
||||
// CHECK-NEXT: [[k0:%\d+]] = OpLoad %v3float %k
|
||||
// CHECK-NEXT: [[s6:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul10:%\d+]] = OpVectorTimesScalar %v3float [[k0]] [[s6]]
|
||||
// CHECK-NEXT: OpStore %l [[mul10]]
|
||||
l = k * s;
|
||||
// CHECK-NEXT: [[k1:%\d+]] = OpLoad %v3float %k
|
||||
// CHECK-NEXT: [[s7:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul11:%\d+]] = OpVectorTimesScalar %v3float [[k1]] [[s7]]
|
||||
// CHECK-NEXT: OpStore %l [[mul11]]
|
||||
l = s * k;
|
||||
|
||||
// Use OpVectorTimesScalar for floatMx1 * float
|
||||
// CHECK-NEXT: [[m0:%\d+]] = OpLoad %v2float %m
|
||||
// CHECK-NEXT: [[s8:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul12:%\d+]] = OpVectorTimesScalar %v2float [[m0]] [[s8]]
|
||||
// CHECK-NEXT: OpStore %n [[mul12]]
|
||||
n = m * s;
|
||||
// CHECK-NEXT: [[m1:%\d+]] = OpLoad %v2float %m
|
||||
// CHECK-NEXT: [[s9:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul13:%\d+]] = OpVectorTimesScalar %v2float [[m1]] [[s9]]
|
||||
// CHECK-NEXT: OpStore %n [[mul13]]
|
||||
n = s * m;
|
||||
|
||||
// Matrix of size 1x1
|
||||
// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
|
||||
// CHECK-NEXT: [[s10:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[mul14:%\d+]] = OpFMul %float [[o0]] [[s10]]
|
||||
// CHECK-NEXT: OpStore %p [[mul14]]
|
||||
p = o * s;
|
||||
// CHECK-NEXT: [[s11:%\d+]] = OpLoad %float %s
|
||||
// CHECK-NEXT: [[o1:%\d+]] = OpLoad %float %o
|
||||
// CHECK-NEXT: [[mul15:%\d+]] = OpFMul %float [[s11]] [[o1]]
|
||||
// CHECK-NEXT: OpStore %p [[mul15]]
|
||||
p = s * o;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче