Start adding Vulkan 1.1 subgroup support to GLSL.

This commit is contained in:
Hans-Kristian Arntzen 2018-04-10 16:13:33 +02:00
Родитель 489e04e09e
Коммит f6c0e53f58
3 изменённых файлов: 375 добавлений и 9 удалений

Просмотреть файл

@ -0,0 +1,101 @@
#version 450
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_clustered : require
#extension GL_KHR_shader_subgroup_quad : require
layout(local_size_x = 1) in;
layout(std430, binding = 0) buffer SSBO
{
float FragColor;
};
void main()
{
// basic
FragColor = float(gl_NumSubgroups);
FragColor = float(gl_SubgroupID);
FragColor = float(gl_SubgroupSize);
FragColor = float(gl_SubgroupInvocationID);
subgroupBarrier();
subgroupMemoryBarrier();
subgroupMemoryBarrierBuffer();
subgroupMemoryBarrierShared();
subgroupMemoryBarrierImage();
bool elected = subgroupElect();
// ballot
FragColor = float(gl_SubgroupEqMask);
FragColor = float(gl_SubgroupGeMask);
FragColor = float(gl_SubgroupGtMask);
FragColor = float(gl_SubgroupLeMask);
FragColor = float(gl_SubgroupLtMask);
vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u);
vec3 first = subgroupBroadcastFirst(vec3(20.0));
uvec4 ballot_value = subgroupBallot(true);
bool inverse_ballot_value = subgroupInverseBallot(ballot_value);
bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u);
uint bit_count = subgroupBallotBitCount(ballot_value);
uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
uint lsb = subgroupBallotFindLSB(ballot_value);
uint msb = subgroupBallotFindMSB(ballot_value);
// shuffle
uint shuffled = subgroupShuffle(10u, 8u);
uint shuffled_xor = subgroupShuffleXor(30u, 8u);
// shuffle relative
uint shuffled_up = subgroupShuffleUp(20u, 4u);
uint shuffled_down = subgroupShuffleDown(20u, 4u);
// vote
bool has_all = subgroupAll(true);
bool has_any = subgroupAny(true);
bool has_equal = subgroupAllEqual(true);
// arithmetic
vec4 added = subgroupAdd(vec4(20.0));
vec4 multiplied = subgroupMul(vec4(20.0));
vec4 lo = subgroupMin(vec4(20.0));
vec4 hi = subgroupMax(vec4(20.0));
uvec4 anded = subgroupAnd(ballot_value);
uvec4 ored = subgroupOr(ballot_value);
uvec4 xored = subgroupXor(ballot_value);
added = subgroupInclusiveAdd(added);
multiplied = subgroupInclusiveMul(multiplied);
lo = subgroupInclusiveMin(lo);
hi = subgroupInclusiveMax(hi);
anded = subgroupInclusiveAnd(anded);
ored = subgroupInclusiveOr(ored);
xored = subgroupInclusiveXor(ored);
added = subgroupExclusiveAdd(lo);
added = subgroupExclusiveAdd(multiplied);
multiplied = subgroupExclusiveMul(multiplied);
lo = subgroupExclusiveMin(lo);
hi = subgroupExclusiveMax(hi);
anded = subgroupExclusiveAnd(anded);
ored = subgroupExclusiveOr(ored);
xored = subgroupExclusiveXor(ored);
// clustered
added = subgroupClusteredAdd(added, 4u);
multiplied = subgroupClusteredMul(multiplied, 4u);
lo = subgroupClusteredMin(lo, 4u);
hi = subgroupClusteredMax(hi, 4u);
anded = subgroupClusteredAnd(anded, 4u);
ored = subgroupClusteredOr(ored, 4u);
xored = subgroupClusteredXor(xored, 4u);
// quad
vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0));
vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0));
vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0));
vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u);
}

Просмотреть файл

