[dxil2spv] Add initial compute shader support (#4345)

Implement first part of support for a passthrough compute shader.
This iteration translates the HLSL resources used in a simple
passthrough compute shader (ByteAddressBuffer and RWByteAddressBuffer)
to the appropriate SPIR-V module variables. Errors are emitted for
unhandled instructions.
This commit is contained in:
Natalie Chouinard 2022-03-24 13:18:12 -04:00 коммит произвёл GitHub
Родитель 9f70135012
Коммит c701ece61c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 213 добавлений и 12 удалений

Просмотреть файл

@ -0,0 +1,123 @@
; RUN: %dxil2spv
;
; Input signature:
;
; Name Index Mask Register SysValue Format Used
; -------------------- ----- ------ -------- -------- ------- ------
; no parameters
;
; Output signature:
;
; Name Index Mask Register SysValue Format Used
; -------------------- ----- ------ -------- -------- ------- ------
; no parameters
; shader hash: aba83cb71e5a9eee4db93a4e5df0d6cd
;
; Pipeline Runtime Information:
;
;
;
; Buffer Definitions:
;
;
; Resource Bindings:
;
; Name Type Format Dim ID HLSL Bind Count
; ------------------------------ ---------- ------- ----------- ------- -------------- ------
; Buffer0 texture byte r/o T0 t0 1
; BufferOut UAV byte r/w U0 u1 1
;
target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64"
target triple = "dxil-ms-dx"
%dx.types.Handle = type { i8* }
%dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 }
%struct.ByteAddressBuffer = type { i32 }
%struct.RWByteAddressBuffer = type { i32 }
define void @main() {
%1 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 1, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex)
%2 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex)
%3 = call i32 @dx.op.threadId.i32(i32 93, i32 0) ; ThreadId(component)
%4 = shl i32 %3, 2
%5 = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle %2, i32 %4, i32 undef) ; BufferLoad(srv,index,wot)
%6 = extractvalue %dx.types.ResRet.i32 %5, 0
call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle %1, i32 %4, i32 undef, i32 %6, i32 undef, i32 undef, i32 undef, i8 1) ; BufferStore(uav,coord0,coord1,value0,value1,value2,value3,mask)
ret void
}
; Function Attrs: nounwind readnone
declare i32 @dx.op.threadId.i32(i32, i32) #0
; Function Attrs: nounwind readonly
declare %dx.types.Handle @dx.op.createHandle(i32, i8, i32, i32, i1) #1
; Function Attrs: nounwind readonly
declare %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32, %dx.types.Handle, i32, i32) #1
; Function Attrs: nounwind
declare void @dx.op.bufferStore.i32(i32, %dx.types.Handle, i32, i32, i32, i32, i32, i32, i8) #2
attributes #0 = { nounwind readnone }
attributes #1 = { nounwind readonly }
attributes #2 = { nounwind }
!llvm.ident = !{!0}
!dx.version = !{!1}
!dx.valver = !{!2}
!dx.shaderModel = !{!3}
!dx.resources = !{!4}
!dx.entryPoints = !{!9}
!0 = !{!"clang version 3.7 (tags/RELEASE_370/final)"}
!1 = !{i32 1, i32 0}
!2 = !{i32 1, i32 7}
!3 = !{!"cs", i32 6, i32 0}
!4 = !{!5, !7, null, null}
!5 = !{!6}
!6 = !{i32 0, %struct.ByteAddressBuffer* undef, !"", i32 0, i32 0, i32 1, i32 11, i32 0, null}
!7 = !{!8}
!8 = !{i32 0, %struct.RWByteAddressBuffer* undef, !"", i32 0, i32 1, i32 1, i32 11, i1 false, i1 false, i1 false, null}
!9 = !{void ()* @main, !"main", null, !4, !10}
!10 = !{i32 0, i64 16, i32 4, !11}
!11 = !{i32 1, i32 1, i32 1}
; CHECK-WHOLE-SPIR-V:
; ; SPIR-V
; ; Version: 1.0
; ; Generator: Google spiregg; 0
; ; Bound: 13
; ; Schema: 0
; OpCapability Shader
; OpMemoryModel Logical GLSL450
; OpEntryPoint GLCompute %main "main"
; OpExecutionMode %main LocalSize 1 1 1
; OpName %type_ByteAddressBuffer "type.ByteAddressBuffer"
; OpName %type_RWByteAddressBuffer "type.RWByteAddressBuffer"
; OpName %main "main"
; OpDecorate %_runtimearr_uint ArrayStride 4
; OpMemberDecorate %type_ByteAddressBuffer 0 Offset 0
; OpMemberDecorate %type_ByteAddressBuffer 0 NonWritable
; OpDecorate %type_ByteAddressBuffer BufferBlock
; OpMemberDecorate %type_RWByteAddressBuffer 0 Offset 0
; OpDecorate %type_RWByteAddressBuffer BufferBlock
; %uint = OpTypeInt 32 0
; %_runtimearr_uint = OpTypeRuntimeArray %uint
; %type_ByteAddressBuffer = OpTypeStruct %_runtimearr_uint
; %_ptr_Uniform_type_ByteAddressBuffer = OpTypePointer Uniform %type_ByteAddressBuffer
; %type_RWByteAddressBuffer = OpTypeStruct %_runtimearr_uint
; %_ptr_Uniform_type_RWByteAddressBuffer = OpTypePointer Uniform %type_RWByteAddressBuffer
; %void = OpTypeVoid
; %11 = OpTypeFunction %void
; %6 = OpVariable %_ptr_Uniform_type_ByteAddressBuffer Uniform
; %9 = OpVariable %_ptr_Uniform_type_RWByteAddressBuffer Uniform
; %main = OpFunction %void None %11
; %12 = OpLabel
; OpReturn
; OpFunctionEnd
; CHECK-ERRORS:
; error: Unhandled DXIL opcode: CreateHandle
; error: Unhandled DXIL opcode: CreateHandle
; error: Unhandled DXIL opcode: ThreadId
; error: Unhandled DXIL opcode: BufferLoad
; error: Unhandled DXIL opcode: BufferStore

