[dxil2spv] Translate createHandle and bufferLoad (#4389)

Add support for translating createHandle and bufferLoad DXIL operations
to SPIR-V instructions. The most significant limitation of this current
implementation is the naive translation of descriptor set and binding
numbers, but it is sufficient for simple passthrough shaders.
This commit is contained in:
Natalie Chouinard 2022-04-21 06:38:03 -07:00 коммит произвёл GitHub
Родитель d3a3683d8c
Коммит 316b849cfa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 235 добавлений и 35 удалений

Просмотреть файл

@ -105,6 +105,9 @@ public:
SpirvVariable *addFnVar(QualType valueType, SourceLocation,
llvm::StringRef name = "", bool isPrecise = false,
SpirvInstruction *init = nullptr);
SpirvVariable *addFnVar(const spirv::SpirvType *valueType, SourceLocation,
llvm::StringRef name = "", bool isPrecise = false,
SpirvInstruction *init = nullptr);
/// \brief Ends building of the current function. All basic blocks constructed
/// from the beginning or after ending the previous function will be collected

Просмотреть файл

@ -147,6 +147,18 @@ SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
return var;
}
SpirvVariable *SpirvBuilder::addFnVar(const spirv::SpirvType *valueType,
SourceLocation loc, llvm::StringRef name,
bool isPrecise, SpirvInstruction *init) {
assert(function && "found detached local variable");
// TODO: Handle potential bindless array of an opaque type.
SpirvVariable *var = new (context) SpirvVariable(
valueType, loc, spv::StorageClass::Function, isPrecise, init);
var->setDebugName(name);
function->addVariable(var);
return var;
}
void SpirvBuilder::endFunction() {
assert(function && "no active function");
mod->addFunctionToListOfSortedModuleFunctions(function);

Просмотреть файл

@ -86,7 +86,7 @@ attributes #2 = { nounwind }
; ; SPIR-V
; ; Version: 1.0
; ; Generator: Google spiregg; 0
; ; Bound: 22
; ; Bound: 56
; ; Schema: 0
; OpCapability Shader
; OpMemoryModel Logical GLSL450
@ -95,6 +95,11 @@ attributes #2 = { nounwind }
; OpName %type_ByteAddressBuffer "type.ByteAddressBuffer"
; OpName %type_RWByteAddressBuffer "type.RWByteAddressBuffer"
; OpName %main "main"
; OpName %dx_types_ResRet_i32 "dx.types.ResRet.i32"
; OpDecorate %3 DescriptorSet 0
; OpDecorate %3 Binding 0
; OpDecorate %4 DescriptorSet 0
; OpDecorate %4 Binding 1
; OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
; OpDecorate %_runtimearr_uint ArrayStride 4
; OpMemberDecorate %type_ByteAddressBuffer 0 Offset 0
@ -105,29 +110,65 @@ attributes #2 = { nounwind }
; %uint = OpTypeInt 32 0
; %uint_0 = OpConstant %uint 0
; %uint_2 = OpConstant %uint 2
; %v3uint = OpTypeVector %uint 3
; %_ptr_Input_v3uint = OpTypePointer Input %v3uint
; %uint_1 = OpConstant %uint 1
; %uint_3 = OpConstant %uint 3
; %uint_4 = OpConstant %uint 4
; %_runtimearr_uint = OpTypeRuntimeArray %uint
; %type_ByteAddressBuffer = OpTypeStruct %_runtimearr_uint
; %_ptr_Uniform_type_ByteAddressBuffer = OpTypePointer Uniform %type_ByteAddressBuffer
; %type_RWByteAddressBuffer = OpTypeStruct %_runtimearr_uint
; %_ptr_Uniform_type_RWByteAddressBuffer = OpTypePointer Uniform %type_RWByteAddressBuffer
; %v3uint = OpTypeVector %uint 3
; %_ptr_Input_v3uint = OpTypePointer Input %v3uint
; %void = OpTypeVoid
; %16 = OpTypeFunction %void
; %19 = OpTypeFunction %void
; %int = OpTypeInt 32 1
; %dx_types_ResRet_i32 = OpTypeStruct %int %int %int %int %int
; %_ptr_Function_dx_types_ResRet_i32 = OpTypePointer Function %dx_types_ResRet_i32
; %_ptr_Input_uint = OpTypePointer Input %uint
; %_ptr_Uniform_uint = OpTypePointer Uniform %uint
; %_ptr_Function_int = OpTypePointer Function %int
; %3 = OpVariable %_ptr_Uniform_type_ByteAddressBuffer Uniform
; %4 = OpVariable %_ptr_Uniform_type_RWByteAddressBuffer Uniform
; %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
; %11 = OpVariable %_ptr_Uniform_type_ByteAddressBuffer Uniform
; %14 = OpVariable %_ptr_Uniform_type_RWByteAddressBuffer Uniform
; %main = OpFunction %void None %16
; %17 = OpLabel
; %19 = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0
; %20 = OpLoad %uint %19
; %21 = OpShiftLeftLogical %uint %20 %uint_2
; %main = OpFunction %void None %19
; %20 = OpLabel
; %24 = OpVariable %_ptr_Function_dx_types_ResRet_i32 Function
; %26 = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0
; %27 = OpLoad %uint %26
; %28 = OpShiftLeftLogical %uint %27 %uint_2
; %29 = OpIAdd %uint %28 %uint_0
; %31 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %29
; %32 = OpLoad %uint %31
; %34 = OpAccessChain %_ptr_Function_int %24 %uint_0
; %35 = OpBitcast %int %32
; OpStore %34 %35
; %36 = OpIAdd %uint %28 %uint_1
; %37 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %36
; %38 = OpLoad %uint %37
; %39 = OpAccessChain %_ptr_Function_int %24 %uint_1
; %40 = OpBitcast %int %38
; OpStore %39 %40
; %41 = OpIAdd %uint %28 %uint_2
; %42 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %41
; %43 = OpLoad %uint %42
; %44 = OpAccessChain %_ptr_Function_int %24 %uint_2
; %45 = OpBitcast %int %43
; OpStore %44 %45
; %46 = OpIAdd %uint %28 %uint_3
; %47 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %46
; %48 = OpLoad %uint %47
; %49 = OpAccessChain %_ptr_Function_int %24 %uint_3
; %50 = OpBitcast %int %48
; OpStore %49 %50
; %51 = OpIAdd %uint %28 %uint_4
; %52 = OpAccessChain %_ptr_Uniform_uint %3 %uint_0 %51
; %53 = OpLoad %uint %52
; %54 = OpAccessChain %_ptr_Function_int %24 %uint_4
; %55 = OpBitcast %int %53
; OpStore %54 %55
; OpReturn
; OpFunctionEnd
; CHECK-ERRORS:
; error: Unhandled DXIL opcode: CreateHandle
; error: Unhandled DXIL opcode: CreateHandle
; error: Unhandled DXIL opcode: BufferLoad
; error: Unhandled LLVM instruction: %6 = extractvalue %dx.types.ResRet.i32 %5, 0
; error: Unhandled DXIL opcode: BufferStore

Просмотреть файл

@ -17,6 +17,7 @@
#include "dxc/Support/ErrorCodes.h"
#include "dxc/Support/Global.h"
#include "clang/SPIRV/SpirvInstruction.h"
#include "clang/SPIRV/SpirvType.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
@ -114,6 +115,10 @@ int Translator::Run() {
createStageIOVariables(program.GetInputSignature().GetElements(),
program.GetOutputSignature().GetElements());
// Add HLSL resources.
createModuleVariables(program.GetSRVs());
createModuleVariables(program.GetUAVs());
// Create entry function.
spirv::SpirvFunction *entryFunction =
createEntryFunction(program.GetEntryFunction());
@ -131,10 +136,6 @@ int Translator::Run() {
{});
}
// Add HLSL resources.
createModuleVariables(program.GetSRVs());
createModuleVariables(program.GetUAVs());
// Contsruct the SPIR-V module.
std::vector<uint32_t> m = spvBuilder.takeModuleForDxilToSpv();
@ -204,8 +205,12 @@ void Translator::createModuleVariables(
assert(hlslType->isPointerTy());
llvm::Type *pointeeType =
cast<llvm::PointerType>(hlslType)->getPointerElementType();
spvBuilder.addModuleVar(toSpirvType(pointeeType),
spv::StorageClass::Uniform, false);
spirv::SpirvVariable *moduleVar = spvBuilder.addModuleVar(
toSpirvType(pointeeType), spv::StorageClass::Uniform, false);
spvBuilder.decorateDSetBinding(moduleVar, nextDescriptorSet,
nextBindingNo++);
resourceMap[{static_cast<unsigned>(resource->GetClass()),
resource->GetID()}] = moduleVar;
}
}
@ -263,6 +268,12 @@ void Translator::createInstruction(llvm::Instruction &instruction) {
case hlsl::DXIL::OpCode::ThreadId: {
createThreadIdInstruction(callInstruction);
} break;
case hlsl::DXIL::OpCode::CreateHandle: {
createHandleInstruction(callInstruction);
} break;
case hlsl::DXIL::OpCode::BufferLoad: {
createBufferLoadInstruction(callInstruction);
} break;
default: {
emitError("Unhandled DXIL opcode: %0")
<< hlsl::OP::GetOpCodeName(dxilOpcode);
@ -281,10 +292,7 @@ void Translator::createInstruction(llvm::Instruction &instruction) {
}
// Unhandled instruction type.
else {
std::string instStr;
llvm::raw_string_ostream os(instStr);
instruction.print(os);
emitError("Unhandled LLVM instruction: %0") << os.str();
emitError("Unhandled LLVM instruction: %0", instruction);
}
}
@ -344,8 +352,8 @@ void Translator::createStoreOutputInstruction(llvm::CallInst &instruction) {
spirv::SpirvAccessChain *outputVarPtr =
spvBuilder.createAccessChain(elemType, outputVar, {index}, {});
spirv::SpirvInstruction *valueToStore =
instructionMap[instruction.getArgOperand(
hlsl::DXIL::OperandIndex::kStoreOutputValOpIdx)];
getSpirvInstruction(instruction.getArgOperand(
hlsl::DXIL::OperandIndex::kStoreOutputValOpIdx));
spvBuilder.createStore(outputVarPtr, valueToStore, {});
}
@ -384,15 +392,8 @@ void Translator::createBinaryOpInstruction(llvm::BinaryOperator &instruction) {
// Shift left instruction.
case llvm::Instruction::Shl: {
// Value to be shifted.
spirv::SpirvInstruction *val = instructionMap[instruction.getOperand(0)];
if (!val) {
std::string instStr;
llvm::raw_string_ostream os(instStr);
instruction.print(os);
emitError("Could not find translation of instruction operand 0: %0")
<< os.str();
return;
}
spirv::SpirvInstruction *val =
getSpirvInstruction(instruction.getOperand(0));
// Amount to shift by.
const spirv::IntegerType *uint32 = spvContext.getUIntType(32);
@ -412,6 +413,108 @@ void Translator::createBinaryOpInstruction(llvm::BinaryOperator &instruction) {
instructionMap[&instruction] = result;
}
void Translator::createHandleInstruction(llvm::CallInst &instruction) {
unsigned resourceClass =
cast<llvm::ConstantInt>(
instruction.getArgOperand(
hlsl::DXIL::OperandIndex::kCreateHandleResClassOpIdx))
->getLimitedValue();
unsigned resourceRangeId =
cast<llvm::ConstantInt>(
instruction.getArgOperand(
hlsl::DXIL::OperandIndex::kCreateHandleResIDOpIdx))
->getLimitedValue();
spirv::SpirvVariable *inputVar =
resourceMap[{resourceClass, resourceRangeId}];
if (!inputVar) {
emitError("No resource found corresponding to handle: %0", instruction);
return;
}
instructionMap[&instruction] = inputVar;
}
void Translator::createBufferLoadInstruction(llvm::CallInst &instruction) {
// TODO: Extend this function to work with all buffer types on which it is
// used, not just ByteAddressBuffers.
// ByteAddressBuffers are represented as a struct with one member that is a
// runtime array of unsigned integers. The SPIR-V OpAccessChain instruction is
// then used to access that offset, and OpLoad is used to load integers
// into a corresponding struct.
// clang-format off
// For example, the following DXIL instruction:
// %dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 }
// %ret = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle %res, i32 %index, i32 undef)
// would translate to the following SPIR-V instructions:
// %dx_types_ResRet_i32 = OpTypeStruct %int %int %int %int %int
// %_ptr_Function_dx_types_ResRet_i32 = OpTypePointer Function %dx_types_ResRet_i32
// %ret = OpVariable %_ptr_Function_dx_types_ResRet_i32 Function
// %i = OpLoad %uint %index
// for %offset = {0, 1, 2, 3, 4}, repeat:
// %v0 = OpIAdd %uint %i %offset
// %v1 = OpAccessChain %_ptr_Uniform_uint %res %offset %v0
// %v2 = OpLoad %uint %v1
// %v3 = OpAccessChain %_ptr_Function_int %ret %offset
// %v4 = OpBitcast %int %v2
// OpStore %v3 %v4
// clang-format on
// Get module input variable corresponding to given DXIL handle.
spirv::SpirvInstruction *inputVar =
getSpirvInstruction(instruction.getArgOperand(
hlsl::DXIL::OperandIndex::kBufferLoadHandleOpIdx));
// Translate DXIL instruction return type (expected to be a struct of
// integers) to a SPIR-V type.
const spirv::SpirvType *returnType = toSpirvType(instruction.getType());
assert(isa<spirv::StructType>(returnType));
const spirv::StructType *structType = cast<spirv::StructType>(returnType);
// Create a return variable to initialize with values loaded from the buffer.
spirv::SpirvVariable *returnVar =
spvBuilder.addFnVar(structType, {}, "", false, nullptr);
// Translate indices into resource buffer to SPIR-V instructions.
auto uint32 = spvContext.getUIntType(32);
spirv::SpirvConstant *indexIntoStruct =
spvBuilder.getConstantInt(uint32, llvm::APInt(32, 0));
spirv::SpirvInstruction *baseArrayIndex =
getSpirvInstruction(instruction.getArgOperand(
hlsl::DXIL::OperandIndex::kBufferLoadCoord0OpIdx));
// Initialize each field in the struct.
for (size_t i = 0; i < structType->getFields().size(); i++) {
// Add offset for current field.
spirv::SpirvConstant *fieldOffset =
spvBuilder.getConstantInt(uint32, llvm::APInt(32, i));
spirv::SpirvInstruction *indexIntoArray = spvBuilder.createBinaryOp(
spv::Op::OpIAdd, uint32, baseArrayIndex, fieldOffset, {});
// Create access chain and load.
spirv::SpirvAccessChain *loadPtr = spvBuilder.createAccessChain(
uint32, inputVar, {indexIntoStruct, indexIntoArray}, {});
spirv::SpirvInstruction *loadInstr =
spvBuilder.createLoad(uint32, loadPtr, {});
// Create access chain and store.
const spirv::SpirvType *fieldType = structType->getFields()[i].type;
spirv::SpirvAccessChain *storePtr =
spvBuilder.createAccessChain(fieldType, returnVar, fieldOffset, {});
// LLVM types are signless, so type conversions are not 1-to-1. A bitcast on
// the unsigned integer may be necessary before storing.
spirv::SpirvInstruction *valToStore =
fieldType == uint32 ? loadInstr
: spvBuilder.createUnaryOp(
spv::Op::OpBitcast, fieldType, loadInstr, {});
spvBuilder.createStore(storePtr, valToStore, {});
}
instructionMap[&instruction] = returnVar;
}
bool Translator::spirvToolsValidate(std::vector<uint32_t> *mod,
std::string *messages) {
spvtools::SpirvTools tools(featureManager.getTargetEnv());
@ -500,6 +603,17 @@ const spirv::SpirvType *Translator::toSpirvType(llvm::StructType *structType) {
return spvContext.getStructType(fields, name);
}
spirv::SpirvInstruction *
Translator::getSpirvInstruction(llvm::Value *instruction) {
spirv::SpirvInstruction *spirvInstruction = instructionMap[instruction];
if (!spirvInstruction) {
emitError("Expected SPIR-V instruction not found for DXIL instruction: %0",
*instruction);
return nullptr;
}
return spirvInstruction;
}
template <unsigned N>
DiagnosticBuilder Translator::emitError(const char (&message)[N]) {
const auto diagId =
@ -507,5 +621,14 @@ DiagnosticBuilder Translator::emitError(const char (&message)[N]) {
return diagnosticsEngine.Report({}, diagId);
}
template <unsigned N>
DiagnosticBuilder Translator::emitError(const char (&message)[N],
llvm::Value &value) {
std::string str;
llvm::raw_string_ostream os(str);
value.print(os);
return emitError(message) << os.str();
}
} // namespace dxil2spv
} // namespace clang

Просмотреть файл

@ -47,9 +47,17 @@ private:
llvm::DenseMap<unsigned, spirv::SpirvVariable *> inputSignatureElementMap;
llvm::DenseMap<unsigned, spirv::SpirvVariable *> outputSignatureElementMap;
// Map from HLSL resource class and range ID to corresponding SPIR-V variable.
llvm::DenseMap<std::pair<unsigned, unsigned>, spirv::SpirvVariable *>
resourceMap;
// Map from DXIL instructions (values) to SPIR-V instructions.
llvm::DenseMap<llvm::Value *, spirv::SpirvInstruction *> instructionMap;
// Get corresponding SPIR-V instruction for a given DXIL instruction, with
// error checking.
spirv::SpirvInstruction *getSpirvInstruction(llvm::Value *instruction);
// Create SPIR-V stage IO variable from DXIL input and output signatures.
void createStageIOVariables(
const std::vector<std::unique_ptr<hlsl::DxilSignatureElement>>
@ -73,6 +81,8 @@ private:
void createStoreOutputInstruction(llvm::CallInst &instruction);
void createThreadIdInstruction(llvm::CallInst &instruction);
void createBinaryOpInstruction(llvm::BinaryOperator &instruction);
void createHandleInstruction(llvm::CallInst &instruction);
void createBufferLoadInstruction(llvm::CallInst &instruction);
// SPIR-V Tools wrapper functions.
bool spirvToolsValidate(std::vector<uint32_t> *mod, std::string *messages);
@ -83,7 +93,18 @@ private:
const spirv::SpirvType *toSpirvType(llvm::Type *llvmType);
const spirv::SpirvType *toSpirvType(llvm::StructType *structType);
// TODO: These variables are used for a temporary hack to assign descriptor
// set and binding numbers that works only for the most simple cases (always
// use descriptor set 0, increment binding number for each resource). Further
// work is needed to translate non-trivial shaders.
unsigned nextDescriptorSet = 0;
unsigned nextBindingNo = 0;
// Helper diagnostic functions for emitting error messages.
template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]);
template <unsigned N>
DiagnosticBuilder emitError(const char (&message)[N],
llvm::Value &instruction);
};
} // namespace dxil2spv