Merge branch 'msl-dispatch-base'
This commit is contained in:
Коммит
78fccc4d5c
4
main.cpp
4
main.cpp
|
@ -516,6 +516,7 @@ struct CLIArguments
|
|||
bool msl_texture_buffer_native = false;
|
||||
bool msl_multiview = false;
|
||||
bool msl_view_index_from_device_index = false;
|
||||
bool msl_dispatch_base = false;
|
||||
bool glsl_emit_push_constant_as_ubo = false;
|
||||
bool glsl_emit_ubo_as_plain_uniforms = false;
|
||||
bool emit_line_directives = false;
|
||||
|
@ -596,6 +597,7 @@ static void print_help()
|
|||
"\t[--msl-discrete-descriptor-set <index>]\n"
|
||||
"\t[--msl-multiview]\n"
|
||||
"\t[--msl-view-index-from-device-index]\n"
|
||||
"\t[--msl-dispatch-base]\n"
|
||||
"\t[--hlsl]\n"
|
||||
"\t[--reflect]\n"
|
||||
"\t[--shader-model]\n"
|
||||
|
@ -756,6 +758,7 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
|
|||
msl_opts.texture_buffer_native = args.msl_texture_buffer_native;
|
||||
msl_opts.multiview = args.msl_multiview;
|
||||
msl_opts.view_index_from_device_index = args.msl_view_index_from_device_index;
|
||||
msl_opts.dispatch_base = args.msl_dispatch_base;
|
||||
msl_comp->set_msl_options(msl_opts);
|
||||
for (auto &v : args.msl_discrete_descriptor_sets)
|
||||
msl_comp->add_discrete_descriptor_set(v);
|
||||
|
@ -1078,6 +1081,7 @@ static int main_inner(int argc, char *argv[])
|
|||
cbs.add("--msl-multiview", [&args](CLIParser &) { args.msl_multiview = true; });
|
||||
cbs.add("--msl-view-index-from-device-index",
|
||||
[&args](CLIParser &) { args.msl_view_index_from_device_index = true; });
|
||||
cbs.add("--msl-dispatch-base", [&args](CLIParser &) { args.msl_dispatch_base = true; });
|
||||
cbs.add("--extension", [&args](CLIParser &parser) { args.extensions.push_back(parser.next_string()); });
|
||||
cbs.add("--rename-entry-point", [&args](CLIParser &parser) {
|
||||
auto old_name = parser.next_string();
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float4 in_data[1];
|
||||
};
|
||||
|
||||
struct SSBO2
|
||||
{
|
||||
float4 out_data[1];
|
||||
};
|
||||
|
||||
struct SSBO3
|
||||
{
|
||||
uint counter;
|
||||
};
|
||||
|
||||
constant uint _59_tmp [[function_constant(10)]];
|
||||
constant uint _59 = is_function_constant_defined(_59_tmp) ? _59_tmp : 1u;
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(_59, 1u, 1u);
|
||||
|
||||
kernel void main0(const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 spvDispatchBase [[grid_origin]])
|
||||
{
|
||||
gl_GlobalInvocationID += spvDispatchBase * gl_WorkGroupSize;
|
||||
float4 _33 = _27.in_data[gl_GlobalInvocationID.x];
|
||||
if (dot(_33, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
|
||||
{
|
||||
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
|
||||
_49.out_data[_56] = _33;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float4 in_data[1];
|
||||
};
|
||||
|
||||
struct SSBO2
|
||||
{
|
||||
float4 out_data[1];
|
||||
};
|
||||
|
||||
struct SSBO3
|
||||
{
|
||||
uint counter;
|
||||
};
|
||||
|
||||
kernel void main0(constant uint3& spvDispatchBase [[buffer(29)]], const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
|
||||
{
|
||||
gl_GlobalInvocationID += spvDispatchBase * uint3(1, 1, 1);
|
||||
float4 _33 = _27.in_data[gl_GlobalInvocationID.x];
|
||||
if (dot(_33, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
|
||||
{
|
||||
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
|
||||
_49.out_data[_56] = _33;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float4 in_data[1];
|
||||
};
|
||||
|
||||
struct SSBO2
|
||||
{
|
||||
float4 out_data[1];
|
||||
};
|
||||
|
||||
struct SSBO3
|
||||
{
|
||||
uint counter;
|
||||
};
|
||||
|
||||
constant uint _59_tmp [[function_constant(10)]];
|
||||
constant uint _59 = is_function_constant_defined(_59_tmp) ? _59_tmp : 1u;
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(_59, 1u, 1u);
|
||||
|
||||
kernel void main0(const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint3 spvDispatchBase [[grid_origin]])
|
||||
{
|
||||
gl_GlobalInvocationID += spvDispatchBase * gl_WorkGroupSize;
|
||||
gl_WorkGroupID += spvDispatchBase;
|
||||
uint ident = gl_GlobalInvocationID.x;
|
||||
uint workgroup = gl_WorkGroupID.x;
|
||||
float4 idata = _27.in_data[ident];
|
||||
if (dot(idata, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
|
||||
{
|
||||
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
|
||||
_49.out_data[_56] = idata;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct SSBO
|
||||
{
|
||||
float4 in_data[1];
|
||||
};
|
||||
|
||||
struct SSBO2
|
||||
{
|
||||
float4 out_data[1];
|
||||
};
|
||||
|
||||
struct SSBO3
|
||||
{
|
||||
uint counter;
|
||||
};
|
||||
|
||||
kernel void main0(constant uint3& spvDispatchBase [[buffer(29)]], const device SSBO& _27 [[buffer(0)]], device SSBO2& _49 [[buffer(1)]], device SSBO3& _52 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
|
||||
{
|
||||
gl_GlobalInvocationID += spvDispatchBase * uint3(1, 1, 1);
|
||||
gl_WorkGroupID += spvDispatchBase;
|
||||
uint ident = gl_GlobalInvocationID.x;
|
||||
uint workgroup = gl_WorkGroupID.x;
|
||||
float4 idata = _27.in_data[ident];
|
||||
if (dot(idata, float4(1.0, 5.0, 6.0, 2.0)) > 8.19999980926513671875)
|
||||
{
|
||||
uint _56 = atomic_fetch_add_explicit((device atomic_uint*)&_52.counter, 1u, memory_order_relaxed);
|
||||
_49.out_data[_56] = idata;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
#version 310 es
|
||||
layout(local_size_x_id = 10) in;
|
||||
|
||||
layout(std430, binding = 0) readonly buffer SSBO
|
||||
{
|
||||
vec4 in_data[];
|
||||
};
|
||||
|
||||
layout(std430, binding = 1) writeonly buffer SSBO2
|
||||
{
|
||||
vec4 out_data[];
|
||||
};
|
||||
|
||||
layout(std430, binding = 2) buffer SSBO3
|
||||
{
|
||||
uint counter;
|
||||
};
|
||||
|
||||
void main()
|
||||
{
|
||||
uint ident = gl_GlobalInvocationID.x;
|
||||
uint workgroup = gl_WorkGroupID.x;
|
||||
vec4 idata = in_data[ident];
|
||||
if (dot(idata, vec4(1.0, 5.0, 6.0, 2.0)) > 8.2)
|
||||
{
|
||||
out_data[atomicAdd(counter, 1u)] = idata;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
#version 310 es
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(std430, binding = 0) readonly buffer SSBO
|
||||
{
|
||||
vec4 in_data[];
|
||||
};
|
||||
|
||||
layout(std430, binding = 1) writeonly buffer SSBO2
|
||||
{
|
||||
vec4 out_data[];
|
||||
};
|
||||
|
||||
layout(std430, binding = 2) buffer SSBO3
|
||||
{
|
||||
uint counter;
|
||||
};
|
||||
|
||||
void main()
|
||||
{
|
||||
uint ident = gl_GlobalInvocationID.x;
|
||||
uint workgroup = gl_WorkGroupID.x;
|
||||
vec4 idata = in_data[ident];
|
||||
if (dot(idata, vec4(1.0, 5.0, 6.0, 2.0)) > 8.2)
|
||||
{
|
||||
out_data[atomicAdd(counter, 1u)] = idata;
|
||||
}
|
||||
}
|
||||
|
|
@ -1433,6 +1433,10 @@ enum ExtendedDecorations
|
|||
// Marks a buffer block for using explicit offsets (GLSL/HLSL).
|
||||
SPIRVCrossDecorationExplicitOffset,
|
||||
|
||||
// Apply to a variable in the Input storage class; marks it as holding the base group passed to vkCmdDispatchBase().
|
||||
// In MSL, this is used to adjust the WorkgroupId and GlobalInvocationId variables.
|
||||
SPIRVCrossDecorationBuiltInDispatchBase,
|
||||
|
||||
SPIRVCrossDecorationCount
|
||||
};
|
||||
|
||||
|
|
|
@ -107,8 +107,11 @@ void CompilerMSL::build_implicit_builtins()
|
|||
active_input_builtins.get(BuiltInSubgroupGtMask));
|
||||
bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
|
||||
(msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
|
||||
bool need_dispatch_base =
|
||||
msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
|
||||
(active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
|
||||
if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
|
||||
need_multiview || needs_subgroup_invocation_id)
|
||||
need_multiview || need_dispatch_base || needs_subgroup_invocation_id)
|
||||
{
|
||||
bool has_frag_coord = false;
|
||||
bool has_sample_id = false;
|
||||
|
@ -121,6 +124,7 @@ void CompilerMSL::build_implicit_builtins()
|
|||
bool has_subgroup_invocation_id = false;
|
||||
bool has_subgroup_size = false;
|
||||
bool has_view_idx = false;
|
||||
uint32_t workgroup_id_type = 0;
|
||||
|
||||
ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
|
||||
if (var.storage != StorageClassInput || !ir.meta[var.self].decoration.builtin)
|
||||
|
@ -208,6 +212,13 @@ void CompilerMSL::build_implicit_builtins()
|
|||
has_view_idx = true;
|
||||
}
|
||||
}
|
||||
|
||||
// The base workgroup needs to have the same type and vector size
|
||||
// as the workgroup or invocation ID, so keep track of the type that
|
||||
// was used.
|
||||
if (need_dispatch_base && workgroup_id_type == 0 &&
|
||||
(builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
|
||||
workgroup_id_type = var.basetype;
|
||||
});
|
||||
|
||||
if (!has_frag_coord && need_subpass_input)
|
||||
|
@ -457,6 +468,42 @@ void CompilerMSL::build_implicit_builtins()
|
|||
builtin_subgroup_size_id = var_id;
|
||||
mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
|
||||
}
|
||||
|
||||
if (need_dispatch_base)
|
||||
{
|
||||
uint32_t var_id;
|
||||
if (msl_options.supports_msl_version(1, 2))
|
||||
{
|
||||
// If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
|
||||
// to convey this information and save a buffer slot.
|
||||
uint32_t offset = ir.increase_bound_by(1);
|
||||
var_id = offset;
|
||||
|
||||
set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
|
||||
set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
|
||||
get_entry_point().interface_variables.push_back(var_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, we need to fall back to a good ol' fashioned buffer.
|
||||
uint32_t offset = ir.increase_bound_by(2);
|
||||
var_id = offset;
|
||||
uint32_t type_id = offset + 1;
|
||||
|
||||
SPIRType var_type = get<SPIRType>(workgroup_id_type);
|
||||
var_type.storage = StorageClassUniform;
|
||||
set<SPIRType>(type_id, var_type);
|
||||
|
||||
set<SPIRVariable>(var_id, type_id, StorageClassUniform);
|
||||
// This should never match anything.
|
||||
set_decoration(var_id, DecorationDescriptorSet, ~(5u));
|
||||
set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
|
||||
set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
|
||||
msl_options.indirect_params_buffer_index);
|
||||
}
|
||||
set_name(var_id, "spvDispatchBase");
|
||||
builtin_dispatch_base_id = var_id;
|
||||
}
|
||||
}
|
||||
|
||||
if (needs_swizzle_buffer_def)
|
||||
|
@ -802,6 +849,8 @@ string CompilerMSL::compile()
|
|||
active_interface_variables.insert(view_mask_buffer_id);
|
||||
if (builtin_layer_id)
|
||||
active_interface_variables.insert(builtin_layer_id);
|
||||
if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
|
||||
active_interface_variables.insert(builtin_dispatch_base_id);
|
||||
|
||||
// Create structs to hold input, output and uniform variables.
|
||||
// Do output first to ensure out. is declared at top of entry function.
|
||||
|
@ -6748,6 +6797,19 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args)
|
|||
ep_args += "]]";
|
||||
}
|
||||
}
|
||||
|
||||
if (var.storage == StorageClassInput &&
|
||||
has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
|
||||
{
|
||||
// This is a special implicit builtin, not corresponding to any SPIR-V builtin,
|
||||
// which holds the base that was passed to vkCmdDispatchBase(). If it's present,
|
||||
// assume we emitted it for a good reason.
|
||||
assert(msl_options.supports_msl_version(1, 2));
|
||||
if (!ep_args.empty())
|
||||
ep_args += ", ";
|
||||
|
||||
ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
|
||||
}
|
||||
});
|
||||
|
||||
// Correct the types of all encountered active builtins. We couldn't do this before
|
||||
|
@ -7023,7 +7085,11 @@ void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
|
|||
default:
|
||||
if (!ep_args.empty())
|
||||
ep_args += ", ";
|
||||
ep_args += type_to_glsl(type, var_id) + " " + r.name;
|
||||
if (!type.pointer)
|
||||
ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
|
||||
type_to_glsl(type, var_id) + "& " + r.name;
|
||||
else
|
||||
ep_args += type_to_glsl(type, var_id) + " " + r.name;
|
||||
ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
|
||||
break;
|
||||
}
|
||||
|
@ -7343,6 +7409,35 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
|
|||
msl_options.device_index, ";");
|
||||
});
|
||||
break;
|
||||
case BuiltInWorkgroupId:
|
||||
if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
|
||||
break;
|
||||
|
||||
// The vkCmdDispatchBase() command lets the client set the base value
|
||||
// of WorkgroupId. Metal has no direct equivalent; we must make this
|
||||
// adjustment ourselves.
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
|
||||
});
|
||||
break;
|
||||
case BuiltInGlobalInvocationId:
|
||||
if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
|
||||
break;
|
||||
|
||||
// GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
|
||||
// This needs to be adjusted too.
|
||||
entry_func.fixup_hooks_in.push_back([=]() {
|
||||
auto &execution = this->get_entry_point();
|
||||
uint32_t workgroup_size_id = execution.workgroup_size.constant;
|
||||
if (workgroup_size_id)
|
||||
statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
|
||||
" * ", to_expression(workgroup_size_id), ";");
|
||||
else
|
||||
statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
|
||||
" * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
|
||||
execution.workgroup_size.z, ");");
|
||||
});
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -198,6 +198,7 @@ public:
|
|||
bool tess_domain_origin_lower_left = false;
|
||||
bool multiview = false;
|
||||
bool view_index_from_device_index = false;
|
||||
bool dispatch_base = false;
|
||||
|
||||
// Enable use of MSL 2.0 indirect argument buffers.
|
||||
// MSL 2.0 must also be enabled.
|
||||
|
@ -225,7 +226,7 @@ public:
|
|||
msl_version = make_msl_version(major, minor, patch);
|
||||
}
|
||||
|
||||
bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0)
|
||||
bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) const
|
||||
{
|
||||
return msl_version >= make_msl_version(major, minor, patch);
|
||||
}
|
||||
|
@ -276,6 +277,13 @@ public:
|
|||
return msl_options.multiview && !msl_options.view_index_from_device_index;
|
||||
}
|
||||
|
||||
// Provide feedback to calling API to allow it to pass a buffer
|
||||
// containing the dispatch base workgroup ID.
|
||||
bool needs_dispatch_base_buffer() const
|
||||
{
|
||||
return msl_options.dispatch_base && !msl_options.supports_msl_version(1, 2);
|
||||
}
|
||||
|
||||
// Provide feedback to calling API to allow it to pass an output
|
||||
// buffer if the shader needs it.
|
||||
bool needs_output_buffer() const
|
||||
|
@ -563,6 +571,7 @@ protected:
|
|||
uint32_t builtin_primitive_id_id = 0;
|
||||
uint32_t builtin_subgroup_invocation_id_id = 0;
|
||||
uint32_t builtin_subgroup_size_id = 0;
|
||||
uint32_t builtin_dispatch_base_id = 0;
|
||||
uint32_t swizzle_buffer_id = 0;
|
||||
uint32_t buffer_size_buffer_id = 0;
|
||||
uint32_t view_mask_buffer_id = 0;
|
||||
|
|
|
@ -207,6 +207,8 @@ def cross_compile_msl(shader, spirv, opt, iterations, paths):
|
|||
msl_args.append('--msl-multiview')
|
||||
if '.viewfromdev.' in shader:
|
||||
msl_args.append('--msl-view-index-from-device-index')
|
||||
if '.dispatchbase.' in shader:
|
||||
msl_args.append('--msl-dispatch-base')
|
||||
|
||||
subprocess.check_call(msl_args)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче