[spirv] Add support for matrix swizzling (#514)

This PR handles two formats for indexing matrices: _mXX and _XX.
The operator[] format will be handled in the next PR.
This commit is contained in:
Lei Zhang 2017-08-03 13:35:57 -04:00 коммит произвёл David Peixotto
Родитель 660ab29b70
Коммит 9a2c5cc89d
6 изменённых файлов: 343 добавлений и 6 удалений

Просмотреть файл

@ -207,7 +207,6 @@ public:
}
}
}
// TODO: enlarge the queue upon seeing a function call.
// Translate all functions reachable from the entry function.
// The queue can grow in the meanwhile; so need to keep evaluating
@ -894,6 +893,10 @@ public:
return doHLSLVectorElementExpr(vecElemExpr);
}
if (const auto *matElemExpr = dyn_cast<ExtMatrixElementExpr>(expr)) {
return doExtMatrixElementExpr(matElemExpr);
}
if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
return doCallExpr(funcCall);
}
@ -1051,6 +1054,70 @@ public:
return rhs;
}
/// Tries to emit instructions for assigning to the given matrix element
/// accessing expression. Returns 0 if the trial fails and no instructions
/// are generated.
uint32_t tryToAssignToMatrixElements(const Expr *lhs, uint32_t rhs) {
const auto *lhsExpr = dyn_cast<ExtMatrixElementExpr>(lhs);
if (!lhsExpr)
return 0;
const Expr *baseMat = lhsExpr->getBase();
const uint32_t base = doExpr(baseMat);
const QualType elemType = hlsl::GetHLSLMatElementType(baseMat->getType());
const uint32_t elemTypeId = typeTranslator.translateType(elemType);
uint32_t rowCount = 0, colCount = 0;
hlsl::GetHLSLMatRowColCount(baseMat->getType(), rowCount, colCount);
// For each lhs element written to:
// 1. Extract the corresponding rhs element using OpCompositeExtract
// 2. Create access chain for the lhs element using OpAccessChain
// 3. Write using OpStore
const auto accessor = lhsExpr->getEncodedElementAccess();
for (uint32_t i = 0; i < accessor.Count; ++i) {
uint32_t row = 0, col = 0;
accessor.GetPosition(i, &row, &col);
llvm::SmallVector<uint32_t, 2> indices;
// If the matrix only has one row/column, we are indexing into a vector
// then. Only one index is needed for such cases.
if (rowCount > 1)
indices.push_back(row);
if (colCount > 1)
indices.push_back(col);
for (uint32_t i = 0; i < indices.size(); ++i)
indices[i] = theBuilder.getConstantInt32(indices[i]);
// If we are writing to only one element, the rhs should already be a
// scalar value.
uint32_t rhsElem = rhs;
if (accessor.Count > 1)
rhsElem = theBuilder.createCompositeExtract(elemTypeId, rhs, {i});
// TODO: select storage type based on the underlying variable
const uint32_t ptrType =
theBuilder.getPointerType(elemTypeId, spv::StorageClass::Function);
// If the lhs is actually a matrix of size 1x1, we don't need the access
// chain. base is already the dest pointer.
uint32_t lhsElemPtr = base;
if (!indices.empty()) {
// Load the element via access chain
lhsElemPtr = theBuilder.createAccessChain(ptrType, base, indices);
}
theBuilder.createStore(lhsElemPtr, rhsElem);
}
// TODO: OK, this return value is incorrect for compound assignments, for
// which cases we should return lvalues. Should at least emit errors if
// this return value is used (can be checked via ASTContext.getParents).
return rhs;
}
/// Generates the necessary instructions for assigning rhs to lhs. If lhsPtr
/// is not zero, it will be used as the pointer from lhs instead of evaluating
/// lhs again.
@ -1060,6 +1127,10 @@ public:
if (const uint32_t result = tryToAssignToVectorElements(lhs, rhs)) {
return result;
}
// Assigning to matrix swizzling should be handled differently.
if (const uint32_t result = tryToAssignToMatrixElements(lhs, rhs)) {
return result;
}
// Normal assignment procedure
if (lhsPtr == 0)
@ -1474,6 +1545,63 @@ public:
return theBuilder.createVectorShuffle(type, baseVal, baseVal, selectors);
}
uint32_t doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
const Expr *baseExpr = expr->getBase();
const uint32_t base = doExpr(baseExpr);
const auto accessor = expr->getEncodedElementAccess();
const uint32_t elemType = typeTranslator.translateType(
hlsl::GetHLSLMatElementType(baseExpr->getType()));
uint32_t rowCount = 0, colCount = 0;
hlsl::GetHLSLMatRowColCount(baseExpr->getType(), rowCount, colCount);
// Construct a temporary vector out of all elements accessed:
// 1. Create access chain for each element using OpAccessChain
// 2. Load each element using OpLoad
// 3. Create the vector using OpCompositeConstruct
llvm::SmallVector<uint32_t, 4> elements;
for (uint32_t i = 0; i < accessor.Count; ++i) {
uint32_t row = 0, col = 0, elem = 0;
accessor.GetPosition(i, &row, &col);
llvm::SmallVector<uint32_t, 2> indices;
// If the matrix only have one row/column, we are indexing into a vector
// then. Only one index is needed for such cases.
if (rowCount > 1)
indices.push_back(row);
if (colCount > 1)
indices.push_back(col);
if (baseExpr->isGLValue()) {
for (uint32_t i = 0; i < indices.size(); ++i)
indices[i] = theBuilder.getConstantInt32(indices[i]);
// TODO: select storage type based on the underlying variable
const uint32_t ptrType =
theBuilder.getPointerType(elemType, spv::StorageClass::Function);
if (!indices.empty()) {
// Load the element via access chain
elem = theBuilder.createAccessChain(ptrType, base, indices);
} else {
// The matrix is of size 1x1. No need to use access chain, base should
// be the source pointer.
elem = base;
}
elem = theBuilder.createLoad(elemType, elem);
} else { // e.g., (mat1 + mat2)._m11
elem = theBuilder.createCompositeExtract(elemType, base, indices);
}
elements.push_back(elem);
}
if (elements.size() == 1)
return elements.front();
const uint32_t vecType = theBuilder.getVecType(elemType, elements.size());
return theBuilder.createCompositeConstruct(vecType, elements);
}
/// Returns true if the given expression will be translated into a vector
/// shuffle instruction in SPIR-V.
///
@ -1522,11 +1650,13 @@ public:
switch (expr->getCastKind()) {
case CastKind::CK_LValueToRValue: {
const uint32_t fromValue = doExpr(subExpr);
if (isVectorShuffle(subExpr)) {
// By reaching here, it means the vector element accessing operation is
// an lvalue. If we generated a vector shuffle for it and trying to use
// it as a rvalue, we cannot do the load here as normal. Need the upper
// nodes in the AST tree to handle it properly.
if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr)) {
// By reaching here, it means the vector/matrix element accessing
// operation is an lvalue. For vector element accessing, if we generated
// a vector shuffle for it and trying to use it as a rvalue, we cannot
// do the load here as normal. Need the upper nodes in the AST tree to
// handle it properly. For matrix element accessing, load should have
// already happened after creating access chain for each element.
return fromValue;
}

Просмотреть файл

@ -0,0 +1,30 @@
// Run: %dxc -T vs_6_0 -E main
void main() {
// CHECK-LABEL: %bb_entry = OpLabel
float1x1 mat;
float3 vec3;
float2 vec2;
float scalar;
// 1 element (from lvalue)
// CHECK: [[load0:%\d+]] = OpLoad %float %mat
// CHECK-NEXT: OpStore %scalar [[load0]]
scalar = mat._m00; // Used as rvalue
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
// CHECK-NEXT: OpStore %mat [[load1]]
mat._11 = scalar; // Used as lvalue
// >1 elements (from lvalue)
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float %mat
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float %mat
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v2float [[load2]] [[load3]]
// CHECK-NEXT: OpStore %vec2 [[cc0]]
vec2 = mat._11_11; // Used as rvalue
// The following statements will trigger errors:
// invalid format for vector swizzle
// scalar = (mat + mat)._m00;
// vec2 = (mat * mat)._11_11;
}

Просмотреть файл

@ -0,0 +1,44 @@
// Run: %dxc -T vs_6_0 -E main
void main() {
// CHECK-LABEL: %bb_entry = OpLabel
float1x3 mat;
float3 vec3;
float2 vec2;
float scalar;
// 1 element (from lvalue)
// CHECK: [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
// CHECK-NEXT: OpStore %scalar [[load0]]
scalar = mat._m02; // Used as rvalue
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
// CHECK-NEXT: OpStore [[access1]] [[load1]]
mat._12 = scalar; // Used as lvalue
// > 1 elements (from lvalue)
// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
// CHECK-NEXT: [[load4:%\d+]] = OpLoad %float [[access4]]
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3float [[load2]] [[load3]] [[load4]]
// CHECK-NEXT: OpStore %vec3 [[cc0]]
vec3 = mat._11_13_12; // Used as rvalue
// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v2float %vec2
// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
// CHECK-NEXT: OpStore [[access5]] [[ce0]]
// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
// CHECK-NEXT: OpStore [[access6]] [[ce1]]
mat._m00_m02 = vec2; // Used as lvalue
// The following statements will trigger errors:
// invalid format for vector swizzle
// scalar = (mat + mat)._m02;
// vec2 = (mat * mat)._11_12;
}

Просмотреть файл

@ -0,0 +1,61 @@
// Run: %dxc -T vs_6_0 -E main
void main() {
// CHECK-LABEL: %bb_entry = OpLabel
float3x1 mat;
float3 vec3;
float2 vec2;
float scalar;
// 1 element (from lvalue)
// CHECK: [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
// CHECK-NEXT: OpStore %scalar [[load0]]
scalar = mat._m20; // Used as rvalue
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
// CHECK-NEXT: OpStore [[access1]] [[load1]]
mat._21 = scalar; // Used as lvalue
// > 1 elements (from lvalue)
// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
// CHECK-NEXT: [[load4:%\d+]] = OpLoad %float [[access4]]
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3float [[load2]] [[load3]] [[load4]]
// CHECK-NEXT: OpStore %vec3 [[cc0]]
vec3 = mat._11_31_21; // Used as rvalue
// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v2float %vec2
// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
// CHECK-NEXT: OpStore [[access5]] [[ce0]]
// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
// CHECK-NEXT: OpStore [[access6]] [[ce1]]
mat._m00_m20 = vec2; // Used as lvalue
// 1 element (from rvalue)
// CHECK-NEXT: [[load5:%\d+]] = OpLoad %v3float %mat
// CHECK-NEXT: [[load6:%\d+]] = OpLoad %v3float %mat
// CHECK-NEXT: [[add0:%\d+]] = OpFAdd %v3float [[load5]] [[load6]]
// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[add0]] 2
// CHECK-NEXT: OpStore %scalar [[ce2]]
// Codegen: construct a temporary vector first out of (mat + mat) and
// then extract the value
scalar = (mat + mat)._m20;
// > 1 element (from rvalue)
// CHECK-NEXT: [[load7:%\d+]] = OpLoad %v3float %mat
// CHECK-NEXT: [[load8:%\d+]] = OpLoad %v3float %mat
// CHECK-NEXT: [[mul0:%\d+]] = OpFMul %v3float [[load7]] [[load8]]
// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[mul0]] 0
// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[mul0]] 1
// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v2float [[ce3]] [[ce4]]
// CHECK-NEXT: OpStore %vec2 [[cc1]]
// Codegen: construct a temporary vector first out of (mat * mat) and
// then extract the value
vec2 = (mat * mat)._11_21;
}

Просмотреть файл

@ -0,0 +1,58 @@
// Run: %dxc -T vs_6_0 -E main
void main() {
// CHECK-LABEL: %bb_entry = OpLabel
float2x3 mat;
float3 vec3;
float2 vec2;
float scalar;
// 1 element (from lvalue)
// CHECK: [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1 %int_2
// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
// CHECK-NEXT: OpStore %scalar [[load0]]
scalar = mat._m12; // Used as rvalue
// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
// CHECK-NEXT: OpStore [[access1]] [[load1]]
mat._12 = scalar; // Used as lvalue
// >1 elements (from lvalue)
// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_2
// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v2float [[load2]] [[load3]]
// CHECK-NEXT: OpStore %vec2 [[cc0]]
vec2 = mat._m01_m02; // Used as rvalue
// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v3float %vec3
// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1 %int_0
// CHECK-NEXT: OpStore [[access4]] [[ce0]]
// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
// CHECK-NEXT: OpStore [[access5]] [[ce1]]
// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[rhs0]] 2
// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_0
// CHECK-NEXT: OpStore [[access6]] [[ce2]]
mat._21_12_11 = vec3; // Used as lvalue
// 1 element (from rvalue)
// CHECK: [[cc1:%\d+]] = OpCompositeConstruct %mat2v3float {{%\d+}} {{%\d+}}
// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[cc1]] 1 2
// CHECK-NEXT: OpStore %scalar [[ce3]]
// Codegen: construct a temporary matrix first out of (mat + mat) and
// then extract the value
scalar = (mat + mat)._m12;
// > 1 element (from rvalue)
// CHECK: [[cc2:%\d+]] = OpCompositeConstruct %mat2v3float {{%\d+}} {{%\d+}}
// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[cc2]] 0 1
// CHECK-NEXT: [[ce5:%\d+]] = OpCompositeExtract %float [[cc2]] 0 2
// CHECK-NEXT: [[cc3:%\d+]] = OpCompositeConstruct %v2float [[ce4]] [[ce5]]
// CHECK-NEXT: OpStore %vec2 [[cc3]]
// Codegen: construct a temporary matrix first out of (mat * mat) and
// then extract the value
vec2 = (mat * mat)._m01_m02;
}

Просмотреть файл

@ -153,6 +153,20 @@ TEST_F(FileTest, OpVectorSize1Swizzle) {
runFileTest("op.vector.swizzle.size1.hlsl");
}
// For matrix accessing operators
TEST_F(FileTest, OpMatrixAccessMxN) {
runFileTest("op.matrix.access.mxn.hlsl");
}
TEST_F(FileTest, OpMatrixAccessMx1) {
runFileTest("op.matrix.access.mx1.hlsl");
}
TEST_F(FileTest, OpMatrixAccess1xN) {
runFileTest("op.matrix.access.1xn.hlsl");
}
TEST_F(FileTest, OpMatrixAccess1x1) {
runFileTest("op.matrix.access.1x1.hlsl");
}
// For casting
TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); }
TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }