Fixed incorrect checks for HLSL aggregate types leading to crashes
Some of our code was using clang::Type::isAggregateType, which uses a specific C++ definition that doesn't make sense in HLSL world since derived structs are definitely aggregates for us. This could lead to crashes as we considered some aggregates as scalar.
This commit is contained in:
Коммит
cbc699d56a
|
@ -382,9 +382,8 @@ bool IsHLSLLineStreamType(clang::QualType type);
|
|||
bool IsHLSLTriangleStreamType(clang::QualType type);
|
||||
bool IsHLSLStreamOutputType(clang::QualType type);
|
||||
bool IsHLSLResourceType(clang::QualType type);
|
||||
bool IsHLSLNumeric(clang::QualType type);
|
||||
bool IsHLSLNumericUserDefinedType(clang::QualType type);
|
||||
bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type);
|
||||
bool IsHLSLAggregateType(clang::QualType type);
|
||||
clang::QualType GetHLSLResourceResultType(clang::QualType type);
|
||||
bool IsIncompleteHLSLResourceArrayType(clang::ASTContext& context, clang::QualType type);
|
||||
clang::QualType GetHLSLInputPatchElementType(clang::QualType type);
|
||||
|
|
|
@ -91,7 +91,7 @@ bool IsHLSLVecType(clang::QualType type) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsHLSLNumeric(clang::QualType type) {
|
||||
static bool IsHLSLNumeric(clang::QualType type) {
|
||||
const clang::Type *Ty = type.getCanonicalType().getTypePtr();
|
||||
if (isa<RecordType>(Ty)) {
|
||||
if (IsHLSLVecMatType(type))
|
||||
|
@ -125,9 +125,11 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type) {
|
||||
// Aggregate types are arrays and user-defined structs
|
||||
if (context.getAsArrayType(type) != nullptr) return true;
|
||||
// Aggregate types are arrays and user-defined structs
|
||||
bool IsHLSLAggregateType(clang::QualType type) {
|
||||
type = type.getCanonicalType();
|
||||
if (isa<clang::ArrayType>(type)) return true;
|
||||
|
||||
const RecordType *Record = dyn_cast<RecordType>(type);
|
||||
return Record != nullptr
|
||||
&& !IsHLSLVecMatType(type) && !IsHLSLResourceType(type)
|
||||
|
|
|
@ -1397,7 +1397,7 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
|
|||
}
|
||||
}
|
||||
|
||||
if (hlsl::IsHLSLAggregateType(getContext(), LV.getType())) {
|
||||
if (hlsl::IsHLSLAggregateType(LV.getType())) {
|
||||
// We cannot load the value because we don't expect to ever have
|
||||
// user-defined struct or array-typed llvm registers, only pointers to them.
|
||||
// To preserve the snapshot semantics of LValue loads, we copy the
|
||||
|
|
|
@ -1825,7 +1825,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
|
|||
// If the aggregate type is the cast source, it should be a pointer.
|
||||
// Aggregate to aggregate casts are handled in CGExprAgg.cpp
|
||||
auto areCompoundAndNumeric = [this](QualType lhs, QualType rhs) {
|
||||
return hlsl::IsHLSLAggregateType(CGF.getContext(), lhs)
|
||||
return hlsl::IsHLSLAggregateType(lhs)
|
||||
&& (rhs->isBuiltinType() || hlsl::IsHLSLVecMatType(rhs));
|
||||
};
|
||||
assert(Src->getType()->isPointerTy()
|
||||
|
@ -1843,7 +1843,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
|
|||
return CGF.CGM.getHLSLRuntime().EmitHLSLMatrixLoad(CGF, DstPtr, DestTy);
|
||||
|
||||
// Structs/arrays are pointers to temporaries
|
||||
if (hlsl::IsHLSLAggregateType(CGF.getContext(), DestTy))
|
||||
if (hlsl::IsHLSLAggregateType(DestTy))
|
||||
return DstPtr;
|
||||
|
||||
// Scalars/vectors are loaded regularly
|
||||
|
|
|
@ -2644,7 +2644,7 @@ bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc,
|
|||
EltTy = hlsl::GetHLSLVecElementType(Ty);
|
||||
} else if (hlsl::IsHLSLMatType(Ty)) {
|
||||
EltTy = hlsl::GetHLSLMatElementType(Ty);
|
||||
} else if (resultTy->isAggregateType()) {
|
||||
} else if (hlsl::IsHLSLAggregateType(resultTy)) {
|
||||
// Struct or array in a none-struct resource.
|
||||
std::vector<QualType> ScalarTys;
|
||||
CollectScalarTypes(ScalarTys, resultTy);
|
||||
|
@ -7012,7 +7012,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
|
|||
QualType ParamTy = Param->getType().getNonReferenceType();
|
||||
bool RValOnRef = false;
|
||||
if (!Param->isModifierOut()) {
|
||||
if (!ParamTy->isAggregateType() || hlsl::IsHLSLMatType(ParamTy)) {
|
||||
if (!hlsl::IsHLSLAggregateType(ParamTy)) {
|
||||
if (Arg->isRValue() && Param->getType()->isReferenceType()) {
|
||||
// RValue on a reference type.
|
||||
if (const CStyleCastExpr *cCast = dyn_cast<CStyleCastExpr>(Arg)) {
|
||||
|
@ -7108,7 +7108,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
|
|||
!isObject) {
|
||||
QualType ArgTy = Arg->getType();
|
||||
Value *outVal = nullptr;
|
||||
bool isAggregateTy = ParamTy->isAggregateType() && !IsHLSLVecMatType(ParamTy);
|
||||
bool isAggregateTy = hlsl::IsHLSLAggregateType(ParamTy);
|
||||
if (!isAggregateTy) {
|
||||
if (!IsHLSLMatType(ParamTy)) {
|
||||
RValue outRVal = CGF.EmitLoadOfLValue(argLV, SourceLocation());
|
||||
|
@ -7151,13 +7151,12 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
|
|||
|
||||
Value *outVal = nullptr;
|
||||
|
||||
bool isAggrageteTy = ArgTy->isAggregateType();
|
||||
isAggrageteTy &= !IsHLSLVecMatType(ArgTy);
|
||||
bool isAggregateTy = hlsl::IsHLSLAggregateType(ArgTy);
|
||||
|
||||
bool isObject = dxilutil::IsHLSLObjectType(
|
||||
tmpArgAddr->getType()->getPointerElementType());
|
||||
if (!isObject) {
|
||||
if (!isAggrageteTy) {
|
||||
if (!isAggregateTy) {
|
||||
if (!IsHLSLMatType(ParamTy))
|
||||
outVal = CGF.Builder.CreateLoad(tmpArgAddr);
|
||||
else
|
||||
|
|
|
@ -4639,7 +4639,7 @@ public:
|
|||
|
||||
// Change return type to rvalue reference type for aggregate types
|
||||
QualType retTy = parameterTypes[0];
|
||||
if (retTy->isAggregateType() && !IsHLSLVecMatType(retTy))
|
||||
if (hlsl::IsHLSLAggregateType(retTy))
|
||||
parameterTypes[0] = m_context->getRValueReferenceType(retTy);
|
||||
|
||||
// Create a new specialization.
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
|
||||
|
||||
// Regression test for GitHub #1929, where we used the C++ definition
|
||||
// of an aggregate type and failed to match derived structs.
|
||||
|
||||
// CHECK: ret void
|
||||
|
||||
struct Base {};
|
||||
struct Derived : Base {};
|
||||
void f(inout Derived d) {}
|
||||
void main()
|
||||
{
|
||||
Derived d;
|
||||
f(d);
|
||||
}
|
Загрузка…
Ссылка в новой задаче