[spirv] handle hull shader OutputPatch (#4119)

The existing code creates a temporary variable with Output storage class
for the parameter with OutputPatch type for the patch constant
functions. However, this adds an additional output stage variable and
consumes more locations. In addition, it can also change the rendering
result depending on the driver.

This commit removes the additional output stage variable and use the
actual output stage variables for the argument passing for the
OutputPatch. This is correctly working because the output stage variable
keeps the output value for all invocation id and we can simply reuse it.
This commit is contained in:
Jaebaek Seo 2021-12-03 13:54:58 +09:00 коммит произвёл GitHub
Родитель 235c801d6e
Коммит b6360c6c0c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 129 добавлений и 110 удалений

Просмотреть файл

@ -2986,8 +2986,12 @@ patch constant function. This would include information about each of the ``N``
vertices that are input to the tessellation control shader.
OutputPatch is an array containing ``N`` elements (where ``N`` is the number of
output vertices). Each element of the array contains information about an
output vertex. OutputPatch may also be passed to the patch constant function.
output vertices). Each element of the array is the hull shader output for each
output vertex. For example, each element of ``OutputPatch<HSOutput, 3>`` is each
output value of the hull shader function for each ``SV_OutputControlPointID``.
It is shared between threads i.e., in the patch constant function, threads for
the same patch must see the same values for the elements of
``OutputPatch<HSOutput, 3>``.
The SPIR-V ``InvocationID`` (``SV_OutputControlPointID`` in HLSL) is used to index
into the InputPatch and OutputPatch arrays to read/write information for the given
@ -3009,7 +3013,11 @@ As mentioned above, the patch constant function is to be invoked only once per p
As a result, in the SPIR-V module, the `entry function wrapper`_ will first invoke the
main entry function, and then use an ``OpControlBarrier`` to wait for all vertex
processing to finish. After the barrier, *only* the first thread (with InvocationID of 0)
will invoke the patch constant function.
will invoke the patch constant function. Since the first thread has to see the
OutputPatch that contains output of the hull shader function for other threads,
we have to use the output stage variable (with Output storage class) of the
hull shader function for OutputPatch that can be an input to the patch constant
function.
The information resulting from the patch constant function will also be returned
as stage output variables. The output struct of the patch constant function must include

Просмотреть файл

