diff --git a/lib/HLSL/DxilValidation.cpp b/lib/HLSL/DxilValidation.cpp index 61af3889c..08de2a961 100644 --- a/lib/HLSL/DxilValidation.cpp +++ b/lib/HLSL/DxilValidation.cpp @@ -15,6 +15,7 @@ #include "dxc/HLSL/DxilModule.h" #include "dxc/HLSL/DxilShaderModel.h" #include "dxc/HLSL/DxilContainer.h" +#include "dxc/hlsl/DxilFunctionProps.h" #include "dxc/Support/Global.h" #include "dxc/HLSL/DxilUtil.h" #include "dxc/HLSL/DxilInstructions.h" @@ -350,8 +351,10 @@ struct ValidationContext { std::unordered_set entryFuncCallSet; std::unordered_set patchConstFuncCallSet; std::unordered_map UavCounterIncMap; + std::unordered_map ResTypeMap; bool hasOutputPosition[DXIL::kNumOutputStreams]; bool hasViewID; + bool isLibProfile; unsigned OutputPositionMask[DXIL::kNumOutputStreams]; std::vector outputCols; std::vector patchConstCols; @@ -382,8 +385,31 @@ struct ValidationContext { hasOutputPosition[i] = false; OutputPositionMask[i] = 0; } - outputCols.resize(DxilMod.GetOutputSignature().GetElements().size(), 0); - patchConstCols.resize(DxilMod.GetPatchConstantSignature().GetElements().size(), 0); + isLibProfile = dxilModule.GetShaderModel()->IsLib(); + if (!isLibProfile) { + outputCols.resize(DxilMod.GetOutputSignature().GetElements().size(), 0); + patchConstCols.resize( + DxilMod.GetPatchConstantSignature().GetElements().size(), 0); + } else { + auto collectResTy = [&](auto &ResTab) { + for (auto &Res : ResTab) { + Type *Ty = Res->GetGlobalSymbol()->getType()->getPointerElementType(); + Ty = dxilutil::GetArrayEltTy(Ty); + ResTypeMap[Ty] = Res.get(); + } + }; + collectResTy(DxilMod.GetCBuffers()); + collectResTy(DxilMod.GetUAVs()); + collectResTy(DxilMod.GetSRVs()); + collectResTy(DxilMod.GetSamplers()); + } + } + DxilResourceBase *GetResFromTy(Type *Ty) { + auto it = ResTypeMap.find(Ty); + if (it == ResTypeMap.end()) + return nullptr; + else + return it->second; } // Provide direct access to the raw_ostream in DiagPrinter. @@ -686,8 +712,31 @@ static DXIL::SamplerKind GetSamplerKind(Value *samplerHandle, ValidationContext DxilInst_CreateHandle createHandle(cast(samplerHandle)); if (!createHandle) { - ValCtx.EmitInstrError(cast(samplerHandle), ValidationRule::InstrHandleNotFromCreateHandle); - return DXIL::SamplerKind::Invalid; + auto EmitError = [&]() -> DXIL::SamplerKind { + ValCtx.EmitInstrError(cast(samplerHandle), + ValidationRule::InstrHandleNotFromCreateHandle); + return DXIL::SamplerKind::Invalid; + }; + if (!ValCtx.isLibProfile) { + return EmitError(); + } + + DxilInst_CreateHandleFromResourceStructForLib createHandleFromRes( + cast(samplerHandle)); + if (!createHandleFromRes) { + return EmitError(); + } + + DxilResourceBase *Res = + ValCtx.GetResFromTy(createHandleFromRes.get_Resource()->getType()); + if (!Res) { + return EmitError(); + } + if (DxilSampler *S = dynamic_cast(Res)) { + return S->GetSamplerKind(); + } else { + return EmitError(); + } } Value *resClass = createHandle.get_resourceClass(); @@ -747,8 +796,32 @@ static DXIL::ResourceKind GetResourceKindAndCompTy(Value *handle, DXIL::Componen DxilInst_CreateHandle createHandle(cast(handle)); if (!createHandle) { - ValCtx.EmitInstrError(cast(handle), ValidationRule::InstrHandleNotFromCreateHandle); - return DXIL::ResourceKind::Invalid; + auto EmitError = [&]() -> DXIL::ResourceKind { + ValCtx.EmitInstrError(cast(handle), + ValidationRule::InstrHandleNotFromCreateHandle); + return DXIL::ResourceKind::Invalid; + }; + if (!ValCtx.isLibProfile) { + return EmitError(); + } + DxilInst_CreateHandleFromResourceStructForLib createHandleFromRes( + cast(handle)); + if (!createHandleFromRes) { + return EmitError(); + } + DxilResourceBase *res = + ValCtx.GetResFromTy(createHandleFromRes.get_Resource()->getType()); + if (!res) { + return EmitError(); + } + // TODO: resIndex for Uav Counter. + if (DxilResource *Res = dynamic_cast(res)) { + CompTy = Res->GetCompType().GetKind(); + } else { + return EmitError(); + } + ResClass = res->GetClass(); + return res->GetKind(); } Value *resourceClass = createHandle.get_resourceClass(); @@ -1101,9 +1174,30 @@ static unsigned StoreValueToMask(ArrayRef vals) { static int GetCBufSize(Value *cbHandle, ValidationContext &ValCtx) { DxilInst_CreateHandle createHandle(cast(cbHandle)); if (!createHandle) { - ValCtx.EmitInstrError(cast(cbHandle), - ValidationRule::InstrHandleNotFromCreateHandle); - return -1; + auto EmitError = [&]() -> int { + ValCtx.EmitInstrError(cast(cbHandle), + ValidationRule::InstrHandleNotFromCreateHandle); + return -1; + }; + if (!ValCtx.isLibProfile) { + return EmitError(); + } + DxilInst_CreateHandleFromResourceStructForLib createHandleFromRes( + cast(cbHandle)); + if (!createHandleFromRes) { + return EmitError(); + } + + DxilResourceBase *Res = + ValCtx.GetResFromTy(createHandleFromRes.get_Resource()->getType()); + if (!Res) { + return EmitError(); + } + if (DxilCBuffer *CB = dynamic_cast(Res)) { + return CB->GetSize(); + } else { + return EmitError(); + } } Value *resourceClass = createHandle.get_resourceClass(); @@ -1180,10 +1274,21 @@ static unsigned GetNumVertices(DXIL::InputPrimitive inputPrimitive) { return InputPrimitiveVertexTab[primitiveIdx]; } +static void ValidateDxilOperationCallInLibProfile(CallInst *CI, + DXIL::OpCode opcode, + ValidationContext &ValCtx) { + // TODO: validation for lib profile. +} + static void ValidateDxilOperationCallInProfile(CallInst *CI, DXIL::OpCode opcode, const ShaderModel *pSM, ValidationContext &ValCtx) { + if (ValCtx.isLibProfile) { + ValidateDxilOperationCallInLibProfile(CI, opcode, ValCtx); + return; + } + switch (opcode) { case DXIL::OpCode::LoadInput: { Value *inputID = CI->getArgOperand(DXIL::OperandIndex::kLoadInputIDOpIdx); @@ -1962,7 +2067,7 @@ static bool IsDxilFunction(llvm::Function *F) { } static void ValidateExternalFunction(Function *F, ValidationContext &ValCtx) { - if (!IsDxilFunction(F)) { + if (!IsDxilFunction(F) && !ValCtx.isLibProfile) { ValCtx.EmitGlobalValueError(F, ValidationRule::DeclDxilFnExtern); return; } @@ -2028,7 +2133,7 @@ static void ValidateExternalFunction(Function *F, ValidationContext &ValCtx) { continue; } - if (!ValidateOpcodeInProfile(dxilOpcode, pSM)) { + if (!ValCtx.isLibProfile && !ValidateOpcodeInProfile(dxilOpcode, pSM)) { // Opcode not available in profile. ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcode, {hlslOP->GetOpCodeName(dxilOpcode), @@ -2137,6 +2242,9 @@ static bool ValidateType(Type *Ty, ValidationContext &ValCtx) { } return true; } + // Lib profile allow all types except those hit ValidationRule::InstrDxilStructUser. + if (ValCtx.isLibProfile) + return true; if (Ty->isVectorTy()) { ValCtx.EmitTypeError(Ty, ValidationRule::TypesNoVector); @@ -2430,6 +2538,29 @@ static void ValidateFunctionMetadata(Function *F, ValidationContext &ValCtx) { } } +static bool IsLLVMInstructionAllowedForLib(Instruction &I, ValidationContext &ValCtx) { + if (!ValCtx.isLibProfile) + return false; + switch (I.getOpcode()) { + case Instruction::InsertElement: + case Instruction::ExtractElement: + return true; + case Instruction::Unreachable: + if (Instruction *Prev = I.getPrevNode()) { + if (CallInst *CI = dyn_cast(Prev)) { + Function *F = CI->getCalledFunction(); + if (IsDxilFunction(F) && + F->hasFnAttribute(Attribute::AttrKind::NoReturn)) { + return true; + } + } + } + return false; + default: + return false; + } +} + static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { bool SupportsMinPrecision = ValCtx.DxilMod.GetGlobalFlags() & DXIL::kEnableMinPrecision; @@ -2446,8 +2577,10 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { // Instructions must be allowed. if (!IsLLVMInstructionAllowed(I)) { - ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed); - continue; + if (!IsLLVMInstructionAllowedForLib(I, ValCtx)) { + ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed); + continue; + } } // Instructions marked precise may not have minprecision arguments. @@ -2499,9 +2632,15 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { } for (Value *op : I.operands()) { - if (!isa(&I) && isa(op)) { - ValCtx.EmitInstrError(&I, - ValidationRule::InstrNoReadingUninitialized); + if (isa(op)) { + bool legalUndef = isa(&I); + if (InsertElementInst *InsertInst = dyn_cast(&I)) { + legalUndef = op == I.getOperand(0); + } + + if (!legalUndef) + ValCtx.EmitInstrError(&I, + ValidationRule::InstrNoReadingUninitialized); } else if (ConstantExpr *CE = dyn_cast(op)) { for (Value *opCE : CE->operands()) { if (isa(opCE)) { @@ -2640,7 +2779,7 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) { ToTy = ToTy->getArrayElementType(); } } - if (isa(FromTy) || isa(ToTy)) { + if ((isa(FromTy) || isa(ToTy)) && !ValCtx.isLibProfile) { ValCtx.EmitInstrError(Cast, ValidationRule::InstrStructBitCast); continue; } @@ -2688,7 +2827,25 @@ static void ValidateFunction(Function &F, ValidationContext &ValCtx) { if (F.isDeclaration()) { ValidateExternalFunction(&F, ValCtx); } else { - if (!F.arg_empty()) + bool isNoArgEntry = ValCtx.DxilMod.HasDxilFunctionProps(&F); + if (isNoArgEntry) { + switch (ValCtx.DxilMod.GetDxilFunctionProps(&F).shaderKind) { + case DXIL::ShaderKind::AnyHit: + case DXIL::ShaderKind::Callable: + case DXIL::ShaderKind::ClosestHit: + case DXIL::ShaderKind::Miss: + isNoArgEntry = false; + break; + default: + isNoArgEntry = true; + break; + } + } else { + isNoArgEntry = &F == ValCtx.DxilMod.GetEntryFunction(); + isNoArgEntry |= &F == ValCtx.DxilMod.GetPatchConstantFunction(); + } + // Entry function should not have parameter. + if (!F.arg_empty() && isNoArgEntry) ValCtx.EmitFormatError(ValidationRule::FlowFunctionCall, {F.getName().str()}); @@ -2709,7 +2866,7 @@ static void ValidateFunction(Function &F, ValidationContext &ValCtx) { argTy = argTy->getArrayElementType(); } - if (argTy->isStructTy()) { + if (argTy->isStructTy() && !ValCtx.isLibProfile) { if (arg.hasName()) ValCtx.EmitFormatError( ValidationRule::DeclFnFlattenParam, @@ -2724,8 +2881,9 @@ static void ValidateFunction(Function &F, ValidationContext &ValCtx) { ValidateFunctionBody(&F, ValCtx); } - - ValidateFunctionAttribute(&F, ValCtx); + // TODO: Remove attribute for lib? + if (!ValCtx.isLibProfile) + ValidateFunctionAttribute(&F, ValCtx); if (F.hasMetadata()) { ValidateFunctionMetadata(&F, ValCtx); @@ -2737,6 +2895,22 @@ static void ValidateGlobalVariable(GlobalVariable &GV, bool isInternalGV = dxilutil::IsStaticGlobal(&GV) || dxilutil::IsSharedMemoryGlobal(&GV); + if (ValCtx.isLibProfile) { + auto isResourceGlobal = [&](auto &ResTab) -> bool { + for (auto &Res : ResTab) { + if (Res->GetGlobalSymbol() == &GV) + return true; + } + return false; + }; + + bool isRes = isResourceGlobal(ValCtx.DxilMod.GetCBuffers()); + isRes |= isResourceGlobal(ValCtx.DxilMod.GetUAVs()); + isRes |= isResourceGlobal(ValCtx.DxilMod.GetSRVs()); + isRes |= isResourceGlobal(ValCtx.DxilMod.GetSamplers()); + isInternalGV |= isRes; + } + if (!isInternalGV) { if (!GV.user_empty()) { bool hasInstructionUser = false; @@ -2907,6 +3081,20 @@ static void ValidateTypeAnnotation(ValidationContext &ValCtx) { } } +static bool IsLibMetadata(ValidationContext &ValCtx, StringRef name) { + if (!ValCtx.isLibProfile) + return false; + // Skip dx.func.props and dx.func.signatures for now. + // And these 2 need validation also. + // Or we merge them into Entry, and validate as entry. + const char * libMetaNames[] = {"dx.func.props","dx.func.signatures"}; + for (const char *libName : libMetaNames) { + if (name.equals(libName)) + return true; + } + return false; +} + static void ValidateMetadata(ValidationContext &ValCtx) { Module *pModule = &ValCtx.M; const std::string &target = pModule->getTargetTriple(); @@ -2926,8 +3114,11 @@ static void ValidateMetadata(ValidationContext &ValCtx) { for (auto &NamedMetaNode : pModule->named_metadata()) { if (!DxilModule::IsKnownNamedMetaData(NamedMetaNode)) { StringRef name = NamedMetaNode.getName(); - if (!name.startswith_lower("llvm.")) + if (IsLibMetadata(ValCtx, name)) + continue; + if (!name.startswith_lower("llvm.")) { ValCtx.EmitFormatError(ValidationRule::MetaKnown, {name.str()}); + } else { if (llvmNamedMeta.count(name) == 0) { ValCtx.EmitFormatError(ValidationRule::MetaKnown, @@ -2964,6 +3155,10 @@ static void ValidateResourceOverlap( SpacesAllocator &spaceAllocator, ValidationContext &ValCtx) { unsigned base = res.GetLowerBound(); + if (ValCtx.isLibProfile && !res.IsAllocated()) { + // Skip unallocated resource for library. + return; + } unsigned size = res.GetRangeSize(); unsigned space = res.GetSpaceID(); @@ -3005,6 +3200,9 @@ static void ValidateResource(hlsl::DxilResource &res, case DXIL::ResourceKind::Texture2DMS: case DXIL::ResourceKind::Texture2DMSArray: break; + case DXIL::ResourceKind::RTAccelerationStructure: + // TODO: check profile. + break; default: ValCtx.EmitResourceError(&res, ValidationRule::SmInvalidResourceKind); break; @@ -3167,7 +3365,7 @@ static void ValidateResources(ValidationContext &ValCtx) { for (auto &uav : uavs) { if (uav->IsROV()) { hasROV = true; - if (!ValCtx.DxilMod.GetShaderModel()->IsPS()) { + if (!ValCtx.DxilMod.GetShaderModel()->IsPS() && !ValCtx.isLibProfile) { ValCtx.EmitResourceError(uav.get(), ValidationRule::SmROVOnlyInPS); } } @@ -3224,6 +3422,10 @@ static void ValidateResources(ValidationContext &ValCtx) { } static void ValidateShaderFlags(ValidationContext &ValCtx) { + // TODO: validate flags foreach entry. + if (ValCtx.isLibProfile) + return; + ShaderFlags calcFlags; ValCtx.DxilMod.CollectShaderFlagsForModule(calcFlags); const uint64_t mask = ShaderFlags::GetShaderFlagsRawForCollection(); @@ -3236,7 +3438,6 @@ static void ValidateShaderFlags(ValidationContext &ValCtx) { if (declaredFlagsRaw == calcFlagsRaw) { return; } - ValCtx.EmitError(ValidationRule::MetaFlagsUsage); ValCtx.DiagStream() << "Flags declared=" << declaredFlagsRaw << ", actual=" << calcFlagsRaw << "\n"; @@ -4546,6 +4747,11 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule, case DFCC_ShaderDebugName: continue; + // Lib part + case DFCC_RuntimeData: + // TODO: Validate RuntimeData. + break; + case DFCC_Container: default: ValCtx.EmitFormatError(ValidationRule::ContainerPartInvalid, {szFourCC}); @@ -4586,7 +4792,9 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule, } } } else { - ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, {"Pipeline State Validation"}); + // Not for lib. + if (!ValCtx.isLibProfile) + ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, {"Pipeline State Validation"}); } if (ValCtx.Failed) { diff --git a/tools/clang/tools/dxcompiler/dxcompilerobj.cpp b/tools/clang/tools/dxcompiler/dxcompilerobj.cpp index 765a24c14..b21a7babe 100644 --- a/tools/clang/tools/dxcompiler/dxcompilerobj.cpp +++ b/tools/clang/tools/dxcompiler/dxcompilerobj.cpp @@ -388,8 +388,12 @@ public: // validator can be used as a fallback. bool produceFullContainer = !opts.CodeGenHighLevel && !opts.AstDump && !opts.OptDump && rootSigMajor == 0; - bool needsValidation = produceFullContainer && !opts.DisableValidation && - !opts.IsLibraryProfile(); + bool needsValidation = produceFullContainer && !opts.DisableValidation; + // Disable validation for lib_6_1 and lib_6_2. + if (compiler.getCodeGenOpts().HLSLProfile == "lib_6_1" || + compiler.getCodeGenOpts().HLSLProfile == "lib_6_2") { + needsValidation = false; + } if (needsValidation || (opts.CodeGenHighLevel && !opts.DisableValidation)) { UINT32 majorVer, minorVer; diff --git a/tools/clang/unittests/HLSL/DxilContainerTest.cpp b/tools/clang/unittests/HLSL/DxilContainerTest.cpp index d4fee2944..3d7fc3960 100644 --- a/tools/clang/unittests/HLSL/DxilContainerTest.cpp +++ b/tools/clang/unittests/HLSL/DxilContainerTest.cpp @@ -685,7 +685,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) { "float function1(float x, min12int i) {" " return x + c_buf + b_buf.Load(x) + tex2[i].x; }" "float function2(float x) { return x + function_import(x); }" - "float function3(int i) {" + "void function3(int i) {" " Foo f = consume_buf.Consume();" " f.f2 += 0.5; append_buf.Append(f);" " rov_buf.Store(i, f.i2.x);"