[spirv] Relax SV_Position type requirements (#4275)

A valid vertex shader output variable with SV_Position semantics may be
constructed from any HLSL BuiltinType that translates to a 32-bit
floating point type in the SPIR-V backend, so relax the requirements to
allow the use of additonal types (such as half4) when
-enable-16bit-types is false.

Fixes #4262
This commit is contained in:
Natalie Chouinard 2022-02-22 11:58:44 -05:00 коммит произвёл GitHub
Родитель 3f8e22cec1
Коммит 3fcd83e43b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 92 добавлений и 29 удалений

Просмотреть файл

@ -517,9 +517,27 @@ bool insertSeenSemanticsForEntryPointIfNotExist(
return true;
}
// Returns whether the type is float4 or a composite type recursively including
// only float4 e.g., float4, float4[1], struct S { float4 foo[1]; }.
bool containOnlyVecWithFourFloats(QualType type) {
// Returns whether the type is translated to a 32-bit floating point type,
// depending on whether SPIR-V codegen options are configured to use 16-bit
// types when possible.
bool is32BitFloatingPointType(BuiltinType::Kind kind, bool use16Bit) {
// Always translated into 32-bit floating point types.
if (kind == BuiltinType::Float || kind == BuiltinType::LitFloat)
return true;
// Translated into 32-bit floating point types when run without
// -enable-16bit-types.
if (kind == BuiltinType::Half || kind == BuiltinType::HalfFloat ||
kind == BuiltinType::Min10Float || kind == BuiltinType::Min16Float)
return !use16Bit;
return false;
}
// Returns whether the type is a 4-component 32-bit float or a composite type
// recursively including only such a vector e.g., float4, float4[1], struct S {
// float4 foo[1]; }.
bool containOnlyVecWithFourFloats(QualType type, bool use16Bit) {
if (type->isReferenceType())
type = type->getPointeeType();
@ -532,7 +550,7 @@ bool containOnlyVecWithFourFloats(QualType type) {
(const ConstantArrayType *)type->getAsArrayTypeUnsafe();
elemCount = hlsl::GetArraySize(type);
return elemCount == 1 &&
containOnlyVecWithFourFloats(arrayType->getElementType());
containOnlyVecWithFourFloats(arrayType->getElementType(), use16Bit);
}
if (const auto *structType = type->getAs<RecordType>()) {
@ -540,7 +558,7 @@ bool containOnlyVecWithFourFloats(QualType type) {
for (const auto *field : structType->getDecl()->fields()) {
if (fieldCount != 0)
return false;
if (!containOnlyVecWithFourFloats(field->getType()))
if (!containOnlyVecWithFourFloats(field->getType(), use16Bit))
return false;
++fieldCount;
}
@ -550,7 +568,8 @@ bool containOnlyVecWithFourFloats(QualType type) {
QualType elemType = {};
if (isVectorType(type, &elemType, &elemCount)) {
if (const auto *builtinType = elemType->getAs<BuiltinType>()) {
return elemCount == 4 && builtinType->getKind() == BuiltinType::Float;
return elemCount == 4 &&
is32BitFloatingPointType(builtinType->getKind(), use16Bit);
}
return false;
}
@ -3300,9 +3319,10 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
// by VSOut, HS/DS/GS In/Out, MSOut.
case hlsl::Semantic::Kind::Position: {
if (sigPointKind == hlsl::SigPoint::Kind::VSOut &&
!containOnlyVecWithFourFloats(type)) {
emitError("semantic Position must be float4 or a composite type "
"recursively including only float4",
!containOnlyVecWithFourFloats(
type, theEmitter.getSpirvOptions().enable16BitTypes)) {
emitError("SV_Position must be a 4-component 32-bit float vector or a "
"composite which recursively contains only such a vector",
srcLoc);
}

Просмотреть файл

@ -2929,9 +2929,11 @@ float4 PSMain(PSInput input) : SV_TARGET
std::string getVertexPositionTypeTestShader(const std::string &subType,
const std::string &positionType,
const std::string &check) {
const std::string command(R"(// RUN: %dxc -T vs_6_0 -E main)");
const std::string code = command + subType + R"(
const std::string &check,
bool use16bit) {
const std::string code = std::string(R"(// RUN: %dxc -T vs_6_2 -E main)") +
(use16bit ? R"( -enable-16bit-types)" : R"()") + R"(
)" + subType + R"(
struct output {
)" + positionType + R"(
};
@ -2946,33 +2948,37 @@ output main() : SV_Position
}
const char *kInvalidPositionTypeForVSErrorMessage =
"// CHECK: error: semantic Position must be float4 or a composite type "
"recursively including only float4";
"// CHECK: error: SV_Position must be a 4-component 32-bit float vector or "
"a composite which recursively contains only such a vector";
TEST_F(FileTest, PositionInVSWithArrayType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "float x[4];", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "float x[4];", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithDoubleType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "double4 x;", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "double4 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithIntType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "int4 x;", kInvalidPositionTypeForVSErrorMessage),
"", "int4 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithMatrixType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "float1x4 x;", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "float1x4 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithInvalidFloatVectorType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "float3 x;", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "float3 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithInvalidInnerStructType) {
runCodeTest(getVertexPositionTypeTestShader(
@ -2980,7 +2986,8 @@ TEST_F(FileTest, PositionInVSWithInvalidInnerStructType) {
struct InvalidType {
float3 x;
};)",
"InvalidType x;", kInvalidPositionTypeForVSErrorMessage),
"InvalidType x;", kInvalidPositionTypeForVSErrorMessage,
false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithValidInnerStructType) {
@ -2991,7 +2998,43 @@ struct validType {
"validType x;", R"(
// CHECK: %validType = OpTypeStruct %v4float
// CHECK: %output = OpTypeStruct %validType
)"));
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidFloatType) {
runCodeTest(getVertexPositionTypeTestShader("", "float4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidMin10Float4Type) {
runCodeTest(getVertexPositionTypeTestShader("", "min10float4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidMin16Float4Type) {
runCodeTest(getVertexPositionTypeTestShader("", "min16float4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidHalf4Type) {
runCodeTest(getVertexPositionTypeTestShader("", "half4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithInvalidHalf4Type) {
runCodeTest(getVertexPositionTypeTestShader(
"", "half4 x;", kInvalidPositionTypeForVSErrorMessage, true),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithInvalidMin10Float4Type) {
runCodeTest(
getVertexPositionTypeTestShader(
"", "min10float4 x;", kInvalidPositionTypeForVSErrorMessage, true),
Expect::Failure);
}
TEST_F(FileTest, ShaderDebugInfoFunction) {
runFileTest("shader.debug.function.hlsl");