diff --git a/lib/HLSL/DxilPreparePasses.cpp b/lib/HLSL/DxilPreparePasses.cpp
index 3eb6f398b..7aff01ebd 100644
--- a/lib/HLSL/DxilPreparePasses.cpp
+++ b/lib/HLSL/DxilPreparePasses.cpp
@@ -648,11 +648,13 @@ public:
if (OpKind == (uint64_t)DXIL::QuadVoteOpKind::All) {
Value *XY = B.CreateAnd(X, Y);
- Result = B.CreateAnd(XY, Z);
+ Value *XYZ = B.CreateAnd(XY, Z);
+ Result = B.CreateAnd(XYZ, Cond);
} else {
DXASSERT_NOMSG(OpKind == (uint64_t)DXIL::QuadVoteOpKind::Any);
Value *XY = B.CreateOr(X, Y);
- Result = B.CreateOr(XY, Z);
+ Value *XYZ = B.CreateOr(XY, Z);
+ Result = B.CreateOr(XYZ, Cond);
}
Value *Res = B.CreateTrunc(Result, Type::getInt1Ty(M.getContext()));
CI->replaceAllUsesWith(Res);
diff --git a/tools/clang/test/HLSL/ShaderOpArith.xml b/tools/clang/test/HLSL/ShaderOpArith.xml
index 748ab56a9..da4abb454 100644
--- a/tools/clang/test/HLSL/ShaderOpArith.xml
+++ b/tools/clang/test/HLSL/ShaderOpArith.xml
@@ -2930,6 +2930,42 @@
g_TestResults[CS_INDEX] = tr;
}
}
+]]>
+
+
+
+ RootFlags(0), UAV(u0)
+
+
+
+
+
+
+ Values;
+groupshared uint WaveOffset = 0;
+[numthreads(8, 8, 1)]
+void main(uint TID: SV_GroupIndex) {
+ if (TID == 0)
+ WaveOffset = 0;
+ GroupMemoryBarrierWithGroupSync();
+ uint Offset = 0;
+ uint QuadElem = WaveGetLaneIndex() & 3;
+ if (QuadElem == 0)
+ Offset = WavePrefixSum(4);
+ Offset = QuadReadLaneAt(Offset, 0) + QuadElem;
+ uint LocalWaveOffset = 0;
+ if (WaveGetLaneIndex() == 0) {
+ InterlockedAdd(WaveOffset, WaveGetLaneCount(), LocalWaveOffset);
+ }
+ uint Idx = Offset + WaveReadLaneFirst(LocalWaveOffset);
+ uint2 ID = {Idx / 8, Idx % 8};
+ uint QuadId = Idx / 4;
+ uint QuadMask = 0x1U << QuadElem;
+ bool ThreadBool = (QuadMask & QuadId) != 0;
+ Values[Idx].x = QuadAny(ThreadBool) ? 1 : 2;
+ Values[Idx].y = QuadAll(ThreadBool) ? 3 : 4;
+}
]]>
diff --git a/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl
index 97ee6203c..87d093669 100644
--- a/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl
+++ b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/wave/vote/any-all.hlsl
@@ -10,7 +10,8 @@
// SM66-NEXT: [[y:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 1)
// SM66-NEXT: [[z:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 2)
// SM66-NEXT: [[xy:%[a-z0-9]+]] = or i32 [[x]], [[y]]
-// SM66-NEXT: [[wide:%[a-z0-9]+]] = or i32 [[xy]], [[z]]
+// SM66-NEXT: [[xyz:%[a-z0-9]+]] = or i32 [[xy]], [[z]]
+// SM66-NEXT: [[wide:%[a-z0-9]+]] = or i32 [[xyz]], [[cond]]
// SM66-NEXT: [[any:%[a-z0-9]+]] = trunc i32 [[wide]] to i1
// SM67: [[cond:%[a-z0-9]+]] = icmp ne i1 %{{[a-z0-9]+}}, false
@@ -24,7 +25,8 @@
// SM66-NEXT: [[y:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 1)
// SM66-NEXT: [[z:%[a-zA-Z0-9]+]] = call i32 @dx.op.quadOp.i32(i32 123, i32 [[cond]], i8 2)
// SM66-NEXT: [[xy:%[a-z0-9]+]] = and i32 [[x]], [[y]]
-// SM66-NEXT: [[wide:%[a-z0-9]+]] = and i32 [[xy]], [[z]]
+// SM66-NEXT: [[xyz:%[a-z0-9]+]] = and i32 [[xy]], [[z]]
+// SM66-NEXT: [[wide:%[a-z0-9]+]] = and i32 [[xyz]], [[cond]]
// SM66-NEXT: [[all:%[a-z0-9]+]] = trunc i32 [[wide]] to i1
// SM67: [[cond:%[a-z0-9]+]] = icmp ne i1 %{{[a-z0-9]+}}, false
diff --git a/tools/clang/unittests/HLSL/ExecutionTest.cpp b/tools/clang/unittests/HLSL/ExecutionTest.cpp
index 0a66dfd6f..fcd84e995 100644
--- a/tools/clang/unittests/HLSL/ExecutionTest.cpp
+++ b/tools/clang/unittests/HLSL/ExecutionTest.cpp
@@ -314,6 +314,7 @@ public:
TEST_METHOD(SignatureResourcesTest)
TEST_METHOD(DynamicResourcesTest)
TEST_METHOD(QuadReadTest)
+ TEST_METHOD(QuadAnyAll);
TEST_METHOD(CBufferTestHalf);
@@ -9566,6 +9567,83 @@ TEST_F(ExecutionTest, HelperLaneTestWave) {
VERIFY_ARE_EQUAL(testPassed, true);
}
+struct int2 {
+ int x;
+ int y;
+};
+
+bool VerifyQuadAnyAllResults(int2 *Res) {
+ int Idx = 0;
+ for ( ; Idx < 4; ++Idx) {
+ if (Res[Idx].x != 2) return false;
+ if (Res[Idx].y != 4) return false;
+ }
+ for ( ; Idx < 60; ++Idx) {
+ if (Res[Idx].x != 1) return false;
+ if (Res[Idx].y != 4) return false;
+ }
+ for ( ; Idx < 64; ++Idx) {
+ if (Res[Idx].x != 1) return false;
+ if (Res[Idx].y != 3) return false;
+ }
+ return true;
+}
+
+TEST_F(ExecutionTest, QuadAnyAll) {
+ WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
+ CComPtr pStream;
+ ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
+
+ std::shared_ptr ShaderOpSet = std::make_shared();
+ st::ParseShaderOpSetFromStream(pStream, ShaderOpSet.get());
+ st::ShaderOp* pShaderOp = ShaderOpSet->GetShaderOp("QuadAnyAll");
+
+ LPCSTR args = "/Od";
+
+ if (args[0]) {
+ for (st::ShaderOpShader& S : pShaderOp->Shaders)
+ S.Arguments = args;
+ }
+
+ bool Skipped = true;
+ D3D_SHADER_MODEL TestShaderModels[] = { D3D_SHADER_MODEL_6_0, D3D_SHADER_MODEL_6_7 };
+ for (unsigned i = 0; i < _countof(TestShaderModels); i++) {
+ D3D_SHADER_MODEL sm = TestShaderModels[i];
+ LogCommentFmt(L"\r\nVerifying QuadAny/QuadAll using Wave intrinsics in shader model 6.%1u", ((UINT)sm & 0x0f));
+
+ if (sm == D3D_SHADER_MODEL_6_7) {
+ pShaderOp->CS = "CS67";
+ }
+
+ CComPtr pDevice;
+ if (!CreateDevice(&pDevice, sm, false /* skipUnsupported */)) {
+ continue;
+ }
+
+ if (IsDeviceBasicAdapter(pDevice)) {
+ WEX::Logging::Log::Comment(L"QuadAny/All fails on basic render driver.");
+ continue;
+ }
+
+ if (!DoesDeviceSupportWaveOps(pDevice)) {
+ LogCommentFmt(L"Device does not support wave operations in shader model 6.%1u", ((UINT)sm & 0x0f));
+ continue;
+ }
+ Skipped = false;
+
+ // test compute
+ std::shared_ptr test = RunShaderOpTestAfterParse(pDevice, m_support, "QuadAnyAll",
+ CleanUAVBuffer0Buffer, ShaderOpSet);
+
+ MappedData uavData;
+ test->Test->GetReadBackData("UAVBuffer0", &uavData);
+ bool Result = VerifyQuadAnyAllResults((int2*)uavData.data());
+ VERIFY_IS_TRUE(Result);
+ }
+ if (Skipped)
+ WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
+}
+
#ifndef _HLK_CONF
static void WriteReadBackDump(st::ShaderOp *pShaderOp, st::ShaderOpTest *pTest,
char **pReadBackDump) {