HLMatrixLower: Handle unflattened lib function matrix return val and param.
This commit is contained in:
Родитель
ea73161d6f
Коммит
2ae113596a
|
@ -269,6 +269,9 @@ private:
|
|||
void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
|
||||
void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
|
||||
|
||||
// Get new matrix value corresponding to vecVal
|
||||
Value *GetMatrixForVec(Value *vecVal, Type *matTy);
|
||||
|
||||
// Replace matVal with vecVal on matUseInst.
|
||||
void TrivialMatReplace(Value *matVal, Value *vecVal,
|
||||
CallInst *matUseInst);
|
||||
|
@ -282,6 +285,10 @@ private:
|
|||
void DeleteDeadInsts();
|
||||
// Map from matrix value to its vector version.
|
||||
DenseMap<Value *, Value *> matToVecMap;
|
||||
// Map from new vector version to matrix version needed by user call or return.
|
||||
DenseMap<Value *, Value *> vecToMatMap;
|
||||
// Record matrix defining instructions that need preserving (in library functions).
|
||||
std::vector<Instruction*> matInstsToKeep;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -841,6 +848,20 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
|
|||
case HLOpcodeGroup::HLSubscript: {
|
||||
vecVal = MatSubscriptToVec(CI);
|
||||
} break;
|
||||
case HLOpcodeGroup::NotHL: {
|
||||
// Translate user function return
|
||||
vecVal = BitCastValueOrPtr( matInst,
|
||||
matInst->getNextNode(),
|
||||
HLMatrixLower::LowerMatrixType(matInst->getType()),
|
||||
/*bOrigAllocaTy*/ false,
|
||||
matInst->getName());
|
||||
// matrix equivalent of this new vector will be the original, retained user call
|
||||
vecToMatMap[vecVal] = matInst;
|
||||
// Add to matInstsToKeep so we don't delete this call
|
||||
matInstsToKeep.push_back(matInst);
|
||||
} break;
|
||||
default:
|
||||
DXASSERT(0, "invalid inst");
|
||||
}
|
||||
} else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
|
||||
Type *Ty = AI->getAllocatedType();
|
||||
|
@ -2069,6 +2090,23 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
|
|||
AddToDeadInsts(matGEP);
|
||||
}
|
||||
|
||||
Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
|
||||
Value *newMatVal = nullptr;
|
||||
if (vecToMatMap.count(vecVal)) {
|
||||
newMatVal = vecToMatMap[vecVal];
|
||||
} else {
|
||||
// create conversion instructions if necessary, caching result for subsequent replacements.
|
||||
// do so right after the vecVal def so it's available to all potential uses.
|
||||
newMatVal = BitCastValueOrPtr(vecVal,
|
||||
cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
|
||||
matTy,
|
||||
/*bOrigAllocaTy*/true,
|
||||
vecVal->getName());
|
||||
vecToMatMap[vecVal] = newMatVal;
|
||||
}
|
||||
return newMatVal;
|
||||
}
|
||||
|
||||
void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
||||
Value *vecVal) {
|
||||
for (Value::user_iterator user = matVal->user_begin();
|
||||
|
@ -2140,10 +2178,24 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|||
DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
|
||||
TranslateMatInit(useCall);
|
||||
} break;
|
||||
case HLOpcodeGroup::NotHL: {
|
||||
// translate user function parameters as necessary
|
||||
for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
|
||||
if (useCall->getArgOperand(i) == matVal) {
|
||||
// update the user call with the correct matrix value in new code sequence
|
||||
Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
||||
if (matVal != newMatVal)
|
||||
useCall->setArgOperand(i, newMatVal);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
}
|
||||
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
|
||||
// Just replace the src with vec version.
|
||||
useInst->setOperand(0, vecVal);
|
||||
} else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
|
||||
Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
||||
RI->setOperand(0, newMatVal);
|
||||
} else {
|
||||
// Must be GEP on mat array alloca.
|
||||
GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
|
||||
|
@ -2462,6 +2514,11 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|||
finalMatTranslation(matToVec->first);
|
||||
}
|
||||
|
||||
// Remove matInstsToKeep from matToVecMap before adding the rest to dead insts.
|
||||
for (auto I : matInstsToKeep) {
|
||||
matToVecMap.erase(I);
|
||||
}
|
||||
|
||||
// Delete the matrix version insts.
|
||||
for (auto matToVecIter = matToVecMap.begin();
|
||||
matToVecIter != matToVecMap.end();) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче