[SPIR-V] Fix for an issue when noninterpolation is decorated on structure inputs. (#6041)

**Fix for**

Fix issue under:
https://github.com/microsoft/DirectXShaderCompiler/issues/2955
For structure type, when `nointerpolation` is decorated on a structure
input, this flag should be broadcast to its members.

PR (https://github.com/microsoft/DirectXShaderCompiler/pull/6018) also
helps to resolve an issue when SPIRV backend generated variable has no
AST type.

**Test**
Please see same case under:
https://github.com/microsoft/DirectXShaderCompiler/pull/6018

Besides, example as below should be invalid as the first parameter of
function `compute` should has only one spirv-type.

If its parameter type declaration has been expanded to an array
implicitly, it should not accept other interpolated inputs as its input
parameter.

```
struct S {
  float4 a : COLOR;
};

float compute(float4 a) {
  return GetAttributeAtVertex(a, 2)[0];
}

float4 main(nointerpolation S s, float4 b : COLOR2) : SV_TARGET
//float4 main(nointerpolation S s) : SV_TARGET
{
  return float4(0, 0, compute(b), compute(s.a));
  //return float4(0, 0, 0, compute(s.a));
}
```
I added an error report point in this commit and gets following reports:
```
1.hlsl:12:31: error: Current function could only use noninterpolated variable as input.
  return float4(0, 0, compute(b), compute(s.a));
                              ^
fatal error: generated SPIR-V is invalid: OpFunctionCall Argument <id> '37[%param_var_a_0]'s type does not match Function <id> '24[%_ptr_Function_v4float]'s parameter type.
  %45 = OpFunctionCall %float %compute %param_var_a_0

note: please file a bug report on https://github.com/Microsoft/DirectXShaderCompiler/issues with source code if possible
```
Any idea to let the error only reported before spirv validator?

**Ref**:
https://github.com/microsoft/DirectXShaderCompiler/issues/2955
https://github.com/microsoft/DirectXShaderCompiler/pull/6018

---------

Co-authored-by: Zhou, Shaochi(AMD) <shaozhou@amd.com>
Co-authored-by: Natalie Chouinard <sudonatalie@google.com>
This commit is contained in:
Chow 2024-03-05 23:09:59 +08:00 коммит произвёл GitHub
Родитель 3da718cd61
Коммит 359b3e9dc5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
7 изменённых файлов: 149 добавлений и 15 удалений

Просмотреть файл

@ -943,8 +943,14 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
type, "in.var", loadedValue);
} else {
StageVarDataBundle stageVarData = {
paramDecl, &inheritSemantic, false, sigPoint,
type, arraySize, "in.var", llvm::None};
paramDecl,
&inheritSemantic,
paramDecl->hasAttr<HLSLNoInterpolationAttr>(),
sigPoint,
type,
arraySize,
"in.var",
llvm::None};
return createStageVars(stageVarData, /*asInput=*/true, loadedValue,
/*noWriteBack=*/false);
}

Просмотреть файл

