diff --git a/lib/HLSL/DxilPreparePasses.cpp b/lib/HLSL/DxilPreparePasses.cpp index 3eb6f398b..7aff01ebd 100644 --- a/lib/HLSL/DxilPreparePasses.cpp +++ b/lib/HLSL/DxilPreparePasses.cpp @@ -648,11 +648,13 @@ public: if (OpKind == (uint64_t)DXIL::QuadVoteOpKind::All) { Value *XY = B.CreateAnd(X, Y); - Result = B.CreateAnd(XY, Z); + Value *XYZ = B.CreateAnd(XY, Z); + Result = B.CreateAnd(XYZ, Cond); } else { DXASSERT_NOMSG(OpKind == (uint64_t)DXIL::QuadVoteOpKind::Any); Value *XY = B.CreateOr(X, Y); - Result = B.CreateOr(XY, Z); + Value *XYZ = B.CreateOr(XY, Z); + Result = B.CreateOr(XYZ, Cond); } Value *Res = B.CreateTrunc(Result, Type::getInt1Ty(M.getContext())); CI->replaceAllUsesWith(Res); diff --git a/tools/clang/test/HLSL/ShaderOpArith.xml b/tools/clang/test/HLSL/ShaderOpArith.xml index 748ab56a9..da4abb454 100644 --- a/tools/clang/test/HLSL/ShaderOpArith.xml +++ b/tools/clang/test/HLSL/ShaderOpArith.xml @@ -2930,6 +2930,42 @@ g_TestResults[CS_INDEX] = tr; } } +]]> + + + + RootFlags(0), UAV(u0) + + + + + + + Values; +groupshared uint WaveOffset = 0; +[numthreads(8, 8, 1)] +void main(uint TID: SV_GroupIndex) { + if (TID == 0) + WaveOffset = 0; + GroupMemoryBarrierWithGroupSync(); + uint Offset = 0; + uint QuadElem = WaveGetLaneIndex() & 3; + if (QuadElem == 0) + Offset = WavePrefixSum(4); + Offset = QuadReadLaneAt(Offset, 0) + QuadElem; + uint LocalWaveOffset = 0; + if (WaveGetLaneIndex() == 0) { + InterlockedAdd(WaveOffset, WaveGetLaneCount(), LocalWaveOffset); + } + uint Idx = Offset + WaveReadLaneFirst(LocalWaveOffset); + uint2 ID = {Idx / 8, Idx % 8}; + uint QuadId = Idx / 4; + uint QuadMask = 0x1U << QuadElem; + bool ThreadBool = (QuadMask & QuadId) != 0; + Values[Idx].x = QuadAny(ThreadBool) ? 1 : 2; + Values[Idx].y = QuadAll(ThreadBool) ? 3 : 4; +} ]]> diff --git a/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl index 97ee6203c..87d093669 100644 --- a/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl +++ b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl @@ -10,7 +10,8 @@ // SM66-NEXT: [[y:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 1) // SM66-NEXT: [[z:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 2) // SM66-NEXT: [[xy:%[a-z0-9]+]] = or i32 [[x]], [[y]] -// SM66-NEXT: [[wide:%[a-z0-9]+]] = or i32 [[xy]], [[z]] +// SM66-NEXT: [[xyz:%[a-z0-9]+]] = or i32 [[xy]], [[z]] +// SM66-NEXT: [[wide:%[a-z0-9]+]] = or i32 [[xyz]], [[cond]] // SM66-NEXT: [[any:%[a-z0-9]+]] = trunc i32 [[wide]] to i1 // SM67: [[cond:%[a-z0-9]+]] = icmp ne i1 %{{[a-z0-9]+}}, false @@ -24,7 +25,8 @@ // SM66-NEXT: [[y:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 1) // SM66-NEXT: [[z:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 2) // SM66-NEXT: [[xy:%[a-z0-9]+]] = and i32 [[x]], [[y]] -// SM66-NEXT: [[wide:%[a-z0-9]+]] = and i32 [[xy]], [[z]] +// SM66-NEXT: [[xyz:%[a-z0-9]+]] = and i32 [[xy]], [[z]] +// SM66-NEXT: [[wide:%[a-z0-9]+]] = and i32 [[xyz]], [[cond]] // SM66-NEXT: [[all:%[a-z0-9]+]] = trunc i32 [[wide]] to i1 // SM67: [[cond:%[a-z0-9]+]] = icmp ne i1 %{{[a-z0-9]+}}, false diff --git a/tools/clang/unittests/HLSL/ExecutionTest.cpp b/tools/clang/unittests/HLSL/ExecutionTest.cpp index 0a66dfd6f..fcd84e995 100644 --- a/tools/clang/unittests/HLSL/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSL/ExecutionTest.cpp @@ -314,6 +314,7 @@ public: TEST_METHOD(SignatureResourcesTest) TEST_METHOD(DynamicResourcesTest) TEST_METHOD(QuadReadTest) + TEST_METHOD(QuadAnyAll); TEST_METHOD(CBufferTestHalf); @@ -9566,6 +9567,83 @@ TEST_F(ExecutionTest, HelperLaneTestWave) { VERIFY_ARE_EQUAL(testPassed, true); } +struct int2 { + int x; + int y; +}; + +bool VerifyQuadAnyAllResults(int2 *Res) { + int Idx = 0; + for ( ; Idx < 4; ++Idx) { + if (Res[Idx].x != 2) return false; + if (Res[Idx].y != 4) return false; + } + for ( ; Idx < 60; ++Idx) { + if (Res[Idx].x != 1) return false; + if (Res[Idx].y != 4) return false; + } + for ( ; Idx < 64; ++Idx) { + if (Res[Idx].x != 1) return false; + if (Res[Idx].y != 3) return false; + } + return true; +} + +TEST_F(ExecutionTest, QuadAnyAll) { + WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); + CComPtr pStream; + ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream); + + std::shared_ptr ShaderOpSet = std::make_shared(); + st::ParseShaderOpSetFromStream(pStream, ShaderOpSet.get()); + st::ShaderOp* pShaderOp = ShaderOpSet->GetShaderOp("QuadAnyAll"); + + LPCSTR args = "/Od"; + + if (args[0]) { + for (st::ShaderOpShader& S : pShaderOp->Shaders) + S.Arguments = args; + } + + bool Skipped = true; + D3D_SHADER_MODEL TestShaderModels[] = { D3D_SHADER_MODEL_6_0, D3D_SHADER_MODEL_6_7 }; + for (unsigned i = 0; i < _countof(TestShaderModels); i++) { + D3D_SHADER_MODEL sm = TestShaderModels[i]; + LogCommentFmt(L"\r\nVerifying QuadAny/QuadAll using Wave intrinsics in shader model 6.%1u", ((UINT)sm & 0x0f)); + + if (sm == D3D_SHADER_MODEL_6_7) { + pShaderOp->CS = "CS67"; + } + + CComPtr pDevice; + if (!CreateDevice(&pDevice, sm, false /* skipUnsupported */)) { + continue; + } + + if (IsDeviceBasicAdapter(pDevice)) { + WEX::Logging::Log::Comment(L"QuadAny/All fails on basic render driver."); + continue; + } + + if (!DoesDeviceSupportWaveOps(pDevice)) { + LogCommentFmt(L"Device does not support wave operations in shader model 6.%1u", ((UINT)sm & 0x0f)); + continue; + } + Skipped = false; + + // test compute + std::shared_ptr test = RunShaderOpTestAfterParse(pDevice, m_support, "QuadAnyAll", + CleanUAVBuffer0Buffer, ShaderOpSet); + + MappedData uavData; + test->Test->GetReadBackData("UAVBuffer0", &uavData); + bool Result = VerifyQuadAnyAllResults((int2*)uavData.data()); + VERIFY_IS_TRUE(Result); + } + if (Skipped) + WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); +} + #ifndef _HLK_CONF static void WriteReadBackDump(st::ShaderOp *pShaderOp, st::ShaderOpTest *pTest, char **pReadBackDump) {