[Validator] Remove the pModule input for RunInternalValidator (#6845)

This change ensures that data validation occurs within the container
itself,
rather than relying on the module—especially since the module may be
modified during container assembly.

Furthermore, simplifying the validator’s interface would be an added
benefit.
This commit is contained in:
Xiang Li 2024-08-05 10:43:44 -07:00 коммит произвёл GitHub
Родитель cc6c6656d5
Коммит 8652894e69
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 75 добавлений и 159 удалений

Просмотреть файл

@ -87,8 +87,7 @@ HRESULT ValidateDxilContainer(const void *pContainer, uint32_t ContainerSize,
// Full container validation, including ValidateDxilModule, with debug module
HRESULT ValidateDxilContainer(const void *pContainer, uint32_t ContainerSize,
const void *pOptDebugBitcode,
uint32_t OptDebugBitcodeSize,
llvm::Module *pDebugModule,
llvm::raw_ostream &DiagStream);
class PrintDiagnosticContext {

Просмотреть файл

@ -75,17 +75,32 @@ namespace {
// Utility class for setting and restoring the diagnostic context so we may
// capture errors/warnings
struct DiagRestore {
LLVMContext &Ctx;
LLVMContext *Ctx = nullptr;
void *OrigDiagContext;
LLVMContext::DiagnosticHandlerTy OrigHandler;
DiagRestore(llvm::LLVMContext &Ctx, void *DiagContext) : Ctx(Ctx) {
OrigHandler = Ctx.getDiagnosticHandler();
OrigDiagContext = Ctx.getDiagnosticContext();
Ctx.setDiagnosticHandler(
DiagRestore(llvm::LLVMContext &InputCtx, void *DiagContext) : Ctx(&InputCtx) {
init(DiagContext);
}
DiagRestore(Module *M, void *DiagContext) {
if (!M)
return;
Ctx = &M->getContext();
init(DiagContext);
}
~DiagRestore() {
if (!Ctx)
return;
Ctx->setDiagnosticHandler(OrigHandler, OrigDiagContext);
}
private:
void init(void *DiagContext) {
OrigHandler = Ctx->getDiagnosticHandler();
OrigDiagContext = Ctx->getDiagnosticContext();
Ctx->setDiagnosticHandler(
hlsl::PrintDiagnosticContext::PrintDiagnosticHandler, DiagContext);
}
~DiagRestore() { Ctx.setDiagnosticHandler(OrigHandler, OrigDiagContext); }
};
static void emitDxilDiag(LLVMContext &Ctx, const char *str) {
@ -6984,11 +6999,10 @@ HRESULT ValidateLoadModuleFromContainerLazy(
}
HRESULT ValidateDxilContainer(const void *pContainer, uint32_t ContainerSize,
const void *pOptDebugBitcode,
uint32_t OptDebugBitcodeSize,
llvm::Module *pDebugModule,
llvm::raw_ostream &DiagStream) {
LLVMContext Ctx, DbgCtx;
std::unique_ptr<llvm::Module> pModule, pDebugModule;
std::unique_ptr<llvm::Module> pModule, pDebugModuleInContainer;
llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
PrintDiagnosticContext DiagContext(DiagPrinter);
@ -6997,31 +7011,29 @@ HRESULT ValidateDxilContainer(const void *pContainer, uint32_t ContainerSize,
DbgCtx.setDiagnosticHandler(PrintDiagnosticContext::PrintDiagnosticHandler,
&DiagContext, true);
IFR(ValidateLoadModuleFromContainer(pContainer, ContainerSize, pModule,
pDebugModule, Ctx, DbgCtx, DiagStream));
DiagRestore DR(pDebugModule, &DiagContext);
if (!pDebugModule && pOptDebugBitcode) {
// TODO: lazy load for perf
IFR(ValidateLoadModule((const char *)pOptDebugBitcode, OptDebugBitcodeSize,
pDebugModule, DbgCtx, DiagStream,
/*bLazyLoad*/ false));
}
IFR(ValidateLoadModuleFromContainer(pContainer, ContainerSize, pModule,
pDebugModuleInContainer, Ctx, DbgCtx,
DiagStream));
if (pDebugModuleInContainer)
pDebugModule = pDebugModuleInContainer.get();
// Validate DXIL Module
IFR(ValidateDxilModule(pModule.get(), pDebugModule.get()));
IFR(ValidateDxilModule(pModule.get(), pDebugModule));
if (DiagContext.HasErrors() || DiagContext.HasWarnings()) {
return DXC_E_IR_VERIFICATION_FAILED;
}
return ValidateDxilContainerParts(
pModule.get(), pDebugModule.get(),
pModule.get(), pDebugModule,
IsDxilContainerLike(pContainer, ContainerSize), ContainerSize);
}
HRESULT ValidateDxilContainer(const void *pContainer, uint32_t ContainerSize,
llvm::raw_ostream &DiagStream) {
return ValidateDxilContainer(pContainer, ContainerSize, nullptr, 0,
DiagStream);
return ValidateDxilContainer(pContainer, ContainerSize, nullptr, DiagStream);
}
} // namespace hlsl

Просмотреть файл

@ -77,17 +77,6 @@ using namespace clang;
using namespace hlsl;
using std::string;
// This declaration is used for the locally-linked validator.
HRESULT CreateDxcValidator(REFIID riid, LPVOID *ppv);
// This internal call allows the validator to avoid having to re-deserialize
// the module. It trusts that the caller didn't make any changes and is
// kept internal because the layout of the module class may change based
// on changes across modules, or picking a different compiler version or CRT.
HRESULT RunInternalValidator(IDxcValidator *pValidator, llvm::Module *pModule,
llvm::Module *pDebugModule, IDxcBlob *pShader,
UINT32 Flags, IDxcOperationResult **ppResult);
static bool ShouldBeCopiedIntoPDB(UINT32 FourCC) {
switch (FourCC) {
case hlsl::DFCC_ShaderDebugName:

Просмотреть файл

@ -42,7 +42,7 @@ HRESULT CreateDxcValidator(REFIID riid, LPVOID *ppv);
// the module. It trusts that the caller didn't make any changes and is
// kept internal because the layout of the module class may change based
// on changes across modules, or picking a different compiler version or CRT.
HRESULT RunInternalValidator(IDxcValidator *pValidator, llvm::Module *pModule,
HRESULT RunInternalValidator(IDxcValidator *pValidator,
llvm::Module *pDebugModule, IDxcBlob *pShader,
UINT32 Flags, IDxcOperationResult **ppResult);
@ -223,8 +223,7 @@ HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs) {
// Important: in-place edit is required so the blob is reused and thus
// dxil.dll can be released.
if (bInternalValidator) {
IFT(RunInternalValidator(pValidator, inputs.pM.get(),
llvmModuleWithDebugInfo.get(),
IFT(RunInternalValidator(pValidator, llvmModuleWithDebugInfo.get(),
inputs.pOutputContainerBlob,
DxcValidatorFlags_InPlaceEdit, &pValResult));
} else {

Просмотреть файл

@ -45,10 +45,9 @@ public:
}
// For internal use only.
HRESULT ValidateWithOptModules(
HRESULT ValidateWithOptDebugModule(
IDxcBlob *pShader, // Shader to validate.
UINT32 Flags, // Validation flags.
llvm::Module *pModule, // Module to validate, if available.
llvm::Module *pDebugModule, // Debug module to validate, if available
IDxcOperationResult *
*ppResult // Validation output status, buffer, and errors
@ -90,7 +89,7 @@ HRESULT STDMETHODCALLTYPE DxcValidator::Validate(
IDxcOperationResult *
*ppResult // Validation output status, buffer, and errors
) {
return hlsl::validate(pShader, Flags, ppResult);
return hlsl::validateWithDebug(pShader, Flags, nullptr, ppResult);
}
HRESULT STDMETHODCALLTYPE DxcValidator::ValidateWithDebug(
@ -104,16 +103,15 @@ HRESULT STDMETHODCALLTYPE DxcValidator::ValidateWithDebug(
return hlsl::validateWithDebug(pShader, Flags, pOptDebugBitcode, ppResult);
}
HRESULT DxcValidator::ValidateWithOptModules(
HRESULT DxcValidator::ValidateWithOptDebugModule(
IDxcBlob *pShader, // Shader to validate.
UINT32 Flags, // Validation flags.
llvm::Module *pModule, // Module to validate, if available.
llvm::Module *pDebugModule, // Debug module to validate, if available
IDxcOperationResult *
*ppResult // Validation output status, buffer, and errors
) {
return hlsl::validateWithOptModules(pShader, Flags, pModule, pDebugModule,
ppResult);
return hlsl::validateWithOptDebugModule(pShader, Flags, pDebugModule,
ppResult);
}
HRESULT STDMETHODCALLTYPE DxcValidator::GetVersion(UINT32 *pMajor,
@ -153,17 +151,16 @@ HRESULT STDMETHODCALLTYPE DxcValidator::GetFlags(UINT32 *pFlags) {
///////////////////////////////////////////////////////////////////////////////
HRESULT RunInternalValidator(IDxcValidator *pValidator, llvm::Module *pModule,
HRESULT RunInternalValidator(IDxcValidator *pValidator,
llvm::Module *pDebugModule, IDxcBlob *pShader,
UINT32 Flags, IDxcOperationResult **ppResult) {
DXASSERT_NOMSG(pValidator != nullptr);
DXASSERT_NOMSG(pModule != nullptr);
DXASSERT_NOMSG(pShader != nullptr);
DXASSERT_NOMSG(ppResult != nullptr);
DxcValidator *pInternalValidator = (DxcValidator *)pValidator;
return pInternalValidator->ValidateWithOptModules(pShader, Flags, pModule,
pDebugModule, ppResult);
return pInternalValidator->ValidateWithOptDebugModule(pShader, Flags,
pDebugModule, ppResult);
}
HRESULT CreateDxcValidator(REFIID riid, LPVOID *ppv) {

Просмотреть файл

@ -31,29 +31,11 @@
using namespace llvm;
using namespace hlsl;
// Utility class for setting and restoring the diagnostic context so we may
// capture errors/warnings
struct DiagRestore {
LLVMContext &Ctx;
void *OrigDiagContext;
LLVMContext::DiagnosticHandlerTy OrigHandler;
DiagRestore(llvm::LLVMContext &Ctx, void *DiagContext) : Ctx(Ctx) {
OrigHandler = Ctx.getDiagnosticHandler();
OrigDiagContext = Ctx.getDiagnosticContext();
Ctx.setDiagnosticHandler(PrintDiagnosticContext::PrintDiagnosticHandler,
DiagContext);
}
~DiagRestore() { Ctx.setDiagnosticHandler(OrigHandler, OrigDiagContext); }
};
static uint32_t runValidation(
IDxcBlob *Shader,
uint32_t Flags, // Validation flags.
llvm::Module *Module, // Module to validate, if available.
llvm::Module *DebugModule, // Debug module to validate, if available
AbstractMemoryStream *DiagMemStream) {
// Run validation may throw, but that indicates an inability to validate,
// not that the validation failed (eg out of memory). That is indicated
// by a failing HRESULT, and possibly error messages in the diagnostics
@ -61,49 +43,9 @@ static uint32_t runValidation(
raw_stream_ostream DiagStream(DiagMemStream);
if (Flags & DxcValidatorFlags_ModuleOnly) {
if (IsDxilContainerLike(Shader->GetBufferPointer(),
Shader->GetBufferSize()))
return E_INVALIDARG;
} else {
if (!IsDxilContainerLike(Shader->GetBufferPointer(),
Shader->GetBufferSize()))
return DXC_E_CONTAINER_INVALID;
}
if (!Module) {
DXASSERT_NOMSG(DebugModule == nullptr);
if (Flags & DxcValidatorFlags_ModuleOnly) {
return ValidateDxilBitcode((const char *)Shader->GetBufferPointer(),
(uint32_t)Shader->GetBufferSize(), DiagStream);
} else {
return ValidateDxilContainer(Shader->GetBufferPointer(),
Shader->GetBufferSize(), DiagStream);
}
}
llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
PrintDiagnosticContext DiagContext(DiagPrinter);
DiagRestore DR(Module->getContext(), &DiagContext);
HRESULT hr = hlsl::ValidateDxilModule(Module, DebugModule);
if (FAILED(hr))
return hr;
if (!(Flags & DxcValidatorFlags_ModuleOnly)) {
hr = ValidateDxilContainerParts(
Module, DebugModule,
IsDxilContainerLike(Shader->GetBufferPointer(),
Shader->GetBufferSize()),
(uint32_t)Shader->GetBufferSize());
if (FAILED(hr))
return hr;
}
if (DiagContext.HasErrors() || DiagContext.HasWarnings()) {
return DXC_E_IR_VERIFICATION_FAILED;
}
return S_OK;
return ValidateDxilContainer(Shader->GetBufferPointer(),
Shader->GetBufferSize(), DebugModule,
DiagStream);
}
static uint32_t
@ -151,23 +93,14 @@ runRootSignatureValidation(IDxcBlob *Shader,
return S_OK;
}
// Compile a single entry point to the target shader model
uint32_t hlsl::validate(
IDxcBlob *Shader, // Shader to validate.
uint32_t Flags, // Validation flags.
IDxcOperationResult **Result // Validation output status, buffer, and errors
) {
DxcThreadMalloc TM(DxcGetThreadMallocNoRef());
if (Result == nullptr)
return false;
*Result = nullptr;
if (Shader == nullptr || Flags & ~DxcValidatorFlags_ValidMask)
return false;
if ((Flags & DxcValidatorFlags_ModuleOnly) &&
(Flags &
(DxcValidatorFlags_InPlaceEdit | DxcValidatorFlags_RootSignatureOnly)))
return false;
return validateWithOptModules(Shader, Flags, nullptr, nullptr, Result);
static uint32_t runDxilModuleValidation(IDxcBlob *Shader, // Shader to validate.
AbstractMemoryStream *DiagMemStream) {
if (IsDxilContainerLike(Shader->GetBufferPointer(), Shader->GetBufferSize()))
return E_INVALIDARG;
raw_stream_ostream DiagStream(DiagMemStream);
return ValidateDxilBitcode((const char *)Shader->GetBufferPointer(),
(uint32_t)Shader->GetBufferSize(), DiagStream);
}
uint32_t hlsl::validateWithDebug(
@ -212,17 +145,15 @@ uint32_t hlsl::validateWithDebug(
if (FAILED(hr))
throw hlsl::Exception(hr);
}
return validateWithOptModules(Shader, Flags, nullptr, DebugModule.get(),
Result);
return validateWithOptDebugModule(Shader, Flags, DebugModule.get(), Result);
}
CATCH_CPP_ASSIGN_HRESULT();
return hr;
}
uint32_t hlsl::validateWithOptModules(
uint32_t hlsl::validateWithOptDebugModule(
IDxcBlob *Shader, // Shader to validate.
uint32_t Flags, // Validation flags.
llvm::Module *Module, // Module to validate, if available.
llvm::Module *DebugModule, // Debug module to validate, if available
IDxcOperationResult **Result // Validation output status, buffer, and errors
) {
@ -238,12 +169,12 @@ uint32_t hlsl::validateWithOptModules(
throw hlsl::Exception(hr);
// Run validation may throw, but that indicates an inability to validate,
// not that the validation failed (eg out of memory).
if (Flags & DxcValidatorFlags_RootSignatureOnly) {
if (Flags & DxcValidatorFlags_RootSignatureOnly)
validationStatus = runRootSignatureValidation(Shader, DiagStream);
} else {
validationStatus =
runValidation(Shader, Flags, Module, DebugModule, DiagStream);
}
else if (Flags & DxcValidatorFlags_ModuleOnly)
validationStatus = runDxilModuleValidation(Shader, DiagStream);
else
validationStatus = runValidation(Shader, Flags, DebugModule, DiagStream);
if (FAILED(validationStatus)) {
std::string msg("Validation failed.\n");
ULONG cbWritten;

Просмотреть файл

@ -24,19 +24,12 @@ class LLVMContext;
} // namespace llvm
namespace hlsl {
// For internal use only.
uint32_t validateWithOptModules(
IDxcBlob *Shader, // Shader to validate.
uint32_t Flags, // Validation flags.
llvm::Module *Module, // Module to validate, if available.
llvm::Module *DebugModule, // Debug module to validate, if available
IDxcOperationResult **Result // Validation output status, buffer, and errors
);
// IDxcValidator
uint32_t validate(
// For internal use only.
uint32_t validateWithOptDebugModule(
IDxcBlob *Shader, // Shader to validate.
uint32_t Flags, // Validation flags.
llvm::Module *DebugModule, // Debug module to validate, if available
IDxcOperationResult **Result // Validation output status, buffer, and errors
);

Просмотреть файл

@ -40,7 +40,7 @@ HRESULT CreateDxcValidator(REFIID riid, LPVOID *ppv);
// the module. It trusts that the caller didn't make any changes and is
// kept internal because the layout of the module class may change based
// on changes across modules, or picking a different compiler version or CRT.
HRESULT RunInternalValidator(IDxcValidator *pValidator, llvm::Module *pModule,
HRESULT RunInternalValidator(IDxcValidator *pValidator,
llvm::Module *pDebugModule, IDxcBlob *pShader,
UINT32 Flags, IDxcOperationResult **ppResult);
@ -190,8 +190,7 @@ HRESULT ValidateAndAssembleToContainer(AssembleInputs &inputs) {
// Important: in-place edit is required so the blob is reused and thus
// dxil.dll can be released.
if (bInternalValidator) {
IFT(RunInternalValidator(pValidator, inputs.pM.get(),
llvmModuleWithDebugInfo.get(),
IFT(RunInternalValidator(pValidator, llvmModuleWithDebugInfo.get(),
inputs.pOutputContainerBlob,
DxcValidatorFlags_InPlaceEdit, &pValResult));
} else {

Просмотреть файл

@ -35,10 +35,9 @@ public:
}
// For internal use only.
HRESULT ValidateWithOptModules(
HRESULT ValidateWithOptDebugModule(
IDxcBlob *pShader, // Shader to validate.
UINT32 Flags, // Validation flags.
llvm::Module *pModule, // Module to validate, if available.
llvm::Module *pDebugModule, // Debug module to validate, if available
IDxcOperationResult *
*ppResult // Validation output status, buffer, and errors
@ -64,19 +63,18 @@ HRESULT STDMETHODCALLTYPE DxcValidator::Validate(
IDxcOperationResult *
*ppResult // Validation output status, buffer, and errors
) {
return hlsl::validate(pShader, Flags, ppResult);
return hlsl::validateWithDebug(pShader, Flags, nullptr, ppResult);
}
HRESULT DxcValidator::ValidateWithOptModules(
HRESULT DxcValidator::ValidateWithOptDebugModule(
IDxcBlob *pShader, // Shader to validate.
UINT32 Flags, // Validation flags.
llvm::Module *pModule, // Module to validate, if available.
llvm::Module *pDebugModule, // Debug module to validate, if available
IDxcOperationResult *
*ppResult // Validation output status, buffer, and errors
) {
return hlsl::validateWithOptModules(pShader, Flags, pModule, pDebugModule,
ppResult);
return hlsl::validateWithOptDebugModule(pShader, Flags, pDebugModule,
ppResult);
}
HRESULT STDMETHODCALLTYPE DxcValidator::GetVersion(UINT32 *pMajor,
@ -97,17 +95,16 @@ HRESULT STDMETHODCALLTYPE DxcValidator::GetFlags(UINT32 *pFlags) {
///////////////////////////////////////////////////////////////////////////////
HRESULT RunInternalValidator(IDxcValidator *pValidator, llvm::Module *pModule,
HRESULT RunInternalValidator(IDxcValidator *pValidator,
llvm::Module *pDebugModule, IDxcBlob *pShader,
UINT32 Flags, IDxcOperationResult **ppResult) {
DXASSERT_NOMSG(pValidator != nullptr);
DXASSERT_NOMSG(pModule != nullptr);
DXASSERT_NOMSG(pShader != nullptr);
DXASSERT_NOMSG(ppResult != nullptr);
DxcValidator *pInternalValidator = (DxcValidator *)pValidator;
return pInternalValidator->ValidateWithOptModules(pShader, Flags, pModule,
pDebugModule, ppResult);
return pInternalValidator->ValidateWithOptDebugModule(pShader, Flags,
pDebugModule, ppResult);
}
HRESULT CreateDxcValidator(REFIID riid, LPVOID *ppv) {