unroll(n) can now override loop bound. unroll(negative) now fails correctly (#2241)

This commit is contained in:
Adam Yang 2019-06-05 13:34:17 -07:00 коммит произвёл GitHub
Родитель 213de5049f
Коммит dc6203ad5b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 62 добавлений и 10 удалений

Просмотреть файл

@ -203,14 +203,13 @@ static bool IsMarkedFullUnroll(Loop *L) {
return false;
}
static bool IsMarkedUnrollCount(Loop *L, unsigned *OutCount) {
static bool IsMarkedUnrollCount(Loop *L, int *OutCount) {
if (MDNode *LoopID = L->getLoopID()) {
if (MDNode *MD = GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
assert(MD->getNumOperands() == 2 &&
"Unroll count hint metadata should have two operands.");
unsigned Count =
mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
assert(Count >= 1 && "Unroll count must be positive.");
ConstantInt *Val = mdconst::extract<ConstantInt>(MD->getOperand(1));
int Count = Val->getZExtValue();
*OutCount = Count;
return true;
}
@ -683,22 +682,32 @@ static void RecursivelyRemoveLoopFromQueue(LPPassManager &LPM, Loop *L) {
bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
Function *F = L->getHeader()->getParent();
bool HasExplicitLoopCount = false;
unsigned UnrollCount = 0;
int ExplicitUnrollCountSigned = 0;
// If the loop is not marked as [unroll], don't do anything.
if (IsMarkedUnrollCount(L, &UnrollCount)) {
if (IsMarkedUnrollCount(L, &ExplicitUnrollCountSigned)) {
HasExplicitLoopCount = true;
}
else if (!IsMarkedFullUnroll(L)) {
return false;
}
unsigned ExplicitUnrollCount = 0;
if (HasExplicitLoopCount) {
if (ExplicitUnrollCountSigned < 1) {
FailLoopUnroll(false, F->getContext(), LoopLoc, "Could not unroll loop. Invalid unroll count.");
return false;
}
ExplicitUnrollCount = (unsigned)ExplicitUnrollCountSigned;
}
if (!L->isSafeToClone())
return false;
DebugLoc LoopLoc = L->getStartLoc(); // Debug location for the start of the loop.
Function *F = L->getHeader()->getParent();
bool FxcCompatMode = false;
if (F->getParent()->HasHLModule()) {
HLModule &HM = F->getParent()->GetHLModule();
@ -830,6 +839,9 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
SmallVector<std::unique_ptr<LoopIteration>, 16> Iterations; // List of cloned iterations
bool Succeeded = false;
if (HasExplicitLoopCount) {
this->MaxIterationAttempt = std::max(this->MaxIterationAttempt, ExplicitUnrollCount);
}
for (unsigned IterationI = 0; IterationI < this->MaxIterationAttempt; IterationI++) {
LoopIteration *PrevIteration = nullptr;
@ -945,7 +957,7 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
}
// We've reached the N defined in [unroll(N)]
if (HasExplicitLoopCount && IterationI+1 >= UnrollCount) {
if (HasExplicitLoopCount && IterationI+1 >= ExplicitUnrollCount) {
Succeeded = true;
BranchInst *BI = cast<BranchInst>(CurIteration.Latch->getTerminator());
@ -1049,7 +1061,9 @@ bool DxilLoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
// If we were unsuccessful in unrolling the loop
else {
FailLoopUnroll(FxcCompatMode /*warn only*/, F->getContext(), LoopLoc, "Could not unroll loop.");
FailLoopUnroll(FxcCompatMode /*warn only*/, F->getContext(), LoopLoc,
"Could not unroll loop. Loop bound could not be deduced at compile time. "
"To give an explicit unroll bound, use unroll(n).");
// Remove all the cloned blocks
for (std::unique_ptr<LoopIteration> &Ptr : Iterations) {

Просмотреть файл

@ -0,0 +1,20 @@
// RUN: %dxc -Od -E main -T ps_6_0 %s | FileCheck %s
// CHECK: @main
// Confirm that the 128 limit on loop unroll can be overritten by an explicit
// loop count
[RootSignature("")]
float main(float y : Y) : SV_Target {
float x = 0;
static const uint kLoopCount = 512;
[unroll(kLoopCount)]
for (uint i = 0; i < kLoopCount; ++i)
{
x = x * x + y;
}
return x;
}

Просмотреть файл

@ -0,0 +1,18 @@
// RUN: %dxc -Od -E main -T ps_6_0 %s | FileCheck %s
// CHECK: Could not unroll loop
// CHECK: To give an explicit unroll bound, use unroll(n)
// CHECK-NOT: @main
[RootSignature("")]
float main(float y : Y) : SV_Target {
float x = 0;
static const uint kLoopCount = 512;
[unroll]
for (uint i = 0; i < kLoopCount; ++i)
{
x = x * x + y;
}
return x;
}