Support matrix transpose. (#900)
This commit is contained in:
Родитель
d7a96e9e80
Коммит
455885b3fc
|
@ -251,7 +251,7 @@ private:
|
|||
void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
|
||||
CallInst *castInst);
|
||||
void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
|
||||
CallInst *castInst, bool rowToCol);
|
||||
CallInst *castInst, bool rowToCol, bool transpose);
|
||||
// Replace matInst with vecInst in matSubscript
|
||||
void TranslateMatSubscript(Value *matInst, Value *vecInst,
|
||||
CallInst *matSubInst);
|
||||
|
@ -1073,7 +1073,8 @@ void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
|
|||
Instruction *vecInst,
|
||||
CallInst *transposeInst) {
|
||||
// Matrix value is row major, transpose is cast it to col major.
|
||||
TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
|
||||
TranslateMatMajorCast(matInst, vecInst, transposeInst,
|
||||
/*bRowToCol*/ true, /*bTranspose*/ true);
|
||||
}
|
||||
|
||||
static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
|
||||
|
@ -1194,10 +1195,22 @@ void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
|
|||
void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
|
||||
Instruction *vecInst,
|
||||
CallInst *castInst,
|
||||
bool bRowToCol) {
|
||||
bool bRowToCol,
|
||||
bool bTranspose) {
|
||||
unsigned col, row;
|
||||
if (!bTranspose) {
|
||||
GetMatrixInfo(castInst->getType(), col, row);
|
||||
DXASSERT(castInst->getType() == matInst->getType(), "type must match");
|
||||
} else {
|
||||
unsigned castCol, castRow;
|
||||
Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
|
||||
unsigned srcCol, srcRow;
|
||||
Type *srcTy = GetMatrixInfo(matInst->getType(), srcCol, srcRow);
|
||||
DXASSERT(srcTy == castTy, "type must match");
|
||||
DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
|
||||
col = srcCol;
|
||||
row = srcRow;
|
||||
}
|
||||
|
||||
IRBuilder<> Builder(castInst);
|
||||
|
||||
|
@ -1321,7 +1334,8 @@ void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
|
|||
if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
|
||||
opcode == HLCastOpcode::RowMatrixToColMatrix) {
|
||||
TranslateMatMajorCast(matInst, vecInst, castInst,
|
||||
opcode == HLCastOpcode::RowMatrixToColMatrix);
|
||||
opcode == HLCastOpcode::RowMatrixToColMatrix,
|
||||
/*bTranspose*/false);
|
||||
} else {
|
||||
bool ToMat = IsMatrixType(castInst->getType());
|
||||
bool FromMat = IsMatrixType(matInst->getType());
|
||||
|
|
|
@ -3397,7 +3397,8 @@ static Value *CastLdValue(Value *Ptr, llvm::Type *FromTy, llvm::Type *ToTy, IRBu
|
|||
Value *V = Builder.CreateLoad(Ptr);
|
||||
// VectorTrunc
|
||||
// Change vector into vec1.
|
||||
return Builder.CreateShuffleVector(V, V, {0});
|
||||
int mask[] = {0};
|
||||
return Builder.CreateShuffleVector(V, V, mask);
|
||||
} else if (FromTy->isArrayTy()) {
|
||||
llvm::Type *FromEltTy = FromTy->getArrayElementType();
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
|
||||
|
||||
// Make sure get cb0[0].y and cb0[1].y.
|
||||
// CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
|
||||
// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 1
|
||||
// CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
|
||||
// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 1
|
||||
|
||||
row_major float2x3 m;
|
||||
|
||||
float2 main(int i : A) : SV_TARGET
|
||||
{
|
||||
return transpose(m)[1];
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
|
||||
|
||||
// Make sure get cb0[1].xy.
|
||||
// CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
|
||||
// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 0
|
||||
// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 1
|
||||
|
||||
float2x3 m;
|
||||
|
||||
float2 main(int i : A) : SV_TARGET
|
||||
{
|
||||
return transpose(m)[1];
|
||||
}
|
Загрузка…
Ссылка в новой задаче