@ -4466,6 +4466,142 @@ void CompilerGLSL::emit_spv_amd_gcn_shader_op(uint32_t result_type, uint32_t id,
}
}
void CompilerGLSL::emit_subgroup_op(const Instruction &i)
{
const uint32_t *ops = stream(i);
auto op = static_cast<Op>(i.op);
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics.");
switch (op)
{
case OpGroupNonUniformElect:
require_extension_internal("GL_KHR_shader_subgroup_basic");
break;
case OpGroupNonUniformBroadcast:
case OpGroupNonUniformBroadcastFirst:
case OpGroupNonUniformBallot:
case OpGroupNonUniformInverseBallot:
case OpGroupNonUniformBallotBitExtract:
case OpGroupNonUniformBallotBitCount:
case OpGroupNonUniformBallotFindLSB:
case OpGroupNonUniformBallotFindMSB:
require_extension_internal("GL_KHR_shader_subgroup_ballot");
break;
case OpGroupNonUniformShuffle:
case OpGroupNonUniformShuffleXor:
require_extension_internal("GL_KHR_shader_subgroup_shuffle");
break;
case OpGroupNonUniformShuffleUp:
case OpGroupNonUniformShuffleDown:
require_extension_internal("GL_KHR_shader_subgroup_shuffle_relative");
break;
case OpGroupNonUniformAll:
case OpGroupNonUniformAny:
case OpGroupNonUniformAllEqual:
require_extension_internal("GL_KHR_shader_subgroup_vote");
break;
case OpGroupNonUniformFAdd:
case OpGroupNonUniformFMul:
case OpGroupNonUniformFMin:
case OpGroupNonUniformFMax:
case OpGroupNonUniformBitwiseAnd:
case OpGroupNonUniformBitwiseOr:
case OpGroupNonUniformBitwiseXor:
{
auto operation = static_cast<GroupOperation>(ops[3]);
if (operation == GroupOperationClusteredReduce)
{
require_extension_internal("GL_KHR_shader_subgroup_clustered");
}
else if (operation == GroupOperationExclusiveScan ||
operation == GroupOperationInclusiveScan ||
operation == GroupOperationReduce)
{
require_extension_internal("GL_KHR_shader_subgroup_arithmetic");
}
else
SPIRV_CROSS_THROW("Invalid group operation.");
break;
}
case OpGroupNonUniformQuadSwap:
case OpGroupNonUniformQuadBroadcast:
require_extension_internal("GL_KHR_shader_subgroup_quad");
break;
default:
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
}
uint32_t result_type = ops[0];
uint32_t id = ops[1];
auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
if (scope != ScopeSubgroup)
SPIRV_CROSS_THROW("Only subgroup scope is supported.");
switch (op)
{
case OpGroupNonUniformElect:
emit_op(result_type, id, "subgroupElect()", true);
break;
case OpGroupNonUniformBroadcast:
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBroadcast");
break;
case OpGroupNonUniformBroadcastFirst:
emit_unary_func_op(result_type, id, ops[3], "subgroupBroadcastFirst");
break;
case OpGroupNonUniformBallot:
emit_unary_func_op(result_type, id, ops[3], "subgroupBallot");
break;
case OpGroupNonUniformInverseBallot:
emit_unary_func_op(result_type, id, ops[3], "subgroupInverseBallot");
break;
case OpGroupNonUniformBallotBitExtract:
emit_binary_func_op(result_type, id, ops[3], ops[4], "subgroupBallotBitExtract");
break;
case OpGroupNonUniformBallotBitCount:
case OpGroupNonUniformBallotFindLSB:
case OpGroupNonUniformBallotFindMSB:
case OpGroupNonUniformShuffle:
case OpGroupNonUniformShuffleXor:
case OpGroupNonUniformShuffleUp:
case OpGroupNonUniformShuffleDown:
case OpGroupNonUniformAll:
case OpGroupNonUniformAny:
case OpGroupNonUniformAllEqual:
case OpGroupNonUniformFAdd:
case OpGroupNonUniformFMul:
case OpGroupNonUniformFMin:
case OpGroupNonUniformFMax:
case OpGroupNonUniformBitwiseAnd:
case OpGroupNonUniformBitwiseOr:
case OpGroupNonUniformBitwiseXor:
case OpGroupNonUniformQuadSwap:
case OpGroupNonUniformQuadBroadcast:
emit_op(result_type, id, "subgroupRandom()", false);
return;
default:
SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
}
register_control_dependent_expression(id);
}
string CompilerGLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
{
if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
@ -4640,6 +4776,60 @@ string CompilerGLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
return "gl_ViewID_OVR";
}
case BuiltInNumSubgroups:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_basic");
return "gl_NumSubgroups";
case BuiltInSubgroupId:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_basic");
return "gl_SubgroupID";
case BuiltInSubgroupSize:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_basic");
return "gl_SubgroupSize";
case BuiltInSubgroupLocalInvocationId:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_basic");
return "gl_SubgroupInvocationID";
case BuiltInSubgroupEqMask:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_ballot");
return "gl_SubgroupEqMask";
case BuiltInSubgroupGeMask:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_ballot");
return "gl_SubgroupGeMask";
case BuiltInSubgroupGtMask:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_ballot");
return "gl_SubgroupGtMask";
case BuiltInSubgroupLeMask:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_ballot");
return "gl_SubgroupLeMask";
case BuiltInSubgroupLtMask:
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Need Vulkan semantics for subgroup.");
require_extension_internal("GL_KHR_shader_subgroup_ballot");
return "gl_SubgroupLtMask";
default:
return join("gl_BuiltIn_", convert_to_string(builtin));
}
@ -7150,14 +7340,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
case OpControlBarrier:
case OpMemoryBarrier:
{
if (get_entry_point().model == ExecutionModelTessellationControl)
{
// Control shaders only have barriers, and it implies memory barriers.
if (opcode == OpControlBarrier)
statement("barrier();");
break;
}
uint32_t execution_scope = 0;
uint32_t memory;
uint32_t semantics;
@ -7168,10 +7351,26 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
}
else
{
execution_scope = get<SPIRConstant>(ops[0]).scalar();
memory = get<SPIRConstant>(ops[1]).scalar();
semantics = get<SPIRConstant>(ops[2]).scalar();
}
if (execution_scope == ScopeSubgroup || memory == ScopeSubgroup)
{
if (!options.vulkan_semantics)
SPIRV_CROSS_THROW("Can only use subgroup operations in Vulkan semantics.");
require_extension_internal("GL_KHR_shader_subgroup_basic");
}
if (execution_scope != ScopeSubgroup && get_entry_point().model == ExecutionModelTessellationControl)
{
// Control shaders only have barriers, and it implies memory barriers.
if (opcode == OpControlBarrier)
statement("barrier();");
break;
}
// We only care about these flags, acquire/release and friends are not relevant to GLSL.
semantics = mask_relevant_memory_semantics(semantics);
@ -7228,6 +7427,33 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
else if (semantics != 0)
statement("groupMemoryBarrier();");
}
else if (memory == ScopeSubgroup)
{
const uint32_t all_barriers = MemorySemanticsWorkgroupMemoryMask | MemorySemanticsUniformMemoryMask |
MemorySemanticsImageMemoryMask;
if (semantics & (MemorySemanticsCrossWorkgroupMemoryMask | MemorySemanticsSubgroupMemoryMask))
{
// These are not relevant for GLSL, but assume it means memoryBarrier().
// memoryBarrier() does everything, so no need to test anything else.
statement("subgroupMemoryBarrier();");
}
else if ((semantics & all_barriers) == all_barriers)
{
// Short-hand instead of emitting 3 barriers.
statement("subgroupMemoryBarrier();");
}
else
{
// Pick out individual barriers.
if (semantics & MemorySemanticsWorkgroupMemoryMask)
statement("subgroupMemoryBarrierShared();");
if (semantics & MemorySemanticsUniformMemoryMask)
statement("subgroupMemoryBarrierBuffer();");
if (semantics & MemorySemanticsImageMemoryMask)
statement("subgroupMemoryBarrierImage();");
}
}
else
{
const uint32_t all_barriers = MemorySemanticsWorkgroupMemoryMask | MemorySemanticsUniformMemoryMask |
@ -7259,7 +7485,12 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
}
if (opcode == OpControlBarrier)
statement("barrier();");
{
if (execution_scope == ScopeSubgroup)
statement("subgroupBarrier();");
else
statement("barrier();");
}
break;
}
@ -7296,6 +7527,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
break;
}
// Legacy sub-group stuff ...
case OpSubgroupBallotKHR:
{
uint32_t result_type = ops[0];
@ -7441,6 +7673,35 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction)
break;
}
// Vulkan 1.1 sub-group stuff ...
case OpGroupNonUniformElect:
case OpGroupNonUniformBroadcast:
case OpGroupNonUniformBroadcastFirst:
case OpGroupNonUniformBallot:
case OpGroupNonUniformInverseBallot:
case OpGroupNonUniformBallotBitExtract:
case OpGroupNonUniformBallotBitCount:
case OpGroupNonUniformBallotFindLSB:
case OpGroupNonUniformBallotFindMSB:
case OpGroupNonUniformShuffle:
case OpGroupNonUniformShuffleXor:
case OpGroupNonUniformShuffleUp:
case OpGroupNonUniformShuffleDown:
case OpGroupNonUniformAll:
case OpGroupNonUniformAny:
case OpGroupNonUniformAllEqual:
case OpGroupNonUniformFAdd:
case OpGroupNonUniformFMul:
case OpGroupNonUniformFMin:
case OpGroupNonUniformFMax:
case OpGroupNonUniformBitwiseAnd:
case OpGroupNonUniformBitwiseOr:
case OpGroupNonUniformBitwiseXor:
case OpGroupNonUniformQuadSwap:
case OpGroupNonUniformQuadBroadcast:
emit_subgroup_op(instruction);
break;
default:
statement("// unimplemented op ", instruction.op);
break;
@ -9203,6 +9464,9 @@ void CompilerGLSL::emit_block_chain(SPIRBlock &block)
else
emit_block_chain(get<SPIRBlock>(block.merge_block));
}
// Forget about control dependent expressions now.
block.invalidate_expressions.clear();
}
void CompilerGLSL::begin_scope()

Просмотреть файл

@ -217,6 +217,7 @@ protected:
virtual void emit_header();
virtual void emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id);
virtual void emit_texture_op(const Instruction &i);
virtual void emit_subgroup_op(const Instruction &i);
virtual std::string type_to_glsl(const SPIRType &type, uint32_t id = 0);
virtual std::string builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage);
virtual void emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,