diff options
Diffstat (limited to 'llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp index 52eecb000d0c..a7a22e042aef 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp @@ -628,3 +628,163 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond, Term->addMetadata(LLVMContext::MD_prof, BranchWeights); } } + +bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { + auto GetMinMaxCompareValue = [](VPReductionPHIRecipe *RedPhiR) -> VPValue * { + auto *MinMaxR = dyn_cast<VPRecipeWithIRFlags>( + RedPhiR->getBackedgeValue()->getDefiningRecipe()); + if (!MinMaxR) + return nullptr; + + auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxR); + if (!isa<VPWidenIntrinsicRecipe>(MinMaxR) && + !(RepR && isa<IntrinsicInst>(RepR->getUnderlyingInstr()))) + return nullptr; + +#ifndef NDEBUG + Intrinsic::ID RdxIntrinsicId = + RedPhiR->getRecurrenceKind() == RecurKind::FMaxNum ? Intrinsic::maxnum + : Intrinsic::minnum; + assert((isa<VPWidenIntrinsicRecipe>(MinMaxR) && + cast<VPWidenIntrinsicRecipe>(MinMaxR)->getVectorIntrinsicID() == + RdxIntrinsicId) || + (RepR && + cast<IntrinsicInst>(RepR->getUnderlyingInstr())->getIntrinsicID() == + RdxIntrinsicId) && + "Intrinsic did not match recurrence kind"); +#endif + + if (MinMaxR->getOperand(0) == RedPhiR) + return MinMaxR->getOperand(1); + + assert(MinMaxR->getOperand(1) == RedPhiR && + "Reduction phi operand expected"); + return MinMaxR->getOperand(0); + }; + + VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); + VPReductionPHIRecipe *RedPhiR = nullptr; + bool HasUnsupportedPhi = false; + for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) { + if (isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe>(&R)) + continue; + auto *Cur = dyn_cast<VPReductionPHIRecipe>(&R); + if (!Cur) { + // TODO: Also support fixed-order recurrence phis. + HasUnsupportedPhi = true; + continue; + } + // For now, only a single reduction is supported. + // TODO: Support multiple MaxNum/MinNum reductions and other reductions. + if (RedPhiR) + return false; + if (Cur->getRecurrenceKind() != RecurKind::FMaxNum && + Cur->getRecurrenceKind() != RecurKind::FMinNum) { + HasUnsupportedPhi = true; + continue; + } + RedPhiR = Cur; + } + + if (!RedPhiR) + return true; + + // We won't be able to resume execution in the scalar tail, if there are + // unsupported header phis or there is no scalar tail at all, due to + // tail-folding. + if (HasUnsupportedPhi || !Plan.hasScalarTail()) + return false; + + VPValue *MinMaxOp = GetMinMaxCompareValue(RedPhiR); + if (!MinMaxOp) + return false; + + RecurKind RedPhiRK = RedPhiR->getRecurrenceKind(); + assert((RedPhiRK == RecurKind::FMaxNum || RedPhiRK == RecurKind::FMinNum) && + "unsupported reduction"); + + /// Check if the vector loop of \p Plan can early exit and restart + /// execution of last vector iteration in the scalar loop. This requires all + /// recipes up to early exit point be side-effect free as they are + /// re-executed. Currently we check that the loop is free of any recipe that + /// may write to memory. Expected to operate on an early VPlan w/o nested + /// regions. + for (VPBlockBase *VPB : vp_depth_first_shallow( + Plan.getVectorLoopRegion()->getEntryBasicBlock())) { + auto *VPBB = cast<VPBasicBlock>(VPB); + for (auto &R : *VPBB) { + if (R.mayWriteToMemory() && + !match(&R, m_BranchOnCount(m_VPValue(), m_VPValue()))) + return false; + } + } + + VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock(); + VPBuilder Builder(LatchVPBB->getTerminator()); + auto *LatchExitingBranch = cast<VPInstruction>(LatchVPBB->getTerminator()); + assert(LatchExitingBranch->getOpcode() == VPInstruction::BranchOnCount && + "Unexpected terminator"); + auto *IsLatchExitTaken = + Builder.createICmp(CmpInst::ICMP_EQ, LatchExitingBranch->getOperand(0), + LatchExitingBranch->getOperand(1)); + + VPValue *IsNaN = Builder.createFCmp(CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp); + VPValue *AnyNaN = Builder.createNaryOp(VPInstruction::AnyOf, {IsNaN}); + auto *AnyExitTaken = + Builder.createNaryOp(Instruction::Or, {AnyNaN, IsLatchExitTaken}); + Builder.createNaryOp(VPInstruction::BranchOnCond, AnyExitTaken); + LatchExitingBranch->eraseFromParent(); + + // If we exit early due to NaNs, compute the final reduction result based on + // the reduction phi at the beginning of the last vector iteration. + auto *RdxResult = find_singleton<VPSingleDefRecipe>( + RedPhiR->users(), [](VPUser *U, bool) -> VPSingleDefRecipe * { + auto *VPI = dyn_cast<VPInstruction>(U); + if (VPI && VPI->getOpcode() == VPInstruction::ComputeReductionResult) + return VPI; + return nullptr; + }); + + auto *MiddleVPBB = Plan.getMiddleBlock(); + Builder.setInsertPoint(MiddleVPBB, MiddleVPBB->begin()); + auto *NewSel = + Builder.createSelect(AnyNaN, RedPhiR, RdxResult->getOperand(1)); + RdxResult->setOperand(1, NewSel); + + auto *ScalarPH = Plan.getScalarPreheader(); + // Update resume phis for inductions in the scalar preheader. If AnyNaN is + // true, the resume from the start of the last vector iteration via the + // canonical IV, otherwise from the original value. + for (auto &R : ScalarPH->phis()) { + auto *ResumeR = cast<VPPhi>(&R); + VPValue *VecV = ResumeR->getOperand(0); + if (VecV == RdxResult) + continue; + if (auto *DerivedIV = dyn_cast<VPDerivedIVRecipe>(VecV)) { + if (DerivedIV->getNumUsers() == 1 && + DerivedIV->getOperand(1) == &Plan.getVectorTripCount()) { + auto *NewSel = Builder.createSelect(AnyNaN, Plan.getCanonicalIV(), + &Plan.getVectorTripCount()); + DerivedIV->moveAfter(&*Builder.getInsertPoint()); + DerivedIV->setOperand(1, NewSel); + continue; + } + } + // Bail out and abandon the current, partially modified, VPlan if we + // encounter resume phi that cannot be updated yet. + if (VecV != &Plan.getVectorTripCount()) { + LLVM_DEBUG(dbgs() << "Found resume phi we cannot update for VPlan with " + "FMaxNum/FMinNum reduction.\n"); + return false; + } + auto *NewSel = Builder.createSelect(AnyNaN, Plan.getCanonicalIV(), VecV); + ResumeR->setOperand(0, NewSel); + } + + auto *MiddleTerm = MiddleVPBB->getTerminator(); + Builder.setInsertPoint(MiddleTerm); + VPValue *MiddleCond = MiddleTerm->getOperand(0); + VPValue *NewCond = Builder.createAnd(MiddleCond, Builder.createNot(AnyNaN)); + MiddleTerm->setOperand(0, NewCond); + return true; +} |
