Support case struct parameter used as function call argument. (#429)
* Support case struct parameter used as function call argument. Also skip input param copy.
This commit is contained in:
Родитель
add0d6ec4c
Коммит
94acf99408
|
@ -3895,10 +3895,18 @@ public:
|
|||
if (F.user_empty())
|
||||
continue;
|
||||
}
|
||||
// Skip void(void) functions.
|
||||
if (F.getReturnType()->isVoidTy() && F.arg_size() == 0)
|
||||
continue;
|
||||
|
||||
WorkList.emplace_back(&F);
|
||||
}
|
||||
|
||||
// Preprocess aggregate function param used as function call arg.
|
||||
for (Function *F : WorkList) {
|
||||
preprocessArgUsedInCall(F);
|
||||
}
|
||||
|
||||
// Process the worklist
|
||||
while (!WorkList.empty()) {
|
||||
Function *F = WorkList.front();
|
||||
|
@ -3988,6 +3996,7 @@ public:
|
|||
|
||||
private:
|
||||
void DeleteDeadInstructions();
|
||||
void preprocessArgUsedInCall(Function *F);
|
||||
void moveFunctionBody(Function *F, Function *flatF);
|
||||
void replaceCall(Function *F, Function *flatF);
|
||||
void createFlattenedFunction(Function *F);
|
||||
|
@ -5347,6 +5356,92 @@ void SROA_Parameter_HLSL::flattenArgument(
|
|||
|
||||
}
|
||||
|
||||
// For function parameter which used in function call and need to be flattened.
|
||||
// Replace with tmp alloca.
|
||||
void SROA_Parameter_HLSL::preprocessArgUsedInCall(Function *F) {
|
||||
if (F->isDeclaration())
|
||||
return;
|
||||
|
||||
const DataLayout &DL = m_pHLModule->GetModule()->getDataLayout();
|
||||
|
||||
DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
|
||||
DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(F);
|
||||
DXASSERT(pFuncAnnot, "else invalid function");
|
||||
|
||||
IRBuilder<> AllocaBuilder(F->getEntryBlock().getFirstInsertionPt());
|
||||
|
||||
SmallVector<ReturnInst*, 2> retList;
|
||||
for (BasicBlock &bb : F->getBasicBlockList()) {
|
||||
if (ReturnInst *RI = dyn_cast<ReturnInst>(bb.getTerminator())) {
|
||||
retList.emplace_back(RI);
|
||||
}
|
||||
}
|
||||
|
||||
for (Argument &arg : F->args()) {
|
||||
Type *Ty = arg.getType();
|
||||
// Only check pointer types.
|
||||
if (!Ty->isPointerTy())
|
||||
continue;
|
||||
Ty = Ty->getPointerElementType();
|
||||
// Skip scalar types.
|
||||
if (!Ty->isAggregateType() &&
|
||||
Ty->getScalarType() == Ty)
|
||||
continue;
|
||||
bool bUsedInCall = false;
|
||||
for (User *U : arg.users()) {
|
||||
if (CallInst *CI = dyn_cast<CallInst>(U)) {
|
||||
Function *CalledF = CI->getCalledFunction();
|
||||
HLOpcodeGroup group = GetHLOpcodeGroup(CalledF);
|
||||
// Skip HL operations.
|
||||
if (group != HLOpcodeGroup::NotHL ||
|
||||
group == HLOpcodeGroup::HLExtIntrinsic) {
|
||||
continue;
|
||||
}
|
||||
// Skip llvm intrinsic.
|
||||
if (CalledF->isIntrinsic())
|
||||
continue;
|
||||
|
||||
bUsedInCall = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (bUsedInCall) {
|
||||
// Create tmp.
|
||||
Value *TmpArg = AllocaBuilder.CreateAlloca(Ty);
|
||||
// Replace arg with tmp.
|
||||
arg.replaceAllUsesWith(TmpArg);
|
||||
|
||||
DxilParameterAnnotation ¶mAnnot = pFuncAnnot->GetParameterAnnotation(arg.getArgNo());
|
||||
DxilParamInputQual inputQual = paramAnnot.GetParamInputQual();
|
||||
unsigned size = DL.getTypeAllocSize(Ty);
|
||||
// Copy between arg and tmp.
|
||||
if (inputQual == DxilParamInputQual::In ||
|
||||
inputQual == DxilParamInputQual::Inout) {
|
||||
// copy arg to tmp.
|
||||
CallInst *argToTmp = AllocaBuilder.CreateMemCpy(TmpArg, &arg, size, 0);
|
||||
// Split the memcpy.
|
||||
MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(argToTmp), DL, nullptr,
|
||||
typeSys);
|
||||
}
|
||||
if (inputQual == DxilParamInputQual::Out ||
|
||||
inputQual == DxilParamInputQual::Inout) {
|
||||
for (ReturnInst *RI : retList) {
|
||||
IRBuilder<> Builder(RI);
|
||||
// copy tmp to arg.
|
||||
CallInst *tmpToArg =
|
||||
AllocaBuilder.CreateMemCpy(&arg, TmpArg, size, 0);
|
||||
// Split the memcpy.
|
||||
MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(tmpToArg), DL, nullptr,
|
||||
typeSys);
|
||||
}
|
||||
}
|
||||
// TODO: support other DxilParamInputQual.
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// moveFunctionBlocks - Move body of F to flatF.
|
||||
void SROA_Parameter_HLSL::moveFunctionBody(Function *F, Function *flatF) {
|
||||
bool updateRetType = F->getReturnType() != flatF->getReturnType();
|
||||
|
|
|
@ -3030,7 +3030,10 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
|
|||
LValue L = EmitLValue(cast<CastExpr>(E)->getSubExpr());
|
||||
assert(L.isSimple());
|
||||
if (L.getAlignment() >= getContext().getTypeAlignInChars(type)) {
|
||||
args.add(L.asAggregateRValue(), type, /*NeedsCopy*/true);
|
||||
// HLSL Change Begin - don't copy input arg.
|
||||
// Copy for out param is done at CGMSHLSLRuntime::EmitHLSLOutParamConversion*.
|
||||
args.add(L.asAggregateRValue(), type); // /*NeedsCopy*/true);
|
||||
// HLSL Change End
|
||||
} else {
|
||||
// We can't represent a misaligned lvalue in the CallArgList, so copy
|
||||
// to an aligned temporary now.
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
// RUN: %dxc -T lib_6_1 %s | FileCheck %s
|
||||
|
||||
// Make sure function call on external function is flattened.
|
||||
|
||||
// CHECK: call void @"\01?test_extern@@YAMUT@@Y01U1@0AAV?$matrix@M$01$01@@@Z"(float %{{.*}}, float %{{.*}}, [2 x float]* nonnull %{{.*}}, [2 x float]* nonnull %{{.*}}, float* nonnull %{{.*}}, float* nonnull %{{.*}}, [4 x float]* nonnull %{{.*}}, float* nonnull %{{.*}})
|
||||
|
||||
struct T {
|
||||
float a;
|
||||
float b;
|
||||
};
|
||||
|
||||
float test_extern(T t, T t2[2], out T t3, inout float2x2 m);
|
||||
|
||||
float test(T t, T t2[2], out T t3, inout float2x2 m)
|
||||
{
|
||||
return test_extern(t, t2, t3, m);
|
||||
}
|
|
@ -520,6 +520,7 @@ public:
|
|||
TEST_METHOD(CodeGenIntrinsic5)
|
||||
TEST_METHOD(CodeGenInvalidInputOutputTypes)
|
||||
TEST_METHOD(CodeGenLegacyStruct)
|
||||
TEST_METHOD(CodeGenLibArgFlatten)
|
||||
TEST_METHOD(CodeGenLibCsEntry)
|
||||
TEST_METHOD(CodeGenLibCsEntry2)
|
||||
TEST_METHOD(CodeGenLibCsEntry3)
|
||||
|
@ -3064,6 +3065,10 @@ TEST_F(CompilerTest, CodeGenLegacyStruct) {
|
|||
CodeGenTestCheck(L"..\\CodeGenHLSL\\legacy_struct.hlsl");
|
||||
}
|
||||
|
||||
TEST_F(CompilerTest, CodeGenLibArgFlatten) {
|
||||
CodeGenTestCheck(L"..\\CodeGenHLSL\\lib_arg_flatten.hlsl");
|
||||
}
|
||||
|
||||
TEST_F(CompilerTest, CodeGenLibCsEntry) {
|
||||
CodeGenTestCheck(L"..\\CodeGenHLSL\\lib_cs_entry.hlsl");
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче