128128import org .hibernate .query .sqm .mutation .internal .SqmInsertStrategyHelper ;
129129import org .hibernate .query .sqm .produce .function .internal .PatternRenderer ;
130130import org .hibernate .query .sqm .spi .BaseSemanticQueryWalker ;
131+ import org .hibernate .query .sqm .spi .SqmCreationHelper ;
131132import org .hibernate .query .sqm .sql .internal .AnyDiscriminatorPathInterpretation ;
132133import org .hibernate .query .sqm .sql .internal .AsWrappedExpression ;
133134import org .hibernate .query .sqm .sql .internal .BasicValuedPathInterpretation ;
227228import org .hibernate .query .sqm .tree .from .SqmFrom ;
228229import org .hibernate .query .sqm .tree .from .SqmFromClause ;
229230import org .hibernate .query .sqm .tree .from .SqmJoin ;
231+ import org .hibernate .query .sqm .tree .from .SqmQualifiedJoin ;
230232import org .hibernate .query .sqm .tree .from .SqmRoot ;
231233import org .hibernate .query .sqm .tree .insert .SqmConflictClause ;
232234import org .hibernate .query .sqm .tree .insert .SqmConflictUpdateAction ;
385387import org .hibernate .sql .results .graph .FetchParent ;
386388import org .hibernate .sql .results .graph .Fetchable ;
387389import org .hibernate .sql .results .graph .FetchableContainer ;
388- import org .hibernate .sql .results .graph .collection .internal .EagerCollectionFetch ;
389390import org .hibernate .sql .results .graph .entity .EntityResultGraphNode ;
390391import org .hibernate .sql .results .graph .instantiation .internal .DynamicInstantiation ;
391392import org .hibernate .sql .results .graph .internal .ImmutableFetchList ;
@@ -2684,7 +2685,12 @@ protected void consumeFromClauseCorrelatedRoot(SqmRoot<?> sqmRoot) {
26842685 // as roots anyway, so nothing to worry about
26852686
26862687 log .tracef ( "Resolved SqmRoot [%s] to correlated TableGroup [%s]" , sqmRoot , tableGroup );
2687- consumeExplicitJoins ( from , tableGroup );
2688+ if ( from instanceof SqmRoot <?> ) {
2689+ consumeJoins ( (SqmRoot <?>) from , fromClauseIndex , tableGroup );
2690+ }
2691+ else {
2692+ consumeExplicitJoins ( from , tableGroup );
2693+ }
26882694 return ;
26892695 }
26902696 else {
@@ -3347,6 +3353,39 @@ private TableGroup consumeAttributeJoin(
33473353 SqmMappingModelHelper .resolveExplicitTreatTarget ( sqmJoin , this )
33483354 );
33493355
3356+ final List <SqmFrom <?, ?>> sqmTreats = sqmJoin .getSqmTreats ();
3357+ final SqmPredicate joinPredicate ;
3358+ final SqmPredicate [] treatPredicates ;
3359+ final boolean hasPredicate ;
3360+ if ( !sqmTreats .isEmpty () ) {
3361+ if ( sqmTreats .size () == 1 ) {
3362+ // If there is only a single treat, combine the predicates just as they are
3363+ joinPredicate = SqmCreationHelper .combinePredicates (
3364+ sqmJoin .getJoinPredicate (),
3365+ ( (SqmQualifiedJoin <?, ?>) sqmTreats .get ( 0 ) ).getJoinPredicate ()
3366+ );
3367+ treatPredicates = null ;
3368+ hasPredicate = joinPredicate != null ;
3369+ }
3370+ else {
3371+ // When there are multiple predicates, we have to apply type filters
3372+ joinPredicate = sqmJoin .getJoinPredicate ();
3373+ treatPredicates = new SqmPredicate [sqmTreats .size ()];
3374+ boolean hasTreatPredicate = false ;
3375+ for ( int i = 0 ; i < sqmTreats .size (); i ++ ) {
3376+ final var p = ( (SqmQualifiedJoin <?, ?>) sqmTreats .get ( i ) ).getJoinPredicate ();
3377+ treatPredicates [i ] = p ;
3378+ hasTreatPredicate = hasTreatPredicate || p != null ;
3379+ }
3380+ hasPredicate = joinPredicate != null || hasTreatPredicate ;
3381+ }
3382+ }
3383+ else {
3384+ joinPredicate = sqmJoin .getJoinPredicate ();
3385+ treatPredicates = null ;
3386+ hasPredicate = joinPredicate != null ;
3387+ }
3388+
33503389 if ( pathSource instanceof PluralPersistentAttribute ) {
33513390 assert modelPart instanceof PluralAttributeMapping ;
33523391
@@ -3363,7 +3402,7 @@ private TableGroup consumeAttributeJoin(
33633402 null ,
33643403 sqmJoinType .getCorrespondingSqlJoinType (),
33653404 sqmJoin .isFetched (),
3366- sqmJoin . getJoinPredicate () != null ,
3405+ hasPredicate ,
33673406 this
33683407 );
33693408
@@ -3379,7 +3418,7 @@ private TableGroup consumeAttributeJoin(
33793418 null ,
33803419 sqmJoinType .getCorrespondingSqlJoinType (),
33813420 sqmJoin .isFetched (),
3382- sqmJoin . getJoinPredicate () != null ,
3421+ hasPredicate ,
33833422 this
33843423 );
33853424
@@ -3388,7 +3427,7 @@ private TableGroup consumeAttributeJoin(
33883427 // Since this is an explicit join, we force the initialization of a possible lazy table group
33893428 // to retain the cardinality, but only if this is a non-trivial attribute join.
33903429 // Left or inner singular attribute joins without a predicate can be safely optimized away
3391- if ( sqmJoin . getJoinPredicate () != null || sqmJoinType != SqmJoinType .INNER && sqmJoinType != SqmJoinType .LEFT ) {
3430+ if ( hasPredicate || sqmJoinType != SqmJoinType .INNER && sqmJoinType != SqmJoinType .LEFT ) {
33923431 joinedTableGroup .getPrimaryTableReference ();
33933432 }
33943433 }
@@ -3425,14 +3464,26 @@ private TableGroup consumeAttributeJoin(
34253464 final TableGroupJoin joinForPredicate ;
34263465
34273466 // add any additional join restrictions
3428- if ( sqmJoin . getJoinPredicate () != null ) {
3467+ if ( hasPredicate ) {
34293468 if ( sqmJoin .isFetched () ) {
34303469 QueryLogging .QUERY_MESSAGE_LOGGER .debugf ( "Join fetch [%s] is restricted" , sqmJoinNavigablePath );
34313470 }
34323471
34333472 final SqmJoin <?, ?> oldJoin = currentlyProcessingJoin ;
34343473 currentlyProcessingJoin = sqmJoin ;
3435- final Predicate predicate = visitNestedTopLevelPredicate ( sqmJoin .getJoinPredicate () );
3474+ Predicate predicate = joinPredicate == null ? null : visitNestedTopLevelPredicate ( joinPredicate );
3475+ if ( treatPredicates != null ) {
3476+ final Junction orPredicate = new Junction ( Junction .Nature .DISJUNCTION );
3477+ for ( int i = 0 ; i < treatPredicates .length ; i ++ ) {
3478+ final EntityDomainType <?> treatType =
3479+ (EntityDomainType <?>) ( (SqmTreatedPath <?, ?>) sqmTreats .get ( i ) ).getTreatTarget ();
3480+ orPredicate .add ( combinePredicates (
3481+ createTreatTypeRestriction ( sqmJoin , treatType ),
3482+ treatPredicates [i ] == null ? null : visitNestedTopLevelPredicate ( treatPredicates [i ] )
3483+ ) );
3484+ }
3485+ predicate = predicate != null ? combinePredicates ( predicate , orPredicate ) : orPredicate ;
3486+ }
34363487 joinForPredicate = TableGroupJoinHelper .determineJoinForPredicateApply ( joinedTableGroupJoin );
34373488 // If translating the join predicate didn't initialize the table group,
34383489 // we can safely apply it on the collection table group instead
0 commit comments