Fix constexpr cast involving matrix type (#3593)

This commit is contained in:
Vishal Sharma 2021-03-17 15:36:14 -07:00 коммит произвёл GitHub
Родитель 640c9af748
Коммит 63ce61aee5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 273 добавлений и 34 удалений

Просмотреть файл

@ -636,6 +636,72 @@ public:
return Visit(E->getInitializer());
}
// HLSL changes begin
static void ExtractConstantValueElems(llvm::Constant *constVec, llvm::SmallVector<llvm::Constant*, 4> &Elems, unsigned vecSize) {
if (llvm::ConstantDataVector *CDV = dyn_cast<llvm::ConstantDataVector>(constVec)) {
for (unsigned c = 0; c < vecSize; c++) {
Elems[c] = CDV->getElementAsConstant(c);
}
}
else if (llvm::ConstantVector *CV = dyn_cast<llvm::ConstantVector>(constVec)) {
for (unsigned c = 0; c < vecSize; c++) {
Elems[c] = CV->getOperand(c);
}
}
else {
llvm::ConstantAggregateZero *CAZ = cast<llvm::ConstantAggregateZero>(constVec);
for (unsigned c = 0; c < vecSize; c++) {
Elems[c] = CAZ->getElementValue(c);
}
}
}
static llvm::Constant* ConvertToMatchDestType (const clang::Type *srcTy, const clang::Type *destTy,
llvm::Type *srcLLVMTy, llvm::Type *destLLVMTy, llvm::Constant *C, CodeGenModule &CGM) {
assert(srcTy->isFloatingType() || srcTy->isIntegerType());
assert(destTy->isFloatingType() || destTy->isIntegerType());
// Special handling for cast to boolean type
if (destLLVMTy->isIntegerTy() && destLLVMTy->getScalarSizeInBits() == 1) {
return C->isZeroValue() ? llvm::ConstantInt::get(destLLVMTy, 0)
: llvm::ConstantInt::get(destLLVMTy, 1);
}
llvm::Instruction::CastOps castOp = llvm::Instruction::CastOpsEnd;
if (srcLLVMTy->isFloatingPointTy() && destLLVMTy->isFloatingPointTy()) {
if (srcLLVMTy->getScalarSizeInBits() > destLLVMTy->getScalarSizeInBits()) {
castOp = llvm::Instruction::FPTrunc;
}
else {
castOp = llvm::Instruction::FPExt;
}
}
else if (srcLLVMTy->isFloatingPointTy() && destLLVMTy->isIntegerTy()) {
castOp = destTy->isSignedIntegerType() ? llvm::Instruction::FPToSI : llvm::Instruction::FPToUI;
}
else if (srcLLVMTy->isIntegerTy() && destLLVMTy->isFloatingPointTy()) {
castOp = srcTy->isSignedIntegerType() ? llvm::Instruction::SIToFP : llvm::Instruction::UIToFP;
}
else {
// Both src and dest should be of integer type here.
assert(srcLLVMTy->isIntegerTy() && destLLVMTy->isIntegerTy());
if (srcLLVMTy->getScalarSizeInBits() > destLLVMTy->getScalarSizeInBits()) {
castOp = llvm::Instruction::Trunc;
}
else {
castOp = srcTy->isSignedIntegerType() ? llvm::Instruction::SExt : llvm::Instruction::ZExt;
}
}
assert(castOp != llvm::Instruction::CastOpsEnd);
return llvm::ConstantExpr::getCast(castOp, C, destLLVMTy);
}
// HLSL changes end
llvm::Constant *VisitCastExpr(CastExpr* E) {
Expr *subExpr = E->getSubExpr();
llvm::Constant *C = CGM.EmitConstantExpr(subExpr, subExpr->getType(), CGF);
@ -748,10 +814,68 @@ public:
case CK_HLSLCC_IntegralToBoolean:
case CK_HLSLCC_IntegralToFloating:
case CK_HLSLCC_FloatingToIntegral:
case CK_HLSLCC_FloatingToBoolean:
// Since these cast kinds have already been handled in ExprConstant.cpp,
// we can reuse the logic there.
return CGM.EmitConstantExpr(E, E->getType(), CGF);
case CK_HLSLCC_FloatingToBoolean: {
bool isMatrixCast = hlsl::IsHLSLMatType(E->getType()) && hlsl::IsHLSLMatType(E->getSubExpr()->getType());
if (!isMatrixCast) {
// Since these cast kinds have already been handled in ExprConstant.cpp,
// we can reuse the logic there.
return CGM.EmitConstantExpr(E, E->getType(), CGF);
}
else {
// For cast involving matrix type, if the subexperssion has already
// been successfully evaluated to a constant, then just cast it to
// match the destination type.
llvm::Constant *SubExprResult = C;
const clang::Type * srcEltType = hlsl::GetHLSLMatElementType(E->getSubExpr()->getType()).getCanonicalType().getTypePtr();
const clang::Type * destEltType = hlsl::GetHLSLMatElementType(E->getType()).getCanonicalType().getTypePtr();
// If the dest type is same as the src type, then trivially
// return the result of the subexpression evaluation.
llvm::Type *srcEltLLVMTy = CGM.getTypes().ConvertType(srcEltType->getCanonicalTypeInternal());
llvm::Type *destEltLLVMTy = CGM.getTypes().ConvertType(destEltType->getCanonicalTypeInternal());
// Use desugared llvm type for comparison as half and float could both mean float type
// when -enable-16bit-types flag is not used.
if (srcEltLLVMTy == destEltLLVMTy) {
return SubExprResult;
}
unsigned destRow, destCol;
hlsl::GetHLSLMatRowColCount(E->getType(), destRow, destCol);
unsigned srcRow, srcCol;
hlsl::GetHLSLMatRowColCount(E->getSubExpr()->getType(), srcRow, srcCol);
// Src and Dest matrices must have same order
assert(destRow == srcRow && destCol == srcCol);
if (llvm::ConstantStruct *srcVal = dyn_cast<llvm::ConstantStruct>(SubExprResult)) {
llvm::ConstantArray *srcMat = cast<llvm::ConstantArray>(srcVal->getOperand(0));
llvm::SmallVector<llvm::Constant*, 4> destRowElts;
for (unsigned r = 0; r < srcRow; r++) {
llvm::SmallVector<llvm::Constant*, 4> destColElts(srcCol);
llvm::Constant *srcColVal = srcMat->getOperand(r);
ExtractConstantValueElems(srcColVal, destColElts, srcCol);
for (unsigned i = 0; i < srcCol; i++) {
destColElts[i] = ConvertToMatchDestType(srcEltType, destEltType, srcEltLLVMTy, destEltLLVMTy, destColElts[i], CGM);
}
llvm::Constant *destCols = llvm::ConstantVector::get(destColElts);
destRowElts.emplace_back(destCols);
}
llvm::StructType *destValType = cast<llvm::StructType>(destType);
llvm::Constant *destMat = llvm::ConstantArray::get(
cast<llvm::ArrayType>(destValType->getElementType(0)), destRowElts);
llvm::Constant* destVal = llvm::ConstantStruct::get(destValType, destMat);
return destVal;
}
else if (llvm::ConstantAggregateZero *CAZ = dyn_cast<llvm::ConstantAggregateZero>(SubExprResult)) {
return llvm::Constant::getNullValue(destType);
}
}
}
case CK_FlatConversion:
return nullptr;
case CK_HLSLVectorSplat: {
@ -773,54 +897,51 @@ public:
case CK_HLSLVectorTruncationCast: {
unsigned vecSize = hlsl::GetHLSLVecSize(E->getType());
SmallVector<llvm::Constant*, 4> Elts(vecSize);
if (llvm::ConstantDataVector *CDV = dyn_cast<llvm::ConstantDataVector>(C)) {
for (unsigned i = 0; i < vecSize; i++)
Elts[i] = CDV->getElementAsConstant(i);
} else if (llvm::ConstantVector* CV = dyn_cast<llvm::ConstantVector>(C)) {
for (unsigned i = 0; i < vecSize; i++)
Elts[i] = CV->getOperand(i);
} else {
llvm::ConstantAggregateZero* CAZ = cast<llvm::ConstantAggregateZero>(C);
for (unsigned i = 0; i < vecSize; i++)
Elts[i] = CAZ->getElementValue(i);
}
ExtractConstantValueElems(C, Elts, vecSize);
return llvm::ConstantVector::get(Elts);
}
case CK_HLSLVectorToScalarCast: {
if (llvm::ConstantDataVector* CDV = dyn_cast<llvm::ConstantDataVector>(C)) {
return CDV->getElementAsConstant(0);
SmallVector<llvm::Constant*, 4> Elts(1);
ExtractConstantValueElems(C, Elts, 1);
return Elts[0];
}
case CK_HLSLMatrixToScalarCast: {
unsigned rowCt, colCt;
hlsl::GetHLSLMatRowColCount(E->getType(), rowCt, colCt);
if (llvm::ConstantStruct *CS = dyn_cast<llvm::ConstantStruct>(C)) {
llvm::ConstantArray *CA = dyn_cast<llvm::ConstantArray>(CS->getOperand(0));
SmallVector<llvm::Constant*, 4> Elts(colCt);
ExtractConstantValueElems(CA->getOperand(0), Elts, colCt);
return Elts[0];
}
else if (llvm::ConstantVector* CV = dyn_cast<llvm::ConstantVector>(C)) {
return CV->getOperand(0);
} else {
llvm::ConstantAggregateZero* CAZ = cast<llvm::ConstantAggregateZero>(C);
return CAZ->getElementValue((unsigned)0);
else if (llvm::ConstantAggregateZero *CAZ = dyn_cast<llvm::ConstantAggregateZero>(C)) {
llvm::Constant *destVal = llvm::Constant::getNullValue(destType);
return destVal;
}
}
case CK_HLSLMatrixTruncationCast: {
llvm::StructType *ST =
cast<llvm::StructType>(CGM.getTypes().ConvertType(E->getType()));
unsigned rowCt,colCt;
hlsl::GetHLSLMatRowColCount(E->getType(), rowCt, colCt);
if (llvm::ConstantStruct *CS = dyn_cast<llvm::ConstantStruct>(C)) {
unsigned rowCt, colCt;
hlsl::GetHLSLMatRowColCount(E->getType(), rowCt, colCt);
llvm::ConstantArray *CA = dyn_cast<llvm::ConstantArray>(CS->getOperand(0));
SmallVector<llvm::Constant *, 4> Rows(rowCt);
for (unsigned i = 0; i < rowCt; i++) {
SmallVector<llvm::Constant*, 4> Elts(colCt);
if (llvm::ConstantDataVector *CDV = dyn_cast<llvm::ConstantDataVector>(CA->getOperand(i))) {
for (unsigned j = 0; j < colCt; j++)
Elts[j] = CDV->getElementAsConstant(j);
} else {
llvm::ConstantVector *CV = cast<llvm::ConstantVector>(CA->getOperand(i));
for (unsigned j = 0; j < colCt; j++)
Elts[j] = CV->getOperand(j);
}
ExtractConstantValueElems(CA->getOperand(i), Elts, colCt);
Rows[i] = llvm::ConstantVector::get(Elts);
}
// Create truncated matrix
llvm::StructType *ST =
cast<llvm::StructType>(CGM.getTypes().ConvertType(E->getType()));
llvm::Constant *Mat = llvm::ConstantArray::get(
cast<llvm::ArrayType>(ST->getElementType(0)), Rows);
return llvm::ConstantStruct::get(ST, Mat);
}
else if (llvm::ConstantAggregateZero *CAZ = dyn_cast<llvm::ConstantAggregateZero>(C)) {
llvm::Constant *destVal = llvm::Constant::getNullValue(destType);
return destVal;
}
}
// HLSL Change Ends.
}

Просмотреть файл

@ -0,0 +1,63 @@
// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
// RUN: %dxc -E main -T vs_6_2 -enable-16bit-types %s | FileCheck %s
// This test checks that constexpr truncation cast involving matrix type.
// CHECK: define void @main()
void main() : OUT {
const float v1 = min16float1x1(0);
const int v2 = min16float1x1(-1);
const uint v3 = int1x2(0, 0);
const double v4 = bool1x2(1, 2);
const bool v5 = float2x1(0, 0);
const min16int v6 = double2x1(1, 2);
const uint v7 = int2x2(0, 0, 0, 0);
const int v8 = min16uint2x2(1, 2, 3, 4);
const uint v9 = double2x3(0, 0, 0, 0, 0, 0);
const min16int v10 = min16float2x3(1, 2, 3, 4, 5, 6);
const uint v11 = double3x2(0, 0, 0, 0, 0, 0);
const min16int v12 = min16float3x2(1, 2, 3, 4, 5, 6);
const bool v13 = min16float3x3(0, 0, 0, 0, 0, 0, 0, 0, 0);
const uint v14 = min16int3x3(1, 2, 3, 4, 5, 6, 7, 8, 9);
const bool v15 = min16float4x4(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
const int v16 = double4x4(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
const uint1x1 v17 = int1x2(0, 0);
const double1x1 v18 = bool1x2(1, 2);
const bool1x1 v19 = float2x1(0, 0);
const min16int1x1 v20 = double2x1(1, 2);
const uint2x1 v21 = int2x2(0, 0, 0, 0);
const int2x1 v22 = min16uint2x2(1, 2, 3, 4);
const uint1x2 v23 = int2x2(0, 0, 0, 0);
const int1x2 v24 = min16uint2x2(1, 2, 3, 4);
const uint2x2 v25 = double2x3(0, 0, 0, 0, 0, 0);
const min16int2x1 v26 = min16float2x3(1, 2, 3, 4, 5, 6);
const uint1x2 v27 = double3x2(0, 0, 0, 0, 0, 0);
const min16int1x1 v28 = min16float3x2(1, 2, 3, 4, 5, 6);
const bool2x3 v29 = min16float3x3(0, 0, 0, 0, 0, 0, 0, 0, 0);
const uint3x2 v30 = min16int3x3(1, 2, 3, 4, 5, 6, 7, 8, 9);
const bool1x1 v31 = min16float3x3(0, 0, 0, 0, 0, 0, 0, 0, 0);
const uint2x2 v32 = min16int3x3(1, 2, 3, 4, 5, 6, 7, 8, 9);
const bool1x1 v33 = min16float4x4(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
const int2x3 v34 = double4x4(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
const bool3x3 v35 = min16float4x4(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
const int3x4 v36 = double4x4(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
}

Просмотреть файл

@ -0,0 +1,22 @@
// RUN: %dxc -E main -T ps_6_2 %s | FileCheck %s
// RUN: %dxc -E main -T ps_6_2 -enable-16bit-types %s | FileCheck %s
// This is a regression test for github issue #3041.
// CHECK: define void @main()
static const float sFoo = 1.5;
static const float3x3 sBar = half3x3
(
sFoo, 0.0f, -sFoo * 0.5f,
0.0f, sFoo, -sFoo * 0.5f,
0.0f, 0.0f, 1.0f
);
half4 main() : SV_Target
{
half3 result = half3(0.0, 0.0, 0.0);
return(float4(result, 0));
}

Просмотреть файл

@ -0,0 +1,33 @@
// RUN: %dxc -E main -T vs_6_2 %s | FileCheck %s
// RUN: %dxc -E main -T vs_6_2 -enable-16bit-types %s | FileCheck %s
// This test checks that constexpr cast between matrices of same dimension but different component type compile fines.
// CHECK: define void @main()
void main() : OUT {
const float1x1 v1 = min16float1x1(0);
const int1x1 v2 = min16float1x1(-1);
const uint1x2 v3 = int1x2(0, 0);
const double1x2 v4 = bool1x2(1, 2);
const bool2x1 v5 = float2x1(0, 0);
const min16int2x1 v6 = double2x1(1, 2);
const uint2x2 v7 = int2x2(0, 0, 0, 0);
const int2x2 v8 = min16uint2x2(1, 2, 3, 4);
const uint2x3 v9 = double2x3(0, 0, 0, 0, 0, 0);
const min16int2x3 v10 = min16float2x3(1, 2, 3, 4, 5, 6);
const uint3x2 v11 = double3x2(0, 0, 0, 0, 0, 0);
const min16int3x2 v12 = min16float3x2(1, 2, 3, 4, 5, 6);
const bool3x3 v13 = min16float3x3(0, 0, 0, 0, 0, 0, 0, 0, 0);
const uint3x3 v14 = min16int3x3(1, 2, 3, 4, 5, 6, 7, 8, 9);
const bool4x4 v15 = min16float4x4(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
const int4x4 v16 = double4x4(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
}