[spirv] Optimize floating point matrix scaling codegen (#525)

SPIR-V has a specific OpMatrixTimesScalar for scaling floating
point matrices.
This commit is contained in:
Lei Zhang 2017-08-08 16:18:26 -04:00 коммит произвёл David Peixotto
Родитель 50517e7f34
Коммит 76796801b8
6 изменённых файлов: 242 добавлений и 16 удалений

Просмотреть файл

@ -122,6 +122,18 @@ bool isCompoundAssignment(BinaryOperatorKind opcode) {
}
}
bool isSpirvMatrixOp(spv::Op opcode) {
switch (opcode) {
case spv::Op::OpMatrixTimesMatrix:
case spv::Op::OpMatrixTimesVector:
case spv::Op::OpMatrixTimesScalar:
return true;
default:
break;
}
return false;
}
} // namespace
SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
@ -731,8 +743,10 @@ uint32_t SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
if (opcode == BO_Assign)
return processAssignment(expr->getLHS(), doExpr(expr->getRHS()), false);
// Try to optimize floatN * float case
// Try to optimize floatMxN * float and floatN * float case
if (opcode == BO_Mul) {
if (const uint32_t result = tryToGenFloatMatrixScale(expr))
return result;
if (const uint32_t result = tryToGenFloatVectorScale(expr))
return result;
}
@ -940,8 +954,10 @@ uint32_t
SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
const auto opcode = expr->getOpcode();
// Try to optimize floatN *= float case
// Try to optimize floatMxN *= float and floatN *= float case
if (opcode == BO_MulAssign) {
if (const uint32_t result = tryToGenFloatMatrixScale(expr))
return result;
if (const uint32_t result = tryToGenFloatVectorScale(expr))
return result;
}
@ -1370,7 +1386,11 @@ uint32_t SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
const uint32_t resultType,
uint32_t *lhsResultId,
const spv::Op mandateGenOpcode) {
if (TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
// If the operands are of matrix type, we need to dispatch the operation
// onto each element vector iff the operands are not degenerated matrices
// and we don't have a matrix specific SPIR-V instruction for the operation.
if (!isSpirvMatrixOp(mandateGenOpcode) &&
TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
return processMatrixBinaryOp(lhs, rhs, opcode);
}
@ -1647,6 +1667,68 @@ uint32_t SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
return 0;
}
uint32_t SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
const QualType type = expr->getType();
// We can only translate floatMxN * float into OpMatrixTimesScalar.
// So the result type must be floatMxN.
if (!hlsl::IsHLSLMatType(type) ||
!hlsl::GetHLSLMatElementType(type)->isFloatingType())
return 0;
const Expr *lhs = expr->getLHS();
const Expr *rhs = expr->getRHS();
const QualType lhsType = lhs->getType();
const QualType rhsType = rhs->getType();
const auto selectOpcode = [](const QualType ty) {
return TypeTranslator::isMx1MatrixType(ty) ||
TypeTranslator::is1xNMatrixType(ty)
? spv::Op::OpVectorTimesScalar
: spv::Op::OpMatrixTimesScalar;
};
// Multiplying a float matrix with a float scalar will be represented in
// AST via a binary operation with two float matrices as operands; one of
// the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
// matrix * scalar
if (hlsl::IsHLSLMatType(lhsType)) {
if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
if (cast->getCastKind() == CK_HLSLMatrixSplat) {
const uint32_t matType = typeTranslator.translateType(expr->getType());
const spv::Op opcode = selectOpcode(lhsType);
if (isa<CompoundAssignOperator>(expr)) {
uint32_t lhsPtr = 0;
const uint32_t result =
processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
matType, &lhsPtr, opcode);
return processAssignment(lhs, result, true, lhsPtr);
} else {
return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
matType, nullptr, opcode);
}
}
}
}
// scalar * matrix
if (hlsl::IsHLSLMatType(rhsType)) {
if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
if (cast->getCastKind() == CK_HLSLMatrixSplat) {
const uint32_t matType = typeTranslator.translateType(expr->getType());
const spv::Op opcode = selectOpcode(rhsType);
// We need to switch the positions of lhs and rhs here because
// OpMatrixTimesScalar requires the first operand to be a matrix and
// the second to be a scalar.
return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
matType, nullptr, opcode);
}
}
}
return 0;
}
uint32_t SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
const uint32_t rhs) {
// Assigning to a vector swizzling lhs is tricky if we are neither

Просмотреть файл

@ -158,6 +158,11 @@ private:
/// floatN * float.
uint32_t tryToGenFloatVectorScale(const BinaryOperator *expr);
/// Translates a floatMxN * float multiplication into SPIR-V instructions and
/// returns the <result-id>. Returns 0 if the given binary operation is not
/// floatMxN * float.
uint32_t tryToGenFloatMatrixScale(const BinaryOperator *expr);
/// Tries to emit instructions for assigning to the given vector element
/// accessing expression. Returns 0 if the trial fails and no instructions
/// are generated.

Просмотреть файл

@ -155,6 +155,16 @@ bool TypeTranslator::is1xNMatrixType(QualType type) {
return rowCount == 1 && colCount > 1;
}
bool TypeTranslator::isMx1MatrixType(QualType type) {
if (!hlsl::IsHLSLMatType(type))
return false;
uint32_t rowCount = 0, colCount = 0;
hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
return rowCount > 1 && colCount == 1;
}
/// Returns true if the given type is a SPIR-V acceptable matrix type, i.e.,
/// with floating point elements and greater than 1 row and column counts.
bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {

Просмотреть файл

@ -48,6 +48,9 @@ public:
/// \brief Returns true if the givne type is a 1xN (N > 1) matrix type.
static bool is1xNMatrixType(QualType type);
/// \brief Returns true if the givne type is a Mx1 (M > 1) matrix type.
static bool isMx1MatrixType(QualType type);
/// \brief Returns true if the given type is a SPIR-V acceptable matrix type,
/// i.e., with floating point elements and greater than 1 row and column
/// counts.

Просмотреть файл

@ -1,7 +1,5 @@
// Run: %dxc -T vs_6_0 -E main
// TODO: matrix *= scalar
void main() {
// CHECK-LABEL: %bb_entry = OpLabel
@ -11,18 +9,70 @@ void main() {
int3 c;
int t;
float1 e;
int1 g;
float2x3 i;
float1x3 k;
float2x1 m;
float1x1 o;
// Use OpVectorTimesScalar for floatN * float
// CHECK: [[s4:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[a4:%\d+]] = OpLoad %v4float %a
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s4]]
// CHECK: [[s0:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[a0:%\d+]] = OpLoad %v4float %a
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a0]] [[s0]]
// CHECK-NEXT: OpStore %a [[mul0]]
a *= s;
// Use normal OpCompositeConstruct and OpIMul for intN * int
// CHECK-NEXT: [[t0:%\d+]] = OpLoad %int %t
// CHECK-NEXT: [[cc10:%\d+]] = OpCompositeConstruct %v3int [[t0]] [[t0]] [[t0]]
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3int [[t0]] [[t0]] [[t0]]
// CHECK-NEXT: [[c0:%\d+]] = OpLoad %v3int %c
// CHECK-NEXT: [[mul2:%\d+]] = OpIMul %v3int [[c0]] [[cc10]]
// CHECK-NEXT: [[mul2:%\d+]] = OpIMul %v3int [[c0]] [[cc0]]
// CHECK-NEXT: OpStore %c [[mul2]]
c *= t;
// Vector of size 1
// CHECK-NEXT: [[s2:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[e0:%\d+]] = OpLoad %float %e
// CHECK-NEXT: [[mul4:%\d+]] = OpFMul %float [[e0]] [[s2]]
// CHECK-NEXT: OpStore %e [[mul4]]
e *= s;
// CHECK-NEXT: [[t2:%\d+]] = OpLoad %int %t
// CHECK-NEXT: [[g0:%\d+]] = OpLoad %int %g
// CHECK-NEXT: [[mul6:%\d+]] = OpIMul %int [[g0]] [[t2]]
// CHECK-NEXT: OpStore %g [[mul6]]
g *= t;
// Use OpMatrixTimesScalar for floatMxN * float
// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[i0:%\d+]] = OpLoad %mat2v3float %i
// CHECK-NEXT: [[mul8:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i0]] [[s4]]
// CHECK-NEXT: OpStore %i [[mul8]]
i *= s;
// Use OpVectorTimesScalar for float1xN * float
// Sadly, the AST is constructed differently for 'float1xN *= float' cases.
// So we are not able generate an OpVectorTimesScalar here.
// TODO: Minor issue. Fix this later maybe.
// CHECK-NEXT: [[s6:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v3float [[s6]] [[s6]] [[s6]]
// CHECK-NEXT: [[k0:%\d+]] = OpLoad %v3float %k
// CHECK-NEXT: [[mul10:%\d+]] = OpFMul %v3float [[k0]] [[cc1]]
// CHECK-NEXT: OpStore %k [[mul10]]
k *= s;
// Use OpVectorTimesScalar for floatMx1 * float
// CHECK-NEXT: [[s8:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[m0:%\d+]] = OpLoad %v2float %m
// CHECK-NEXT: [[mul12:%\d+]] = OpVectorTimesScalar %v2float [[m0]] [[s8]]
// CHECK-NEXT: OpStore %m [[mul12]]
m *= s;
// Matrix of size 1x1
// CHECK-NEXT: [[s10:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
// CHECK-NEXT: [[mul14:%\d+]] = OpFMul %float [[o0]] [[s10]]
// CHECK-NEXT: OpStore %o [[mul14]]
o *= s;
}

