[spirv] Legalization: support structured/byte buffer in structs (#970)

We need to change the type of these struct fields to have an extra
level of pointer.

A local resource always has void as its layout rule (because local
resource is not in the Uniform storage class). So in the TypeTranslator,
when we are trying to translate a structured/byte buffer resource that
has void layout rule, we know it must be a local resource. Then we
apply an extra level of pointer to it. Because of TypeTranslator is
recursive, that automatically handles both stand-alone local resources
and the ones in structs. 

In the SPIRVEmitter, we need to have a way to tell whether a resource
is a local resource or not because if it is a local resource, we need to
OpLoad once to get the pointer to the aliased-to global resource.
That's why we have the containsAlias field in SpirvEvalInfo. We set it to
true in getTypeForPotentialAliasVar() for local resources. And do an
extra OpLoad to get the pointer in SPIRVEmitter if it is true.
This commit is contained in:
Lei Zhang 2018-01-08 11:47:55 -05:00 коммит произвёл GitHub
Родитель f9d613b795
Коммит 415e190a8b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 222 добавлений и 58 удалений

Просмотреть файл

@ -517,12 +517,10 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
info.setResultId(id);
if (isAlias)
// No need to dereference to get the pointer. Alias function returns
// themselves are already pointers to values.
info.setValTypeId(0);
else
// All other cases should be normal rvalues.
// No need to dereference to get the pointer. Alias function returns
// themselves are already pointers to values. All other cases should be
// normal rvalues.
if (!isAlias)
info.setRValue();
// Create alias counter variable if suitable
@ -1882,25 +1880,18 @@ uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar(
if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
// This method is only intended to be used to create SPIR-V variables in the
// Function or Private storage class.
assert(!varDecl->isExceptionVariable() || varDecl->isStaticDataMember());
assert(!varDecl->isExternallyVisible() || varDecl->isStaticDataMember());
}
const QualType type = getTypeOrFnRetType(decl);
// Whether we should generate this decl as an alias variable.
bool genAlias = false;
// All texture/structured/byte buffers use GLSL std430 rules.
LayoutRule rule = LayoutRule::GLSLStd430;
if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
// For ConstantBuffer and TextureBuffer
if (buffer->isConstantBufferView())
genAlias = true;
// ConstantBuffer uses GLSL std140 rules.
// TODO: do we actually want to include constant/texture buffers
// in this method?
if (buffer->isCBuffer())
rule = LayoutRule::GLSLStd140;
} else if (TypeTranslator::isAKindOfStructuredOrByteBuffer(type)) {
} else if (TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(type)) {
genAlias = true;
}
@ -1910,18 +1901,8 @@ uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar(
if (genAlias) {
needsLegalization = true;
const uint32_t valType = typeTranslator.translateType(type, rule);
// All constant/texture/structured/byte buffers are in the Uniform
// storage class.
const auto ptrType =
theBuilder.getPointerType(valType, spv::StorageClass::Uniform);
if (info)
info->setStorageClass(spv::StorageClass::Uniform)
.setLayoutRule(rule)
.setValTypeId(ptrType);
return ptrType;
info->setContainsAliasComponent(true);
}
return typeTranslator.translateType(type);

Просмотреть файл

@ -660,17 +660,22 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
return info.setRValue();
}
uint32_t valType = 0;
if (valType = info.getValTypeId()) {
if (loadIfAliasVarRef(expr, info)) {
// We are loading an alias variable as a whole here. This is likely for
// wholesale assignments or function returns. Need to load the pointer.
//
// Note: legalization specific code
// TODO: It seems we should not set rvalue here since info is still
// holding a pointer. But it fails structured buffer assignment because
// of double loadIfGLValue() calls if we do not. Fix it.
return info.setRValue();
}
uint32_t valType = 0;
// TODO: Ouch. Very hacky. We need special path to get the value type if
// we are loading a whole ConstantBuffer/TextureBuffer since the normal
// type translation path won't work.
else if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
} else {
valType =
@ -684,18 +689,34 @@ SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) {
auto info = doExpr(expr);
loadIfAliasVarRef(expr, info);
return info;
}
if (const auto valTypeId = info.getValTypeId()) {
return info
// Load the pointer of the aliased-to-variable
.setResultId(theBuilder.createLoad(valTypeId, info))
// Set the value's <type-id> to zero to indicate that we've performed
// dereference over the pointer-to-pointer and now should fallback to
// the normal path
.setValTypeId(0);
bool SPIRVEmitter::loadIfAliasVarRef(const Expr *varExpr, SpirvEvalInfo &info) {
if (info.containsAliasComponent() &&
TypeTranslator::isAKindOfStructuredOrByteBuffer(varExpr->getType())) {
// Aliased-to variables are all in the Uniform storage class with GLSL
// std430 layout rules.
const auto ptrType = typeTranslator.translateType(varExpr->getType());
// Load the pointer of the aliased-to-variable if the expression has a
// pointer to pointer type. That is, the expression itself is a lvalue.
// (Note that we translate alias function return values as pointer types,
// not pointer to pointer types.)
if (varExpr->isGLValue())
info.setResultId(theBuilder.createLoad(ptrType, info));
info.setStorageClass(spv::StorageClass::Uniform)
.setLayoutRule(LayoutRule::GLSLStd430)
// Set to false to indicate that we've performed dereference over the
// pointer-to-pointer and now should fallback to the normal path
.setContainsAliasComponent(false);
return true;
}
return info;
return false;
}
uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
@ -977,7 +998,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
else
storeValue(varId, loadIfGLValue(init), decl->getType());
// Update counter variable associatd with local variables
// Update counter variable associated with local variables
tryToAssignCounterVar(decl, init);
}
@ -1427,7 +1448,7 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
if (const auto *retVal = stmt->getRetValue()) {
// Update counter variable associatd with function returns
// Update counter variable associated with function returns
tryToAssignCounterVar(curFunction, retVal);
const auto retInfo = doExpr(retVal);
@ -1559,7 +1580,7 @@ SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
// For other binary operations, we need to evaluate lhs before rhs.
if (opcode == BO_Assign) {
if (const auto *dstDecl = getReferencedDef(expr->getLHS()))
// Update counter variable associatd with lhs of assignments
// Update counter variable associated with lhs of assignments
tryToAssignCounterVar(dstDecl, expr->getRHS());
return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
@ -4645,6 +4666,18 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
const auto thisBaseType = thisBase->getType();
const Expr *base = collectArrayStructIndices(thisBase, indices);
if (thisBaseType != base->getType() &&
TypeTranslator::isAKindOfStructuredOrByteBuffer(thisBaseType)) {
// The immediate base is a kind of structured or byte buffer. It should
// be an alias variable. Break the normal index collecting chain.
// Return the immediate base as the base so that we can apply other
// hacks for legalization over it.
//
// Note: legalization specific code
indices->clear();
base = thisBase;
}
// If the base is a StructureType, we need to push an addtional index 0
// here. This is because we created an additional OpTypeRuntimeArray
// in the structure.
@ -7174,7 +7207,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
if (const auto *init = varDecl->getInit()) {
storeValue(varInfo, doExpr(init), varDecl->getType());
// Update counter variable associatd with global variables
// Update counter variable associated with global variables
tryToAssignCounterVar(varDecl, init);
} else {
const auto typeId = typeTranslator.translateType(varDecl->getType());

Просмотреть файл

@ -118,6 +118,13 @@ private:
/// Note: legalization specific code
SpirvEvalInfo loadIfAliasVarRef(const Expr *expr);
/// Loads the pointer of the aliased-to-variable and ajusts aliasVarInfo
/// accordingly if aliasVarExpr is referencing an alias variable. Returns true
/// if aliasVarInfo is changed, false otherwise.
///
/// Note: legalization specific code
bool loadIfAliasVarRef(const Expr *aliasVarExpr, SpirvEvalInfo &aliasVarInfo);
private:
/// Translates the given frontend binary operator into its SPIR-V equivalent
/// taking consideration of the operand type.

Просмотреть файл

@ -79,8 +79,8 @@ public:
/// Handly implicit conversion to test whether the <result-id> is valid.
operator bool() const { return resultId != 0; }
inline SpirvEvalInfo &setValTypeId(uint32_t id);
uint32_t getValTypeId() const { return valTypeId; }
inline SpirvEvalInfo &setContainsAliasComponent(bool);
bool containsAliasComponent() const { return containsAlias; }
inline SpirvEvalInfo &setStorageClass(spv::StorageClass sc);
spv::StorageClass getStorageClass() const { return storageClass; }
@ -99,14 +99,15 @@ public:
private:
uint32_t resultId;
/// The value's <type-id> for this variable.
/// Indicates whether this evaluation result contains alias variables
///
/// This field should only be non-zero for original alias variables, which is
/// of pointer-to-pointer type. After dereferencing the alias variable, this
/// should be set to zero to let CodeGen fall back to normal handling path.
/// This field should only be true for stand-alone alias variables, which is
/// of pointer-to-pointer type, or struct variables containing alias fields.
/// After dereferencing the alias variable, this should be set to false to let
/// CodeGen fall back to normal handling path.
///
/// Note: legalization specific code
uint32_t valTypeId;
bool containsAlias;
spv::StorageClass storageClass;
LayoutRule layoutRule;
@ -117,9 +118,9 @@ private:
};
SpirvEvalInfo::SpirvEvalInfo(uint32_t id)
: resultId(id), valTypeId(0), storageClass(spv::StorageClass::Function),
layoutRule(LayoutRule::Void), isRValue_(false), isConstant_(false),
isRelaxedPrecision_(false) {}
: resultId(id), containsAlias(false),
storageClass(spv::StorageClass::Function), layoutRule(LayoutRule::Void),
isRValue_(false), isConstant_(false), isRelaxedPrecision_(false) {}
SpirvEvalInfo &SpirvEvalInfo::setResultId(uint32_t id) {
resultId = id;
@ -132,8 +133,8 @@ SpirvEvalInfo SpirvEvalInfo::substResultId(uint32_t newId) const {
return info;
}
SpirvEvalInfo &SpirvEvalInfo::setValTypeId(uint32_t id) {
valTypeId = id;
SpirvEvalInfo &SpirvEvalInfo::setContainsAliasComponent(bool contains) {
containsAlias = contains;
return *this;
}

Просмотреть файл

@ -488,6 +488,22 @@ bool TypeTranslator::isAKindOfStructuredOrByteBuffer(QualType type) {
return false;
}
bool TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(QualType type) {
if (const RecordType *recordType = type->getAs<RecordType>()) {
StringRef name = recordType->getDecl()->getName();
if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer")
return true;
for (const auto *field : recordType->getDecl()->fields()) {
if (isOrContainsAKindOfStructuredOrByteBuffer(field->getType()))
return true;
}
}
return false;
}
bool TypeTranslator::isStructuredBuffer(QualType type) {
const auto *recordType = type->getAs<RecordType>();
if (!recordType)
@ -836,10 +852,20 @@ uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule) {
if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
auto &context = *theBuilder.getSPIRVContext();
// StructureBuffer<S> will be translated into an OpTypeStruct with one
// field, which is an OpTypeRuntimeArray of OpTypeStruct (S).
// If layout rule is void, it means these resource types are used for
// declaring local resources, which should be created as alias variables.
// The aliased-to variable should surely be in the Uniform storage class,
// which has layout decorations.
bool asAlias = false;
if (rule == LayoutRule::Void) {
asAlias = true;
rule = LayoutRule::GLSLStd430;
}
auto &context = *theBuilder.getSPIRVContext();
const auto s = hlsl::GetHLSLResourceResultType(type);
const uint32_t structType = translateType(s, rule);
std::string structName;
@ -864,16 +890,36 @@ uint32_t TypeTranslator::translateResourceType(QualType type, LayoutRule rule) {
decorations.push_back(Decoration::getNonWritable(context, 0));
decorations.push_back(Decoration::getBufferBlock(context));
const std::string typeName = "type." + name.str() + "." + structName;
return theBuilder.getStructType(raType, typeName, {}, decorations);
const auto valType =
theBuilder.getStructType(raType, typeName, {}, decorations);
if (asAlias) {
// All structured buffers are in the Uniform storage class.
return theBuilder.getPointerType(valType, spv::StorageClass::Uniform);
} else {
return valType;
}
}
// ByteAddressBuffer types.
if (name == "ByteAddressBuffer") {
return theBuilder.getByteAddressBufferType(/*isRW*/ false);
const auto bufferType = theBuilder.getByteAddressBufferType(/*isRW*/ false);
if (rule == LayoutRule::Void) {
// All byte address buffers are in the Uniform storage class.
return theBuilder.getPointerType(bufferType, spv::StorageClass::Uniform);
} else {
return bufferType;
}
}
// RWByteAddressBuffer types.
if (name == "RWByteAddressBuffer") {
return theBuilder.getByteAddressBufferType(/*isRW*/ true);
const auto bufferType = theBuilder.getByteAddressBufferType(/*isRW*/ true);
if (rule == LayoutRule::Void) {
// All byte address buffers are in the Uniform storage class.
return theBuilder.getPointerType(bufferType, spv::StorageClass::Uniform);
} else {
return bufferType;
}
}
// Buffer and RWBuffer types

Просмотреть файл

@ -95,6 +95,11 @@ public:
/// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
static bool isAKindOfStructuredOrByteBuffer(QualType type);
/// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
/// (RW)ByteAddressBuffer, {Append|Consume}StructuredBuffer, or a struct
/// containing one of the above.
static bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type);
/// \brief Returns true if the given type is the HLSL Buffer type.
static bool isBuffer(QualType type);

Просмотреть файл

@ -0,0 +1,86 @@
// Run: %dxc -T ps_6_0 -E main
struct Basic {
float3 a;
float4 b;
};
// CHECK: %S = OpTypeStruct %_ptr_Uniform_type_AppendStructuredBuffer_v4float %_ptr_Uniform_type_AppendStructuredBuffer_v4float
struct S {
AppendStructuredBuffer<float4> append;
ConsumeStructuredBuffer<float4> consume;
};
// CHECK: %T = OpTypeStruct %_ptr_Uniform_type_StructuredBuffer_Basic %_ptr_Uniform_type_RWStructuredBuffer_Basic
struct T {
StructuredBuffer<Basic> ro;
RWStructuredBuffer<Basic> rw;
};
// CHECK: %Combine = OpTypeStruct %S %T %_ptr_Uniform_type_ByteAddressBuffer %_ptr_Uniform_type_RWByteAddressBuffer
struct Combine {
S s;
T t;
ByteAddressBuffer ro;
RWByteAddressBuffer rw;
};
StructuredBuffer<Basic> gSBuffer;
RWStructuredBuffer<Basic> gRWSBuffer;
AppendStructuredBuffer<float4> gASBuffer;
ConsumeStructuredBuffer<float4> gCSBuffer;
ByteAddressBuffer gBABuffer;
RWByteAddressBuffer gRWBABuffer;
float4 foo(Combine comb);
float4 main() : SV_Target {
Combine c;
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_AppendStructuredBuffer_v4float %c %int_0 %int_0
// CHECK-NEXT: OpStore [[ptr]] %gASBuffer
c.s.append = gASBuffer;
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_AppendStructuredBuffer_v4float %c %int_0 %int_1
// CHECK-NEXT: OpStore [[ptr]] %gCSBuffer
c.s.consume = gCSBuffer;
T t;
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_StructuredBuffer_Basic %t %int_0
// CHECK-NEXT: OpStore [[ptr]] %gSBuffer
t.ro = gSBuffer;
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_RWStructuredBuffer_Basic %t %int_1
// CHECK-NEXT: OpStore [[ptr]] %gRWSBuffer
t.rw = gRWSBuffer;
// CHECK: [[val:%\d+]] = OpLoad %T %t
// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_T %c %int_1
// CHECK-NEXT: OpStore [[ptr]] [[val]]
c.t = t;
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer %c %int_2
// CHECK-NEXT: OpStore [[ptr]] %gBABuffer
c.ro = gBABuffer;
// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_RWByteAddressBuffer %c %int_3
// CHECK-NEXT: OpStore [[ptr]] %gRWBABuffer
c.rw = gRWBABuffer;
// CHECK: [[val:%\d+]] = OpLoad %Combine %c
// CHECK-NEXT: OpStore %param_var_comb [[val]]
return foo(c);
}
float4 foo(Combine comb) {
// TODO: add support for associated counters of struct fields
// comb.s.append.Append(float4(1, 2, 3, 4));
// float4 val = comb.s.consume.Consume();
// comb.t.rw[5].a = 4.2;
// CHECK: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_ByteAddressBuffer %comb %int_2
// CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_ByteAddressBuffer [[ptr1]]
// CHECK-NEXT: [[idx:%\d+]] = OpShiftRightLogical %uint %uint_5 %uint_2
// CHECK-NEXT: {{%\d+}} = OpAccessChain %_ptr_Uniform_uint [[ptr2]] %uint_0 [[idx]]
uint val = comb.ro.Load(5);
// CHECK: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function__ptr_Uniform_type_StructuredBuffer_Basic %comb %int_1 %int_0
// CHECK-NEXT: [[ptr2:%\d+]] = OpLoad %_ptr_Uniform_type_StructuredBuffer_Basic [[ptr1]]
// CHECK-NEXT: {{%\d+}} = OpAccessChain %_ptr_Uniform_v4float [[ptr2]] %int_0 %uint_0 %int_1
return comb.t.ro[0].b;
}

Просмотреть файл

@ -1003,6 +1003,11 @@ TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
// The generated SPIR-V needs legalization.
/*runValidation=*/false);
}
TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
runFileTest("spirv.legal.sbuffer.struct.hlsl", Expect::Success,
// The generated SPIR-V needs legalization.
/*runValidation=*/false);
}
TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
runFileTest("spirv.legal.cbuffer.hlsl");
}