Support const vector argument for isnan(), atan2() (#3908)

Fixes #3823, #3824
This commit is contained in:
Jaebaek Seo 2021-08-17 14:25:04 -04:00 коммит произвёл GitHub
Родитель 5dc5e42c8d
Коммит 5eb8a5b1eb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 95 добавлений и 27 удалений

Просмотреть файл

@ -897,6 +897,67 @@ void AddOpcodeParamForIntrinsics(
}
}
// Returns whether the first argument of CI is NaN or not. If the argument is
// a vector, returns a vector of boolean values.
Constant *IsNaN(CallInst *CI) {
Value *V = CI->getArgOperand(0);
llvm::Type *Ty = V->getType();
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
Constant *CV = cast<Constant>(V);
SmallVector<Constant *, 4> ConstVec;
llvm::Type *CIElemTy =
cast<llvm::VectorType>(CI->getType())->getElementType();
for (unsigned i = 0; i < VT->getNumElements(); i++) {
ConstantFP *fpV = cast<ConstantFP>(CV->getAggregateElement(i));
bool isNan = fpV->getValueAPF().isNaN();
ConstVec.push_back(ConstantInt::get(CIElemTy, isNan ? 1 : 0));
}
return ConstantVector::get(ConstVec);
} else {
ConstantFP *fV = cast<ConstantFP>(V);
bool isNan = fV->getValueAPF().isNaN();
return ConstantInt::get(CI->getType(), isNan ? 1 : 0);
}
}
// Returns a constant for atan2() intrinsic function for scalars.
Constant *Atan2ForScalar(llvm::Type *ResultTy, ConstantFP *fpV0,
ConstantFP *fpV1) {
if (ResultTy->isDoubleTy()) {
double dV0 = fpV0->getValueAPF().convertToDouble();
double dV1 = fpV1->getValueAPF().convertToDouble();
return ConstantFP::get(ResultTy, atan2(dV0, dV1));
} else {
DXASSERT_NOMSG(ResultTy->isFloatTy());
float fV0 = fpV0->getValueAPF().convertToFloat();
float fV1 = fpV1->getValueAPF().convertToFloat();
return ConstantFP::get(ResultTy, atan2f(fV0, fV1));
}
}
// Returns Value for atan2() intrinsic function. If the argument of CI has
// a vector type, it returns the vector value of atan2().
Value *Atan2(CallInst *CI) {
Value *V0 = CI->getArgOperand(0);
Value *V1 = CI->getArgOperand(1);
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(V0->getType())) {
Constant *CV0 = cast<Constant>(V0);
Constant *CV1 = cast<Constant>(V1);
SmallVector<Constant *, 4> ConstVec;
llvm::Type *CIElemTy =
cast<llvm::VectorType>(CI->getType())->getElementType();
for (unsigned i = 0; i < VT->getNumElements(); i++) {
ConstantFP *fpV0 = cast<ConstantFP>(CV0->getAggregateElement(i));
ConstantFP *fpV1 = cast<ConstantFP>(CV1->getAggregateElement(i));
ConstVec.push_back(Atan2ForScalar(CIElemTy, fpV0, fpV1));
}
return ConstantVector::get(ConstVec);
} else {
ConstantFP *fpV0 = cast<ConstantFP>(V0);
ConstantFP *fpV1 = cast<ConstantFP>(V1);
return Atan2ForScalar(CI->getType(), fpV0, fpV1);
}
}
} // namespace
namespace {
@ -1760,30 +1821,10 @@ Value *TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp,
return EvalUnaryIntrinsic(CI, atanf, atan);
} break;
case IntrinsicOp::IOP_atan2: {
Value *V0 = CI->getArgOperand(0);
ConstantFP *fpV0 = cast<ConstantFP>(V0);
Value *V1 = CI->getArgOperand(1);
ConstantFP *fpV1 = cast<ConstantFP>(V1);
llvm::Type *Ty = CI->getType();
Value *Result = nullptr;
if (Ty->isDoubleTy()) {
double dV0 = fpV0->getValueAPF().convertToDouble();
double dV1 = fpV1->getValueAPF().convertToDouble();
Value *atanV = ConstantFP::get(CI->getType(), atan2(dV0, dV1));
CI->replaceAllUsesWith(atanV);
Result = atanV;
} else {
DXASSERT_NOMSG(Ty->isFloatTy());
float fV0 = fpV0->getValueAPF().convertToFloat();
float fV1 = fpV1->getValueAPF().convertToFloat();
Value *atanV = ConstantFP::get(CI->getType(), atan2f(fV0, fV1));
CI->replaceAllUsesWith(atanV);
Result = atanV;
}
Value *atanV = Atan2(CI);
CI->replaceAllUsesWith(atanV);
CI->eraseFromParent();
return Result;
return atanV;
} break;
case IntrinsicOp::IOP_sqrt: {
return EvalUnaryIntrinsic(CI, sqrtf, sqrt);
@ -1885,10 +1926,7 @@ Value *TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp,
return EvalUnaryIntrinsic(CI, fracF, fracD);
} break;
case IntrinsicOp::IOP_isnan: {
Value *V = CI->getArgOperand(0);
ConstantFP *fV = cast<ConstantFP>(V);
bool isNan = fV->getValueAPF().isNaN();
Constant *cNan = ConstantInt::get(CI->getType(), isNan ? 1 : 0);
Constant *cNan = IsNaN(CI);
CI->replaceAllUsesWith(cNan);
CI->eraseFromParent();
return cNan;

Просмотреть файл

@ -0,0 +1,10 @@
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
// CHECK: call void @dx.op.storeOutput.f32(i32 [[outputSigId:[0-9]+]], i32 0, i32 0, i8 0, float 0.0
// CHECK: call void @dx.op.storeOutput.f32(i32 [[outputSigId]], i32 0, i32 0, i8 1, float 0.0
// CHECK: call void @dx.op.storeOutput.f32(i32 [[outputSigId]], i32 0, i32 0, i8 2, float 0.0
float3 main() : SV_TARGET
{
return atan2(float3(0.0f, 0.0f, 0.0f), float3(0.0f, 0.0f, 0.0f));
}

Просмотреть файл

@ -0,0 +1,10 @@
// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
// CHECK: call void @dx.op.storeOutput.i32(i32 [[outputSigId:[0-9]+]], i32 0, i32 0, i8 0, i32 0)
// CHECK: call void @dx.op.storeOutput.i32(i32 [[outputSigId]], i32 0, i32 0, i8 1, i32 0)
// CHECK: call void @dx.op.storeOutput.i32(i32 [[outputSigId]], i32 0, i32 0, i8 2, i32 0)
bool3 main() : SV_TARGET
{
return isnan(float3(0.0f, 0.0f, 0.0f));
}

Просмотреть файл

@ -188,6 +188,8 @@ public:
TEST_METHOD(CodeGenRootSigProfile2)
TEST_METHOD(CodeGenRootSigProfile5)
TEST_METHOD(CodeGenWaveSize)
TEST_METHOD(CodeGenVectorIsnan)
TEST_METHOD(CodeGenVectorAtan2)
TEST_METHOD(PreprocessWhenValidThenOK)
TEST_METHOD(LibGVStore)
TEST_METHOD(PreprocessWhenExpandTokenPastingOperandThenAccept)
@ -2877,6 +2879,14 @@ TEST_F(CompilerTest, CodeGenWaveSize) {
CodeGenTestCheck(L"attributes_wavesize.hlsl");
}
TEST_F(CompilerTest, CodeGenVectorIsnan) {
CodeGenTestCheck(L"isnan_vector_argument.hlsl");
}
TEST_F(CompilerTest, CodeGenVectorAtan2) {
CodeGenTestCheck(L"atan2_vector_argument.hlsl");
}
TEST_F(CompilerTest, LibGVStore) {
CComPtr<IDxcCompiler> pCompiler;
CComPtr<IDxcOperationResult> pResult;