@ -3234,43 +3234,6 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
return var;
}
SpirvVariable *DeclResultIdMapper::createSpirvIntermediateOutputStageVar(
const NamedDecl *decl, const llvm::StringRef name, QualType type) {
const auto *semantic = hlsl::Semantic::GetByName(name);
SemanticInfo thisSemantic{name, semantic, name, 0, decl->getLocation()};
const auto *sigPoint =
deduceSigPoint(cast<DeclaratorDecl>(decl), /*asInput=*/false,
spvContext.getCurrentShaderModelKind(), /*forPCF=*/false);
StageVar stageVar(sigPoint, thisSemantic, decl->getAttr<VKBuiltInAttr>(),
type, /*locCount=*/1);
SpirvVariable *varInstr =
createSpirvStageVar(&stageVar, decl, name, thisSemantic.loc);
if (!varInstr)
return nullptr;
stageVar.setSpirvInstr(varInstr);
stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
stageVar.setIndexAttr(decl->getAttr<VKIndexAttr>());
if (stageVar.getStorageClass() == spv::StorageClass::Input ||
stageVar.getStorageClass() == spv::StorageClass::Output) {
stageVar.setEntryPoint(entryFunction);
}
stageVars.push_back(stageVar);
// Emit OpDecorate* instructions to link this stage variable with the HLSL
// semantic it is created for.
spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr());
// We have semantics attached to this decl, which means it must be a
// function/parameter/variable. All are DeclaratorDecls.
stageVarInstructions[cast<DeclaratorDecl>(decl)] = varInstr;
return varInstr;
}
SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
StageVar *stageVar, const NamedDecl *decl, const llvm::StringRef name,
SourceLocation srcLoc) {
@ -4008,21 +3971,6 @@ void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
astDecls[varDecl].instr = constVal;
}
SpirvInstruction *
DeclResultIdMapper::createHullMainOutputPatch(const ParmVarDecl *param,
const QualType retType,
uint32_t numOutputControlPoints) {
const QualType hullMainRetType = astContext.getConstantArrayType(
retType, llvm::APInt(32, numOutputControlPoints),
clang::ArrayType::Normal, 0);
SpirvInstruction *hullMainOutputPatch = createSpirvIntermediateOutputStageVar(
param, "temp.var.hullMainRetVal", hullMainRetType);
assert(astDecls[param].instr == nullptr);
astDecls[param].instr = hullMainOutputPatch;
return hullMainOutputPatch;
}
template <typename Functor>
void DeclResultIdMapper::decorateWithIntrinsicAttrs(const NamedDecl *decl,
SpirvVariable *varInst,
@ -4048,5 +3996,53 @@ void DeclResultIdMapper::decorateVariableWithIntrinsicAttrs(
decorateWithIntrinsicAttrs(decl, varInst, [](VKDecorateExtAttr *) {});
}
void DeclResultIdMapper::copyHullOutStageVarsToOutputPatch(
SpirvInstruction *hullMainOutputPatch, const ParmVarDecl *outputPatchDecl,
QualType outputControlPointType, uint32_t numOutputControlPoints) {
for (uint32_t outputCtrlPoint = 0; outputCtrlPoint < numOutputControlPoints;
++outputCtrlPoint) {
SpirvConstant *index = spvBuilder.getConstantInt(
astContext.UnsignedIntTy, llvm::APInt(32, outputCtrlPoint));
auto *tempLocation = spvBuilder.createAccessChain(
outputControlPointType, hullMainOutputPatch, {index}, /*loc=*/{});
storeOutStageVarsToStorage(cast<DeclaratorDecl>(outputPatchDecl), index,
outputControlPointType, tempLocation);
}
}
void DeclResultIdMapper::storeOutStageVarsToStorage(
const DeclaratorDecl *outputPatchDecl, SpirvConstant *ctrlPointID,
QualType outputControlPointType, SpirvInstruction *ptr) {
if (!outputControlPointType->isStructureType()) {
const auto found = stageVarInstructions.find(outputPatchDecl);
if (found == stageVarInstructions.end()) {
emitError("Shader output variable '%0' was not created", {})
<< outputPatchDecl->getName();
}
auto *ptrToOutputStageVar = spvBuilder.createAccessChain(
outputControlPointType, found->second, {ctrlPointID}, /*loc=*/{});
auto *load = spvBuilder.createLoad(outputControlPointType,
ptrToOutputStageVar, /*loc=*/{});
spvBuilder.createStore(ptr, load, /*loc=*/{});
return;
}
const auto *recordType = outputControlPointType->getAs<RecordType>();
assert(recordType != nullptr);
const auto *structDecl = recordType->getDecl();
assert(structDecl != nullptr);
uint32_t index = 0;
for (const auto *field : structDecl->fields()) {
SpirvConstant *indexInst = spvBuilder.getConstantInt(
astContext.UnsignedIntTy, llvm::APInt(32, index));
auto *tempLocation = spvBuilder.createAccessChain(field->getType(), ptr,
{indexInst}, /*loc=*/{});
storeOutStageVarsToStorage(cast<DeclaratorDecl>(field), ctrlPointID,
field->getType(), tempLocation);
++index;
}
}
} // end namespace spirv
} // end namespace clang

Просмотреть файл

