[spirv] Add support for matrix swizzling (#514)
This PR handles two formats for indexing matrices: _mXX and _XX. The operator[] format will be handled in the next PR.
This commit is contained in:
Родитель
660ab29b70
Коммит
9a2c5cc89d
|
@ -207,7 +207,6 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
// TODO: enlarge the queue upon seeing a function call.
|
||||
|
||||
// Translate all functions reachable from the entry function.
|
||||
// The queue can grow in the meanwhile; so need to keep evaluating
|
||||
|
@ -894,6 +893,10 @@ public:
|
|||
return doHLSLVectorElementExpr(vecElemExpr);
|
||||
}
|
||||
|
||||
if (const auto *matElemExpr = dyn_cast<ExtMatrixElementExpr>(expr)) {
|
||||
return doExtMatrixElementExpr(matElemExpr);
|
||||
}
|
||||
|
||||
if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
|
||||
return doCallExpr(funcCall);
|
||||
}
|
||||
|
@ -1051,6 +1054,70 @@ public:
|
|||
return rhs;
|
||||
}
|
||||
|
||||
/// Tries to emit instructions for assigning to the given matrix element
|
||||
/// accessing expression. Returns 0 if the trial fails and no instructions
|
||||
/// are generated.
|
||||
uint32_t tryToAssignToMatrixElements(const Expr *lhs, uint32_t rhs) {
|
||||
const auto *lhsExpr = dyn_cast<ExtMatrixElementExpr>(lhs);
|
||||
if (!lhsExpr)
|
||||
return 0;
|
||||
|
||||
const Expr *baseMat = lhsExpr->getBase();
|
||||
const uint32_t base = doExpr(baseMat);
|
||||
const QualType elemType = hlsl::GetHLSLMatElementType(baseMat->getType());
|
||||
const uint32_t elemTypeId = typeTranslator.translateType(elemType);
|
||||
|
||||
uint32_t rowCount = 0, colCount = 0;
|
||||
hlsl::GetHLSLMatRowColCount(baseMat->getType(), rowCount, colCount);
|
||||
|
||||
// For each lhs element written to:
|
||||
// 1. Extract the corresponding rhs element using OpCompositeExtract
|
||||
// 2. Create access chain for the lhs element using OpAccessChain
|
||||
// 3. Write using OpStore
|
||||
|
||||
const auto accessor = lhsExpr->getEncodedElementAccess();
|
||||
for (uint32_t i = 0; i < accessor.Count; ++i) {
|
||||
uint32_t row = 0, col = 0;
|
||||
accessor.GetPosition(i, &row, &col);
|
||||
|
||||
llvm::SmallVector<uint32_t, 2> indices;
|
||||
// If the matrix only has one row/column, we are indexing into a vector
|
||||
// then. Only one index is needed for such cases.
|
||||
if (rowCount > 1)
|
||||
indices.push_back(row);
|
||||
if (colCount > 1)
|
||||
indices.push_back(col);
|
||||
|
||||
for (uint32_t i = 0; i < indices.size(); ++i)
|
||||
indices[i] = theBuilder.getConstantInt32(indices[i]);
|
||||
|
||||
// If we are writing to only one element, the rhs should already be a
|
||||
// scalar value.
|
||||
uint32_t rhsElem = rhs;
|
||||
if (accessor.Count > 1)
|
||||
rhsElem = theBuilder.createCompositeExtract(elemTypeId, rhs, {i});
|
||||
|
||||
// TODO: select storage type based on the underlying variable
|
||||
const uint32_t ptrType =
|
||||
theBuilder.getPointerType(elemTypeId, spv::StorageClass::Function);
|
||||
|
||||
// If the lhs is actually a matrix of size 1x1, we don't need the access
|
||||
// chain. base is already the dest pointer.
|
||||
uint32_t lhsElemPtr = base;
|
||||
if (!indices.empty()) {
|
||||
// Load the element via access chain
|
||||
lhsElemPtr = theBuilder.createAccessChain(ptrType, base, indices);
|
||||
}
|
||||
|
||||
theBuilder.createStore(lhsElemPtr, rhsElem);
|
||||
}
|
||||
|
||||
// TODO: OK, this return value is incorrect for compound assignments, for
|
||||
// which cases we should return lvalues. Should at least emit errors if
|
||||
// this return value is used (can be checked via ASTContext.getParents).
|
||||
return rhs;
|
||||
}
|
||||
|
||||
/// Generates the necessary instructions for assigning rhs to lhs. If lhsPtr
|
||||
/// is not zero, it will be used as the pointer from lhs instead of evaluating
|
||||
/// lhs again.
|
||||
|
@ -1060,6 +1127,10 @@ public:
|
|||
if (const uint32_t result = tryToAssignToVectorElements(lhs, rhs)) {
|
||||
return result;
|
||||
}
|
||||
// Assigning to matrix swizzling should be handled differently.
|
||||
if (const uint32_t result = tryToAssignToMatrixElements(lhs, rhs)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Normal assignment procedure
|
||||
if (lhsPtr == 0)
|
||||
|
@ -1474,6 +1545,63 @@ public:
|
|||
return theBuilder.createVectorShuffle(type, baseVal, baseVal, selectors);
|
||||
}
|
||||
|
||||
uint32_t doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
|
||||
const Expr *baseExpr = expr->getBase();
|
||||
const uint32_t base = doExpr(baseExpr);
|
||||
const auto accessor = expr->getEncodedElementAccess();
|
||||
const uint32_t elemType = typeTranslator.translateType(
|
||||
hlsl::GetHLSLMatElementType(baseExpr->getType()));
|
||||
|
||||
uint32_t rowCount = 0, colCount = 0;
|
||||
hlsl::GetHLSLMatRowColCount(baseExpr->getType(), rowCount, colCount);
|
||||
|
||||
// Construct a temporary vector out of all elements accessed:
|
||||
// 1. Create access chain for each element using OpAccessChain
|
||||
// 2. Load each element using OpLoad
|
||||
// 3. Create the vector using OpCompositeConstruct
|
||||
|
||||
llvm::SmallVector<uint32_t, 4> elements;
|
||||
for (uint32_t i = 0; i < accessor.Count; ++i) {
|
||||
uint32_t row = 0, col = 0, elem = 0;
|
||||
accessor.GetPosition(i, &row, &col);
|
||||
|
||||
llvm::SmallVector<uint32_t, 2> indices;
|
||||
// If the matrix only have one row/column, we are indexing into a vector
|
||||
// then. Only one index is needed for such cases.
|
||||
if (rowCount > 1)
|
||||
indices.push_back(row);
|
||||
if (colCount > 1)
|
||||
indices.push_back(col);
|
||||
|
||||
if (baseExpr->isGLValue()) {
|
||||
for (uint32_t i = 0; i < indices.size(); ++i)
|
||||
indices[i] = theBuilder.getConstantInt32(indices[i]);
|
||||
|
||||
// TODO: select storage type based on the underlying variable
|
||||
const uint32_t ptrType =
|
||||
theBuilder.getPointerType(elemType, spv::StorageClass::Function);
|
||||
if (!indices.empty()) {
|
||||
// Load the element via access chain
|
||||
elem = theBuilder.createAccessChain(ptrType, base, indices);
|
||||
} else {
|
||||
// The matrix is of size 1x1. No need to use access chain, base should
|
||||
// be the source pointer.
|
||||
elem = base;
|
||||
}
|
||||
elem = theBuilder.createLoad(elemType, elem);
|
||||
} else { // e.g., (mat1 + mat2)._m11
|
||||
elem = theBuilder.createCompositeExtract(elemType, base, indices);
|
||||
}
|
||||
elements.push_back(elem);
|
||||
}
|
||||
|
||||
if (elements.size() == 1)
|
||||
return elements.front();
|
||||
|
||||
const uint32_t vecType = theBuilder.getVecType(elemType, elements.size());
|
||||
return theBuilder.createCompositeConstruct(vecType, elements);
|
||||
}
|
||||
|
||||
/// Returns true if the given expression will be translated into a vector
|
||||
/// shuffle instruction in SPIR-V.
|
||||
///
|
||||
|
@ -1522,11 +1650,13 @@ public:
|
|||
switch (expr->getCastKind()) {
|
||||
case CastKind::CK_LValueToRValue: {
|
||||
const uint32_t fromValue = doExpr(subExpr);
|
||||
if (isVectorShuffle(subExpr)) {
|
||||
// By reaching here, it means the vector element accessing operation is
|
||||
// an lvalue. If we generated a vector shuffle for it and trying to use
|
||||
// it as a rvalue, we cannot do the load here as normal. Need the upper
|
||||
// nodes in the AST tree to handle it properly.
|
||||
if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr)) {
|
||||
// By reaching here, it means the vector/matrix element accessing
|
||||
// operation is an lvalue. For vector element accessing, if we generated
|
||||
// a vector shuffle for it and trying to use it as a rvalue, we cannot
|
||||
// do the load here as normal. Need the upper nodes in the AST tree to
|
||||
// handle it properly. For matrix element accessing, load should have
|
||||
// already happened after creating access chain for each element.
|
||||
return fromValue;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
// Run: %dxc -T vs_6_0 -E main
|
||||
|
||||
void main() {
|
||||
// CHECK-LABEL: %bb_entry = OpLabel
|
||||
|
||||
float1x1 mat;
|
||||
float3 vec3;
|
||||
float2 vec2;
|
||||
float scalar;
|
||||
|
||||
// 1 element (from lvalue)
|
||||
// CHECK: [[load0:%\d+]] = OpLoad %float %mat
|
||||
// CHECK-NEXT: OpStore %scalar [[load0]]
|
||||
scalar = mat._m00; // Used as rvalue
|
||||
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
|
||||
// CHECK-NEXT: OpStore %mat [[load1]]
|
||||
mat._11 = scalar; // Used as lvalue
|
||||
|
||||
// >1 elements (from lvalue)
|
||||
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float %mat
|
||||
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float %mat
|
||||
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v2float [[load2]] [[load3]]
|
||||
// CHECK-NEXT: OpStore %vec2 [[cc0]]
|
||||
vec2 = mat._11_11; // Used as rvalue
|
||||
|
||||
// The following statements will trigger errors:
|
||||
// invalid format for vector swizzle
|
||||
// scalar = (mat + mat)._m00;
|
||||
// vec2 = (mat * mat)._11_11;
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
// Run: %dxc -T vs_6_0 -E main
|
||||
|
||||
void main() {
|
||||
// CHECK-LABEL: %bb_entry = OpLabel
|
||||
|
||||
float1x3 mat;
|
||||
float3 vec3;
|
||||
float2 vec2;
|
||||
float scalar;
|
||||
|
||||
// 1 element (from lvalue)
|
||||
// CHECK: [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
|
||||
// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
|
||||
// CHECK-NEXT: OpStore %scalar [[load0]]
|
||||
scalar = mat._m02; // Used as rvalue
|
||||
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
|
||||
// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
|
||||
// CHECK-NEXT: OpStore [[access1]] [[load1]]
|
||||
mat._12 = scalar; // Used as lvalue
|
||||
|
||||
// > 1 elements (from lvalue)
|
||||
// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
|
||||
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
|
||||
// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
|
||||
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
|
||||
// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
|
||||
// CHECK-NEXT: [[load4:%\d+]] = OpLoad %float [[access4]]
|
||||
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3float [[load2]] [[load3]] [[load4]]
|
||||
// CHECK-NEXT: OpStore %vec3 [[cc0]]
|
||||
vec3 = mat._11_13_12; // Used as rvalue
|
||||
// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v2float %vec2
|
||||
// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
|
||||
// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
|
||||
// CHECK-NEXT: OpStore [[access5]] [[ce0]]
|
||||
// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
|
||||
// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
|
||||
// CHECK-NEXT: OpStore [[access6]] [[ce1]]
|
||||
mat._m00_m02 = vec2; // Used as lvalue
|
||||
|
||||
// The following statements will trigger errors:
|
||||
// invalid format for vector swizzle
|
||||
// scalar = (mat + mat)._m02;
|
||||
// vec2 = (mat * mat)._11_12;
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
// Run: %dxc -T vs_6_0 -E main
|
||||
|
||||
void main() {
|
||||
// CHECK-LABEL: %bb_entry = OpLabel
|
||||
|
||||
float3x1 mat;
|
||||
float3 vec3;
|
||||
float2 vec2;
|
||||
float scalar;
|
||||
|
||||
// 1 element (from lvalue)
|
||||
// CHECK: [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
|
||||
// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
|
||||
// CHECK-NEXT: OpStore %scalar [[load0]]
|
||||
scalar = mat._m20; // Used as rvalue
|
||||
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
|
||||
// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
|
||||
// CHECK-NEXT: OpStore [[access1]] [[load1]]
|
||||
mat._21 = scalar; // Used as lvalue
|
||||
|
||||
// > 1 elements (from lvalue)
|
||||
// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
|
||||
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
|
||||
// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
|
||||
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
|
||||
// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
|
||||
// CHECK-NEXT: [[load4:%\d+]] = OpLoad %float [[access4]]
|
||||
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3float [[load2]] [[load3]] [[load4]]
|
||||
// CHECK-NEXT: OpStore %vec3 [[cc0]]
|
||||
vec3 = mat._11_31_21; // Used as rvalue
|
||||
// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v2float %vec2
|
||||
// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
|
||||
// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
|
||||
// CHECK-NEXT: OpStore [[access5]] [[ce0]]
|
||||
// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
|
||||
// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
|
||||
// CHECK-NEXT: OpStore [[access6]] [[ce1]]
|
||||
mat._m00_m20 = vec2; // Used as lvalue
|
||||
|
||||
// 1 element (from rvalue)
|
||||
// CHECK-NEXT: [[load5:%\d+]] = OpLoad %v3float %mat
|
||||
// CHECK-NEXT: [[load6:%\d+]] = OpLoad %v3float %mat
|
||||
// CHECK-NEXT: [[add0:%\d+]] = OpFAdd %v3float [[load5]] [[load6]]
|
||||
// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[add0]] 2
|
||||
// CHECK-NEXT: OpStore %scalar [[ce2]]
|
||||
// Codegen: construct a temporary vector first out of (mat + mat) and
|
||||
// then extract the value
|
||||
scalar = (mat + mat)._m20;
|
||||
|
||||
// > 1 element (from rvalue)
|
||||
// CHECK-NEXT: [[load7:%\d+]] = OpLoad %v3float %mat
|
||||
// CHECK-NEXT: [[load8:%\d+]] = OpLoad %v3float %mat
|
||||
// CHECK-NEXT: [[mul0:%\d+]] = OpFMul %v3float [[load7]] [[load8]]
|
||||
// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[mul0]] 0
|
||||
// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[mul0]] 1
|
||||
// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v2float [[ce3]] [[ce4]]
|
||||
// CHECK-NEXT: OpStore %vec2 [[cc1]]
|
||||
// Codegen: construct a temporary vector first out of (mat * mat) and
|
||||
// then extract the value
|
||||
vec2 = (mat * mat)._11_21;
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
// Run: %dxc -T vs_6_0 -E main
|
||||
|
||||
void main() {
|
||||
// CHECK-LABEL: %bb_entry = OpLabel
|
||||
|
||||
float2x3 mat;
|
||||
float3 vec3;
|
||||
float2 vec2;
|
||||
float scalar;
|
||||
|
||||
// 1 element (from lvalue)
|
||||
// CHECK: [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1 %int_2
|
||||
// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
|
||||
// CHECK-NEXT: OpStore %scalar [[load0]]
|
||||
scalar = mat._m12; // Used as rvalue
|
||||
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
|
||||
// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
|
||||
// CHECK-NEXT: OpStore [[access1]] [[load1]]
|
||||
mat._12 = scalar; // Used as lvalue
|
||||
|
||||
// >1 elements (from lvalue)
|
||||
// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
|
||||
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
|
||||
// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_2
|
||||
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
|
||||
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v2float [[load2]] [[load3]]
|
||||
// CHECK-NEXT: OpStore %vec2 [[cc0]]
|
||||
vec2 = mat._m01_m02; // Used as rvalue
|
||||
// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v3float %vec3
|
||||
// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
|
||||
// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1 %int_0
|
||||
// CHECK-NEXT: OpStore [[access4]] [[ce0]]
|
||||
// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
|
||||
// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
|
||||
// CHECK-NEXT: OpStore [[access5]] [[ce1]]
|
||||
// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[rhs0]] 2
|
||||
// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_0
|
||||
// CHECK-NEXT: OpStore [[access6]] [[ce2]]
|
||||
mat._21_12_11 = vec3; // Used as lvalue
|
||||
|
||||
// 1 element (from rvalue)
|
||||
// CHECK: [[cc1:%\d+]] = OpCompositeConstruct %mat2v3float {{%\d+}} {{%\d+}}
|
||||
// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[cc1]] 1 2
|
||||
// CHECK-NEXT: OpStore %scalar [[ce3]]
|
||||
// Codegen: construct a temporary matrix first out of (mat + mat) and
|
||||
// then extract the value
|
||||
scalar = (mat + mat)._m12;
|
||||
|
||||
// > 1 element (from rvalue)
|
||||
// CHECK: [[cc2:%\d+]] = OpCompositeConstruct %mat2v3float {{%\d+}} {{%\d+}}
|
||||
// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[cc2]] 0 1
|
||||
// CHECK-NEXT: [[ce5:%\d+]] = OpCompositeExtract %float [[cc2]] 0 2
|
||||
// CHECK-NEXT: [[cc3:%\d+]] = OpCompositeConstruct %v2float [[ce4]] [[ce5]]
|
||||
// CHECK-NEXT: OpStore %vec2 [[cc3]]
|
||||
// Codegen: construct a temporary matrix first out of (mat * mat) and
|
||||
// then extract the value
|
||||
vec2 = (mat * mat)._m01_m02;
|
||||
}
|
|
@ -153,6 +153,20 @@ TEST_F(FileTest, OpVectorSize1Swizzle) {
|
|||
runFileTest("op.vector.swizzle.size1.hlsl");
|
||||
}
|
||||
|
||||
// For matrix accessing operators
|
||||
TEST_F(FileTest, OpMatrixAccessMxN) {
|
||||
runFileTest("op.matrix.access.mxn.hlsl");
|
||||
}
|
||||
TEST_F(FileTest, OpMatrixAccessMx1) {
|
||||
runFileTest("op.matrix.access.mx1.hlsl");
|
||||
}
|
||||
TEST_F(FileTest, OpMatrixAccess1xN) {
|
||||
runFileTest("op.matrix.access.1xn.hlsl");
|
||||
}
|
||||
TEST_F(FileTest, OpMatrixAccess1x1) {
|
||||
runFileTest("op.matrix.access.1x1.hlsl");
|
||||
}
|
||||
|
||||
// For casting
|
||||
TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); }
|
||||
TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }
|
||||
|
|
Загрузка…
Ссылка в новой задаче