Fixed writing to RWTexture generating a load and store for each component (#3304)

This commit is contained in:
Adam Yang 2020-12-07 22:51:53 -08:00 коммит произвёл GitHub
Родитель ef84d9e045
Коммит 4d40670ef3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 238 добавлений и 29 удалений

Просмотреть файл

@ -16,6 +16,8 @@
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
#include "llvm/IR/Constants.h" #include "llvm/IR/Constants.h"
#include "dxc/DXIL/DxilConstants.h"
#include "dxc/DXIL/DxilResourceProperties.h"
namespace llvm { namespace llvm {
class Type; class Type;
@ -128,6 +130,7 @@ namespace dxilutil {
bool IsIntegerOrFloatingPointType(llvm::Type *Ty); bool IsIntegerOrFloatingPointType(llvm::Type *Ty);
// Returns true if type contains HLSL Object type (resource) // Returns true if type contains HLSL Object type (resource)
bool ContainsHLSLObjectType(llvm::Type *Ty); bool ContainsHLSLObjectType(llvm::Type *Ty);
std::pair<bool, DxilResourceProperties> GetHLSLResourceProperties(llvm::Type *Ty);
bool IsHLSLResourceType(llvm::Type *Ty); bool IsHLSLResourceType(llvm::Type *Ty);
bool IsHLSLObjectType(llvm::Type *Ty); bool IsHLSLObjectType(llvm::Type *Ty);
bool IsHLSLRayQueryType(llvm::Type *Ty); bool IsHLSLRayQueryType(llvm::Type *Ty);

Просмотреть файл

@ -627,64 +627,100 @@ uint8_t GetResourceComponentCount(llvm::Type *Ty) {
} }
bool IsHLSLResourceType(llvm::Type *Ty) { bool IsHLSLResourceType(llvm::Type *Ty) {
return GetHLSLResourceProperties(Ty).first;
}
static DxilResourceProperties MakeResourceProperties(hlsl::DXIL::ResourceKind Kind, bool UAV, bool ROV, bool Cmp) {
DxilResourceProperties Ret = {};
Ret.Basic.IsROV = ROV;
Ret.Basic.SamplerCmpOrHasCounter = Cmp;
Ret.Basic.IsUAV = UAV;
Ret.Basic.ResourceKind = (uint8_t)Kind;
return Ret;
}
std::pair<bool, DxilResourceProperties> GetHLSLResourceProperties(llvm::Type *Ty)
{
using RetType = std::pair<bool, DxilResourceProperties>;
RetType FalseRet(false, DxilResourceProperties{});
if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) { if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
if (!ST->hasName()) if (!ST->hasName())
return false; return FalseRet;
StringRef name = ST->getName(); StringRef name = ST->getName();
ConsumePrefix(name, "class."); ConsumePrefix(name, "class.");
ConsumePrefix(name, "struct."); ConsumePrefix(name, "struct.");
if (name == "SamplerState") if (name == "SamplerState")
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Sampler, false, false, false));
if (name == "SamplerComparisonState") if (name == "SamplerComparisonState")
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Sampler, false, false, /*cmp or counter*/true));
if (name.startswith("AppendStructuredBuffer<")) if (name.startswith("AppendStructuredBuffer<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::StructuredBuffer, false, false, /*cmp or counter*/true));
if (name.startswith("ConsumeStructuredBuffer<")) if (name.startswith("ConsumeStructuredBuffer<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::StructuredBuffer, false, false, /*cmp or counter*/true));
if (name == "RaytracingAccelerationStructure") if (name == "RaytracingAccelerationStructure")
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::RTAccelerationStructure, false, false, false));
if (ConsumePrefix(name, "FeedbackTexture2D")) { if (ConsumePrefix(name, "FeedbackTexture2D")) {
ConsumePrefix(name, "Array"); hlsl::DXIL::ResourceKind kind = hlsl::DXIL::ResourceKind::Invalid;
return name.startswith("<"); if (ConsumePrefix(name, "Array"))
kind = hlsl::DXIL::ResourceKind::FeedbackTexture2DArray;
else
kind = hlsl::DXIL::ResourceKind::FeedbackTexture2D;
if (name.startswith("<"))
return RetType(true, MakeResourceProperties(kind, false, false, false));
} }
ConsumePrefix(name, "RasterizerOrdered"); bool ROV = ConsumePrefix(name, "RasterizerOrdered");
ConsumePrefix(name, "RW"); bool UAV = ConsumePrefix(name, "RW");
if (name == "ByteAddressBuffer") if (name == "ByteAddressBuffer")
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::RawBuffer, UAV, ROV, false));
if (name.startswith("Buffer<")) if (name.startswith("Buffer<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::TypedBuffer, UAV, ROV, false));
if (name.startswith("StructuredBuffer<")) if (name.startswith("StructuredBuffer<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::StructuredBuffer, UAV, ROV, false));
if (ConsumePrefix(name, "Texture")) { if (ConsumePrefix(name, "Texture")) {
if (name.startswith("1D<")) if (name.startswith("1D<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Texture1D, UAV, ROV, false));
if (name.startswith("1DArray<")) if (name.startswith("1DArray<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Texture1DArray, UAV, ROV, false));
if (name.startswith("2D<")) if (name.startswith("2D<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Texture2D, UAV, ROV, false));
if (name.startswith("2DArray<")) if (name.startswith("2DArray<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Texture2DArray, UAV, ROV, false));
if (name.startswith("3D<")) if (name.startswith("3D<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Texture3D, UAV, ROV, false));
if (name.startswith("Cube<")) if (name.startswith("Cube<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::TextureCube, UAV, ROV, false));
if (name.startswith("CubeArray<")) if (name.startswith("CubeArray<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::TextureCubeArray, UAV, ROV, false));
if (name.startswith("2DMS<")) if (name.startswith("2DMS<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Texture2DMS, UAV, ROV, false));
if (name.startswith("2DMSArray<")) if (name.startswith("2DMSArray<"))
return true; return RetType(true, MakeResourceProperties(hlsl::DXIL::ResourceKind::Texture2DMSArray, UAV, ROV, false));
return false; return FalseRet;
} }
} }
return false; return FalseRet;
} }
bool IsHLSLObjectType(llvm::Type *Ty) { bool IsHLSLObjectType(llvm::Type *Ty) {

Просмотреть файл

@ -18,6 +18,9 @@
#include "CGObjCRuntime.h" #include "CGObjCRuntime.h"
#include "CGOpenMPRuntime.h" #include "CGOpenMPRuntime.h"
#include "CGHLSLRuntime.h" // HLSL Change #include "CGHLSLRuntime.h" // HLSL Change
#include "dxc/HLSL/HLOperations.h" // HLSL Change
#include "dxc/DXIL/DxilUtil.h" // HLSL Change
#include "dxc/DXIL/DxilResource.h" // HLSL Change
#include "CGRecordLayout.h" #include "CGRecordLayout.h"
#include "CodeGenModule.h" #include "CodeGenModule.h"
#include "TargetInfo.h" #include "TargetInfo.h"
@ -1757,6 +1760,36 @@ void CodeGenFunction::EmitStoreThroughBitfieldLValue(RValue Src, LValue Dst,
} }
} }
// HLSL Change - begin
static bool IsHLSubscriptOfTypedBuffer(llvm::Value *V) {
llvm::CallInst *CI = nullptr;
llvm::Function *F = nullptr;
if ((CI = dyn_cast<llvm::CallInst>(V)) &&
(F = CI->getCalledFunction()) &&
hlsl::GetHLOpcodeGroup(F) == hlsl::HLOpcodeGroup::HLSubscript)
{
for (llvm::Value *arg : CI->arg_operands()) {
llvm::Type *Ty = arg->getType();
if (Ty->isPointerTy()) {
std::pair<bool, hlsl::DxilResourceProperties> Result =
hlsl::dxilutil::GetHLSLResourceProperties(Ty->getPointerElementType());
if (Result.first &&
Result.second.isUAV() &&
// These are the types of buffers that are typed.
(hlsl::DxilResource::IsAnyTexture(Result.second.getResourceKind()) ||
Result.second.getResourceKind() == hlsl::DXIL::ResourceKind::TypedBuffer))
{
return true;
}
}
}
}
return false;
}
// HLSL Change - end
void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src, void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
LValue Dst) { LValue Dst) {
// This access turns into a read/modify/write of the vector. Load the input // This access turns into a read/modify/write of the vector. Load the input
@ -1791,6 +1824,25 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
} }
Builder.CreateStore(Vec, VecDstPtr); Builder.CreateStore(Vec, VecDstPtr);
} else { } else {
// If the vector pointer comes from subscripting a typed rw buffer (Buffer<>, Texture*<>, etc.),
// insert the elements from the load.
//
// This is to avoid the final DXIL producing a load+store for each component later down the line,
// as there's no clean way to associate the geps+store with each other.
//
if (IsHLSubscriptOfTypedBuffer(VecDstPtr)) {
llvm::Value *vec = Load;
for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
if (llvm::Constant *Elt = Elts->getAggregateElement(i)) {
llvm::Value *SrcElt = Builder.CreateExtractElement(SrcVal, i);
vec = Builder.CreateInsertElement(vec, SrcElt, Elt);
}
}
Builder.CreateStore(vec, VecDstPtr);
}
// Otherwise just do a gep + store for each component that we're writing to.
else {
for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) { for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
if (llvm::Constant *Elt = Elts->getAggregateElement(i)) { if (llvm::Constant *Elt = Elts->getAggregateElement(i)) {
llvm::Value *EltGEP = Builder.CreateGEP(VecDstPtr, {Zero, Elt}); llvm::Value *EltGEP = Builder.CreateGEP(VecDstPtr, {Zero, Elt});
@ -1799,6 +1851,7 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
} }
} }
} }
}
} else { } else {
// If the Src is a scalar (not a vector) it must be updating one element. // If the Src is a scalar (not a vector) it must be updating one element.
llvm::Value *EltGEP = Builder.CreateGEP( llvm::Value *EltGEP = Builder.CreateGEP(

Просмотреть файл

@ -0,0 +1,43 @@
// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
// CHECK: @dx.op.textureLoad.f32(i32 66
// CHECK: call void @dx.op.textureStore.f32(i32 67
// CHECK-NOT: @dx.op.textureLoad.f32(i32 66
// CHECK-NOT: call void @dx.op.textureStore.f32(i32 67
struct PS_INPUT
{
float4 pos : SV_POSITION;
float2 vPos : TEXCOORD0;
float2 sPos : TEXCOORD1;
};
RWTexture2D<float4> t_output : register(u0);
float4 psMain(in PS_INPUT I)
{
return float4(0, 0, 0, 1);
}
[numthreads(8, 8, 1)]
void main(uint2 dtid : SV_DispatchThreadID)
{
PS_INPUT I = {
float4(0,0, 0, 1),
float2(dtid),
float2(dtid)
};
uint2 uspos = uint2(I.pos.xy);
{
uint w, h;
t_output.GetDimensions(w, h);
if (any(uspos >= uint2(w, h)))
return;
}
float3 prev = t_output[uspos].rgb;
t_output[uspos].rgb = float3(psMain(I).rgb + prev);
}

Просмотреть файл

@ -0,0 +1,36 @@
// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
// CHECK: @dx.op.bufferLoad.f32(i32 68
// CHECK: call void @dx.op.bufferStore.f32(i32 69
// CHECK-NOT: @dx.op.bufferLoad.f32(i32 68
// CHECK-NOT: call void @dx.op.bufferStore.f32(i32 69
struct PS_INPUT
{
float4 pos : SV_POSITION;
float2 vPos : TEXCOORD0;
float2 sPos : TEXCOORD1;
};
RWBuffer<float4> t_output : register(u0);
float4 psMain(in PS_INPUT I)
{
return float4(0, 0, 0, 1);
}
[numthreads(8, 8, 1)]
void main(uint2 dtid : SV_DispatchThreadID)
{
PS_INPUT I = {
float4(0,0, 0, 1),
float2(dtid),
float2(dtid)
};
uint2 uspos = uint2(I.pos.xy);
float3 prev = t_output[uspos.x].rgb;
t_output[uspos.x].rgb = float3(psMain(I).rgb + prev);
}

Просмотреть файл

@ -0,0 +1,38 @@
// RUN: %dxc -E main -T cs_6_4 %s | FileCheck %s -check-prefix=CHECK64
// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s -check-prefix=CHECK60
// CHECK64-DAG: call void @dx.op.rawBufferStore
// CHECK64-NOT: call void @dx.op.rawBufferLoad
// CHECK60-DAG: call void @dx.op.bufferStore
// CHECK60-NOT: call void @dx.op.bufferLoad
struct PS_INPUT
{
float4 pos : SV_POSITION;
float2 vPos : TEXCOORD0;
float2 sPos : TEXCOORD1;
};
RWStructuredBuffer<float4> t_output : register(u0);
float4 psMain(in PS_INPUT I)
{
return float4(0, 0, 0, 1);
}
[numthreads(8, 8, 1)]
void main(uint2 dtid : SV_DispatchThreadID)
{
PS_INPUT I = {
float4(0,0, 0, 1),
float2(dtid),
float2(dtid)
};
uint2 uspos = uint2(I.pos.xy);
t_output[uspos.x].rgb = float3(1, 2, 3);
}