@ -447,11 +447,15 @@ public:
/// VarDecls (such as some ray tracing enums).
void tryToCreateImplicitConstVar(const ValueDecl *);
/// \brief Creates a variable for hull shader output patch with Output
/// storage class, and registers the SPIR-V variable for the given decl.
SpirvInstruction *createHullMainOutputPatch(const ParmVarDecl *param,
const QualType retType,
uint32_t numOutputControlPoints);
/// \brief Creates instructions to copy output stage variables defined by
/// outputPatchDecl to hullMainOutputPatch that is a variable for the
/// OutputPatch argument passing. outputControlPointType is the template
/// parameter type of OutputPatch and numOutputControlPoints is the number of
/// output control points.
void copyHullOutStageVarsToOutputPatch(SpirvInstruction *hullMainOutputPatch,
const ParmVarDecl *outputPatchDecl,
QualType outputControlPointType,
uint32_t numOutputControlPoints);
/// \brief An enum class for representing what the DeclContext is used for
enum class ContextUsageKind {
@ -612,6 +616,17 @@ public:
void decorateVariableWithIntrinsicAttrs(const NamedDecl *decl,
SpirvVariable *varInst);
/// \brief Creates instructions to load the value of output stage variable
/// defined by outputPatchDecl and store it to ptr. Since the output stage
/// variable for OutputPatch is an array whose number of elements is the
/// number of output control points, we need ctrlPointID to indicate which
/// output control point is the target for copy. outputControlPointType is the
/// template parameter type of OutputPatch.
void storeOutStageVarsToStorage(const DeclaratorDecl *outputPatchDecl,
SpirvConstant *ctrlPointID,
QualType outputControlPointType,
SpirvInstruction *ptr);
private:
/// \brief Wrapper method to create a fatal error message and report it
/// in the diagnostic engine associated with this consumer.
@ -726,11 +741,6 @@ private:
const llvm::StringRef name,
SourceLocation);
// Create intermediate output variable to communicate patch constant
// data in hull shader since workgroup memory is not allowed there.
SpirvVariable *createSpirvIntermediateOutputStageVar(
const NamedDecl *decl, const llvm::StringRef name, QualType asType);
/// Returns true if all vk:: attributes usages are valid.
bool validateVKAttributes(const NamedDecl *decl);

Просмотреть файл