@ -125,14 +125,11 @@ PervertexInputVisitor::createFirstPerVertexVar(SpirvInstruction *base,
createVertexStore(vtx, createVertexLoad(base));
return vtx;
}
SpirvInstruction *PervertexInputVisitor::createProvokingVertexAccessChain(
SpirvInstruction *base, uint32_t index, QualType resultType) {
SpirvInstruction *PervertexInputVisitor::createVertexAccessChain(
QualType resultType, SpirvInstruction *base,
llvm::ArrayRef<SpirvInstruction *> indexes) {
auto loc = base->getSourceLocation();
auto range = base->getSourceRange();
llvm::SmallVector<SpirvInstruction *, 1> indexes;
indexes.push_back(spirvBuilder.getConstantInt(astContext.UnsignedIntTy,
llvm::APInt(32, index)));
SpirvInstruction *instruction =
new (context) SpirvAccessChain(resultType, loc, base, indexes, range);
instruction->setStorageClass(spv::StorageClass::Function);
@ -143,6 +140,16 @@ SpirvInstruction *PervertexInputVisitor::createProvokingVertexAccessChain(
return instruction;
}
SpirvInstruction *PervertexInputVisitor::createProvokingVertexAccessChain(
SpirvInstruction *base, uint32_t index, QualType resultType) {
llvm::SmallVector<SpirvInstruction *, 1> indexes;
indexes.push_back(spirvBuilder.getConstantInt(astContext.UnsignedIntTy,
llvm::APInt(32, index)));
SpirvInstruction *instruction =
createVertexAccessChain(resultType, base, indexes);
return instruction;
}
SpirvVariable *
PervertexInputVisitor::addFunctionTempVar(llvm::StringRef varName,
QualType valueType,
@ -237,7 +244,8 @@ bool PervertexInputVisitor::visit(SpirvFunction *sf, Phase phase) {
m_instrReplaceMap[var] = vtx0;
}
for (auto *param : currentFunc->getParameters()) {
if (!param->isNoninterpolated())
if (!param->isNoninterpolated() ||
param->getAstResultType().getTypePtr()->isStructureType())
continue;
auto *vtx0 =
createProvokingVertexAccessChain(param, 0, param->getAstResultType());
@ -306,11 +314,71 @@ bool PervertexInputVisitor::visit(SpirvFunctionCall *inst) {
return true;
/// Load/Store instructions related to this argument may have been replaced
/// with other instructions, so we need to get its original mapped variables.
for (auto *arg : inst->getArgs())
if (currentFunc->getMappedFuncParam(arg)) {
createVertexStore(arg,
createVertexLoad(currentFunc->getMappedFuncParam(arg)));
unsigned argIndex = 0;
for (auto *arg : inst->getArgs()) {
auto paramVar = currentFunc->getMappedFuncParam(arg);
if (paramVar) {
if (isa<SpirvAccessChain>(paramVar)) {
auto tempVar = paramVar;
while (isa<SpirvAccessChain>(tempVar)) {
tempVar = dyn_cast<SpirvAccessChain>(tempVar)->getBase();
}
if (tempVar->isNoninterpolated()) {
/// When function parameters have a structure type, some local
/// variables may be created and mapped to an stage inputs
/// in 'src.main' block.
///
/// We use first vertex value of those non-interpolated inputs to
/// replace normal usage of those local variables in HLSL and SPIRV.
///
/// But when those variables are then used in a function call as
/// its arguments, we need to copy the values for all of the vertices
/// to the local variable. This means copying an entire array.
///
/// At this point, original access chain to those member variables
/// have been appended an zero index at the end to access first
/// vertex for replacement before.
///
/// Hence we need to recreate a new access chain instruction and
/// and pass argument as an array to this function call.
auto paramAccessChain = dyn_cast<SpirvAccessChain>(paramVar);
auto indexes = paramAccessChain->getIndexes();
auto elemType = astContext.getConstantArrayType(
paramAccessChain->getAstResultType(), llvm::APInt(32, 3),
clang::ArrayType::Normal, 0);
llvm::SmallVector<SpirvInstruction *, 4> indices(indexes.begin(),
indexes.end());
indices.pop_back();
paramVar = createVertexAccessChain(
elemType, paramAccessChain->getBase(), indices);
}
}
createVertexStore(arg, createVertexLoad(paramVar));
}
auto funcParam = inst->getFunction()->getParameters()[argIndex];
if (arg->isNoninterpolated()) {
/// Broadcast nointerpolated flag to each called function which uses a
/// nointerpolated variable as its functionCall parameter within a call
/// chain.
funcParam->setNoninterpolated();
}
paramCaller[funcParam].push_back(arg);
if (funcParam->isNoninterpolated()) {
/// Error: this broadcast process is from top to lower, hence this
/// argument should be noninterpolated (will be expanded) here.
/// When any matched param is noninterpolated, it means one or more
/// noninterpolated variable will be passed as an expanded array
for (auto caller : paramCaller[funcParam])
if (!caller->isNoninterpolated()) {
emitError("Function '%0' could only use noninterpolated variable "
"as input.",
caller->getSourceLocation())
<< inst->getFunction()->getFunctionName().data();
return 0;
}
}
argIndex++;
}
currentFunc->addInstrCacheToFront();
return true;
}

Просмотреть файл

@ -69,6 +69,10 @@ public:
void createVertexStore(SpirvInstruction *pt, SpirvInstruction *obj);
SpirvInstruction *
createVertexAccessChain(QualType resultType, SpirvInstruction *base,
llvm::ArrayRef<SpirvInstruction *> indexes);
///< Visit different SPIR-V constructs for emitting.
using Visitor::visit;
bool visit(SpirvModule *, Phase phase) override;
@ -127,6 +131,17 @@ private:
ASTContext &astContext;
SpirvModule *currentMod;
SpirvFunction *currentFunc;
llvm::DenseMap<SpirvFunctionParameter *, std::vector<SpirvInstruction *>>
paramCaller;
/// Emits error to the diagnostic engine associated with this visitor.
template <unsigned N>
DiagnosticBuilder emitError(const char (&message)[N],
SourceLocation srcLoc = {}) {
const auto diagId = astContext.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Error, message);
return astContext.getDiagnostics().Report(srcLoc, diagId);
}
};
} // end namespace spirv

Просмотреть файл

@ -320,7 +320,11 @@ SpirvStore *SpirvBuilder::createStore(SpirvInstruction *address,
}
if (isa<SpirvLoad>(value) && isa<SpirvVariable>(address)) {
if (isa<SpirvFunctionParameter>(dyn_cast<SpirvLoad>(value)->getPointer()))
auto paramPtr = dyn_cast<SpirvLoad>(value)->getPointer();
while (isa<SpirvAccessChain>(paramPtr)) {
paramPtr = dyn_cast<SpirvAccessChain>(paramPtr)->getBase();
}
if (isa<SpirvFunctionParameter>(paramPtr))
function->addFuncParamVarEntry(address,
dyn_cast<SpirvLoad>(value)->getPointer());
}

Просмотреть файл

@ -3018,7 +3018,8 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
// inside the function, not the variables at the call sites. Therefore, we
// do not need to mark the "param.var.*" variables as precise.
const bool isPrecise = false;
const bool isNoInterp = param->hasAttr<HLSLNoInterpolationAttr>();
const bool isNoInterp = param->hasAttr<HLSLNoInterpolationAttr>() ||
(argInst && argInst->isNoninterpolated());
auto *tempVar = spvBuilder.addFnVar(varType, arg->getLocStart(), varName,
isPrecise, isNoInterp);

Просмотреть файл

@ -0,0 +1,24 @@
// RUN: %dxc -T ps_6_1 -E main %s -spirv -fcgl 2>&1 | FileCheck %s
struct S {
float4 a : COLOR;
};
float compute(float4 a) {
return GetAttributeAtVertex(a, 2)[0];
}
float4 main(nointerpolation S s) : SV_TARGET
{
return float4(0, 0, 0, compute(s.a));
}
// CHECK: [[param_var_a:%[a-zA-Z0-9_]+]] = OpVariable %_ptr_Function__arr_v4float_uint_3 Function
// CHECK: [[inst32:%[0-9_]+]] = OpAccessChain %_ptr_Function_v4float [[param_var_a]] %uint_0
// CHECK: [[inst33:%[0-9_]+]] = OpAccessChain %_ptr_Function__arr_v4float_uint_3 [[s:%[a-zA-Z0-9_]+]] %int_0
// CHECK: [[inst34:%[0-9_]+]] = OpLoad %_arr_v4float_uint_3 [[inst33]]
// CHECK: OpStore [[param_var_a]] [[inst34]]
// CHECK: [[inst35:%[0-9_]+]] = OpAccessChain %_ptr_Function_v4float [[s]] %int_0 %uint_0
// CHECK: [[inst36:%[0-9_]+]] = OpLoad %v4float [[inst35]]
// CHECK: OpStore [[inst32]] [[inst36]]
// CHECK: [[inst37:%[0-9_]+]] = OpFunctionCall %float %compute [[param_var_a]]

Просмотреть файл

@ -0,0 +1,16 @@
// RUN: not %dxc -T ps_6_1 -E main %s -spirv 2>&1 | FileCheck %s
struct S {
float4 a : COLOR;
};
float compute(float4 a) {
return GetAttributeAtVertex(a, 2)[0];
}
float4 main(nointerpolation S s, float4 b : COLOR2) : SV_TARGET
{
return float4(0, 0, compute(b), compute(s.a));
}
// CHECK: error: Function 'compute' could only use noninterpolated variable as input.