Support case struct parameter used as function call argument. (#429)

* Support case struct parameter used as function call argument.
Also skip input param copy.
This commit is contained in:
Xiang Li 2017-07-14 13:13:25 -07:00 коммит произвёл GitHub
Родитель add0d6ec4c
Коммит 94acf99408
4 изменённых файлов: 121 добавлений и 1 удалений

Просмотреть файл

@ -3895,10 +3895,18 @@ public:
if (F.user_empty())
continue;
}
// Skip void(void) functions.
if (F.getReturnType()->isVoidTy() && F.arg_size() == 0)
continue;
WorkList.emplace_back(&F);
}
// Preprocess aggregate function param used as function call arg.
for (Function *F : WorkList) {
preprocessArgUsedInCall(F);
}
// Process the worklist
while (!WorkList.empty()) {
Function *F = WorkList.front();
@ -3988,6 +3996,7 @@ public:
private:
void DeleteDeadInstructions();
void preprocessArgUsedInCall(Function *F);
void moveFunctionBody(Function *F, Function *flatF);
void replaceCall(Function *F, Function *flatF);
void createFlattenedFunction(Function *F);
@ -5347,6 +5356,92 @@ void SROA_Parameter_HLSL::flattenArgument(
}
// For function parameter which used in function call and need to be flattened.
// Replace with tmp alloca.
void SROA_Parameter_HLSL::preprocessArgUsedInCall(Function *F) {
if (F->isDeclaration())
return;
const DataLayout &DL = m_pHLModule->GetModule()->getDataLayout();
DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(F);
DXASSERT(pFuncAnnot, "else invalid function");
IRBuilder<> AllocaBuilder(F->getEntryBlock().getFirstInsertionPt());
SmallVector<ReturnInst*, 2> retList;
for (BasicBlock &bb : F->getBasicBlockList()) {
if (ReturnInst *RI = dyn_cast<ReturnInst>(bb.getTerminator())) {
retList.emplace_back(RI);
}
}
for (Argument &arg : F->args()) {
Type *Ty = arg.getType();
// Only check pointer types.
if (!Ty->isPointerTy())
continue;
Ty = Ty->getPointerElementType();
// Skip scalar types.
if (!Ty->isAggregateType() &&
Ty->getScalarType() == Ty)
continue;
bool bUsedInCall = false;
for (User *U : arg.users()) {
if (CallInst *CI = dyn_cast<CallInst>(U)) {
Function *CalledF = CI->getCalledFunction();
HLOpcodeGroup group = GetHLOpcodeGroup(CalledF);
// Skip HL operations.
if (group != HLOpcodeGroup::NotHL ||
group == HLOpcodeGroup::HLExtIntrinsic) {
continue;
}
// Skip llvm intrinsic.
if (CalledF->isIntrinsic())
continue;
bUsedInCall = true;
break;
}
}
if (bUsedInCall) {
// Create tmp.
Value *TmpArg = AllocaBuilder.CreateAlloca(Ty);
// Replace arg with tmp.
arg.replaceAllUsesWith(TmpArg);
DxilParameterAnnotation &paramAnnot = pFuncAnnot->GetParameterAnnotation(arg.getArgNo());
DxilParamInputQual inputQual = paramAnnot.GetParamInputQual();
unsigned size = DL.getTypeAllocSize(Ty);
// Copy between arg and tmp.
if (inputQual == DxilParamInputQual::In ||
inputQual == DxilParamInputQual::Inout) {
// copy arg to tmp.
CallInst *argToTmp = AllocaBuilder.CreateMemCpy(TmpArg, &arg, size, 0);
// Split the memcpy.
MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(argToTmp), DL, nullptr,
typeSys);
}
if (inputQual == DxilParamInputQual::Out ||
inputQual == DxilParamInputQual::Inout) {
for (ReturnInst *RI : retList) {
IRBuilder<> Builder(RI);
// copy tmp to arg.
CallInst *tmpToArg =
AllocaBuilder.CreateMemCpy(&arg, TmpArg, size, 0);
// Split the memcpy.
MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(tmpToArg), DL, nullptr,
typeSys);
}
}
// TODO: support other DxilParamInputQual.
}
}
}
/// moveFunctionBlocks - Move body of F to flatF.
void SROA_Parameter_HLSL::moveFunctionBody(Function *F, Function *flatF) {
bool updateRetType = F->getReturnType() != flatF->getReturnType();

Просмотреть файл

@ -3030,7 +3030,10 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
LValue L = EmitLValue(cast<CastExpr>(E)->getSubExpr());
assert(L.isSimple());
if (L.getAlignment() >= getContext().getTypeAlignInChars(type)) {
args.add(L.asAggregateRValue(), type, /*NeedsCopy*/true);
// HLSL Change Begin - don't copy input arg.
// Copy for out param is done at CGMSHLSLRuntime::EmitHLSLOutParamConversion*.
args.add(L.asAggregateRValue(), type); // /*NeedsCopy*/true);
// HLSL Change End
} else {
// We can't represent a misaligned lvalue in the CallArgList, so copy
// to an aligned temporary now.

Просмотреть файл

@ -0,0 +1,17 @@
// RUN: %dxc -T lib_6_1 %s | FileCheck %s
// Make sure function call on external function is flattened.
// CHECK: call void @"\01?test_extern@@YAMUT@@Y01U1@0AAV?$matrix@M$01$01@@@Z"(float %{{.*}}, float %{{.*}}, [2 x float]* nonnull %{{.*}}, [2 x float]* nonnull %{{.*}}, float* nonnull %{{.*}}, float* nonnull %{{.*}}, [4 x float]* nonnull %{{.*}}, float* nonnull %{{.*}})
struct T {
float a;
float b;
};
float test_extern(T t, T t2[2], out T t3, inout float2x2 m);
float test(T t, T t2[2], out T t3, inout float2x2 m)
{
return test_extern(t, t2, t3, m);
}

Просмотреть файл

@ -520,6 +520,7 @@ public:
TEST_METHOD(CodeGenIntrinsic5)
TEST_METHOD(CodeGenInvalidInputOutputTypes)
TEST_METHOD(CodeGenLegacyStruct)
TEST_METHOD(CodeGenLibArgFlatten)
TEST_METHOD(CodeGenLibCsEntry)
TEST_METHOD(CodeGenLibCsEntry2)
TEST_METHOD(CodeGenLibCsEntry3)
@ -3064,6 +3065,10 @@ TEST_F(CompilerTest, CodeGenLegacyStruct) {
CodeGenTestCheck(L"..\\CodeGenHLSL\\legacy_struct.hlsl");
}
TEST_F(CompilerTest, CodeGenLibArgFlatten) {
CodeGenTestCheck(L"..\\CodeGenHLSL\\lib_arg_flatten.hlsl");
}
TEST_F(CompilerTest, CodeGenLibCsEntry) {
CodeGenTestCheck(L"..\\CodeGenHLSL\\lib_cs_entry.hlsl");
}