diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h index f794faf4..5cea101f 100644 --- a/include/spirv-tools/libspirv.h +++ b/include/spirv-tools/libspirv.h @@ -494,6 +494,20 @@ SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxStoreStruct( SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetRelaxLogicalPointer( spv_validator_options options, bool val); +// Records whether or not the validator should relax the rules because it is +// expected that the optimizations will make the code legal. +// +// When relaxed, it will allow the following: +// 1) It will allow relaxed logical pointers. Setting this option will also +// set that option. +// 2) Pointers that are pass as parameters to function calls do not have to +// match the storage class of the formal parameter. +// 3) Pointers that are actaul parameters on function calls do not have to point +// to the same type pointed as the formal parameter. The types just need to +// logically match. +SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetBeforeHlslLegalization( + spv_validator_options options, bool val); + // Records whether the validator should use "relaxed" block layout rules. // Relaxed layout rules are described by Vulkan extension // VK_KHR_relaxed_block_layout, and they affect uniform blocks, storage blocks, diff --git a/include/spirv-tools/libspirv.hpp b/include/spirv-tools/libspirv.hpp index 27b87e7f..da27404d 100644 --- a/include/spirv-tools/libspirv.hpp +++ b/include/spirv-tools/libspirv.hpp @@ -116,6 +116,21 @@ class ValidatorOptions { spvValidatorOptionsSetRelaxLogicalPointer(options_, val); } + // Records whether or not the validator should relax the rules because it is + // expected that the optimizations will make the code legal. + // + // When relaxed, it will allow the following: + // 1) It will allow relaxed logical pointers. Setting this option will also + // set that option. + // 2) Pointers that are pass as parameters to function calls do not have to + // match the storage class of the formal parameter. + // 3) Pointers that are actaul parameters on function calls do not have to + // point to the same type pointed as the formal parameter. The types just + // need to logically match. + void SetBeforeHlslLegalization(bool val) { + spvValidatorOptionsSetBeforeHlslLegalization(options_, val); + } + private: spv_validator_options options_; }; diff --git a/source/spirv_validator_options.cpp b/source/spirv_validator_options.cpp index 9b522fb5..01aa7974 100644 --- a/source/spirv_validator_options.cpp +++ b/source/spirv_validator_options.cpp @@ -90,6 +90,12 @@ void spvValidatorOptionsSetRelaxLogicalPointer(spv_validator_options options, options->relax_logical_pointer = val; } +void spvValidatorOptionsSetBeforeHlslLegalization(spv_validator_options options, + bool val) { + options->before_hlsl_legalization = val; + options->relax_logical_pointer = val; +} + void spvValidatorOptionsSetRelaxBlockLayout(spv_validator_options options, bool val) { options->relax_block_layout = val; diff --git a/source/spirv_validator_options.h b/source/spirv_validator_options.h index 5b27de6e..b7da5d8e 100644 --- a/source/spirv_validator_options.h +++ b/source/spirv_validator_options.h @@ -45,7 +45,8 @@ struct spv_validator_options_t { relax_block_layout(false), uniform_buffer_standard_layout(false), scalar_block_layout(false), - skip_block_layout(false) {} + skip_block_layout(false), + before_hlsl_legalization(false) {} validator_universal_limits_t universal_limits_; bool relax_struct_store; @@ -54,6 +55,7 @@ struct spv_validator_options_t { bool uniform_buffer_standard_layout; bool scalar_block_layout; bool skip_block_layout; + bool before_hlsl_legalization; }; #endif // SOURCE_SPIRV_VALIDATOR_OPTIONS_H_ diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp index 1c1e77c6..1d93378c 100644 --- a/source/val/validate_composites.cpp +++ b/source/val/validate_composites.cpp @@ -513,74 +513,6 @@ spv_result_t ValidateVectorShuffle(ValidationState_t& _, return SPV_SUCCESS; } -// Returns true if |lhs| and |rhs| logically match. -// 1. Must both be either OpTypeArray or OpTypeStruct -// 2. If OpTypeArray, then -// * Length must be the same -// * Element type must match or logically match -// 3. If OpTypeStruct, then -// * Both have same number of elements -// * Element N for both structs must match or logically match -bool LogicallyMatch(ValidationState_t& _, const Instruction* lhs, - const Instruction* rhs) { - if (lhs->opcode() != rhs->opcode()) { - return false; - } - - if (lhs->opcode() == SpvOpTypeArray) { - // Size operands must match. - if (lhs->GetOperandAs(2u) != rhs->GetOperandAs(2u)) { - return false; - } - - // Elements must match or logically match. - const auto lhs_ele_id = lhs->GetOperandAs(1u); - const auto rhs_ele_id = rhs->GetOperandAs(1u); - if (lhs_ele_id == rhs_ele_id) { - return true; - } - - const auto lhs_ele = _.FindDef(lhs_ele_id); - const auto rhs_ele = _.FindDef(rhs_ele_id); - if (!lhs_ele || !rhs_ele) { - return false; - } - return LogicallyMatch(_, lhs_ele, rhs_ele); - } else if (lhs->opcode() == SpvOpTypeStruct) { - // Number of elements must match. - if (lhs->operands().size() != rhs->operands().size()) { - return false; - } - - for (size_t i = 1u; i < lhs->operands().size(); ++i) { - const auto lhs_ele_id = lhs->GetOperandAs(i); - const auto rhs_ele_id = rhs->GetOperandAs(i); - // Elements must match or logically match. - if (lhs_ele_id == rhs_ele_id) { - continue; - } - - const auto lhs_ele = _.FindDef(lhs_ele_id); - const auto rhs_ele = _.FindDef(rhs_ele_id); - if (!lhs_ele || !rhs_ele) { - return false; - } - - if (!LogicallyMatch(_, lhs_ele, rhs_ele)) { - return false; - } - } - - // All checks passed. - return true; - } - - // No other opcodes are acceptable at this point. Arrays and structs are - // caught above and if they're elements are not arrays or structs they are - // required to match exactly. - return false; -} - spv_result_t ValidateCopyLogical(ValidationState_t& _, const Instruction* inst) { const auto result_type = _.FindDef(inst->type_id()); @@ -591,7 +523,7 @@ spv_result_t ValidateCopyLogical(ValidationState_t& _, << "Result Type must not equal the Operand type"; } - if (!LogicallyMatch(_, source_type, result_type)) { + if (!_.LogicallyMatch(source_type, result_type, false)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Result Type does not logically match the Operand type"; } diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp index 56090086..b9831941 100644 --- a/source/val/validate_function.cpp +++ b/source/val/validate_function.cpp @@ -12,27 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "source/val/validate.h" - #include #include "source/opcode.h" #include "source/val/instruction.h" +#include "source/val/validate.h" #include "source/val/validation_state.h" namespace spvtools { namespace val { namespace { -// Returns true if |a| and |b| are instruction defining pointers that point to -// the same type. -bool ArePointersToSameType(val::Instruction* a, val::Instruction* b) { +// Returns true if |a| and |b| are instructions defining pointers that point to +// types logically match and the decorations that apply to |b| are a subset +// of the decorations that apply to |a|. +bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, + ValidationState_t& _) { if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) { return false; } + const auto& dec_a = _.id_decorations(a->id()); + const auto& dec_b = _.id_decorations(b->id()); + for (const auto& dec : dec_b) { + if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { + return false; + } + } + uint32_t a_type = a->GetOperandAs(2); - return a_type && (a_type == b->GetOperandAs(2)); + uint32_t b_type = b->GetOperandAs(2); + + if (a_type == b_type) { + return true; + } + + Instruction* a_type_inst = _.FindDef(a_type); + Instruction* b_type_inst = _.FindDef(b_type); + + return _.LogicallyMatch(a_type_inst, b_type_inst, true); } spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { @@ -256,14 +274,14 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, const auto parameter_type_id = function_type->GetOperandAs(param_index); const auto parameter_type = _.FindDef(parameter_type_id); - if (!parameter_type || - (argument_type->id() != parameter_type->id() && - !(_.options()->relax_logical_pointer && - ArePointersToSameType(argument_type, parameter_type)))) { - return _.diag(SPV_ERROR_INVALID_ID, inst) - << "OpFunctionCall Argument '" << _.getIdName(argument_id) - << "'s type does not match Function '" - << _.getIdName(parameter_type_id) << "'s parameter type."; + if (!parameter_type || argument_type->id() != parameter_type->id()) { + if (!_.options()->before_hlsl_legalization || + !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpFunctionCall Argument '" << _.getIdName(argument_id) + << "'s type does not match Function '" + << _.getIdName(parameter_type_id) << "'s parameter type."; + } } if (_.addressing_model() == SpvAddressingModelLogical) { diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 325a3e85..745c82bb 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -1135,5 +1135,77 @@ std::string ValidationState_t::Disassemble(const uint32_t* words, words_, num_words_, disassembly_options); } +bool ValidationState_t::LogicallyMatch(const Instruction* lhs, + const Instruction* rhs, + bool check_decorations) { + if (lhs->opcode() != rhs->opcode()) { + return false; + } + + if (check_decorations) { + const auto& dec_a = id_decorations(lhs->id()); + const auto& dec_b = id_decorations(rhs->id()); + + for (const auto& dec : dec_b) { + if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { + return false; + } + } + } + + if (lhs->opcode() == SpvOpTypeArray) { + // Size operands must match. + if (lhs->GetOperandAs(2u) != rhs->GetOperandAs(2u)) { + return false; + } + + // Elements must match or logically match. + const auto lhs_ele_id = lhs->GetOperandAs(1u); + const auto rhs_ele_id = rhs->GetOperandAs(1u); + if (lhs_ele_id == rhs_ele_id) { + return true; + } + + const auto lhs_ele = FindDef(lhs_ele_id); + const auto rhs_ele = FindDef(rhs_ele_id); + if (!lhs_ele || !rhs_ele) { + return false; + } + return LogicallyMatch(lhs_ele, rhs_ele, check_decorations); + } else if (lhs->opcode() == SpvOpTypeStruct) { + // Number of elements must match. + if (lhs->operands().size() != rhs->operands().size()) { + return false; + } + + for (size_t i = 1u; i < lhs->operands().size(); ++i) { + const auto lhs_ele_id = lhs->GetOperandAs(i); + const auto rhs_ele_id = rhs->GetOperandAs(i); + // Elements must match or logically match. + if (lhs_ele_id == rhs_ele_id) { + continue; + } + + const auto lhs_ele = FindDef(lhs_ele_id); + const auto rhs_ele = FindDef(rhs_ele_id); + if (!lhs_ele || !rhs_ele) { + return false; + } + + if (!LogicallyMatch(lhs_ele, rhs_ele, check_decorations)) { + return false; + } + } + + // All checks passed. + return true; + } + + // No other opcodes are acceptable at this point. Arrays and structs are + // caught above and if they're elements are not arrays or structs they are + // required to match exactly. + return false; +} + } // namespace val } // namespace spvtools diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 5e84e24d..9b0a5860 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -664,6 +664,21 @@ class ValidationState_t { spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst, uint32_t m1, uint32_t m2); + // Returns true if |lhs| and |rhs| logically match and, if the decorations of + // |rhs| are a subset of |lhs|. + // + // 1. Must both be either OpTypeArray or OpTypeStruct + // 2. If OpTypeArray, then + // * Length must be the same + // * Element type must match or logically match + // 3. If OpTypeStruct, then + // * Both have same number of elements + // * Element N for both structs must match or logically match + // + // If |check_decorations| is false, then the decorations are not checked. + bool LogicallyMatch(const Instruction* lhs, const Instruction* rhs, + bool check_decorations); + private: ValidationState_t(const ValidationState_t&); diff --git a/test/val/val_function_test.cpp b/test/val/val_function_test.cpp index 6c0e8a10..af0199a6 100644 --- a/test/val/val_function_test.cpp +++ b/test/val/val_function_test.cpp @@ -416,6 +416,423 @@ TEST_P(ValidateFunctionCall, NonMemoryObjectDeclarationVariablePointers) { } } +TEST_F(ValidateFunctionCall, LogicallyMatchingPointers) { + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpMemberDecorate %_struct_3 0 Offset 0 + OpDecorate %_runtimearr__struct_3 ArrayStride 4 + OpMemberDecorate %_struct_5 0 Offset 0 + OpDecorate %_struct_5 BufferBlock + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_3 = OpTypeStruct %int +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 + %_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %_struct_15 = OpTypeStruct %int +%_ptr_Function__struct_15 = OpTypePointer Function %_struct_15 +%_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3 + %18 = OpTypeFunction %void %_ptr_Function__struct_15 + %2 = OpVariable %_ptr_Uniform__struct_5 Uniform + %1 = OpFunction %void None %14 + %19 = OpLabel + %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0 + %21 = OpFunctionCall %void %22 %20 + OpReturn + OpFunctionEnd + %22 = OpFunction %void None %18 + %23 = OpFunctionParameter %_ptr_Function__struct_15 + %24 = OpLabel + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(spirv); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateFunctionCall, LogicallyMatchingPointersNestedStruct) { + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpMemberDecorate %_struct_3 0 Offset 0 + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %_runtimearr__struct_4 ArrayStride 4 + OpMemberDecorate %_struct_6 0 Offset 0 + OpDecorate %_struct_6 BufferBlock + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_3 = OpTypeStruct %int + %_struct_4 = OpTypeStruct %_struct_3 +%_runtimearr__struct_4 = OpTypeRuntimeArray %_struct_4 + %_struct_6 = OpTypeStruct %_runtimearr__struct_4 +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 + %void = OpTypeVoid + %13 = OpTypeFunction %void + %_struct_14 = OpTypeStruct %int + %_struct_15 = OpTypeStruct %_struct_14 +%_ptr_Function__struct_15 = OpTypePointer Function %_struct_15 +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 + %18 = OpTypeFunction %void %_ptr_Function__struct_15 + %2 = OpVariable %_ptr_Uniform__struct_6 Uniform + %1 = OpFunction %void None %13 + %19 = OpLabel + %20 = OpVariable %_ptr_Function__struct_15 Function + %21 = OpAccessChain %_ptr_Uniform__struct_4 %2 %int_0 %uint_0 + %22 = OpFunctionCall %void %23 %21 + OpReturn + OpFunctionEnd + %23 = OpFunction %void None %18 + %24 = OpFunctionParameter %_ptr_Function__struct_15 + %25 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateFunctionCall, LogicallyMatchingPointersNestedArray) { + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpDecorate %_arr_int_uint_10 ArrayStride 4 + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %_runtimearr__struct_4 ArrayStride 40 + OpMemberDecorate %_struct_6 0 Offset 0 + OpDecorate %_struct_6 BufferBlock + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_10 = OpConstant %uint 10 +%_arr_int_uint_10 = OpTypeArray %int %uint_10 + %_struct_4 = OpTypeStruct %_arr_int_uint_10 +%_runtimearr__struct_4 = OpTypeRuntimeArray %_struct_4 + %_struct_6 = OpTypeStruct %_runtimearr__struct_4 +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 + %void = OpTypeVoid + %14 = OpTypeFunction %void +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 +%_arr_int_uint_10_0 = OpTypeArray %int %uint_10 + %_struct_17 = OpTypeStruct %_arr_int_uint_10_0 +%_ptr_Function__struct_17 = OpTypePointer Function %_struct_17 + %19 = OpTypeFunction %void %_ptr_Function__struct_17 + %2 = OpVariable %_ptr_Uniform__struct_6 Uniform + %1 = OpFunction %void None %14 + %20 = OpLabel + %21 = OpAccessChain %_ptr_Uniform__struct_4 %2 %int_0 %uint_0 + %22 = OpFunctionCall %void %23 %21 + OpReturn + OpFunctionEnd + %23 = OpFunction %void None %19 + %24 = OpFunctionParameter %_ptr_Function__struct_17 + %25 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateFunctionCall, LogicallyMismatchedPointersMissingMember) { + // Validation should fail because the formal parameter type has two members, + // while the actual parameter only has 1. + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpMemberDecorate %_struct_3 0 Offset 0 + OpDecorate %_runtimearr__struct_3 ArrayStride 4 + OpMemberDecorate %_struct_5 0 Offset 0 + OpDecorate %_struct_5 BufferBlock + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_3 = OpTypeStruct %int +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 + %_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %_struct_15 = OpTypeStruct %int %int +%_ptr_Function__struct_15 = OpTypePointer Function %_struct_15 +%_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3 + %18 = OpTypeFunction %void %_ptr_Function__struct_15 + %2 = OpVariable %_ptr_Uniform__struct_5 Uniform + %1 = OpFunction %void None %14 + %19 = OpLabel + %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0 + %21 = OpFunctionCall %void %22 %20 + OpReturn + OpFunctionEnd + %22 = OpFunction %void None %18 + %23 = OpFunctionParameter %_ptr_Function__struct_15 + %24 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument ")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("type does not match Function ")); +} + +TEST_F(ValidateFunctionCall, LogicallyMismatchedPointersDifferentMemberType) { + // Validation should fail because the formal parameter has a member that is + // a different type than the actual parameter. + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpMemberDecorate %_struct_3 0 Offset 0 + OpDecorate %_runtimearr__struct_3 ArrayStride 4 + OpMemberDecorate %_struct_5 0 Offset 0 + OpDecorate %_struct_5 BufferBlock + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_3 = OpTypeStruct %uint +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 + %_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %_struct_15 = OpTypeStruct %int +%_ptr_Function__struct_15 = OpTypePointer Function %_struct_15 +%_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3 + %18 = OpTypeFunction %void %_ptr_Function__struct_15 + %2 = OpVariable %_ptr_Uniform__struct_5 Uniform + %1 = OpFunction %void None %14 + %19 = OpLabel + %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0 + %21 = OpFunctionCall %void %22 %20 + OpReturn + OpFunctionEnd + %22 = OpFunction %void None %18 + %23 = OpFunctionParameter %_ptr_Function__struct_15 + %24 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument ")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("type does not match Function ")); +} + +TEST_F(ValidateFunctionCall, + LogicallyMismatchedPointersIncompatableDecorations) { + // Validation should fail because the formal parameter has an incompatible + // decoration. + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpMemberDecorate %_struct_3 0 Offset 0 + OpDecorate %_runtimearr__struct_3 ArrayStride 4 + OpMemberDecorate %_struct_5 0 Offset 0 + OpDecorate %_struct_5 Block + OpMemberDecorate %_struct_15 0 NonWritable + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_3 = OpTypeStruct %int +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 + %_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_StorageBuffer__struct_5 = OpTypePointer StorageBuffer %_struct_5 + %void = OpTypeVoid + %14 = OpTypeFunction %void + %_struct_15 = OpTypeStruct %int +%_ptr_Function__struct_15 = OpTypePointer Function %_struct_15 +%_ptr_StorageBuffer__struct_3 = OpTypePointer StorageBuffer %_struct_3 + %18 = OpTypeFunction %void %_ptr_Function__struct_15 + %2 = OpVariable %_ptr_StorageBuffer__struct_5 StorageBuffer + %1 = OpFunction %void None %14 + %19 = OpLabel + %20 = OpAccessChain %_ptr_StorageBuffer__struct_3 %2 %int_0 %uint_0 + %21 = OpFunctionCall %void %22 %20 + OpReturn + OpFunctionEnd + %22 = OpFunction %void None %18 + %23 = OpFunctionParameter %_ptr_Function__struct_15 + %24 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_4); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_4)); + EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument ")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("type does not match Function ")); +} + +TEST_F(ValidateFunctionCall, + LogicallyMismatchedPointersIncompatableDecorations2) { + // Validation should fail because the formal parameter has an incompatible + // decoration. + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpMemberDecorate %_struct_3 0 Offset 0 + OpDecorate %_runtimearr__struct_3 ArrayStride 4 + OpMemberDecorate %_struct_5 0 Offset 0 + OpDecorate %_struct_5 BufferBlock + OpDecorate %_ptr_Uniform__struct_3 ArrayStride 4 + OpDecorate %_ptr_Uniform__struct_3_0 ArrayStride 8 + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %_struct_3 = OpTypeStruct %int +%_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3 + %_struct_5 = OpTypeStruct %_runtimearr__struct_3 +%_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5 + %void = OpTypeVoid + %14 = OpTypeFunction %void +%_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3 +%_ptr_Uniform__struct_3_0 = OpTypePointer Uniform %_struct_3 + %18 = OpTypeFunction %void %_ptr_Uniform__struct_3_0 + %2 = OpVariable %_ptr_Uniform__struct_5 Uniform + %1 = OpFunction %void None %14 + %19 = OpLabel + %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0 + %21 = OpFunctionCall %void %22 %20 + OpReturn + OpFunctionEnd + %22 = OpFunction %void None %18 + %23 = OpFunctionParameter %_ptr_Uniform__struct_3_0 + %24 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument ")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("type does not match Function ")); +} + +TEST_F(ValidateFunctionCall, LogicallyMismatchedPointersArraySize) { + // Validation should fail because the formal parameter array has a different + // number of element than the actual parameter. + std::string spirv = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + OpSource HLSL 600 + OpDecorate %2 DescriptorSet 0 + OpDecorate %2 Binding 0 + OpDecorate %_arr_int_uint_10 ArrayStride 4 + OpMemberDecorate %_struct_4 0 Offset 0 + OpDecorate %_runtimearr__struct_4 ArrayStride 40 + OpMemberDecorate %_struct_6 0 Offset 0 + OpDecorate %_struct_6 BufferBlock + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_5 = OpConstant %uint 5 + %uint_10 = OpConstant %uint 10 +%_arr_int_uint_10 = OpTypeArray %int %uint_10 + %_struct_4 = OpTypeStruct %_arr_int_uint_10 +%_runtimearr__struct_4 = OpTypeRuntimeArray %_struct_4 + %_struct_6 = OpTypeStruct %_runtimearr__struct_4 +%_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6 + %void = OpTypeVoid + %14 = OpTypeFunction %void +%_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4 +%_arr_int_uint_5 = OpTypeArray %int %uint_5 + %_struct_17 = OpTypeStruct %_arr_int_uint_5 +%_ptr_Function__struct_17 = OpTypePointer Function %_struct_17 + %19 = OpTypeFunction %void %_ptr_Function__struct_17 + %2 = OpVariable %_ptr_Uniform__struct_6 Uniform + %1 = OpFunction %void None %14 + %20 = OpLabel + %21 = OpAccessChain %_ptr_Uniform__struct_4 %2 %int_0 %uint_0 + %22 = OpFunctionCall %void %23 %21 + OpReturn + OpFunctionEnd + %23 = OpFunction %void None %19 + %24 = OpFunctionParameter %_ptr_Function__struct_17 + %25 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv); + spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument ")); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("type does not match Function ")); +} + INSTANTIATE_TEST_SUITE_P(StorageClass, ValidateFunctionCall, Values("UniformConstant", "Input", "Uniform", "Output", "Workgroup", "Private", "Function", diff --git a/test/val/val_storage_test.cpp b/test/val/val_storage_test.cpp index c02b7690..f54b425e 100644 --- a/test/val/val_storage_test.cpp +++ b/test/val/val_storage_test.cpp @@ -213,7 +213,7 @@ TEST_F(ValidateStorage, RelaxedLogicalPointerFunctionParam) { OpFunctionEnd )"; CompileSuccessfully(str); - getValidatorOptions()->relax_logical_pointer = true; + getValidatorOptions()->before_hlsl_legalization = true; ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index 3e579809..cc495401 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -760,7 +760,7 @@ OptStatus ParseFlags(int argc, const char** argv, // If we were requested to legalize SPIR-V generated from the HLSL // front-end, skip validation. if (0 == strcmp(cur_arg, "--legalize-hlsl")) { - validator_options->SetRelaxLogicalPointer(true); + validator_options->SetBeforeHlslLegalization(true); } } } else { diff --git a/tools/reduce/reduce.cpp b/tools/reduce/reduce.cpp index 55d09833..7de3aa8b 100644 --- a/tools/reduce/reduce.cpp +++ b/tools/reduce/reduce.cpp @@ -104,11 +104,12 @@ Options (in lexicographical order): Display reducer version information. Supported validator options are as follows. See `spirv-val --help` for details. - --relax-logical-pointer + --before-hlsl-legalization --relax-block-layout + --relax-logical-pointer + --relax-struct-store --scalar-block-layout --skip-block-layout - --relax-struct-store )", program, program); } @@ -166,6 +167,8 @@ ReduceStatus ParseFlags(int argc, const char** argv, const char** in_file, positional_arg_index++; } else if (0 == strcmp(cur_arg, "--fail-on-validation-error")) { reducer_options->set_fail_on_validation_error(true); + } else if (0 == strcmp(cur_arg, "--before-hlsl-legalization")) { + validator_options->SetBeforeHlslLegalization(true); } else if (0 == strcmp(cur_arg, "--relax-logical-pointer")) { validator_options->SetRelaxLogicalPointer(true); } else if (0 == strcmp(cur_arg, "--relax-block-layout")) { diff --git a/tools/val/val.cpp b/tools/val/val.cpp index 49ff8a93..a61f4d1f 100644 --- a/tools/val/val.cpp +++ b/tools/val/val.cpp @@ -62,6 +62,8 @@ Options: --relax-struct-store Allow store from one struct type to a different type with compatible layout and members. + --before-hlsl-legalization Allows code patterns that are intended to be + fixed by spirv-opt's legalization passes. --version Display validator version information. --target-env {vulkan1.0|vulkan1.1|vulkan1.1spv1.4|opencl2.2|spv1.0|spv1.1| spv1.2|spv1.3|spv1.4|webgpu0} @@ -138,6 +140,8 @@ int main(int argc, char** argv) { continue_processing = false; return_code = 1; } + } else if (0 == strcmp(cur_arg, "--before-hlsl-legalization")) { + options.SetBeforeHlslLegalization(true); } else if (0 == strcmp(cur_arg, "--relax-logical-pointer")) { options.SetRelaxLogicalPointer(true); } else if (0 == strcmp(cur_arg, "--relax-block-layout")) {