3232#include "ompi/mca/part/persist_aggregated/part_persist_aggregated_sendreq.h"
3333#include "ompi/mca/part/persist_aggregated/part_persist_aggregated_recvreq.h"
3434
35+ #include "ompi/mca/part/persist_aggregated/schemes/part_persist_aggregated_scheme_regular.h"
36+
3537static int mca_part_persist_aggregated_progress (void );
3638static int mca_part_persist_aggregated_precv_init (void * , size_t , size_t , ompi_datatype_t * , int , int , struct ompi_communicator_t * , struct ompi_info_t * , struct ompi_request_t * * );
3739static int mca_part_persist_aggregated_psend_init (const void * , size_t , size_t , ompi_datatype_t * , int , int , ompi_communicator_t * , struct ompi_info_t * , ompi_request_t * * );
@@ -49,6 +51,73 @@ ompi_part_persist_aggregated_t ompi_part_persist_aggregated = {
4951 }
5052};
5153
54+ /**
55+ * @brief selects an internal partitioning based on the user-provided partitioning
56+ * and the mca parameters for minimal partition size and maximal partition count.
57+ *
58+ * More precisely, given a partitioning into p partitions of size s, computes
59+ * an internal partitioning into p' partitions of size s' (apart from the last one,
60+ * which has potentially different size r * s):
61+ * p * s = (p' - 1) * s' + r * s
62+ * where
63+ * s' >= s
64+ * p' <= p
65+ * 0 < r * s <= s'
66+ * and
67+ * s' <= max_message_count
68+ * p' >= min_message_size
69+ * (given by mca parameters).
70+ *
71+ * @param[in] partitions number of user-provided partitions
72+ * @param[in] count size of user-provided partitions in elements
73+ * @param[out] internal_partitions number of internal partitions
74+ * @param[out] factor number of public partitions corresponding to each internal
75+ * partitions other than the last one
76+ * @param[out] last_size number of public partitions corresponding to the last internal
77+ * partition
78+ */
79+ static inline void
80+ part_persist_aggregated_select_internal_partitioning (size_t partitions ,
81+ size_t part_size ,
82+ size_t * internal_partitions ,
83+ size_t * factor ,
84+ size_t * remainder )
85+ {
86+ size_t buffer_size = partitions * part_size ;
87+ size_t min_part_size = ompi_part_persist_aggregated .min_message_size ;
88+ size_t max_part_count = ompi_part_persist_aggregated .max_message_count ;
89+
90+ // check if max_part_count imposes higher lower bound on partition size
91+ if (max_part_count > 0 && (buffer_size / max_part_count ) > min_part_size ) {
92+ min_part_size = buffer_size / max_part_count ;
93+ }
94+
95+ // cannot have partitions larger than buffer size
96+ if (min_part_size > buffer_size ) {
97+ min_part_size = buffer_size ;
98+ }
99+
100+ if (part_size < min_part_size ) {
101+ // have to use larger partititions
102+ // solve p = (p' - 1) * a + r for a (factor) and r (remainder)
103+ * factor = min_part_size / part_size ;
104+ * internal_partitions = partitions / * factor ;
105+ * remainder = partitions % (* internal_partitions );
106+
107+ if (* remainder == 0 ) { // size of last partition must be set
108+ * remainder = * factor ;
109+ } else {
110+ // number of partitions was floored, so add 1 for last (smaller) partition
111+ * internal_partitions += 1 ;
112+ }
113+ } else {
114+ // can keep original partitioning
115+ * internal_partitions = partitions ;
116+ * factor = 1 ;
117+ * remainder = 1 ;
118+ }
119+ }
120+
52121/**
53122 * This is a helper function that frees a request. This requires ompi_part_persist_aggregated.lock be held before calling.
54123 */
@@ -59,6 +128,12 @@ mca_part_persist_aggregated_free_req(struct mca_part_persist_aggregated_request_
59128 size_t i ;
60129 opal_list_remove_item (ompi_part_persist_aggregated .progress_list , (opal_list_item_t * )req -> progress_elem );
61130 OBJ_RELEASE (req -> progress_elem );
131+
132+ // if on sender side, free aggregation state
133+ if (MCA_PART_PERSIST_AGGREGATED_REQUEST_PSEND == req -> req_type ) {
134+ mca_part_persist_aggregated_psend_request_t * sendreq = (mca_part_persist_aggregated_psend_request_t * ) req ;
135+ part_persist_aggregate_regular_free (& sendreq -> aggregation_state );
136+ }
62137
63138 for (i = 0 ; i < req -> real_parts ; i ++ ) {
64139 ompi_request_free (& (req -> persist_reqs [i ]));
@@ -187,17 +262,21 @@ mca_part_persist_aggregated_progress(void)
187262
188263 /* Set up persistent sends */
189264 req -> persist_reqs = (ompi_request_t * * ) malloc (sizeof (ompi_request_t * )* (req -> real_parts ));
190- for (i = 0 ; i < req -> real_parts ; i ++ ) {
265+ for (i = 0 ; i < req -> real_parts - 1 ; i ++ ) {
191266 void * buf = ((void * ) (((char * )req -> req_addr ) + (bytes * i )));
192267 err = MCA_PML_CALL (isend_init (buf , req -> real_count , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , MCA_PML_BASE_SEND_STANDARD , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
193268 }
269+ // last transfer partition can have different size
270+ void * buf = ((void * ) (((char * )req -> req_addr ) + (bytes * i )));
271+ err = MCA_PML_CALL (isend_init (buf , req -> real_remainder , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , MCA_PML_BASE_SEND_STANDARD , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
194272 } else {
195273 /* parse message */
196- req -> world_peer = req -> setup_info [1 ].world_rank ;
197- req -> my_send_tag = req -> setup_info [1 ].start_tag ;
198- req -> my_recv_tag = req -> setup_info [1 ].setup_tag ;
199- req -> real_parts = req -> setup_info [1 ].num_parts ;
200- req -> real_count = req -> setup_info [1 ].count ;
274+ req -> world_peer = req -> setup_info [1 ].world_rank ;
275+ req -> my_send_tag = req -> setup_info [1 ].start_tag ;
276+ req -> my_recv_tag = req -> setup_info [1 ].setup_tag ;
277+ req -> real_parts = req -> setup_info [1 ].num_parts ;
278+ req -> real_count = req -> setup_info [1 ].count ;
279+ req -> real_remainder = req -> setup_info [1 ].remainder ;
201280
202281 err = opal_datatype_type_size (& (req -> req_datatype -> super ), & dt_size_ );
203282 if (OMPI_SUCCESS != err ) return OMPI_ERROR ;
@@ -207,10 +286,14 @@ mca_part_persist_aggregated_progress(void)
207286 /* Set up persistent sends */
208287 req -> persist_reqs = (ompi_request_t * * ) malloc (sizeof (ompi_request_t * )* (req -> real_parts ));
209288 req -> flags = (int * ) calloc (req -> real_parts ,sizeof (int ));
210- for (i = 0 ; i < req -> real_parts ; i ++ ) {
289+ for (i = 0 ; i < req -> real_parts - 1 ; i ++ ) {
211290 void * buf = ((void * ) (((char * )req -> req_addr ) + (bytes * i )));
212291 err = MCA_PML_CALL (irecv_init (buf , req -> real_count , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
213292 }
293+ // last transfer partition can have different size
294+ void * buf = ((void * ) (((char * )req -> req_addr ) + (bytes * i )));
295+ err = MCA_PML_CALL (irecv_init (buf , req -> real_remainder , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
296+
214297 err = req -> persist_reqs [0 ]-> req_start (req -> real_parts , (& (req -> persist_reqs [0 ])));
215298
216299 /* Send back a message */
@@ -373,19 +456,26 @@ mca_part_persist_aggregated_psend_init(const void* buf,
373456 dt_size = (dt_size_ > (size_t ) INT_MAX ) ? MPI_UNDEFINED : (int ) dt_size_ ;
374457 req -> req_bytes = parts * count * dt_size ;
375458
459+ // select internal partitioning (i.e. real_parts) here
460+ size_t factor , remaining_partitions ;
461+ part_persist_aggregated_select_internal_partitioning (parts , count , & req -> real_parts , & factor , & remaining_partitions );
462+
463+ req -> real_remainder = remaining_partitions * count ; // convert to number of elements
464+ req -> real_count = factor * count ;
465+ req -> setup_info [0 ].num_parts = req -> real_parts ; // setup info has to contain internal partitioning
466+ req -> setup_info [0 ].count = req -> real_count ;
467+ req -> setup_info [0 ].remainder = req -> real_remainder ;
468+ opal_output_verbose (5 , ompi_part_base_framework .framework_output , "mapped given %lu*%lu partitioning to internal partitioning of %lu*%lu + %lu\n" , parts , count , req -> real_parts - 1 , req -> real_count , req -> real_remainder );
376469
470+ // init aggregation state
471+ part_persist_aggregate_regular_init (& sendreq -> aggregation_state , req -> real_parts , factor , remaining_partitions );
377472
378473 /* non-blocking send set-up data */
379474 req -> setup_info [0 ].world_rank = ompi_comm_rank (& ompi_mpi_comm_world .comm );
380475 req -> setup_info [0 ].start_tag = ompi_part_persist_aggregated .next_send_tag ; ompi_part_persist_aggregated .next_send_tag += parts ;
381476 req -> my_send_tag = req -> setup_info [0 ].start_tag ;
382477 req -> setup_info [0 ].setup_tag = ompi_part_persist_aggregated .next_recv_tag ; ompi_part_persist_aggregated .next_recv_tag ++ ;
383478 req -> my_recv_tag = req -> setup_info [0 ].setup_tag ;
384- req -> setup_info [0 ].num_parts = parts ;
385- req -> real_parts = parts ;
386- req -> setup_info [0 ].count = count ;
387- req -> real_count = count ;
388-
389479
390480 req -> flags = (int * ) calloc (req -> real_parts , sizeof (int ));
391481
@@ -428,6 +518,13 @@ mca_part_persist_aggregated_start(size_t count, ompi_request_t** requests)
428518
429519 for (size_t i = 0 ; i < _count && OMPI_SUCCESS == err ; i ++ ) {
430520 mca_part_persist_aggregated_request_t * req = (mca_part_persist_aggregated_request_t * )(requests [i ]);
521+
522+ // reset aggregation state here
523+ if (MCA_PART_PERSIST_AGGREGATED_REQUEST_PSEND == req -> req_type ) {
524+ mca_part_persist_aggregated_psend_request_t * sendreq = (mca_part_persist_aggregated_psend_request_t * )(req );
525+ part_persist_aggregate_regular_reset (& sendreq -> aggregation_state );
526+ }
527+
431528 /* First use is a special case, to support lazy initialization */
432529 if (false == req -> first_send )
433530 {
@@ -470,19 +567,31 @@ mca_part_persist_aggregated_pready(size_t min_part,
470567 size_t i ;
471568
472569 mca_part_persist_aggregated_request_t * req = (mca_part_persist_aggregated_request_t * )(request );
570+ int flag_value ;
473571 if (true == req -> initialized )
474572 {
475- err = req -> persist_reqs [min_part ]-> req_start (max_part - min_part + 1 , (& (req -> persist_reqs [min_part ])));
476- for (i = min_part ; i <= max_part && OMPI_SUCCESS == err ; i ++ ) {
477- req -> flags [i ] = 0 ; /* Mark partition as ready for testing */
478- }
573+ flag_value = 0 ; /* Mark partition as ready for testing */
479574 }
480575 else
481576 {
482- for (i = min_part ; i <= max_part && OMPI_SUCCESS == err ; i ++ ) {
483- req -> flags [i ] = -2 ; /* Mark partition as queued */
577+ flag_value = -2 ; /* Mark partition as queued */
578+ }
579+
580+ mca_part_persist_aggregated_psend_request_t * sendreq = (mca_part_persist_aggregated_psend_request_t * )(request );
581+ int internal_part_ready ;
582+ for (i = min_part ; i <= max_part && OMPI_SUCCESS == err ; i ++ ) {
583+ part_persist_aggregate_regular_pready (& sendreq -> aggregation_state , i , & internal_part_ready );
584+
585+ if (-1 != internal_part_ready ) {
586+ // transfer partition is ready
587+ if (true == req -> initialized ) {
588+ err = req -> persist_reqs [internal_part_ready ]-> req_start (1 , (& (req -> persist_reqs [internal_part_ready ])));
589+ }
590+
591+ req -> flags [internal_part_ready ] = flag_value ;
484592 }
485593 }
594+
486595 return err ;
487596}
488597
0 commit comments