From f6c0e53f58302be8ef696007efd0f4cb043ab6b1 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Tue, 10 Apr 2018 16:13:33 +0200 Subject: [PATCH] Start adding Vulkan 1.1 subgroup support to GLSL. --- .../vulkan/comp/subgroups.nocompat.vk.comp | 101 +++++++ spirv_glsl.cpp | 282 +++++++++++++++++- spirv_glsl.hpp | 1 + 3 files changed, 375 insertions(+), 9 deletions(-) create mode 100644 shaders/vulkan/comp/subgroups.nocompat.vk.comp diff --git a/shaders/vulkan/comp/subgroups.nocompat.vk.comp b/shaders/vulkan/comp/subgroups.nocompat.vk.comp new file mode 100644 index 0000000..a5593f3 --- /dev/null +++ b/shaders/vulkan/comp/subgroups.nocompat.vk.comp @@ -0,0 +1,101 @@ +#version 450 +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_clustered : require +#extension GL_KHR_shader_subgroup_quad : require +layout(local_size_x = 1) in; + +layout(std430, binding = 0) buffer SSBO +{ + float FragColor; +}; + +void main() +{ + // basic + FragColor = float(gl_NumSubgroups); + FragColor = float(gl_SubgroupID); + FragColor = float(gl_SubgroupSize); + FragColor = float(gl_SubgroupInvocationID); + subgroupBarrier(); + subgroupMemoryBarrier(); + subgroupMemoryBarrierBuffer(); + subgroupMemoryBarrierShared(); + subgroupMemoryBarrierImage(); + bool elected = subgroupElect(); + + // ballot + FragColor = float(gl_SubgroupEqMask); + FragColor = float(gl_SubgroupGeMask); + FragColor = float(gl_SubgroupGtMask); + FragColor = float(gl_SubgroupLeMask); + FragColor = float(gl_SubgroupLtMask); + vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u); + vec3 first = subgroupBroadcastFirst(vec3(20.0)); + uvec4 ballot_value = subgroupBallot(true); + bool inverse_ballot_value = subgroupInverseBallot(ballot_value); + bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u); + uint bit_count = subgroupBallotBitCount(ballot_value); + uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value); + uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value); + uint lsb = subgroupBallotFindLSB(ballot_value); + uint msb = subgroupBallotFindMSB(ballot_value); + + // shuffle + uint shuffled = subgroupShuffle(10u, 8u); + uint shuffled_xor = subgroupShuffleXor(30u, 8u); + + // shuffle relative + uint shuffled_up = subgroupShuffleUp(20u, 4u); + uint shuffled_down = subgroupShuffleDown(20u, 4u); + + // vote + bool has_all = subgroupAll(true); + bool has_any = subgroupAny(true); + bool has_equal = subgroupAllEqual(true); + + // arithmetic + vec4 added = subgroupAdd(vec4(20.0)); + vec4 multiplied = subgroupMul(vec4(20.0)); + vec4 lo = subgroupMin(vec4(20.0)); + vec4 hi = subgroupMax(vec4(20.0)); + uvec4 anded = subgroupAnd(ballot_value); + uvec4 ored = subgroupOr(ballot_value); + uvec4 xored = subgroupXor(ballot_value); + + added = subgroupInclusiveAdd(added); + multiplied = subgroupInclusiveMul(multiplied); + lo = subgroupInclusiveMin(lo); + hi = subgroupInclusiveMax(hi); + anded = subgroupInclusiveAnd(anded); + ored = subgroupInclusiveOr(ored); + xored = subgroupInclusiveXor(ored); + added = subgroupExclusiveAdd(lo); + + added = subgroupExclusiveAdd(multiplied); + multiplied = subgroupExclusiveMul(multiplied); + lo = subgroupExclusiveMin(lo); + hi = subgroupExclusiveMax(hi); + anded = subgroupExclusiveAnd(anded); + ored = subgroupExclusiveOr(ored); + xored = subgroupExclusiveXor(ored); + + // clustered + added = subgroupClusteredAdd(added, 4u); + multiplied = subgroupClusteredMul(multiplied, 4u); + lo = subgroupClusteredMin(lo, 4u); + hi = subgroupClusteredMax(hi, 4u); + anded = subgroupClusteredAnd(anded, 4u); + ored = subgroupClusteredOr(ored, 4u); + xored = subgroupClusteredXor(xored, 4u); + + // quad + vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0)); + vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0)); + vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0)); + vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u); +} diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 7818ff0..1f9b9cc 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -4466,6 +4466,142 @@ void CompilerGLSL::emit_spv_amd_gcn_shader_op(uint32_t result_type, uint32_t id, } } +void CompilerGLSL::emit_subgroup_op(const Instruction &i) +{ + const uint32_t *ops = stream(i); + auto op = static_cast(i.op); + + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics."); + + switch (op) + { + case OpGroupNonUniformElect: + require_extension_internal("GL_KHR_shader_subgroup_basic"); + break; + + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformBroadcastFirst: + case OpGroupNonUniformBallot: + case OpGroupNonUniformInverseBallot: + case OpGroupNonUniformBallotBitExtract: + case OpGroupNonUniformBallotBitCount: + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + require_extension_internal("GL_KHR_shader_subgroup_ballot"); + break; + + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + require_extension_internal("GL_KHR_shader_subgroup_shuffle"); + break; + + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + require_extension_internal("GL_KHR_shader_subgroup_shuffle_relative"); + break; + + case OpGroupNonUniformAll: + case OpGroupNonUniformAny: + case OpGroupNonUniformAllEqual: + require_extension_internal("GL_KHR_shader_subgroup_vote"); + break; + + case OpGroupNonUniformFAdd: + case OpGroupNonUniformFMul: + case OpGroupNonUniformFMin: + case OpGroupNonUniformFMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + { + auto operation = static_cast(ops[3]); + if (operation == GroupOperationClusteredReduce) + { + require_extension_internal("GL_KHR_shader_subgroup_clustered"); + } + else if (operation == GroupOperationExclusiveScan || + operation == GroupOperationInclusiveScan || + operation == GroupOperationReduce) + { + require_extension_internal("GL_KHR_shader_subgroup_arithmetic"); + } + else + SPIRV_CROSS_THROW("Invalid group operation."); + break; + } + + case OpGroupNonUniformQuadSwap: + case OpGroupNonUniformQuadBroadcast: + require_extension_internal("GL_KHR_shader_subgroup_quad"); + break; + + default: + SPIRV_CROSS_THROW("Invalid opcode for subgroup."); + } + + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + + auto scope = static_cast(get(ops[2]).scalar()); + if (scope != ScopeSubgroup) + SPIRV_CROSS_THROW("Only subgroup scope is supported."); + + switch (op) + { + case OpGroupNonUniformElect: + emit_op(result_type, id, "subgroupElect()", true); + break; + + case OpGroupNonUniformBroadcast: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBroadcast"); + break; + + case OpGroupNonUniformBroadcastFirst: + emit_unary_func_op(result_type, id, ops[3], "subgroupBroadcastFirst"); + break; + + case OpGroupNonUniformBallot: + emit_unary_func_op(result_type, id, ops[3], "subgroupBallot"); + break; + + case OpGroupNonUniformInverseBallot: + emit_unary_func_op(result_type, id, ops[3], "subgroupInverseBallot"); + break; + + case OpGroupNonUniformBallotBitExtract: + emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBallotBitExtract"); + break; + + case OpGroupNonUniformBallotBitCount: + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + case OpGroupNonUniformAll: + case OpGroupNonUniformAny: + case OpGroupNonUniformAllEqual: + case OpGroupNonUniformFAdd: + case OpGroupNonUniformFMul: + case OpGroupNonUniformFMin: + case OpGroupNonUniformFMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + case OpGroupNonUniformQuadSwap: + case OpGroupNonUniformQuadBroadcast: + emit_op(result_type, id, "subgroupRandom()", false); + return; + + default: + SPIRV_CROSS_THROW("Invalid opcode for subgroup."); + } + + register_control_dependent_expression(id); +} + string CompilerGLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type) { if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int) @@ -4640,6 +4776,60 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) return "gl_ViewID_OVR"; } + case BuiltInNumSubgroups: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_basic"); + return "gl_NumSubgroups"; + + case BuiltInSubgroupId: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_basic"); + return "gl_SubgroupID"; + + case BuiltInSubgroupSize: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_basic"); + return "gl_SubgroupSize"; + + case BuiltInSubgroupLocalInvocationId: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_basic"); + return "gl_SubgroupInvocationID"; + + case BuiltInSubgroupEqMask: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_ballot"); + return "gl_SubgroupEqMask"; + + case BuiltInSubgroupGeMask: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_ballot"); + return "gl_SubgroupGeMask"; + + case BuiltInSubgroupGtMask: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_ballot"); + return "gl_SubgroupGtMask"; + + case BuiltInSubgroupLeMask: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_ballot"); + return "gl_SubgroupLeMask"; + + case BuiltInSubgroupLtMask: + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup."); + require_extension_internal("GL_KHR_shader_subgroup_ballot"); + return "gl_SubgroupLtMask"; + default: return join("gl_BuiltIn_", convert_to_string(builtin)); } @@ -7150,14 +7340,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) case OpControlBarrier: case OpMemoryBarrier: { - if (get_entry_point().model == ExecutionModelTessellationControl) - { - // Control shaders only have barriers, and it implies memory barriers. - if (opcode == OpControlBarrier) - statement("barrier();"); - break; - } - + uint32_t execution_scope = 0; uint32_t memory; uint32_t semantics; @@ -7168,10 +7351,26 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) } else { + execution_scope = get(ops[0]).scalar(); memory = get(ops[1]).scalar(); semantics = get(ops[2]).scalar(); } + if (execution_scope == ScopeSubgroup || memory == ScopeSubgroup) + { + if (!options.vulkan_semantics) + SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics."); + require_extension_internal("GL_KHR_shader_subgroup_basic"); + } + + if (execution_scope != ScopeSubgroup && get_entry_point().model == ExecutionModelTessellationControl) + { + // Control shaders only have barriers, and it implies memory barriers. + if (opcode == OpControlBarrier) + statement("barrier();"); + break; + } + // We only care about these flags, acquire/release and friends are not relevant to GLSL. semantics = mask_relevant_memory_semantics(semantics); @@ -7228,6 +7427,33 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) else if (semantics != 0) statement("groupMemoryBarrier();"); } + else if (memory == ScopeSubgroup) + { + const uint32_t all_barriers = MemorySemanticsWorkgroupMemoryMask | MemorySemanticsUniformMemoryMask | + MemorySemanticsImageMemoryMask; + + if (semantics & (MemorySemanticsCrossWorkgroupMemoryMask | MemorySemanticsSubgroupMemoryMask)) + { + // These are not relevant for GLSL, but assume it means memoryBarrier(). + // memoryBarrier() does everything, so no need to test anything else. + statement("subgroupMemoryBarrier();"); + } + else if ((semantics & all_barriers) == all_barriers) + { + // Short-hand instead of emitting 3 barriers. + statement("subgroupMemoryBarrier();"); + } + else + { + // Pick out individual barriers. + if (semantics & MemorySemanticsWorkgroupMemoryMask) + statement("subgroupMemoryBarrierShared();"); + if (semantics & MemorySemanticsUniformMemoryMask) + statement("subgroupMemoryBarrierBuffer();"); + if (semantics & MemorySemanticsImageMemoryMask) + statement("subgroupMemoryBarrierImage();"); + } + } else { const uint32_t all_barriers = MemorySemanticsWorkgroupMemoryMask | MemorySemanticsUniformMemoryMask | @@ -7259,7 +7485,12 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) } if (opcode == OpControlBarrier) - statement("barrier();"); + { + if (execution_scope == ScopeSubgroup) + statement("subgroupBarrier();"); + else + statement("barrier();"); + } break; } @@ -7296,6 +7527,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; } + // Legacy sub-group stuff ... case OpSubgroupBallotKHR: { uint32_t result_type = ops[0]; @@ -7441,6 +7673,35 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; } + // Vulkan 1.1 sub-group stuff ... + case OpGroupNonUniformElect: + case OpGroupNonUniformBroadcast: + case OpGroupNonUniformBroadcastFirst: + case OpGroupNonUniformBallot: + case OpGroupNonUniformInverseBallot: + case OpGroupNonUniformBallotBitExtract: + case OpGroupNonUniformBallotBitCount: + case OpGroupNonUniformBallotFindLSB: + case OpGroupNonUniformBallotFindMSB: + case OpGroupNonUniformShuffle: + case OpGroupNonUniformShuffleXor: + case OpGroupNonUniformShuffleUp: + case OpGroupNonUniformShuffleDown: + case OpGroupNonUniformAll: + case OpGroupNonUniformAny: + case OpGroupNonUniformAllEqual: + case OpGroupNonUniformFAdd: + case OpGroupNonUniformFMul: + case OpGroupNonUniformFMin: + case OpGroupNonUniformFMax: + case OpGroupNonUniformBitwiseAnd: + case OpGroupNonUniformBitwiseOr: + case OpGroupNonUniformBitwiseXor: + case OpGroupNonUniformQuadSwap: + case OpGroupNonUniformQuadBroadcast: + emit_subgroup_op(instruction); + break; + default: statement("// unimplemented op ", instruction.op); break; @@ -9203,6 +9464,9 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block) else emit_block_chain(get(block.merge_block)); } + + // Forget about control dependent expressions now. + block.invalidate_expressions.clear(); } void CompilerGLSL::begin_scope() diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 13b68d2..c896787 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -217,6 +217,7 @@ protected: virtual void emit_header(); virtual void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id); virtual void emit_texture_op(const Instruction &i); + virtual void emit_subgroup_op(const Instruction &i); virtual std::string type_to_glsl(const SPIRType &type, uint32_t id = 0); virtual std::string builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage); virtual void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,