diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index bf33a25..9a9bc44 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -2691,6 +2691,7 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) if (i + 1 < length) funexpr += ", "; } + funexpr += static_func_args(callee, length); funexpr += ")"; if (get(result_type).basetype != SPIRType::Void) @@ -3682,6 +3683,26 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) } } +// Returns a string expression of function arguments beyond the specified index. +// This is used when a function call uses fewer arguments than the function defines. +// This situation may occur if the function signature has been dynamically modified +// to extract static global variables referenced from within the function and convert +// them to function arguments. This is necessary for shader languages that do not +// support global access to shader input content from within a function (eg. Metal). +// Each additional function args uses the name of the global var. Function nesting +// will modify the functions and calls all the way up the nesting chain. +string CompilerGLSL::static_func_args(const SPIRFunction &func, uint32_t index) +{ + string static_args; + auto& args = func.arguments; + uint32_t arg_cnt = (uint32_t)args.size(); + for (uint32_t arg_idx = index; arg_idx < arg_cnt; arg_idx++) { + if (arg_idx > 0) static_args += ", "; + static_args += to_expression(args[arg_idx].id); + } + return static_args; +} + string CompilerGLSL::to_member_name(const SPIRType &type, uint32_t index) { auto &memb = meta[type.self].members; diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 3831740..c148f6a 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -262,6 +262,7 @@ protected: const char *index_to_swizzle(uint32_t index); std::string remap_swizzle(uint32_t result_type, uint32_t input_components, uint32_t expr); std::string declare_temporary(uint32_t type, uint32_t id); + std::string static_func_args(const SPIRFunction &func, uint32_t index); std::string to_expression(uint32_t id); std::string to_member_name(const SPIRType &type, uint32_t index); std::string type_to_glsl_constructor(const SPIRType &type); diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 1a1d1bd..c2b8849 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -162,6 +162,93 @@ void CompilerMSL::localize_global_variables() iter++; } } + + // For any global variable accessed directly by a function, + // extract that variable and add it as an argument to that function. + extract_global_variables_from_functions(); +} + +// For any global variable accessed directly by a function, +// extract that variable and add it as an argument to that function. +void CompilerMSL::extract_global_variables_from_functions() +{ + + // Uniforms + std::set global_var_ids; + for (auto &id : ids) + { + if (id.get_type() == TypeVariable) + { + auto &var = id.get(); + if (var.storage == StorageClassUniform || + var.storage == StorageClassUniformConstant || + var.storage == StorageClassPushConstant) + global_var_ids.insert(var.self); + } + } + + std::set added_arg_ids; + std::set processed_func_ids; + extract_global_variables_from_functions(execution.entry_point, added_arg_ids, global_var_ids, processed_func_ids); +} + +// MSL does not support the use of global variables for shader input content. +// For any global variable accessed directly by the specified function, extract that variable, +// add it as an argument to that function, and the arg to the added_arg_ids collection. +void CompilerMSL::extract_global_variables_from_functions(uint32_t func_id, + std::set& added_arg_ids, + std::set& global_var_ids, + std::set& processed_func_ids) +{ + // Avoid processing a function more than once + if ( processed_func_ids.find(func_id) != processed_func_ids.end() ) + return; + + processed_func_ids.insert(func_id); + + auto& func = get(func_id); + + // Recursively establish global args added to functions on which we depend. + for (auto block : func.blocks) + { + auto &b = get(block); + for (auto &i : b.ops) + { + auto ops = stream(i); + auto op = static_cast(i.op); + + switch (op) { + case OpAccessChain: { + uint32_t base_id = ops[2]; + if ( global_var_ids.find(base_id) != global_var_ids.end() ) + added_arg_ids.insert(base_id); + break; + } + case OpFunctionCall: { + uint32_t inner_func_id = ops[2]; + std::set inner_func_args; + extract_global_variables_from_functions(inner_func_id, inner_func_args, global_var_ids, processed_func_ids); + added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end()); + break; + } + + default: + break; + } + } + } + + // Add the global variables as arguments to the function + if (func_id != execution.entry_point) { + uint32_t next_id = increase_bound_by((uint32_t)added_arg_ids.size()); + for (uint32_t arg_id : added_arg_ids) { + uint32_t type_id = get(arg_id).basetype; + func.add_parameter(type_id, next_id); + set(next_id, type_id, StorageClassFunction); + set_name(next_id, get_name(arg_id)); + next_id++; + } + } } // Adds any interface structure variables needed by this shader @@ -444,8 +531,8 @@ void CompilerMSL::emit_resources() auto &var = id.get(); auto &type = get(var.basetype); - if (type.pointer && (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant || - type.storage == StorageClassPushConstant) && + if (type.pointer && (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || + var.storage == StorageClassPushConstant) && !is_builtin_variable(var) && (meta[type.self].decoration.decoration_flags & ((1ull << DecorationBlock) | (1ull << DecorationBufferBlock)))) { @@ -528,15 +615,24 @@ void CompilerMSL::emit_function_prototype(SPIRFunction &func, bool is_decl) { add_local_variable_name(arg.id); - decl += "thread " + argument_decl(arg); - if (&arg != &func.arguments.back()) - decl += ", "; - - // Hold a pointer to the parameter so we can invalidate the readonly field if needed. + bool is_uniform = false; auto *var = maybe_get(arg.id); - if (var) - var->parameter = &arg; - } + if (var) { + var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed. + + // Check if this arg is one of the synthetic uniform args + // created to handle uniform access inside the function + auto &var_type = get(var->basetype); + is_uniform = (var_type.storage == StorageClassUniform || + var_type.storage == StorageClassUniformConstant || + var_type.storage == StorageClassPushConstant); + } + + decl += (is_uniform ? "constant " : "thread "); + decl += argument_decl(arg); + if (&arg != &func.arguments.back()) + decl += ", "; + } decl += ")"; statement(decl, (is_decl ? ";" : "")); diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 363f188..66c2119 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -111,6 +111,11 @@ protected: void extract_builtins(); void add_builtin(spv::BuiltIn builtin_type); void localize_global_variables(); + void extract_global_variables_from_functions(); + void extract_global_variables_from_functions(uint32_t func_id, + std::set& added_arg_ids, + std::set& global_var_ids, + std::set& processed_func_ids); void add_interface_structs(); void bind_vertex_attributes(std::set &bindings); uint32_t add_interface_struct(spv::StorageClass storage, uint32_t vtx_binding = 0);