From c266429be9b882f1553fcf8c0810e3782b394f97 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Wed, 11 Apr 2018 15:02:02 +0200 Subject: [PATCH] Partially implement subgroup ops for HLSL SM 6.0. Lots of stuff that needs tons of emulation, which I'm not going to bother with. --- .../comp/subgroups.invalid.nofxc.sm60.comp | 67 +++++ .../comp/subgroups.invalid.nofxc.sm60.comp | 93 ++++++ .../comp/subgroups.invalid.nofxc.sm60.comp | 131 +++++++++ spirv_hlsl.cpp | 277 +++++++++++++++++- spirv_hlsl.hpp | 1 + test_shaders.py | 4 +- tests-other/hlsl_wave_mask.cpp | 73 +++++ 7 files changed, 644 insertions(+), 2 deletions(-) create mode 100644 reference/opt/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp create mode 100644 reference/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp create mode 100644 shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp create mode 100644 tests-other/hlsl_wave_mask.cpp diff --git a/reference/opt/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp b/reference/opt/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp new file mode 100644 index 0000000..dabc7df --- /dev/null +++ b/reference/opt/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp @@ -0,0 +1,67 @@ +RWByteAddressBuffer _9 : register(u0, space0); + +static uint4 gl_SubgroupEqMask; +static uint4 gl_SubgroupGeMask; +static uint4 gl_SubgroupGtMask; +static uint4 gl_SubgroupLeMask; +static uint4 gl_SubgroupLtMask; +void comp_main() +{ + _9.Store(0, asuint(float(WaveGetLaneCount()))); + _9.Store(0, asuint(float(WaveGetLaneIndex()))); + _9.Store(0, asuint(float4(gl_SubgroupEqMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupGeMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupGtMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupLeMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupLtMask).x)); + uint4 _75 = WaveActiveBallot(true); + float4 _88 = WaveActiveSum(20.0f.xxxx); + int4 _94 = WaveActiveSum(int4(20, 20, 20, 20)); + float4 _96 = WaveActiveProduct(20.0f.xxxx); + int4 _98 = WaveActiveProduct(int4(20, 20, 20, 20)); + float4 _127 = WavePrefixProduct(_96) * _96; + int4 _129 = WavePrefixProduct(_98) * _98; +} + +[numthreads(1, 1, 1)] +void main() +{ + gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96)); + if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0; + if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0; + if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0; + if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0; + gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u); + if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u; + if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u; + if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u; + if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u; + if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u; + if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u; + uint gt_lane_index = WaveGetLaneIndex() + 1; + gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u); + if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u; + if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u; + if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u; + if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u; + if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u; + if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u; + if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u; + uint le_lane_index = WaveGetLaneIndex() + 1; + gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u; + if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u; + if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u; + if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u; + if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u; + if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u; + if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u; + if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u; + gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u; + if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u; + if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u; + if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u; + if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u; + if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u; + if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u; + comp_main(); +} diff --git a/reference/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp b/reference/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp new file mode 100644 index 0000000..b87574f --- /dev/null +++ b/reference/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp @@ -0,0 +1,93 @@ +RWByteAddressBuffer _9 : register(u0, space0); + +static uint4 gl_SubgroupEqMask; +static uint4 gl_SubgroupGeMask; +static uint4 gl_SubgroupGtMask; +static uint4 gl_SubgroupLeMask; +static uint4 gl_SubgroupLtMask; +void comp_main() +{ + _9.Store(0, asuint(float(WaveGetLaneCount()))); + _9.Store(0, asuint(float(WaveGetLaneIndex()))); + bool elected = WaveIsFirstLane(); + _9.Store(0, asuint(float4(gl_SubgroupEqMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupGeMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupGtMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupLeMask).x)); + _9.Store(0, asuint(float4(gl_SubgroupLtMask).x)); + float4 broadcasted = WaveReadLaneAt(10.0f.xxxx, 8u); + float3 first = WaveReadLaneFirst(20.0f.xxx); + uint4 ballot_value = WaveActiveBallot(true); + uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w); + bool has_all = WaveActiveAllTrue(true); + bool has_any = WaveActiveAnyTrue(true); + bool has_equal = WaveActiveAllEqualBool(true); + float4 added = WaveActiveSum(20.0f.xxxx); + int4 iadded = WaveActiveSum(int4(20, 20, 20, 20)); + float4 multiplied = WaveActiveProduct(20.0f.xxxx); + int4 imultiplied = WaveActiveProduct(int4(20, 20, 20, 20)); + float4 lo = WaveActiveMin(20.0f.xxxx); + float4 hi = WaveActiveMax(20.0f.xxxx); + int4 slo = WaveActiveMin(int4(20, 20, 20, 20)); + int4 shi = WaveActiveMax(int4(20, 20, 20, 20)); + uint4 ulo = WaveActiveMin(uint4(20u, 20u, 20u, 20u)); + uint4 uhi = WaveActiveMax(uint4(20u, 20u, 20u, 20u)); + uint4 anded = WaveActiveBitAnd(ballot_value); + uint4 ored = WaveActiveBitOr(ballot_value); + uint4 xored = WaveActiveBitXor(ballot_value); + added = WavePrefixSum(added) + added; + iadded = WavePrefixSum(iadded) + iadded; + multiplied = WavePrefixProduct(multiplied) * multiplied; + imultiplied = WavePrefixProduct(imultiplied) * imultiplied; + added = WavePrefixSum(multiplied); + multiplied = WavePrefixProduct(multiplied); + iadded = WavePrefixSum(imultiplied); + imultiplied = WavePrefixProduct(imultiplied); + float4 swap_horiz = QuadReadAcrossX(20.0f.xxxx); + float4 swap_vertical = QuadReadAcrossY(20.0f.xxxx); + float4 swap_diagonal = QuadReadAcrossDiagonal(20.0f.xxxx); + float4 quad_broadcast = QuadReadLaneAt(20.0f.xxxx, 3u); +} + +[numthreads(1, 1, 1)] +void main() +{ + gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96)); + if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0; + if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0; + if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0; + if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0; + gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u); + if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u; + if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u; + if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u; + if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u; + if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u; + if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u; + uint gt_lane_index = WaveGetLaneIndex() + 1; + gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u); + if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u; + if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u; + if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u; + if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u; + if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u; + if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u; + if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u; + uint le_lane_index = WaveGetLaneIndex() + 1; + gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u; + if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u; + if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u; + if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u; + if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u; + if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u; + if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u; + if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u; + gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u; + if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u; + if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u; + if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u; + if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u; + if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u; + if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u; + comp_main(); +} diff --git a/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp b/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp new file mode 100644 index 0000000..81135e2 --- /dev/null +++ b/shaders-hlsl/comp/subgroups.invalid.nofxc.sm60.comp @@ -0,0 +1,131 @@ +#version 450 +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_clustered : require +#extension GL_KHR_shader_subgroup_quad : require +layout(local_size_x = 1) in; + +layout(std430, binding = 0) buffer SSBO +{ + float FragColor; +}; + +void main() +{ + // basic + //FragColor = float(gl_NumSubgroups); + //FragColor = float(gl_SubgroupID); + FragColor = float(gl_SubgroupSize); + FragColor = float(gl_SubgroupInvocationID); + subgroupBarrier(); + subgroupMemoryBarrier(); + subgroupMemoryBarrierBuffer(); + subgroupMemoryBarrierShared(); + subgroupMemoryBarrierImage(); + bool elected = subgroupElect(); + + // ballot + FragColor = float(gl_SubgroupEqMask); + FragColor = float(gl_SubgroupGeMask); + FragColor = float(gl_SubgroupGtMask); + FragColor = float(gl_SubgroupLeMask); + FragColor = float(gl_SubgroupLtMask); + vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u); + vec3 first = subgroupBroadcastFirst(vec3(20.0)); + uvec4 ballot_value = subgroupBallot(true); + //bool inverse_ballot_value = subgroupInverseBallot(ballot_value); + //bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u); + uint bit_count = subgroupBallotBitCount(ballot_value); + //uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value); + //uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value); + //uint lsb = subgroupBallotFindLSB(ballot_value); + //uint msb = subgroupBallotFindMSB(ballot_value); + + // shuffle + //uint shuffled = subgroupShuffle(10u, 8u); + //uint shuffled_xor = subgroupShuffleXor(30u, 8u); + + // shuffle relative + //uint shuffled_up = subgroupShuffleUp(20u, 4u); + //uint shuffled_down = subgroupShuffleDown(20u, 4u); + + // vote + bool has_all = subgroupAll(true); + bool has_any = subgroupAny(true); + bool has_equal = subgroupAllEqual(true); + + // arithmetic + vec4 added = subgroupAdd(vec4(20.0)); + ivec4 iadded = subgroupAdd(ivec4(20)); + vec4 multiplied = subgroupMul(vec4(20.0)); + ivec4 imultiplied = subgroupMul(ivec4(20)); + vec4 lo = subgroupMin(vec4(20.0)); + vec4 hi = subgroupMax(vec4(20.0)); + ivec4 slo = subgroupMin(ivec4(20)); + ivec4 shi = subgroupMax(ivec4(20)); + uvec4 ulo = subgroupMin(uvec4(20)); + uvec4 uhi = subgroupMax(uvec4(20)); + uvec4 anded = subgroupAnd(ballot_value); + uvec4 ored = subgroupOr(ballot_value); + uvec4 xored = subgroupXor(ballot_value); + + added = subgroupInclusiveAdd(added); + iadded = subgroupInclusiveAdd(iadded); + multiplied = subgroupInclusiveMul(multiplied); + imultiplied = subgroupInclusiveMul(imultiplied); +#if 0 + lo = subgroupInclusiveMin(lo); + hi = subgroupInclusiveMax(hi); + slo = subgroupInclusiveMin(slo); + shi = subgroupInclusiveMax(shi); + ulo = subgroupInclusiveMin(ulo); + uhi = subgroupInclusiveMax(uhi); + anded = subgroupInclusiveAnd(anded); + ored = subgroupInclusiveOr(ored); + xored = subgroupInclusiveXor(ored); + added = subgroupExclusiveAdd(lo); +#endif + + added = subgroupExclusiveAdd(multiplied); + multiplied = subgroupExclusiveMul(multiplied); + iadded = subgroupExclusiveAdd(imultiplied); + imultiplied = subgroupExclusiveMul(imultiplied); +#if 0 + lo = subgroupExclusiveMin(lo); + hi = subgroupExclusiveMax(hi); + ulo = subgroupExclusiveMin(ulo); + uhi = subgroupExclusiveMax(uhi); + slo = subgroupExclusiveMin(slo); + shi = subgroupExclusiveMax(shi); + anded = subgroupExclusiveAnd(anded); + ored = subgroupExclusiveOr(ored); + xored = subgroupExclusiveXor(ored); +#endif + +#if 0 + // clustered + added = subgroupClusteredAdd(added, 4u); + multiplied = subgroupClusteredMul(multiplied, 4u); + iadded = subgroupClusteredAdd(iadded, 4u); + imultiplied = subgroupClusteredMul(imultiplied, 4u); + lo = subgroupClusteredMin(lo, 4u); + hi = subgroupClusteredMax(hi, 4u); + ulo = subgroupClusteredMin(ulo, 4u); + uhi = subgroupClusteredMax(uhi, 4u); + slo = subgroupClusteredMin(slo, 4u); + shi = subgroupClusteredMax(shi, 4u); + anded = subgroupClusteredAnd(anded, 4u); + ored = subgroupClusteredOr(ored, 4u); + xored = subgroupClusteredXor(xored, 4u); +#endif + + // quad + vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0)); + vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0)); + vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0)); + vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u); +} diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 57f350a..1704dae 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -625,6 +625,13 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() break; case BuiltInNumWorkgroups: + case BuiltInSubgroupSize: + case BuiltInSubgroupLocalInvocationId: + case BuiltInSubgroupEqMask: + case BuiltInSubgroupLtMask: + case BuiltInSubgroupLeMask: + case BuiltInSubgroupGtMask: + case BuiltInSubgroupGeMask: // Handled specially. break; @@ -864,6 +871,11 @@ std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClas case BuiltInPointCoord: // Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set. return "float2(0.5f, 0.5f)"; + case BuiltInSubgroupLocalInvocationId: + return "WaveGetLaneIndex()"; + case BuiltInSubgroupSize: + return "WaveGetLaneCount()"; + default: return CompilerGLSL::builtin_to_glsl(builtin, storage); } @@ -928,6 +940,22 @@ void CompilerHLSL::emit_builtin_variables() // Handled specially. break; + case BuiltInSubgroupLocalInvocationId: + case BuiltInSubgroupSize: + if (hlsl_options.shader_model < 60) + SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops."); + break; + + case BuiltInSubgroupEqMask: + case BuiltInSubgroupLtMask: + case BuiltInSubgroupLeMask: + case BuiltInSubgroupGtMask: + case BuiltInSubgroupGeMask: + if (hlsl_options.shader_model < 60) + SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops."); + type = "uint4"; + break; + case BuiltInClipDistance: array_size = clip_distance_count; type = "float"; @@ -940,7 +968,6 @@ void CompilerHLSL::emit_builtin_variables() default: SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin))); - break; } StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput; @@ -1225,6 +1252,14 @@ void CompilerHLSL::emit_resources() auto input_builtins = active_input_builtins; input_builtins.clear(BuiltInNumWorkgroups); input_builtins.clear(BuiltInPointCoord); + input_builtins.clear(BuiltInSubgroupSize); + input_builtins.clear(BuiltInSubgroupLocalInvocationId); + input_builtins.clear(BuiltInSubgroupEqMask); + input_builtins.clear(BuiltInSubgroupLtMask); + input_builtins.clear(BuiltInSubgroupLeMask); + input_builtins.clear(BuiltInSubgroupGtMask); + input_builtins.clear(BuiltInSubgroupGeMask); + if (!input_variables.empty() || !input_builtins.empty()) { require_input = true; @@ -2106,6 +2141,70 @@ void CompilerHLSL::emit_hlsl_entry_point() case BuiltInNumWorkgroups: case BuiltInPointCoord: + case BuiltInSubgroupSize: + case BuiltInSubgroupLocalInvocationId: + break; + + case BuiltInSubgroupEqMask: + // Emulate these ... + // No 64-bit in HLSL, so have to do it in 32-bit and unroll. + statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));"); + statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;"); + statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;"); + statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;"); + statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;"); + break; + + case BuiltInSubgroupGeMask: + // Emulate these ... + // No 64-bit in HLSL, so have to do it in 32-bit and unroll. + statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);"); + statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;"); + statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;"); + statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;"); + statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;"); + statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;"); + statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;"); + break; + + case BuiltInSubgroupGtMask: + // Emulate these ... + // No 64-bit in HLSL, so have to do it in 32-bit and unroll. + statement("uint gt_lane_index = WaveGetLaneIndex() + 1;"); + statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);"); + statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;"); + statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;"); + statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;"); + statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;"); + statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;"); + statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;"); + statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;"); + break; + + case BuiltInSubgroupLeMask: + // Emulate these ... + // No 64-bit in HLSL, so have to do it in 32-bit and unroll. + statement("uint le_lane_index = WaveGetLaneIndex() + 1;"); + statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;"); + statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;"); + statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;"); + statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;"); + statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;"); + statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;"); + statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;"); + statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;"); + break; + + case BuiltInSubgroupLtMask: + // Emulate these ... + // No 64-bit in HLSL, so have to do it in 32-bit and unroll. + statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;"); + statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;"); + statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;"); + statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;"); + statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;"); + statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;"); + statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;"); break; case BuiltInClipDistance: @@ -3528,6 +3627,176 @@ void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op) register_read(ops[1], ops[2], should_forward(ops[2])); } +void CompilerHLSL::emit_subgroup_op(const Instruction &i) +{ + if (hlsl_options.shader_model < 60) + SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher."); + + const uint32_t *ops = stream(i); + auto op = static_cast(i.op); + + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + auto scope = static_cast(get(ops[2]).scalar()); + if (scope != ScopeSubgroup) + SPIRV_CROSS_THROW("Only subgroup scope is supported."); + + const auto make_inclusive_Sum = [&](const string &expr) -> string { + return join(expr, " + ", to_expression(ops[4])); + }; + + const auto make_inclusive_Product = [&](const string &expr) -> string { + return join(expr, " * ", to_expression(ops[4])); + }; + +#define make_inclusive_BitAnd(expr) "" +#define make_inclusive_BitOr(expr) "" +#define make_inclusive_BitXor(expr) "" +#define make_inclusive_Min(expr) "" +#define make_inclusive_Max(expr) "" + + switch (op) + { + case OpGroupNonUniformElect: + emit_op(result_type, id, "WaveIsFirstLane()", true); + break; + + case OpGroupNonUniformBroadcast: + emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt"); + break; + + case OpGroupNonUniformBroadcastFirst: + emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst"); + break; + + case OpGroupNonUniformBallot: + emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot"); + break; + + case OpGroupNonUniformInverseBallot: + SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL."); + break; + + case OpGroupNonUniformBallotBitExtract: + SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL."); + break; + + case OpGroupNonUniformBallotFindLSB: + SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL."); + break; + + case OpGroupNonUniformBallotFindMSB: + SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL."); + break; + + case OpGroupNonUniformBallotBitCount: + { + auto operation = static_cast(ops[3]); + if (operation == GroupOperationReduce) + { + bool forward = should_forward(ops[4]); + auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(", to_enclosed_expression(ops[4]), ".y)"); + auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(", to_enclosed_expression(ops[4]), ".w)"); + emit_op(result_type, id, join(left, " + ", right), forward); + inherit_expression_dependencies(id, ops[4]); + } + else if (operation == GroupOperationInclusiveScan) + SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL."); + else if (operation == GroupOperationExclusiveScan) + SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL."); + else + SPIRV_CROSS_THROW("Invalid BitCount operation."); + break; + } + + case OpGroupNonUniformShuffle: + SPIRV_CROSS_THROW("Cannot trivially implement Shuffle in HLSL."); + case OpGroupNonUniformShuffleXor: + SPIRV_CROSS_THROW("Cannot trivially implement ShuffleXor in HLSL."); + case OpGroupNonUniformShuffleUp: + SPIRV_CROSS_THROW("Cannot trivially implement ShuffleUp in HLSL."); + case OpGroupNonUniformShuffleDown: + SPIRV_CROSS_THROW("Cannot trivially implement ShuffleDown in HLSL."); + + case OpGroupNonUniformAll: + emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue"); + break; + + case OpGroupNonUniformAny: + emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue"); + break; + + case OpGroupNonUniformAllEqual: + { + auto &type = get(result_type); + emit_unary_func_op(result_type, id, ops[3], + type.basetype == SPIRType::Boolean ? "WaveActiveAllEqualBool" : "WaveActiveAllEqual"); + break; + } + +#define GROUP_OP(op, hlsl_op, supports_scan) \ +case OpGroupNonUniform##op: \ + { \ + auto operation = static_cast(ops[3]); \ + if (operation == GroupOperationReduce) \ + emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \ + else if (operation == GroupOperationInclusiveScan && supports_scan) \ + { \ + bool forward = should_forward(ops[4]); \ + emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \ + inherit_expression_dependencies(id, ops[4]); \ + } \ + else if (operation == GroupOperationExclusiveScan && supports_scan) \ + emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \ + else if (operation == GroupOperationClusteredReduce) \ + SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \ + else \ + SPIRV_CROSS_THROW("Invalid group operation."); \ + break; \ + } + GROUP_OP(FAdd, Sum, true) + GROUP_OP(FMul, Product, true) + GROUP_OP(FMin, Min, false) + GROUP_OP(FMax, Max, false) + GROUP_OP(IAdd, Sum, true) + GROUP_OP(IMul, Product, true) + GROUP_OP(SMin, Min, false) + GROUP_OP(SMax, Max, false) + GROUP_OP(UMin, Min, false) + GROUP_OP(UMax, Max, false) + GROUP_OP(BitwiseAnd, BitAnd, false) + GROUP_OP(BitwiseOr, BitOr, false) + GROUP_OP(BitwiseXor, BitXor, false) +#undef GROUP_OP + + case OpGroupNonUniformQuadSwap: + { + uint32_t direction = get(ops[4]).scalar(); + if (direction == 0) + emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX"); + else if (direction == 1) + emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY"); + else if (direction == 2) + emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal"); + else + SPIRV_CROSS_THROW("Invalid quad swap direction."); + break; + } + + case OpGroupNonUniformQuadBroadcast: + { + emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt"); + break; + } + + default: + SPIRV_CROSS_THROW("Invalid opcode for subgroup."); + } + + register_control_dependent_expression(id); +} + void CompilerHLSL::emit_instruction(const Instruction &instruction) { auto ops = stream(instruction); @@ -4004,6 +4273,12 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) semantics = get(ops[2]).scalar(); } + if (memory == ScopeSubgroup) + { + // No Wave-barriers in HLSL. + break; + } + // We only care about these flags, acquire/release and friends are not relevant to GLSL. semantics = mask_relevant_memory_semantics(semantics); diff --git a/spirv_hlsl.hpp b/spirv_hlsl.hpp index 80682f7..8a539a4 100644 --- a/spirv_hlsl.hpp +++ b/spirv_hlsl.hpp @@ -157,6 +157,7 @@ private: void write_access_chain(const SPIRAccessChain &chain, uint32_t value); void emit_store(const Instruction &instruction); void emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op); + void emit_subgroup_op(const Instruction &i) override; void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index, const std::string &qualifier, uint32_t base_offset = 0) override; diff --git a/test_shaders.py b/test_shaders.py index 217e81c..b6cc792 100755 --- a/test_shaders.py +++ b/test_shaders.py @@ -155,7 +155,9 @@ def validate_shader_hlsl(shader): sys.exit(1) def shader_to_sm(shader): - if '.sm51.' in shader: + if '.sm60.' in shader: + return '60' + elif '.sm51.' in shader: return '51' elif '.sm20.' in shader: return '20' diff --git a/tests-other/hlsl_wave_mask.cpp b/tests-other/hlsl_wave_mask.cpp new file mode 100644 index 0000000..de11dd9 --- /dev/null +++ b/tests-other/hlsl_wave_mask.cpp @@ -0,0 +1,73 @@ +// Ad-hoc test that the wave op masks work as expected. +#include +#include + +using namespace glm; + +static uvec4 gl_SubgroupEqMask; +static uvec4 gl_SubgroupGeMask; +static uvec4 gl_SubgroupGtMask; +static uvec4 gl_SubgroupLeMask; +static uvec4 gl_SubgroupLtMask; +using uint4 = uvec4; + +static void test_main(unsigned wave_index) +{ + const auto WaveGetLaneIndex = [&]() { return wave_index; }; + + gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96)); + if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0; + if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0; + if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0; + if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0; + gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u); + if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u; + if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u; + if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u; + if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u; + if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u; + if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u; + uint gt_lane_index = WaveGetLaneIndex() + 1; + gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u); + if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u; + if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u; + if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u; + if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u; + if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u; + if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u; + if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u; + uint le_lane_index = WaveGetLaneIndex() + 1; + gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u; + if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u; + if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u; + if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u; + if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u; + if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u; + if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u; + if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u; + gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u; + if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u; + if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u; + if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u; + if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u; + if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u; + if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u; +} + +int main() +{ + for (unsigned subgroup_id = 0; subgroup_id < 128; subgroup_id++) + { + test_main(subgroup_id); + + for (unsigned bit = 0; bit < 128; bit++) + { + assert(bool(gl_SubgroupEqMask[bit / 32] & (1u << (bit & 31))) == (bit == subgroup_id)); + assert(bool(gl_SubgroupGtMask[bit / 32] & (1u << (bit & 31))) == (bit > subgroup_id)); + assert(bool(gl_SubgroupGeMask[bit / 32] & (1u << (bit & 31))) == (bit >= subgroup_id)); + assert(bool(gl_SubgroupLtMask[bit / 32] & (1u << (bit & 31))) == (bit < subgroup_id)); + assert(bool(gl_SubgroupLeMask[bit / 32] & (1u << (bit & 31))) == (bit <= subgroup_id)); + } + } +} +