@@ -1347,27 +1347,46 @@ class LoopVectorizationCostModel {
13471347 return InterleaveInfo.getInterleaveGroup (Instr);
13481348 }
13491349
1350+ // / Calculate in advance whether a scalar epilogue is required when
1351+ // / vectorizing and not vectorizing. If \p Invalidate is true then
1352+ // / invalidate a previous decision.
1353+ void collectScalarEpilogueRequirements (bool Invalidate) {
1354+ auto NeedsScalarEpilogue = [&](bool IsVectorizing) -> bool {
1355+ if (!isScalarEpilogueAllowed ()) {
1356+ LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue" );
1357+ return false ;
1358+ }
1359+ // If we might exit from anywhere but the latch, must run the exiting
1360+ // iteration in scalar form.
1361+ if (TheLoop->getExitingBlock () != TheLoop->getLoopLatch ()) {
1362+ LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: not exiting "
1363+ " from latch block\n " );
1364+ return true ;
1365+ }
1366+ if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue ()) {
1367+ LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: "
1368+ " interleaved group requires scalar epilogue" );
1369+ return true ;
1370+ }
1371+ LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue" );
1372+ return false ;
1373+ };
1374+
1375+ assert ((Invalidate || !RequiresScalarEpilogue) &&
1376+ " Already determined scalar epilogue requirements!" );
1377+ std::pair<bool , bool > Result;
1378+ Result.first = NeedsScalarEpilogue (true );
1379+ LLVM_DEBUG (dbgs () << " , when vectorizing\n " );
1380+ Result.second = NeedsScalarEpilogue (false );
1381+ LLVM_DEBUG (dbgs () << " , when not vectorizing\n " );
1382+ RequiresScalarEpilogue = Result;
1383+ }
1384+
13501385 // / Returns true if we're required to use a scalar epilogue for at least
13511386 // / the final iteration of the original loop.
13521387 bool requiresScalarEpilogue (bool IsVectorizing) const {
1353- if (!isScalarEpilogueAllowed ()) {
1354- LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue\n " );
1355- return false ;
1356- }
1357- // If we might exit from anywhere but the latch, must run the exiting
1358- // iteration in scalar form.
1359- if (TheLoop->getExitingBlock () != TheLoop->getLoopLatch ()) {
1360- LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: not exiting "
1361- " from latch block\n " );
1362- return true ;
1363- }
1364- if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue ()) {
1365- LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: "
1366- " interleaved group requires scalar epilogue\n " );
1367- return true ;
1368- }
1369- LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue\n " );
1370- return false ;
1388+ auto &CachedResult = *RequiresScalarEpilogue;
1389+ return IsVectorizing ? CachedResult.first : CachedResult.second ;
13711390 }
13721391
13731392 // / Returns true if we're required to use a scalar epilogue for at least
@@ -1391,6 +1410,15 @@ class LoopVectorizationCostModel {
13911410 return ScalarEpilogueStatus == CM_ScalarEpilogueAllowed;
13921411 }
13931412
1413+ // / Update the ScalarEpilogueStatus to a new value, potentially triggering a
1414+ // / recalculation of the scalar epilogue requirements.
1415+ void setScalarEpilogueStatus (ScalarEpilogueLowering Status) {
1416+ bool Changed = ScalarEpilogueStatus != Status;
1417+ ScalarEpilogueStatus = Status;
1418+ if (Changed)
1419+ collectScalarEpilogueRequirements (/* Invalidate=*/ true );
1420+ }
1421+
13941422 // / Returns the TailFoldingStyle that is best for the current loop.
13951423 TailFoldingStyle getTailFoldingStyle (bool IVUpdateMayOverflow = true ) const {
13961424 if (!ChosenTailFoldingStyle)
@@ -1771,6 +1799,9 @@ class LoopVectorizationCostModel {
17711799
17721800 // / All element types found in the loop.
17731801 SmallPtrSet<Type *, 16 > ElementTypesInLoop;
1802+
1803+ // / Keeps track of whether we require a scalar epilogue.
1804+ std::optional<std::pair<bool , bool >> RequiresScalarEpilogue;
17741805};
17751806} // end namespace llvm
17761807
@@ -4058,7 +4089,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
40584089 if (ScalarEpilogueStatus == CM_ScalarEpilogueNotNeededUsePredicate) {
40594090 LLVM_DEBUG (dbgs () << " LV: Cannot fold tail by masking: vectorize with a "
40604091 " scalar epilogue instead.\n " );
4061- ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
4092+ setScalarEpilogueStatus ( CM_ScalarEpilogueAllowed) ;
40624093 return computeFeasibleMaxVF (MaxTC, UserVF, false );
40634094 }
40644095 return FixedScalableVFPair::getNone ();
@@ -4074,6 +4105,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
40744105 // Note: There is no need to invalidate any cost modeling decisions here, as
40754106 // none were taken so far.
40764107 InterleaveInfo.invalidateGroupsRequiringScalarEpilogue ();
4108+ collectScalarEpilogueRequirements (/* Invalidate=*/ true );
40774109 }
40784110
40794111 FixedScalableVFPair MaxFactors = computeFeasibleMaxVF (MaxTC, UserVF, true );
@@ -4145,7 +4177,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
41454177 if (ScalarEpilogueStatus == CM_ScalarEpilogueNotNeededUsePredicate) {
41464178 LLVM_DEBUG (dbgs () << " LV: Cannot fold tail by masking: vectorize with a "
41474179 " scalar epilogue instead.\n " );
4148- ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
4180+ setScalarEpilogueStatus ( CM_ScalarEpilogueAllowed) ;
41494181 return MaxFactors;
41504182 }
41514183
@@ -7058,6 +7090,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
70587090 if (!OrigLoop->isInnermost ()) {
70597091 // If the user doesn't provide a vectorization factor, determine a
70607092 // reasonable one.
7093+ CM.collectScalarEpilogueRequirements (/* Invalidate=*/ false );
70617094 if (UserVF.isZero ()) {
70627095 VF = determineVPlanVF (TTI, CM);
70637096 LLVM_DEBUG (dbgs () << " LV: VPlan computed VF " << VF << " .\n " );
@@ -7102,6 +7135,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
71027135
71037136void LoopVectorizationPlanner::plan (ElementCount UserVF, unsigned UserIC) {
71047137 assert (OrigLoop->isInnermost () && " Inner loop expected." );
7138+ CM.collectScalarEpilogueRequirements (/* Invalidate=*/ false );
71057139 CM.collectValuesToIgnore ();
71067140 CM.collectElementTypesForWidening ();
71077141
@@ -7116,11 +7150,13 @@ void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
71167150 dbgs ()
71177151 << " LV: Invalidate all interleaved groups due to fold-tail by masking "
71187152 " which requires masked-interleaved support.\n " );
7119- if (CM.InterleaveInfo .invalidateGroups ())
7153+ if (CM.InterleaveInfo .invalidateGroups ()) {
71207154 // Invalidating interleave groups also requires invalidating all decisions
71217155 // based on them, which includes widening decisions and uniform and scalar
71227156 // values.
71237157 CM.invalidateCostModelingDecisions ();
7158+ CM.collectScalarEpilogueRequirements (/* Invalidate=*/ true );
7159+ }
71247160 }
71257161
71267162 if (CM.foldTailByMasking ())
0 commit comments