diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index 8f3d5b50a..a1847e57a 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -382,9 +382,8 @@ bool IsHLSLLineStreamType(clang::QualType type); bool IsHLSLTriangleStreamType(clang::QualType type); bool IsHLSLStreamOutputType(clang::QualType type); bool IsHLSLResourceType(clang::QualType type); -bool IsHLSLNumeric(clang::QualType type); bool IsHLSLNumericUserDefinedType(clang::QualType type); -bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type); +bool IsHLSLAggregateType(clang::QualType type); clang::QualType GetHLSLResourceResultType(clang::QualType type); bool IsIncompleteHLSLResourceArrayType(clang::ASTContext& context, clang::QualType type); clang::QualType GetHLSLInputPatchElementType(clang::QualType type); diff --git a/tools/clang/lib/AST/HlslTypes.cpp b/tools/clang/lib/AST/HlslTypes.cpp index f38173410..24301675a 100644 --- a/tools/clang/lib/AST/HlslTypes.cpp +++ b/tools/clang/lib/AST/HlslTypes.cpp @@ -91,7 +91,7 @@ bool IsHLSLVecType(clang::QualType type) { return false; } -bool IsHLSLNumeric(clang::QualType type) { +static bool IsHLSLNumeric(clang::QualType type) { const clang::Type *Ty = type.getCanonicalType().getTypePtr(); if (isa(Ty)) { if (IsHLSLVecMatType(type)) @@ -125,9 +125,11 @@ bool IsHLSLNumericUserDefinedType(clang::QualType type) { return false; } -bool IsHLSLAggregateType(clang::ASTContext& context, clang::QualType type) { - // Aggregate types are arrays and user-defined structs - if (context.getAsArrayType(type) != nullptr) return true; +// Aggregate types are arrays and user-defined structs +bool IsHLSLAggregateType(clang::QualType type) { + type = type.getCanonicalType(); + if (isa(type)) return true; + const RecordType *Record = dyn_cast(type); return Record != nullptr && !IsHLSLVecMatType(type) && !IsHLSLResourceType(type) diff --git a/tools/clang/lib/CodeGen/CGExpr.cpp b/tools/clang/lib/CodeGen/CGExpr.cpp index 5b612bb0e..0a207e10e 100644 --- a/tools/clang/lib/CodeGen/CGExpr.cpp +++ b/tools/clang/lib/CodeGen/CGExpr.cpp @@ -1397,7 +1397,7 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) { } } - if (hlsl::IsHLSLAggregateType(getContext(), LV.getType())) { + if (hlsl::IsHLSLAggregateType(LV.getType())) { // We cannot load the value because we don't expect to ever have // user-defined struct or array-typed llvm registers, only pointers to them. // To preserve the snapshot semantics of LValue loads, we copy the diff --git a/tools/clang/lib/CodeGen/CGExprScalar.cpp b/tools/clang/lib/CodeGen/CGExprScalar.cpp index f14ee177c..6072e0806 100644 --- a/tools/clang/lib/CodeGen/CGExprScalar.cpp +++ b/tools/clang/lib/CodeGen/CGExprScalar.cpp @@ -1825,7 +1825,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { // If the aggregate type is the cast source, it should be a pointer. // Aggregate to aggregate casts are handled in CGExprAgg.cpp auto areCompoundAndNumeric = [this](QualType lhs, QualType rhs) { - return hlsl::IsHLSLAggregateType(CGF.getContext(), lhs) + return hlsl::IsHLSLAggregateType(lhs) && (rhs->isBuiltinType() || hlsl::IsHLSLVecMatType(rhs)); }; assert(Src->getType()->isPointerTy() @@ -1843,7 +1843,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { return CGF.CGM.getHLSLRuntime().EmitHLSLMatrixLoad(CGF, DstPtr, DestTy); // Structs/arrays are pointers to temporaries - if (hlsl::IsHLSLAggregateType(CGF.getContext(), DestTy)) + if (hlsl::IsHLSLAggregateType(DestTy)) return DstPtr; // Scalars/vectors are loaded regularly diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 4250d3968..bbddba548 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -2644,7 +2644,7 @@ bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc, EltTy = hlsl::GetHLSLVecElementType(Ty); } else if (hlsl::IsHLSLMatType(Ty)) { EltTy = hlsl::GetHLSLMatElementType(Ty); - } else if (resultTy->isAggregateType()) { + } else if (hlsl::IsHLSLAggregateType(resultTy)) { // Struct or array in a none-struct resource. std::vector ScalarTys; CollectScalarTypes(ScalarTys, resultTy); @@ -7012,7 +7012,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit( QualType ParamTy = Param->getType().getNonReferenceType(); bool RValOnRef = false; if (!Param->isModifierOut()) { - if (!ParamTy->isAggregateType() || hlsl::IsHLSLMatType(ParamTy)) { + if (!hlsl::IsHLSLAggregateType(ParamTy)) { if (Arg->isRValue() && Param->getType()->isReferenceType()) { // RValue on a reference type. if (const CStyleCastExpr *cCast = dyn_cast(Arg)) { @@ -7108,7 +7108,7 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit( !isObject) { QualType ArgTy = Arg->getType(); Value *outVal = nullptr; - bool isAggregateTy = ParamTy->isAggregateType() && !IsHLSLVecMatType(ParamTy); + bool isAggregateTy = hlsl::IsHLSLAggregateType(ParamTy); if (!isAggregateTy) { if (!IsHLSLMatType(ParamTy)) { RValue outRVal = CGF.EmitLoadOfLValue(argLV, SourceLocation()); @@ -7151,13 +7151,12 @@ void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack( Value *outVal = nullptr; - bool isAggrageteTy = ArgTy->isAggregateType(); - isAggrageteTy &= !IsHLSLVecMatType(ArgTy); + bool isAggregateTy = hlsl::IsHLSLAggregateType(ArgTy); bool isObject = dxilutil::IsHLSLObjectType( tmpArgAddr->getType()->getPointerElementType()); if (!isObject) { - if (!isAggrageteTy) { + if (!isAggregateTy) { if (!IsHLSLMatType(ParamTy)) outVal = CGF.Builder.CreateLoad(tmpArgAddr); else diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 3ef13faf7..62a35fa10 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -4639,7 +4639,7 @@ public: // Change return type to rvalue reference type for aggregate types QualType retTy = parameterTypes[0]; - if (retTy->isAggregateType() && !IsHLSLVecMatType(retTy)) + if (hlsl::IsHLSLAggregateType(retTy)) parameterTypes[0] = m_context->getRValueReferenceType(retTy); // Create a new specialization. diff --git a/tools/clang/test/CodeGenHLSL/declarations/functions/inout_derived_struct_no_crash.hlsl b/tools/clang/test/CodeGenHLSL/declarations/functions/inout_derived_struct_no_crash.hlsl new file mode 100644 index 000000000..955fbc032 --- /dev/null +++ b/tools/clang/test/CodeGenHLSL/declarations/functions/inout_derived_struct_no_crash.hlsl @@ -0,0 +1,15 @@ +// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s + +// Regression test for GitHub #1929, where we used the C++ definition +// of an aggregate type and failed to match derived structs. + +// CHECK: ret void + +struct Base {}; +struct Derived : Base {}; +void f(inout Derived d) {} +void main() +{ + Derived d; + f(d); +} \ No newline at end of file