1616
1717package org .springframework .batch .integration .partition ;
1818
19- import java .util .Arrays ;
2019import java .util .Collection ;
2120import java .util .Collections ;
2221import java .util .HashSet ;
22+ import java .util .Set ;
2323import java .util .concurrent .TimeoutException ;
24+ import java .util .stream .Collectors ;
2425
2526import org .junit .jupiter .api .Test ;
2627
@@ -175,12 +176,11 @@ void testHandleWithJobRepositoryPolling() throws Exception {
175176 stepExecutions .add (partition2 );
176177 stepExecutions .add (partition3 );
177178 when (stepExecutionSplitter .split (any (StepExecution .class ), eq (1 ))).thenReturn (stepExecutions );
178- JobExecution runningJobExecution = new JobExecution (5L , new JobParameters ());
179- runningJobExecution .addStepExecutions (Arrays .asList (partition2 , partition1 , partition3 ));
180- JobExecution completedJobExecution = new JobExecution (5L , new JobParameters ());
181- completedJobExecution .addStepExecutions (Arrays .asList (partition2 , partition1 , partition4 ));
182- when (jobExplorer .getJobExecution (5L )).thenReturn (runningJobExecution , runningJobExecution , runningJobExecution ,
183- completedJobExecution );
179+ Set <Long > stepExecutionIds = stepExecutions .stream ().map (StepExecution ::getId ).collect (Collectors .toSet ());
180+ when (jobExplorer .getStepExecutionCount (stepExecutionIds , BatchStatus .RUNNING_STATUSES )).thenReturn (3L , 2L , 1L ,
181+ 0L );
182+ Set <StepExecution > completedStepExecutions = Set .of (partition2 , partition1 , partition4 );
183+ when (jobExplorer .getStepExecutions (jobExecution .getId (), stepExecutionIds )).thenReturn (completedStepExecutions );
184184
185185 // set
186186 messageChannelPartitionHandler .setMessagingOperations (operations );
@@ -200,6 +200,8 @@ void testHandleWithJobRepositoryPolling() throws Exception {
200200 assertTrue (executions .contains (partition4 ));
201201
202202 // verify
203+ verify (jobExplorer , times (4 )).getStepExecutionCount (stepExecutionIds , BatchStatus .RUNNING_STATUSES );
204+ verify (jobExplorer , times (1 )).getStepExecutions (jobExecution .getId (), stepExecutionIds );
203205 verify (operations , times (3 )).send (any (Message .class ));
204206 }
205207
@@ -225,9 +227,8 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception {
225227 stepExecutions .add (partition2 );
226228 stepExecutions .add (partition3 );
227229 when (stepExecutionSplitter .split (any (StepExecution .class ), eq (1 ))).thenReturn (stepExecutions );
228- JobExecution runningJobExecution = new JobExecution (5L , new JobParameters ());
229- runningJobExecution .addStepExecutions (Arrays .asList (partition2 , partition1 , partition3 ));
230- when (jobExplorer .getJobExecution (5L )).thenReturn (runningJobExecution );
230+ Set <Long > stepExecutionIds = stepExecutions .stream ().map (StepExecution ::getId ).collect (Collectors .toSet ());
231+ when (jobExplorer .getStepExecutionCount (stepExecutionIds , BatchStatus .RUNNING_STATUSES )).thenReturn (1L );
231232
232233 // set
233234 messageChannelPartitionHandler .setMessagingOperations (operations );
0 commit comments