HLMatrixLower: Handle unflattened lib function matrix return val and param.

This commit is contained in:
Tex Riddell 2018-01-29 00:01:29 -08:00
Родитель ea73161d6f
Коммит 2ae113596a
1 изменённых файлов: 57 добавлений и 0 удалений

Просмотреть файл

@ -269,6 +269,9 @@ private:
void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
// Get new matrix value corresponding to vecVal
Value *GetMatrixForVec(Value *vecVal, Type *matTy);
// Replace matVal with vecVal on matUseInst.
void TrivialMatReplace(Value *matVal, Value *vecVal,
CallInst *matUseInst);
@ -282,6 +285,10 @@ private:
void DeleteDeadInsts();
// Map from matrix value to its vector version.
DenseMap<Value *, Value *> matToVecMap;
// Map from new vector version to matrix version needed by user call or return.
DenseMap<Value *, Value *> vecToMatMap;
// Record matrix defining instructions that need preserving (in library functions).
std::vector<Instruction*> matInstsToKeep;
};
}
@ -841,6 +848,20 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
case HLOpcodeGroup::HLSubscript: {
vecVal = MatSubscriptToVec(CI);
} break;
case HLOpcodeGroup::NotHL: {
// Translate user function return
vecVal = BitCastValueOrPtr( matInst,
matInst->getNextNode(),
HLMatrixLower::LowerMatrixType(matInst->getType()),
/*bOrigAllocaTy*/ false,
matInst->getName());
// matrix equivalent of this new vector will be the original, retained user call
vecToMatMap[vecVal] = matInst;
// Add to matInstsToKeep so we don't delete this call
matInstsToKeep.push_back(matInst);
} break;
default:
DXASSERT(0, "invalid inst");
}
} else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
Type *Ty = AI->getAllocatedType();
@ -2069,6 +2090,23 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
AddToDeadInsts(matGEP);
}
Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
Value *newMatVal = nullptr;
if (vecToMatMap.count(vecVal)) {
newMatVal = vecToMatMap[vecVal];
} else {
// create conversion instructions if necessary, caching result for subsequent replacements.
// do so right after the vecVal def so it's available to all potential uses.
newMatVal = BitCastValueOrPtr(vecVal,
cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
matTy,
/*bOrigAllocaTy*/true,
vecVal->getName());
vecToMatMap[vecVal] = newMatVal;
}
return newMatVal;
}
void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
Value *vecVal) {
for (Value::user_iterator user = matVal->user_begin();
@ -2140,10 +2178,24 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
TranslateMatInit(useCall);
} break;
case HLOpcodeGroup::NotHL: {
// translate user function parameters as necessary
for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
if (useCall->getArgOperand(i) == matVal) {
// update the user call with the correct matrix value in new code sequence
Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
if (matVal != newMatVal)
useCall->setArgOperand(i, newMatVal);
}
}
} break;
}
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
// Just replace the src with vec version.
useInst->setOperand(0, vecVal);
} else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
RI->setOperand(0, newMatVal);
} else {
// Must be GEP on mat array alloca.
GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
@ -2462,6 +2514,11 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
finalMatTranslation(matToVec->first);
}
// Remove matInstsToKeep from matToVecMap before adding the rest to dead insts.
for (auto I : matInstsToKeep) {
matToVecMap.erase(I);
}
// Delete the matrix version insts.
for (auto matToVecIter = matToVecMap.begin();
matToVecIter != matToVecMap.end();) {