Просмотреть файл

@ -24,6 +24,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MemoryBuffer.h"
#include "dxc/DXIL/DxilOperations.h"
#include "spirv-tools/libspirv.hpp"
#include "clang/CodeGen/CodeGenAction.h"
#include "clang/Frontend/CodeGenOptions.h"
@ -114,7 +115,25 @@ int Translator::Run() {
program.GetOutputSignature().GetElements());
// Create entry function.
createEntryFunction(program.GetEntryFunction());
spirv::SpirvFunction *entryFunction =
createEntryFunction(program.GetEntryFunction());
// Set execution mode if necessary.
if (spvContext.isPS()) {
spvBuilder.addExecutionMode(entryFunction,
spv::ExecutionMode::OriginUpperLeft, {}, {});
}
if (spvContext.isCS()) {
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{program.GetNumThreads(0),
program.GetNumThreads(1),
program.GetNumThreads(2)},
{});
}
// Add HLSL resources.
createModuleVariables(program.GetSRVs());
createModuleVariables(program.GetUAVs());
// Contsruct the SPIR-V module.
std::vector<uint32_t> m = spvBuilder.takeModuleForDxilToSpv();
@ -178,7 +197,20 @@ void Translator::createStageIOVariables(
}
}
void Translator::createEntryFunction(llvm::Function *entryFunction) {
void Translator::createModuleVariables(
const std::vector<std::unique_ptr<hlsl::DxilResource>> &resources) {
for (const std::unique_ptr<hlsl::DxilResource> &resource : resources) {
llvm::Type *hlslType = resource->GetHLSLType();
assert(hlslType->isPointerTy());
llvm::Type *pointeeType =
cast<llvm::PointerType>(hlslType)->getPointerElementType();
spvBuilder.addModuleVar(toSpirvType(pointeeType),
spv::StorageClass::Uniform, false);
}
}
spirv::SpirvFunction *
Translator::createEntryFunction(llvm::Function *entryFunction) {
spirv::SpirvFunction *spirvEntryFunction =
spvBuilder.beginFunction(toSpirvType(entryFunction->getReturnType()), {},
entryFunction->getName());
@ -200,12 +232,7 @@ void Translator::createEntryFunction(llvm::Function *entryFunction) {
spirv::SpirvUtils::getSpirvShaderStage(
spvContext.getCurrentShaderModelKind()),
spirvEntryFunction, spirvEntryFunction->getFunctionName(), interfaceVars);
// Set execution mode if necessary.
if (spvContext.isPS()) {
spvBuilder.addExecutionMode(spirvEntryFunction,
spv::ExecutionMode::OriginUpperLeft, {}, {});
}
return spirvEntryFunction;
}
void Translator::createBasicBlock(llvm::BasicBlock &basicBlock) {
@ -234,7 +261,8 @@ void Translator::createInstruction(llvm::Instruction &instruction) {
createStoreOutputInstruction(callInstruction);
} break;
default: {
emitError("Unhandled DXIL opcode");
emitError("Unhandled DXIL opcode: %0")
<< hlsl::OP::GetOpCodeName(dxilOpcode);
} break;
}
} else if (isa<llvm::ReturnInst>(instruction)) {
@ -324,7 +352,8 @@ const spirv::SpirvType *Translator::toSpirvType(hlsl::CompType compType) {
else if (compType.IsUIntTy())
return spvContext.getUIntType(compType.GetSizeInBits());
llvm_unreachable("Unhandled DXIL Component Type");
emitError("Unhandled DXIL Component Type: %0") << compType.GetName();
return nullptr;
}
const spirv::SpirvType *
@ -348,7 +377,46 @@ Translator::toSpirvType(hlsl::DxilSignatureElement *elem) {
const spirv::SpirvType *Translator::toSpirvType(llvm::Type *llvmType) {
if (llvmType->isVoidTy())
return new spirv::VoidType();
return toSpirvType(hlsl::CompType::GetCompType(llvmType));
if (llvmType->isIntegerTy() || llvmType->isFloatingPointTy())
return toSpirvType(hlsl::CompType::GetCompType(llvmType));
if (llvmType->isStructTy()) {
return toSpirvType(cast<llvm::StructType>(llvmType));
}
std::string typeStr;
llvm::raw_string_ostream os(typeStr);
llvmType->print(os);
emitError("Unhandled LLVM type: %0") << os.str();
return nullptr;
}
const spirv::SpirvType *Translator::toSpirvType(llvm::StructType *structType) {
// Remove prefix from struct name.
llvm::StringRef prefix = "struct.";
llvm::StringRef name = structType->getName();
if (name.startswith(prefix))
name = name.drop_front(prefix.size());
// ByteAddressBuffer types.
if (name == "ByteAddressBuffer") {
return spvContext.getByteAddressBufferType(/*isRW*/ false);
}
// RWByteAddressBuffer types.
if (name == "RWByteAddressBuffer") {
return spvContext.getByteAddressBufferType(/*isRW*/ true);
}
// Other struct types.
std::vector<spirv::StructType::FieldInfo> fields;
fields.reserve(structType->getNumElements());
for (llvm::Type *elemType : structType->elements()) {
fields.emplace_back(toSpirvType(elemType));
}
return spvContext.getStructType(fields, name);
}
template <unsigned N>

Просмотреть файл

@ -12,6 +12,7 @@
#ifndef __DXIL2SPV_DXIL2SPV__
#define __DXIL2SPV_DXIL2SPV__
#include "dxc/DXIL/DxilResource.h"
#include "dxc/DXIL/DxilSignature.h"
#include "dxc/Support/SPIRVOptions.h"
#include "dxc/dxcapi.h"
@ -55,8 +56,12 @@ private:
const std::vector<std::unique_ptr<hlsl::DxilSignatureElement>>
&outputSignature);
// Create SPIR-V module variables from DXIL resources.
void createModuleVariables(
const std::vector<std::unique_ptr<hlsl::DxilResource>> &resources);
// Create SPIR-V entry function from DXIL function.
void createEntryFunction(llvm::Function *function);
spirv::SpirvFunction *createEntryFunction(llvm::Function *function);
// Create SPIR-V basic block from DXIL basic block.
void createBasicBlock(llvm::BasicBlock &basicBlock);
@ -73,6 +78,7 @@ private:
const spirv::SpirvType *toSpirvType(hlsl::CompType compType);
const spirv::SpirvType *toSpirvType(hlsl::DxilSignatureElement *elem);
const spirv::SpirvType *toSpirvType(llvm::Type *llvmType);
const spirv::SpirvType *toSpirvType(llvm::StructType *structType);
template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]);
};

Просмотреть файл

@ -20,4 +20,8 @@ TEST_F(WholeFileTest, PassThruVertexShader) {
runWholeFileTest("passthru-vs.ll");
}
TEST_F(WholeFileTest, PassThruComputeShader) {
runWholeFileTest("passthru-cs.ll");
}
} // namespace