@ -1263,13 +1263,6 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
// Create all parameters.
for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
const ParmVarDecl *paramDecl = decl->getParamDecl(i);
if (spvContext.isHS() && decl == patchConstFunc &&
hlsl::IsHLSLOutputPatchType(paramDecl->getType())) {
// Since the output patch used in hull shaders is translated to
// a variable with Output storage class, there is no need
// to pass the variable as function parameter in SPIR-V.
continue;
}
(void)declIdMapper.createFnParam(paramDecl, i + 1 + isNonStaticMemberFn);
}
@ -11803,18 +11796,6 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
return false;
}
SpirvInstruction *hullMainOutputPatch = nullptr;
// If the patch constant function (PCF) takes the result of the Hull main
// entry point, create a temporary function-scope variable and write the
// results to it, so it can be passed to the PCF.
if (const auto *param = patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
hullMainOutputPatch = declIdMapper.createHullMainOutputPatch(
param, retType, numOutputControlPoints);
auto *tempLocation = spvBuilder.createAccessChain(
retType, hullMainOutputPatch, {outputControlPointId}, locEnd);
spvBuilder.createStore(tempLocation, retVal, locEnd);
}
// Now create a barrier before calling the Patch Constant Function (PCF).
// Flags are:
// Execution Barrier scope = Workgroup (2)
@ -11824,6 +11805,21 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
spv::MemorySemanticsMask::MaskNone,
spv::Scope::Workgroup, {});
SpirvInstruction *hullMainOutputPatch = nullptr;
// If the patch constant function (PCF) takes the result of the Hull main
// entry point, create a temporary function-scope variable and write the
// results to it, so it can be passed to the PCF.
if (const ParmVarDecl *outputPatchDecl =
patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
const QualType hullMainRetType = astContext.getConstantArrayType(
retType, llvm::APInt(32, numOutputControlPoints),
clang::ArrayType::Normal, 0);
hullMainOutputPatch =
spvBuilder.addFnVar(hullMainRetType, locEnd, "temp.var.hullMainRetVal");
declIdMapper.copyHullOutStageVarsToOutputPatch(
hullMainOutputPatch, outputPatchDecl, retType, numOutputControlPoints);
}
// The PCF should be called only once. Therefore, we check the invocationID,
// and we only allow ID 0 to call the PCF.
auto *condition = spvBuilder.createBinaryOp(
@ -11871,10 +11867,7 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
if (hlsl::IsHLSLInputPatchType(param->getType())) {
pcfParams.push_back(hullMainInputPatch);
} else if (hlsl::IsHLSLOutputPatchType(param->getType())) {
// Since the output patch used in hull shaders is translated to
// a variable with Workgroup storage class, there is no need
// to pass the variable as function parameter in SPIR-V.
continue;
pcfParams.push_back(hullMainOutputPatch);
} else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
if (!primitiveId) {
primitiveId = createParmVarAndInitFromStageInputVar(param);

Просмотреть файл

@ -10,23 +10,19 @@ struct HSPatchConstData {
float4 constData : CONSTANTDATA;
};
// CHECK: OpDecorate %temp_var_hullMainRetVal Location 2
// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Output__arr_HSCtrlPt_uint_3 Output
// CHECK: [[invoc_id:%\d+]] = OpLoad %uint %gl_InvocationID
// CHECK: [[HSResult:%\d+]] = OpFunctionCall %HSCtrlPt %src_main
// CHECK: [[OutCtrl:%\d+]] = OpAccessChain %_ptr_Output_HSCtrlPt %temp_var_hullMainRetVal [[invoc_id]]
// CHECK: OpStore [[OutCtrl]] [[HSResult]]
// CHECK: OpFunctionCall %HSPatchConstData %HSPatchConstantFunc %temp_var_hullMainRetVal
HSPatchConstData HSPatchConstantFunc(const OutputPatch<HSCtrlPt, 3> input) {
HSPatchConstData data;
// CHECK: [[OutCtrl0:%\d+]] = OpAccessChain %_ptr_Output_v4float %temp_var_hullMainRetVal %uint_0 %int_0
// CHECK: %input = OpFunctionParameter %_ptr_Function__arr_HSCtrlPt_uint_3
// CHECK: [[OutCtrl0:%\d+]] = OpAccessChain %_ptr_Function_v4float %input %uint_0 %int_0
// CHECK: [[input0:%\d+]] = OpLoad %v4float [[OutCtrl0]]
// CHECK: [[OutCtrl1:%\d+]] = OpAccessChain %_ptr_Output_v4float %temp_var_hullMainRetVal %uint_1 %int_0
// CHECK: [[OutCtrl1:%\d+]] = OpAccessChain %_ptr_Function_v4float %input %uint_1 %int_0
// CHECK: [[input1:%\d+]] = OpLoad %v4float [[OutCtrl1]]
// CHECK: [[add:%\d+]] = OpFAdd %v4float [[input0]] [[input1]]
// CHECK: [[OutCtrl2:%\d+]] = OpAccessChain %_ptr_Output_v4float %temp_var_hullMainRetVal %uint_2 %int_0
// CHECK: [[OutCtrl2:%\d+]] = OpAccessChain %_ptr_Function_v4float %input %uint_2 %int_0
// CHECK: [[input2:%\d+]] = OpLoad %v4float [[OutCtrl2]]
// CHECK: OpFAdd %v4float [[add]] [[input2]]
data.constData = input[0].ctrlPt + input[1].ctrlPt + input[2].ctrlPt;

Просмотреть файл

@ -5,23 +5,38 @@
// Test: PCF takes the output (OutputPatch) of the main entry point function.
// CHECK: %_arr_BEZIER_CONTROL_POINT_uint_16 = OpTypeArray %BEZIER_CONTROL_POINT %uint_16
// CHECK: %_ptr_Output__arr_BEZIER_CONTROL_POINT_uint_16 = OpTypePointer Output %_arr_BEZIER_CONTROL_POINT_uint_16
// CHECK: %_arr_BEZIER_CONTROL_POINT_uint_3 = OpTypeArray %BEZIER_CONTROL_POINT %uint_3
// CHECK: %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_3 = OpTypePointer Function %_arr_BEZIER_CONTROL_POINT_uint_3
// CHECK: [[fType:%\d+]] = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT
// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Output__arr_BEZIER_CONTROL_POINT_uint_16 Output
// CHECK: %main = OpFunction %void None {{%\d+}}
// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_3 Function
// CHECK: [[id:%\d+]] = OpLoad %uint %gl_InvocationID
// CHECK: [[mainResult:%\d+]] = OpFunctionCall %BEZIER_CONTROL_POINT %src_main %param_var_ip %param_var_i %param_var_PatchID
// CHECK: [[loc:%\d+]] = OpAccessChain %_ptr_Output_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal [[id]]
// CHECK: OpStore [[loc]] [[mainResult]]
// CHECK: {{%\d+}} = OpFunctionCall %HS_CONSTANT_DATA_OUTPUT %PCF
// CHECK: [[output_patch_0:%\d+]] = OpAccessChain %_ptr_Function_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal %uint_0
// CHECK: [[output_patch_0_0:%\d+]] = OpAccessChain %_ptr_Function_v3float [[output_patch_0]] %uint_0
// CHECK: [[out_var_BEZIERPOS_0:%\d+]] = OpAccessChain %_ptr_Output_v3float %out_var_BEZIERPOS %uint_0
// CHECK: [[BEZIERPOS_0:%\d+]] = OpLoad %v3float [[out_var_BEZIERPOS_0]]
// CHECK: OpStore [[output_patch_0_0]] [[BEZIERPOS_0]]
// CHECK: [[output_patch_1:%\d+]] = OpAccessChain %_ptr_Function_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal %uint_1
// CHECK: [[output_patch_1_0:%\d+]] = OpAccessChain %_ptr_Function_v3float [[output_patch_1]] %uint_0
// CHECK: [[out_var_BEZIERPOS_1:%\d+]] = OpAccessChain %_ptr_Output_v3float %out_var_BEZIERPOS %uint_1
// CHECK: [[BEZIERPOS_1:%\d+]] = OpLoad %v3float [[out_var_BEZIERPOS_1]]
// CHECK: OpStore [[output_patch_1_0]] [[BEZIERPOS_1]]
// CHECK: [[output_patch_2:%\d+]] = OpAccessChain %_ptr_Function_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal %uint_2
// CHECK: [[output_patch_2_0:%\d+]] = OpAccessChain %_ptr_Function_v3float [[output_patch_2]] %uint_0
// CHECK: [[out_var_BEZIERPOS_2:%\d+]] = OpAccessChain %_ptr_Output_v3float %out_var_BEZIERPOS %uint_2
// CHECK: [[BEZIERPOS_2:%\d+]] = OpLoad %v3float [[out_var_BEZIERPOS_2]]
// CHECK: OpStore [[output_patch_2_0]] [[BEZIERPOS_2]]
// CHECK: {{%\d+}} = OpFunctionCall %HS_CONSTANT_DATA_OUTPUT %PCF %temp_var_hullMainRetVal
// CHECK: %PCF = OpFunction %HS_CONSTANT_DATA_OUTPUT None [[fType]]
HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, 3> op) {
HS_CONSTANT_DATA_OUTPUT Output;
// Must initialize Edges and Inside; otherwise HLSL validation will fail.
Output.Edges[0] = 1.0;
@ -36,9 +51,9 @@ HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
[domain("isoline")]
[partitioning("fractional_odd")]
[outputtopology("line")]
[outputcontrolpoints(16)]
[outputcontrolpoints(3)]
[patchconstantfunc("PCF")]
BEZIER_CONTROL_POINT main(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POINTS> ip, uint i : SV_OutputControlPointID, uint PatchID : SV_PrimitiveID) {
BEZIER_CONTROL_POINT main(InputPatch<VS_CONTROL_POINT_OUTPUT, 3> ip, uint i : SV_OutputControlPointID, uint PatchID : SV_PrimitiveID) {
VS_CONTROL_POINT_OUTPUT vsOutput;
BEZIER_CONTROL_POINT result;
result.vPosition = vsOutput.vPosition;

Просмотреть файл

@ -31,6 +31,7 @@ struct HS_CONSTANT_DATA_OUTPUT
};
HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
// CHECK: %op = OpFunctionParameter %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16
HS_CONSTANT_DATA_OUTPUT Output;
// Must initialize Edges and Inside; otherwise HLSL validation will fail.
Output.Edges[0] = 1.0;
@ -42,12 +43,12 @@ HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
uint x = 5;
// CHECK: [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Output_v3float %temp_var_hullMainRetVal %uint_1 %int_0
// CHECK: [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Function_v3float %op %uint_1 %int_0
// CHECK-NEXT: {{%\d+}} = OpLoad %v3float [[op_1_loc]]
float3 out1pos = op[1].vPosition;
// CHECK: [[x:%\d+]] = OpLoad %uint %x
// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Output_uint %temp_var_hullMainRetVal [[x]] %int_1
// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Function_uint %op [[x]] %int_1
// CHECK-NEXT: {{%\d+}} = OpLoad %uint [[op_x_loc]]
uint out5id = op[x].pointID;