diff --git a/reference/shaders-msl/vulkan/frag/spec-constant.vk.frag b/reference/shaders-msl/vulkan/frag/spec-constant.vk.frag new file mode 100644 index 0000000..47dabb1 --- /dev/null +++ b/reference/shaders-msl/vulkan/frag/spec-constant.vk.frag @@ -0,0 +1,74 @@ +#include +#include + +using namespace metal; + +constant float a_tmp [[function_constant(1)]]; +constant float a = is_function_constant_defined(a_tmp) ? a_tmp : 1.0; +constant float b_tmp [[function_constant(2)]]; +constant float b = is_function_constant_defined(b_tmp) ? b_tmp : 2.0; +constant int c_tmp [[function_constant(3)]]; +constant int c = is_function_constant_defined(c_tmp) ? c_tmp : 3; +constant int d_tmp [[function_constant(4)]]; +constant int d = is_function_constant_defined(d_tmp) ? d_tmp : 4; +constant uint e_tmp [[function_constant(5)]]; +constant uint e = is_function_constant_defined(e_tmp) ? e_tmp : 5u; +constant uint f_tmp [[function_constant(6)]]; +constant uint f = is_function_constant_defined(f_tmp) ? f_tmp : 6u; +constant bool g_tmp [[function_constant(7)]]; +constant bool g = is_function_constant_defined(g_tmp) ? g_tmp : false; +constant bool h_tmp [[function_constant(8)]]; +constant bool h = is_function_constant_defined(h_tmp) ? h_tmp : true; + +struct main0_out +{ + float4 FragColor [[color(0)]]; +}; + +fragment main0_out main0() +{ + main0_out out = {}; + float t0 = a; + float t1 = b; + uint c0 = (uint(c) + 0u); + int c1 = (-c); + int c2 = (~c); + int c3 = (c + d); + int c4 = (c - d); + int c5 = (c * d); + int c6 = (c / d); + uint c7 = (e / f); + int c8 = (c % d); + uint c9 = (e % f); + int c10 = (c >> d); + uint c11 = (e >> f); + int c12 = (c << d); + int c13 = (c | d); + int c14 = (c ^ d); + int c15 = (c & d); + bool c16 = (g || h); + bool c17 = (g && h); + bool c18 = (!g); + bool c19 = (g == h); + bool c20 = (g != h); + bool c21 = (c == d); + bool c22 = (c != d); + bool c23 = (c < d); + bool c24 = (e < f); + bool c25 = (c > d); + bool c26 = (e > f); + bool c27 = (c <= d); + bool c28 = (e <= f); + bool c29 = (c >= d); + bool c30 = (e >= f); + int c31 = c8 + c3; + int c32 = int(e + 0u); + bool c33 = (c != int(0u)); + bool c34 = (e != 0u); + int c35 = int(g); + uint c36 = uint(g); + float c37 = float(g); + out.FragColor = float4(t0 + t1); + return out; +} + diff --git a/shaders-msl/vulkan/frag/spec-constant.vk.frag b/shaders-msl/vulkan/frag/spec-constant.vk.frag new file mode 100644 index 0000000..3cb75da --- /dev/null +++ b/shaders-msl/vulkan/frag/spec-constant.vk.frag @@ -0,0 +1,67 @@ +#version 310 es +precision mediump float; + +layout(location = 0) out vec4 FragColor; +layout(constant_id = 1) const float a = 1.0; +layout(constant_id = 2) const float b = 2.0; +layout(constant_id = 3) const int c = 3; +layout(constant_id = 4) const int d = 4; +layout(constant_id = 5) const uint e = 5u; +layout(constant_id = 6) const uint f = 6u; +layout(constant_id = 7) const bool g = false; +layout(constant_id = 8) const bool h = true; +// glslang doesn't seem to support partial spec constants or composites yet, so only test the basics. + +void main() +{ + float t0 = a; + float t1 = b; + + uint c0 = uint(c); // OpIAdd with different types. + // FConvert, float-to-double. + int c1 = -c; // SNegate + int c2 = ~c; // OpNot + int c3 = c + d; // OpIAdd + int c4 = c - d; // OpISub + int c5 = c * d; // OpIMul + int c6 = c / d; // OpSDiv + uint c7 = e / f; // OpUDiv + int c8 = c % d; // OpSMod + uint c9 = e % f; // OpUMod + // TODO: OpSRem, any way to access this in GLSL? + int c10 = c >> d; // OpShiftRightArithmetic + uint c11 = e >> f; // OpShiftRightLogical + int c12 = c << d; // OpShiftLeftLogical + int c13 = c | d; // OpBitwiseOr + int c14 = c ^ d; // OpBitwiseXor + int c15 = c & d; // OpBitwiseAnd + // VectorShuffle, CompositeExtract, CompositeInsert, not testable atm. + bool c16 = g || h; // OpLogicalOr + bool c17 = g && h; // OpLogicalAnd + bool c18 = !g; // OpLogicalNot + bool c19 = g == h; // OpLogicalEqual + bool c20 = g != h; // OpLogicalNotEqual + // OpSelect not testable atm. + bool c21 = c == d; // OpIEqual + bool c22 = c != d; // OpINotEqual + bool c23 = c < d; // OpSLessThan + bool c24 = e < f; // OpULessThan + bool c25 = c > d; // OpSGreaterThan + bool c26 = e > f; // OpUGreaterThan + bool c27 = c <= d; // OpSLessThanEqual + bool c28 = e <= f; // OpULessThanEqual + bool c29 = c >= d; // OpSGreaterThanEqual + bool c30 = e >= f; // OpUGreaterThanEqual + // OpQuantizeToF16 not testable atm. + + int c31 = c8 + c3; + + int c32 = int(e); // OpIAdd with different types. + bool c33 = bool(c); // int -> bool + bool c34 = bool(e); // uint -> bool + int c35 = int(g); // bool -> int + uint c36 = uint(g); // bool -> uint + float c37 = float(g); // bool -> float + + FragColor = vec4(t0 + t1); +} diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 7c4e0a4..869d9b2 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -101,6 +101,7 @@ string CompilerMSL::compile() extract_global_variables_from_functions(); // Do not deal with GLES-isms like precision, older extensions and such. + CompilerGLSL::options.vulkan_semantics = true; CompilerGLSL::options.es = false; CompilerGLSL::options.version = 120; CompilerGLSL::options.vertex.fixup_clipspace = false; @@ -127,6 +128,7 @@ string CompilerMSL::compile() buffer = unique_ptr(new ostringstream()); emit_header(); + emit_specialization_constants(); emit_resources(); emit_custom_functions(); emit_function(get(entry_point), 0); @@ -987,6 +989,28 @@ void CompilerMSL::emit_resources() emit_interface_block(stage_uniforms_var_id); } +// Emit declarations for the specialization Metal function constants +void CompilerMSL::emit_specialization_constants() +{ + const vector spec_consts = get_specialization_constants(); + + if (spec_consts.empty()) + return; + + for (auto &sc : spec_consts) + { + string sc_type_name = type_to_glsl(expression_type(sc.id)); + string sc_name = to_name(sc.id); + string sc_tmp_name = to_name(sc.id) + "_tmp"; + + statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", + convert_to_string(sc.constant_id), ")]];"); + statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name, ") ? ", + sc_tmp_name, " : ", constant_expression(get(sc.id)), ";"); + } + statement(""); +} + // Override for MSL-specific syntax instructions void CompilerMSL::emit_instruction(const Instruction &instruction) { @@ -2536,9 +2560,25 @@ string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id) return img_type_name; } -string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &) +string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type) { - return "as_type<" + type_to_glsl(out_type) + ">"; + if ((out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int) || + (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt) || + (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64) || + (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)) + return type_to_glsl(out_type); + + if ((out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float) || + (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float) || + (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt) || + (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int) || + (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double) || + (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double) || + (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64) || + (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)) + return "as_type<" + type_to_glsl(out_type) + ">"; + + return ""; } // Returns an MSL string identifying the name of a SPIR-V builtin. @@ -2703,22 +2743,8 @@ string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma) // Returns the byte size of a struct member. size_t CompilerMSL::get_declared_struct_member_size(const SPIRType &struct_type, uint32_t index) const { - uint32_t type_id = struct_type.member_types[index]; auto dec_mask = get_member_decoration_mask(struct_type.self, index); - return get_declared_type_size(type_id, dec_mask); -} - -// Returns the effective size of a variable type. -size_t CompilerMSL::get_declared_type_size(uint32_t type_id) const -{ - return get_declared_type_size(type_id, get_decoration_mask(type_id)); -} - -// Returns the effective size in bytes of a variable type -// or member type, taking into consideration the decorations mask. -size_t CompilerMSL::get_declared_type_size(uint32_t type_id, uint64_t dec_mask) const -{ - auto &type = get(type_id); + auto &type = get(struct_type.member_types[index]); switch (type.basetype) { @@ -2739,14 +2765,9 @@ size_t CompilerMSL::get_declared_type_size(uint32_t type_id, uint64_t dec_mask) unsigned vecsize = type.vecsize; unsigned columns = type.columns; + // For arrays, we can use ArrayStride to get an easy check. if (!type.array.empty()) - { - // For arrays, we can use ArrayStride to get an easy check if it has been populated. - // ArrayStride is part of the array type not OpMemberDecorate. - auto &dec = meta[type_id].decoration; - if (dec.decoration_flags & (1ull << DecorationArrayStride)) - return dec.array_stride * to_array_size_literal(type, uint32_t(type.array.size()) - 1); - } + return type_struct_member_array_stride(struct_type, index) * type.array.back(); if (columns == 1) // An unpacked 3-element vector is the same size as a 4-element vector. { @@ -2778,16 +2799,7 @@ size_t CompilerMSL::get_declared_type_size(uint32_t type_id, uint64_t dec_mask) // Returns the byte alignment of a struct member. size_t CompilerMSL::get_declared_struct_member_alignment(const SPIRType &struct_type, uint32_t index) const { - uint32_t type_id = struct_type.member_types[index]; - auto dec_mask = get_member_decoration_mask(struct_type.self, index); - return get_declared_type_alignment(type_id, dec_mask); -} - -// Returns the effective alignment in bytes of a variable type -// or member type, taking into consideration the decorations mask. -size_t CompilerMSL::get_declared_type_alignment(uint32_t type_id, uint64_t dec_mask) const -{ - auto &type = get(type_id); + auto &type = get(struct_type.member_types[index]); switch (type.basetype) { @@ -2806,10 +2818,11 @@ size_t CompilerMSL::get_declared_type_alignment(uint32_t type_id, uint64_t dec_m { // Alignment of packed type is the same as the underlying component size. // Alignment of unpacked type is the same as the type size (or one matrix column). + auto dec_mask = get_member_decoration_mask(struct_type.self, index); if (dec_mask & (1ull << DecorationCPacked)) return type.width / 8; else - return get_declared_type_size(type_id, dec_mask) / type.columns; + return get_declared_struct_member_size(struct_type, index) / type.columns; } } } diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 014e3cd..25f1a37 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -168,7 +168,6 @@ protected: bool skip_argument(uint32_t id) const override; void preprocess_op_codes(); - void emit_custom_functions(); void localize_global_variables(); void extract_global_variables_from_functions(); @@ -179,7 +178,9 @@ protected: uint32_t add_interface_block(spv::StorageClass storage); void mark_location_as_used_by_shader(uint32_t location, spv::StorageClass storage); + void emit_custom_functions(); void emit_resources(); + void emit_specialization_constants(); void emit_interface_block(uint32_t ib_var_id); void populate_func_name_overrides(); void populate_var_name_overrides(); @@ -199,10 +200,7 @@ protected: std::string round_fp_tex_coords(std::string tex_coords, bool coord_is_fp); uint32_t get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype); uint32_t get_ordered_member_location(uint32_t type_id, uint32_t index); - size_t get_declared_type_size(uint32_t type_id) const; - size_t get_declared_type_size(uint32_t type_id, uint64_t dec_mask) const; size_t get_declared_struct_member_alignment(const SPIRType &struct_type, uint32_t index) const; - size_t get_declared_type_alignment(uint32_t type_id, uint64_t dec_mask) const; std::string to_component_argument(uint32_t id); void exclude_from_stage_in(SPIRVariable &var); void exclude_member_from_stage_in(const SPIRType &type, uint32_t index); diff --git a/test_shaders.py b/test_shaders.py index ff89de6..74aa6e0 100755 --- a/test_shaders.py +++ b/test_shaders.py @@ -75,8 +75,9 @@ def print_msl_compiler_version(): def validate_shader_msl(shader): msl_path = reference_path(shader[0], shader[1]) try: - subprocess.check_call(['xcrun', '--sdk', 'macosx', 'metal', '-x', 'metal', '-std=osx-metal1.2', '-Werror', msl_path]) -# subprocess.check_call(['xcrun', '--sdk', 'iphoneos', 'metal', '-x', 'metal', '-std=ios-metal1.2', '-Werror', msl_path]) + msl_os = 'macosx' +# msl_os = 'iphoneos' + subprocess.check_call(['xcrun', '--sdk', msl_os, 'metal', '-x', 'metal', '-std=osx-metal1.2', '-Werror', '-Wno-unused-variable', msl_path]) print('Compiled Metal shader: ' + msl_path) # display after so xcrun FNF is silent except OSError as oe: if (oe.errno != os.errno.ENOENT): # Ignore xcrun not found error