From bc4a68b6b4059ddfc14649512e1b944b983ddb2f Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Mon, 24 Apr 2017 14:34:57 -0700 Subject: [PATCH] Use SRet for struct return type. (#243) 1. Use SRet for struct return type. 2. Add SM6.1 and DXIL1.1. --- include/dxc/HLSL/DxilShaderModel.h | 6 +- lib/HLSL/DxilGenerationPass.cpp | 20 +++ lib/HLSL/DxilModule.cpp | 1 + lib/HLSL/DxilShaderModel.cpp | 21 +++ lib/HLSL/HLModule.cpp | 5 +- tools/clang/lib/CodeGen/CGHLSLMS.cpp | 128 ++++++++++-------- tools/clang/lib/CodeGen/TargetInfo.cpp | 20 ++- .../test/CodeGenHLSL/BasicHLSL11_PS2.hlsl | 1 - tools/clang/test/CodeGenHLSL/class.hlsl | 25 ++++ tools/clang/unittests/HLSL/CompilerTest.cpp | 5 + 10 files changed, 170 insertions(+), 62 deletions(-) create mode 100644 tools/clang/test/CodeGenHLSL/class.hlsl diff --git a/include/dxc/HLSL/DxilShaderModel.h b/include/dxc/HLSL/DxilShaderModel.h index 6289bd8f6..8c56d457b 100644 --- a/include/dxc/HLSL/DxilShaderModel.h +++ b/include/dxc/HLSL/DxilShaderModel.h @@ -29,7 +29,7 @@ public: // Major/Minor version of highest shader model static const unsigned kHighestMajor = 6; - static const unsigned kHighestMinor = 0; + static const unsigned kHighestMinor = 1; bool IsPS() const { return m_Kind == Kind::Pixel; } bool IsVS() const { return m_Kind == Kind::Vertex; } @@ -42,8 +42,10 @@ public: Kind GetKind() const { return m_Kind; } unsigned GetMajor() const { return m_Major; } unsigned GetMinor() const { return m_Minor; } + void GetDxilVersion(unsigned &DxilMajor, unsigned &DxilMinor) const; bool IsSM50Plus() const { return m_Major >= 5; } bool IsSM51Plus() const { return m_Major > 5 || (m_Major == 5 && m_Minor >= 1); } + bool IsSM61Plus() const { return m_Major > 6 || (m_Major == 6 && m_Minor >= 1); } const char *GetName() const { return m_pszName; } std::string GetKindName() const; unsigned GetNumTempRegs() const { return DXIL::kMaxTempRegCount; } @@ -79,7 +81,7 @@ private: unsigned m_NumInputRegs, unsigned m_NumOutputRegs, bool m_bUAVs, bool m_bTypedUavs, unsigned m_UAVRegsLim); - static const unsigned kNumShaderModels = 27; + static const unsigned kNumShaderModels = 33; static const ShaderModel ms_ShaderModels[kNumShaderModels]; static const ShaderModel *GetInvalid(); diff --git a/lib/HLSL/DxilGenerationPass.cpp b/lib/HLSL/DxilGenerationPass.cpp index f41b76519..81338fdae 100644 --- a/lib/HLSL/DxilGenerationPass.cpp +++ b/lib/HLSL/DxilGenerationPass.cpp @@ -2838,10 +2838,30 @@ public: const char *getPassName() const override { return "HLSL DXIL Metadata Emit"; } + void patchSM60(Module &M) { + for (iplist::iterator F : M.getFunctionList()) { + for (Function::iterator BBI = F->begin(), BBE = F->end(); BBI != BBE; + ++BBI) { + BasicBlock *BB = BBI; + for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE; + ++II) { + Instruction *I = II; + if (I->getMetadata(LLVMContext::MD_noalias)) { + I->setMetadata(LLVMContext::MD_noalias, nullptr); + } + } + } + } + } + bool runOnModule(Module &M) override { if (M.HasDxilModule()) { // Remove store undef output. hlsl::OP *hlslOP = M.GetDxilModule().GetOP(); + bool bIsSM61Plus = M.GetDxilModule().GetShaderModel()->IsSM61Plus(); + if (!bIsSM61Plus) { + patchSM60(M); + } for (iplist::iterator F : M.getFunctionList()) { if (!hlslOP->IsDxilOpFunc(F)) continue; diff --git a/lib/HLSL/DxilModule.cpp b/lib/HLSL/DxilModule.cpp index 29899c942..ef0f6a2f2 100644 --- a/lib/HLSL/DxilModule.cpp +++ b/lib/HLSL/DxilModule.cpp @@ -927,6 +927,7 @@ vector &DxilModule::GetLLVMUsed() { // DXIL metadata serialization/deserialization. void DxilModule::EmitDxilMetadata() { + m_pSM->GetDxilVersion(m_DxilMajor, m_DxilMinor); m_pMDHelper->EmitDxilVersion(m_DxilMajor, m_DxilMinor); m_pMDHelper->EmitDxilShaderModel(m_pSM); diff --git a/lib/HLSL/DxilShaderModel.cpp b/lib/HLSL/DxilShaderModel.cpp index 6e17792a8..1c8e3fd5d 100644 --- a/lib/HLSL/DxilShaderModel.cpp +++ b/lib/HLSL/DxilShaderModel.cpp @@ -101,6 +101,21 @@ const ShaderModel *ShaderModel::GetByName(const char *pszName) { return Get(Kind, Major, Minor); } +void ShaderModel::GetDxilVersion(unsigned &DxilMajor, unsigned &DxilMinor) const { + DXASSERT(m_Major == 6, "invalid major"); + switch (m_Minor) { + case 0: + DxilMinor = 0; + break; + case 1: + DxilMinor = 1; + break; + default: + DXASSERT(0, "invalid minor"); + break; + } +} + std::string ShaderModel::GetKindName() const { return std::string(m_pszName).substr(0, 2); } @@ -118,32 +133,38 @@ const ShaderModel ShaderModel::ms_ShaderModels[kNumShaderModels] = { SM(Kind::Compute, 5, 0, "cs_5_0", 0, 0, true, true, 64), SM(Kind::Compute, 5, 1, "cs_5_1", 0, 0, true, true, UINT_MAX), SM(Kind::Compute, 6, 0, "cs_6_0", 0, 0, true, true, UINT_MAX), + SM(Kind::Compute, 6, 1, "cs_6_1", 0, 0, true, true, UINT_MAX), SM(Kind::Domain, 5, 0, "ds_5_0", 32, 32, true, true, 64), SM(Kind::Domain, 5, 1, "ds_5_1", 32, 32, true, true, UINT_MAX), SM(Kind::Domain, 6, 0, "ds_6_0", 32, 32, true, true, UINT_MAX), + SM(Kind::Domain, 6, 1, "ds_6_1", 32, 32, true, true, UINT_MAX), SM(Kind::Geometry, 4, 0, "gs_4_0", 16, 32, false, false, 0), SM(Kind::Geometry, 4, 1, "gs_4_1", 32, 32, false, false, 0), SM(Kind::Geometry, 5, 0, "gs_5_0", 32, 32, true, true, 64), SM(Kind::Geometry, 5, 1, "gs_5_1", 32, 32, true, true, UINT_MAX), SM(Kind::Geometry, 6, 0, "gs_6_0", 32, 32, true, true, UINT_MAX), + SM(Kind::Geometry, 6, 1, "gs_6_1", 32, 32, true, true, UINT_MAX), SM(Kind::Hull, 5, 0, "hs_5_0", 32, 32, true, true, 64), SM(Kind::Hull, 5, 1, "hs_5_1", 32, 32, true, true, UINT_MAX), SM(Kind::Hull, 6, 0, "hs_6_0", 32, 32, true, true, UINT_MAX), + SM(Kind::Hull, 6, 1, "hs_6_1", 32, 32, true, true, UINT_MAX), SM(Kind::Pixel, 4, 0, "ps_4_0", 32, 8, false, false, 0), SM(Kind::Pixel, 4, 1, "ps_4_1", 32, 8, false, false, 0), SM(Kind::Pixel, 5, 0, "ps_5_0", 32, 8, true, true, 64), SM(Kind::Pixel, 5, 1, "ps_5_1", 32, 8, true, true, UINT_MAX), SM(Kind::Pixel, 6, 0, "ps_6_0", 32, 8, true, true, UINT_MAX), + SM(Kind::Pixel, 6, 1, "ps_6_1", 32, 8, true, true, UINT_MAX), SM(Kind::Vertex, 4, 0, "vs_4_0", 16, 16, false, false, 0), SM(Kind::Vertex, 4, 1, "vs_4_1", 32, 32, false, false, 0), SM(Kind::Vertex, 5, 0, "vs_5_0", 32, 32, true, true, 64), SM(Kind::Vertex, 5, 1, "vs_5_1", 32, 32, true, true, UINT_MAX), SM(Kind::Vertex, 6, 0, "vs_6_0", 32, 32, true, true, UINT_MAX), + SM(Kind::Vertex, 6, 1, "vs_6_1", 32, 32, true, true, UINT_MAX), SM(Kind::Invalid, 0, 0, "invalid", 0, 0, false, false, 0), }; diff --git a/lib/HLSL/HLModule.cpp b/lib/HLSL/HLModule.cpp index ab756750f..5631c1eed 100644 --- a/lib/HLSL/HLModule.cpp +++ b/lib/HLSL/HLModule.cpp @@ -61,8 +61,8 @@ HLModule::HLModule(Module *pModule) pModule, llvm::make_unique(pModule))) , m_pDebugInfoFinder(nullptr) , m_pSM(nullptr) - , m_DxilMajor(1) - , m_DxilMinor(0) + , m_DxilMajor(DXIL::kDxilMajor) + , m_DxilMinor(DXIL::kDxilMinor) , m_pOP(llvm::make_unique(pModule->getContext(), pModule)) , m_pTypeSystem(llvm::make_unique(pModule)) { DXASSERT_NOMSG(m_pModule != nullptr); @@ -83,6 +83,7 @@ OP *HLModule::GetOP() const { return m_pOP.get(); } void HLModule::SetShaderModel(const ShaderModel *pSM) { DXASSERT(m_pSM == nullptr, "shader model must not change for the module"); m_pSM = pSM; + m_pSM->GetDxilVersion(m_DxilMajor, m_DxilMinor); m_pMDHelper->SetShaderModel(m_pSM); CreateSignatures(m_pSM, m_InputSignature, m_OutputSignature, m_PatchConstantSignature, m_RootSignature); } diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index e3f91cf41..bd864c731 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -301,11 +301,12 @@ CGMSHLSLRuntime::CGMSHLSLRuntime(CodeGenModule &CGM) const hlsl::ShaderModel *SM = hlsl::ShaderModel::GetByName(CGM.getCodeGenOpts().HLSLProfile.c_str()); // Only accept valid, 6.0 shader model. - if (!SM->IsValid() || SM->GetMajor() != 6 || SM->GetMinor() != 0) { + if (!SM->IsValid() || SM->GetMajor() != 6) { DiagnosticsEngine &Diags = CGM.getDiags(); unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "invalid profile %0"); Diags.Report(DiagID) << CGM.getCodeGenOpts().HLSLProfile; + return; } // TODO: add AllResourceBound. if (CGM.getCodeGenOpts().HLSLAvoidControlFlow && !CGM.getCodeGenOpts().HLSLAllResourcesBound) { @@ -978,7 +979,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { if (const CXXMethodDecl *MD = dyn_cast(FD)) { const CXXRecordDecl *RD = MD->getParent(); // For nested case like sample_slice_type. - if (const CXXRecordDecl *PRD = dyn_cast(RD->getDeclContext())) { + if (const CXXRecordDecl *PRD = + dyn_cast(RD->getDeclContext())) { RD = PRD; } @@ -994,7 +996,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { SetUAVSRV(FD->getLocation(), resClass, &UAV, RD); // Set global symbol to save type. UAV.SetGlobalSymbol(UndefValue::get(Ty)); - MDNode * MD = m_pHLModule->DxilUAVToMDNode(UAV); + MDNode *MD = m_pHLModule->DxilUAVToMDNode(UAV); resMetadataMap[Ty] = MD; } break; case DXIL::ResourceClass::SRV: { @@ -1002,13 +1004,13 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { SetUAVSRV(FD->getLocation(), resClass, &SRV, RD); // Set global symbol to save type. SRV.SetGlobalSymbol(UndefValue::get(Ty)); - MDNode * Meta = m_pHLModule->DxilSRVToMDNode(SRV); + MDNode *Meta = m_pHLModule->DxilSRVToMDNode(SRV); resMetadataMap[Ty] = Meta; if (FT->getNumParams() > 1) { QualType paramTy = MD->getParamDecl(0)->getType(); // Add sampler type. if (TypeToClass(paramTy) == DXIL::ResourceClass::Sampler) { - llvm::Type * Ty = FT->getParamType(1)->getPointerElementType(); + llvm::Type *Ty = FT->getParamType(1)->getPointerElementType(); DxilSampler S; const RecordType *RT = paramTy->getAs(); DxilSampler::SamplerKind kind = @@ -1034,14 +1036,15 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // Don't need to add FunctionQual for intrinsic function. return; } - + // Set entry function const std::string &entryName = m_pHLModule->GetEntryFunctionName(); bool isEntry = FD->getNameAsString() == entryName; if (isEntry) EntryFunc = F; - std::unique_ptr funcProps = llvm::make_unique(); + std::unique_ptr funcProps = + llvm::make_unique(); // Save patch constant function to patchConstantFunctionMap. bool isPatchConstantFunction = false; @@ -1051,9 +1054,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { patchConstantFunctionMap[FD->getName()] = F; else { // TODO: This is not the same as how fxc handles patch constant functions. - // This will fail if more than one function with the same name has a SV_TessFactor semantic. - // Fxc just selects the last function defined that has the matching name when referenced - // by the patchconstantfunc attribute from the hull shader currently being compiled. + // This will fail if more than one function with the same name has a + // SV_TessFactor semantic. Fxc just selects the last function defined + // that has the matching name when referenced by the patchconstantfunc + // attribute from the hull shader currently being compiled. // Report error DiagnosticsEngine &Diags = CGM.getDiags(); unsigned DiagID = @@ -1063,8 +1067,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { return; } - for (Argument &arg : F->getArgumentList()) { - const ParmVarDecl *parmDecl = FD->getParamDecl(arg.getArgNo()); + for (ParmVarDecl *parmDecl : FD->parameters()) { QualType Ty = parmDecl->getType(); if (IsHLSLOutputPatchType(Ty)) { funcProps->ShaderProps.HS.outputControlPoints = @@ -1080,7 +1083,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // TODO: how to know VS/PS? funcProps->shaderKind = DXIL::ShaderKind::Invalid; - + DiagnosticsEngine &Diags = CGM.getDiags(); // Geometry shader. bool isGS = false; @@ -1092,8 +1095,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { funcProps->ShaderProps.GS.inputPrimitive = DXIL::InputPrimitive::Undefined; if (isEntry && !SM->IsGS()) { - unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, - "attribute maxvertexcount only valid for GS."); + unsigned DiagID = + Diags.getCustomDiagID(DiagnosticsEngine::Error, + "attribute maxvertexcount only valid for GS."); Diags.Report(Attr->getLocation(), DiagID); return; } @@ -1102,13 +1106,13 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { unsigned instanceCount = Attr->getCount(); funcProps->ShaderProps.GS.instanceCount = instanceCount; if (isEntry && !SM->IsGS()) { - unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, - "attribute maxvertexcount only valid for GS."); + unsigned DiagID = + Diags.getCustomDiagID(DiagnosticsEngine::Error, + "attribute maxvertexcount only valid for GS."); Diags.Report(Attr->getLocation(), DiagID); return; } - } - else { + } else { // Set default instance count. if (isGS) funcProps->ShaderProps.GS.instanceCount = 1; @@ -1209,7 +1213,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { FD->getAttr()) { if (isHS) { DXIL::TessellatorOutputPrimitive primitive = - StringToTessOutputPrimitive(Attr->getTopology()); + StringToTessOutputPrimitive(Attr->getTopology()); funcProps->ShaderProps.HS.outputPrimitive = primitive; } else if (isEntry && !SM->IsHS()) { unsigned DiagID = @@ -1264,26 +1268,26 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { bool isVS = false; if (const HLSLClipPlanesAttr *Attr = FD->getAttr()) { if (isEntry && !SM->IsVS()) { - unsigned DiagID = - Diags.getCustomDiagID(DiagnosticsEngine::Error, - "attribute clipplane only valid for VS."); + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, "attribute clipplane only valid for VS."); Diags.Report(Attr->getLocation(), DiagID); return; } isVS = true; - // The real job is done at EmitHLSLFunctionProlog where debug info is available. - // Only set shader kind here. + // The real job is done at EmitHLSLFunctionProlog where debug info is + // available. Only set shader kind here. funcProps->shaderKind = DXIL::ShaderKind::Vertex; } // Pixel shader. bool isPS = false; - if (const HLSLEarlyDepthStencilAttr *Attr = FD->getAttr()) { + if (const HLSLEarlyDepthStencilAttr *Attr = + FD->getAttr()) { if (isEntry && !SM->IsPS()) { - unsigned DiagID = - Diags.getCustomDiagID(DiagnosticsEngine::Error, - "attribute earlydepthstencil only valid for PS."); + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "attribute earlydepthstencil only valid for PS."); Diags.Report(Attr->getLocation(), DiagID); return; } @@ -1308,7 +1312,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { profileAttributes++; // TODO: check this in front-end and report error. - DXASSERT(profileAttributes<2, "profile attributes are mutual exclusive"); + DXASSERT(profileAttributes < 2, "profile attributes are mutual exclusive"); if (isEntry) { switch (funcProps->shaderKind) { @@ -1324,11 +1328,37 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } } - DxilFunctionAnnotation *FuncAnnotation = m_pHLModule->AddFunctionAnnotation(F); + DxilFunctionAnnotation *FuncAnnotation = + m_pHLModule->AddFunctionAnnotation(F); + bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor; + + // Param Info + unsigned streamIndex = 0; + unsigned inputPatchCount = 0; + unsigned outputPatchCount = 0; + + unsigned ArgNo = 0; + unsigned ParmIdx = 0; + + if (const CXXMethodDecl *MethodDecl = dyn_cast(FD)) { + QualType ThisTy = MethodDecl->getThisType(FD->getASTContext()); + DxilParameterAnnotation ¶mAnnotation = + FuncAnnotation->GetParameterAnnotation(ArgNo++); + // Construct annoation for this pointer. + ConstructFieldAttributedAnnotation(paramAnnotation, ThisTy, + bDefaultRowMajor); + } // Ret Info - DxilParameterAnnotation &retTyAnnotation = FuncAnnotation->GetRetTypeAnnotation(); QualType retTy = FD->getReturnType(); + DxilParameterAnnotation *pRetTyAnnotation = nullptr; + if (F->getReturnType()->isVoidTy() && !retTy->isVoidType()) { + // SRet. + pRetTyAnnotation = &FuncAnnotation->GetParameterAnnotation(ArgNo++); + } else { + pRetTyAnnotation = &FuncAnnotation->GetRetTypeAnnotation(); + } + DxilParameterAnnotation &retTyAnnotation = *pRetTyAnnotation; // keep Undefined here, we cannot decide for struct retTyAnnotation.SetInterpolationMode( GetInterpMode(FD, CompType::Kind::Invalid, /*bKeepUndefined*/ true) @@ -1336,35 +1366,22 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { SourceLocation retTySemanticLoc = SetSemantic(FD, retTyAnnotation); retTyAnnotation.SetParamInputQual(DxilParamInputQual::Out); if (isEntry) { - CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation, /*isPatchConstantFunction*/false); + CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation, + /*isPatchConstantFunction*/ false); } - bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor; - ConstructFieldAttributedAnnotation(retTyAnnotation, retTy, bDefaultRowMajor); if (FD->hasAttr()) retTyAnnotation.SetPrecise(); - // Param Info - unsigned streamIndex = 0; - unsigned inputPatchCount = 0; - unsigned outputPatchCount = 0; + for (; ArgNo < F->arg_size(); ++ArgNo, ++ParmIdx) { + DxilParameterAnnotation ¶mAnnotation = + FuncAnnotation->GetParameterAnnotation(ArgNo); - for (unsigned ArgNo = 0; ArgNo < F->arg_size(); ++ArgNo) { - unsigned ParmIdx = ArgNo; - - DxilParameterAnnotation ¶mAnnotation = FuncAnnotation->GetParameterAnnotation(ArgNo); - - if (isa(FD)) { - // skip arg0 for this pointer - if (ArgNo == 0) - continue; - // update idx for rest params - ParmIdx--; - } const ParmVarDecl *parmDecl = FD->getParamDecl(ParmIdx); - - ConstructFieldAttributedAnnotation(paramAnnotation, parmDecl->getType(), bDefaultRowMajor); + + ConstructFieldAttributedAnnotation(paramAnnotation, parmDecl->getType(), + bDefaultRowMajor); if (parmDecl->hasAttr()) paramAnnotation.SetPrecise(); @@ -1532,7 +1549,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { paramAnnotation.SetParamInputQual(dxilInputQ); if (isEntry) { - CheckParameterAnnotation(paramSemanticLoc, paramAnnotation, /*isPatchConstantFunction*/false); + CheckParameterAnnotation(paramSemanticLoc, paramAnnotation, + /*isPatchConstantFunction*/ false); } } @@ -1561,7 +1579,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize); } - for (const ValueDecl*param : FD->params()) { + for (const ValueDecl *param : FD->params()) { QualType Ty = param->getType(); AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize); } diff --git a/tools/clang/lib/CodeGen/TargetInfo.cpp b/tools/clang/lib/CodeGen/TargetInfo.cpp index 0c3e02c1b..1fe4c96f5 100644 --- a/tools/clang/lib/CodeGen/TargetInfo.cpp +++ b/tools/clang/lib/CodeGen/TargetInfo.cpp @@ -6181,7 +6181,14 @@ public: ABIArgInfo classifyReturnType(QualType RetTy) const { if (RetTy->isVoidType()) return ABIArgInfo::getIgnore(); - // do not create SRet for HLSL + if (isAggregateTypeForABI(RetTy)) + return ABIArgInfo::getIndirect(0); + + // Treat an enum type as its underlying type. + if (const EnumType *EnumTy = RetTy->getAs()) + RetTy = EnumTy->getDecl()->getIntegerType(); + + // do not use extend for hlsl. return ABIArgInfo::getDirect(CGT.ConvertType(RetTy)); } @@ -6218,7 +6225,16 @@ ABIArgInfo MSDXILABIInfo::classifyArgumentType(QualType Ty) const { } void MSDXILABIInfo::computeInfo(CGFunctionInfo &FI) const { - FI.getReturnInfo() = classifyReturnType(FI.getReturnType()); + QualType RetTy = FI.getReturnType(); + if (RetTy->isVoidType()) { + FI.getReturnInfo() = ABIArgInfo::getIgnore(); + } else if (isAggregateTypeForABI(RetTy)) { + if (!getCXXABI().classifyReturnType(FI)) + FI.getReturnInfo() = classifyReturnType(RetTy); + } else { + // Make vector and matrix direct ret. + FI.getReturnInfo() = classifyReturnType(RetTy); + } for (auto &I : FI.arguments()) { I.info = classifyArgumentType(I.type); // Do not flat matrix diff --git a/tools/clang/test/CodeGenHLSL/BasicHLSL11_PS2.hlsl b/tools/clang/test/CodeGenHLSL/BasicHLSL11_PS2.hlsl index 1363c7d05..e6f3a08f0 100644 --- a/tools/clang/test/CodeGenHLSL/BasicHLSL11_PS2.hlsl +++ b/tools/clang/test/CodeGenHLSL/BasicHLSL11_PS2.hlsl @@ -20,7 +20,6 @@ // CHECK: DILocalVariable(tag: DW_TAG_auto_variable, name: "vDiffuse" // CHECK: DILocalVariable(tag: DW_TAG_auto_variable, name: "fLighting" -// CHECK: DILocalVariable(tag: DW_TAG_auto_variable, name: "Output" //-------------------------------------------------------------------------------------- diff --git a/tools/clang/test/CodeGenHLSL/class.hlsl b/tools/clang/test/CodeGenHLSL/class.hlsl new file mode 100644 index 000000000..ef28861a5 --- /dev/null +++ b/tools/clang/test/CodeGenHLSL/class.hlsl @@ -0,0 +1,25 @@ +// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s + +// Make sure N n2[2] lowered into float [2] +// CHECK:[2 x float] + + struct N { + float n; + }; + +class X { + float2x2 ma[2]; + N n2[2]; + row_major float3x3 m; + N test_inout(float idx) { + return n2[idx]; + } +}; + +X x0; + +float4 main(float4 a : A, float4 b:B) : SV_TARGET +{ + X x = x0; + return x.test_inout(a.x).n; +} diff --git a/tools/clang/unittests/HLSL/CompilerTest.cpp b/tools/clang/unittests/HLSL/CompilerTest.cpp index c733d063b..6c57f4f03 100644 --- a/tools/clang/unittests/HLSL/CompilerTest.cpp +++ b/tools/clang/unittests/HLSL/CompilerTest.cpp @@ -335,6 +335,7 @@ public: TEST_METHOD(CodeGenCbufferAlloc) TEST_METHOD(CodeGenCbufferAllocLegacy) TEST_METHOD(CodeGenCbufferInLoop) + TEST_METHOD(CodeGenClass) TEST_METHOD(CodeGenClip) TEST_METHOD(CodeGenClipPlanes) TEST_METHOD(CodeGenConstoperand1) @@ -2181,6 +2182,10 @@ TEST_F(CompilerTest, CodeGenCbufferInLoop) { CodeGenTest(L"..\\CodeGenHLSL\\cbufferInLoop.hlsl"); } +TEST_F(CompilerTest, CodeGenClass) { + CodeGenTestCheck(L"..\\CodeGenHLSL\\class.hlsl"); +} + TEST_F(CompilerTest, CodeGenClip) { CodeGenTestCheck(L"..\\CodeGenHLSL\\clip.hlsl"); }