Просмотреть файл

@ -1,7 +1,5 @@
// Run: %dxc -T vs_6_0 -E main
// TODO: matrix * scalar
void main() {
// CHECK-LABEL: %bb_entry = OpLabel
@ -11,15 +9,23 @@ void main() {
int3 c, d;
int t;
float1 e, f;
int1 g, h;
float2x3 i, j;
float1x3 k, l;
float2x1 m, n;
float1x1 o, p;
// Use OpVectorTimesScalar for floatN * float
// CHECK: [[a4:%\d+]] = OpLoad %v4float %a
// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s4]]
// CHECK-NEXT: [[s0:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s0]]
// CHECK-NEXT: OpStore %b [[mul0]]
b = a * s;
// CHECK-NEXT: [[a5:%\d+]] = OpLoad %v4float %a
// CHECK-NEXT: [[s5:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul1:%\d+]] = OpVectorTimesScalar %v4float [[a5]] [[s5]]
// CHECK-NEXT: [[s1:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul1:%\d+]] = OpVectorTimesScalar %v4float [[a5]] [[s1]]
// CHECK-NEXT: OpStore %b [[mul1]]
b = s * a;
@ -36,4 +42,74 @@ void main() {
// CHECK-NEXT: [[mul3:%\d+]] = OpIMul %v3int [[cc11]] [[c1]]
// CHECK-NEXT: OpStore %d [[mul3]]
d = t * c;
// Vector of size 1
// CHECK-NEXT: [[e0:%\d+]] = OpLoad %float %e
// CHECK-NEXT: [[s2:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul4:%\d+]] = OpFMul %float [[e0]] [[s2]]
// CHECK-NEXT: OpStore %f [[mul4]]
f = e * s;
// CHECK-NEXT: [[s3:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[e1:%\d+]] = OpLoad %float %e
// CHECK-NEXT: [[mul5:%\d+]] = OpFMul %float [[s3]] [[e1]]
// CHECK-NEXT: OpStore %f [[mul5]]
f = s * e;
// CHECK-NEXT: [[g0:%\d+]] = OpLoad %int %g
// CHECK-NEXT: [[t2:%\d+]] = OpLoad %int %t
// CHECK-NEXT: [[mul6:%\d+]] = OpIMul %int [[g0]] [[t2]]
// CHECK-NEXT: OpStore %h [[mul6]]
h = g * t;
// CHECK-NEXT: [[t3:%\d+]] = OpLoad %int %t
// CHECK-NEXT: [[g1:%\d+]] = OpLoad %int %g
// CHECK-NEXT: [[mul7:%\d+]] = OpIMul %int [[t3]] [[g1]]
// CHECK-NEXT: OpStore %h [[mul7]]
h = t * g;
// Use OpMatrixTimesScalar for floatMxN * float
// CHECK-NEXT: [[i0:%\d+]] = OpLoad %mat2v3float %i
// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul8:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i0]] [[s4]]
// CHECK-NEXT: OpStore %j [[mul8]]
j = i * s;
// CHECK-NEXT: [[i1:%\d+]] = OpLoad %mat2v3float %i
// CHECK-NEXT: [[s5:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul9:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i1]] [[s5]]
// CHECK-NEXT: OpStore %j [[mul9]]
j = s * i;
// Use OpVectorTimesScalar for float1xN * float
// CHECK-NEXT: [[k0:%\d+]] = OpLoad %v3float %k
// CHECK-NEXT: [[s6:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul10:%\d+]] = OpVectorTimesScalar %v3float [[k0]] [[s6]]
// CHECK-NEXT: OpStore %l [[mul10]]
l = k * s;
// CHECK-NEXT: [[k1:%\d+]] = OpLoad %v3float %k
// CHECK-NEXT: [[s7:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul11:%\d+]] = OpVectorTimesScalar %v3float [[k1]] [[s7]]
// CHECK-NEXT: OpStore %l [[mul11]]
l = s * k;
// Use OpVectorTimesScalar for floatMx1 * float
// CHECK-NEXT: [[m0:%\d+]] = OpLoad %v2float %m
// CHECK-NEXT: [[s8:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul12:%\d+]] = OpVectorTimesScalar %v2float [[m0]] [[s8]]
// CHECK-NEXT: OpStore %n [[mul12]]
n = m * s;
// CHECK-NEXT: [[m1:%\d+]] = OpLoad %v2float %m
// CHECK-NEXT: [[s9:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul13:%\d+]] = OpVectorTimesScalar %v2float [[m1]] [[s9]]
// CHECK-NEXT: OpStore %n [[mul13]]
n = s * m;
// Matrix of size 1x1
// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
// CHECK-NEXT: [[s10:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[mul14:%\d+]] = OpFMul %float [[o0]] [[s10]]
// CHECK-NEXT: OpStore %p [[mul14]]
p = o * s;
// CHECK-NEXT: [[s11:%\d+]] = OpLoad %float %s
// CHECK-NEXT: [[o1:%\d+]] = OpLoad %float %o
// CHECK-NEXT: [[mul15:%\d+]] = OpFMul %float [[s11]] [[o1]]
// CHECK-NEXT: OpStore %p [[mul15]]
p = s * o;
}