Skip to content

Commit 3312228

Browse files
jtrongehppritcha
andcommitted
Update coll framework count/disp arrays for bigcount
This updates the coll framework functions using count/displacement arrays to support bigcount. Instead of directly using pointers to bigcount/non-bigcount type arrays, this adds special descriptor types, ompi_count_array_t and ompi_disp_array_t, that must be used through inline accessor functions to ensure use of the correct type, whether bigcount or non-bigcount. Internally, these descriptors are typedefs of intptr_t and hold both a pointer and flag indicating the pointer type. Co-authored-by: Howard Pritchard <[email protected]> Signed-off-by: Jake Tronge <[email protected]>
1 parent ff12b69 commit 3312228

File tree

116 files changed

+1828
-1341
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+1828
-1341
lines changed

ompi/communicator/comm.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2402,6 +2402,8 @@ int ompi_comm_determine_first ( ompi_communicator_t *intercomm, int high )
24022402
int rank, rsize;
24032403
int *rcounts;
24042404
int *rdisps;
2405+
ompi_count_array_t rcounts_desc;
2406+
ompi_disp_array_t rdisps_desc;
24052407
int scount=0;
24062408
int rc;
24072409

@@ -2429,8 +2431,10 @@ int ompi_comm_determine_first ( ompi_communicator_t *intercomm, int high )
24292431
scount = 1;
24302432
}
24312433

2434+
OMPI_COUNT_ARRAY_INIT(&rcounts_desc, rcounts);
2435+
OMPI_DISP_ARRAY_INIT(&rdisps_desc, rdisps);
24322436
rc = intercomm->c_coll->coll_allgatherv(&high, scount, MPI_INT,
2433-
&rhigh, rcounts, rdisps,
2437+
&rhigh, rcounts_desc, rdisps_desc,
24342438
MPI_INT, intercomm,
24352439
intercomm->c_coll->coll_allgatherv_module);
24362440
if ( NULL != rdisps ) {

ompi/mca/coll/base/coll_base_allgatherv.c

Lines changed: 94 additions & 59 deletions
Large diffs are not rendered by default.

ompi/mca/coll/base/coll_base_alltoallv.c

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
* and count) to send the data to the other.
5151
*/
5252
int
53-
mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts, const int *rdisps,
53+
mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
5454
struct ompi_datatype_t *rdtype,
5555
struct ompi_communicator_t *comm,
5656
mca_coll_base_module_t *module)
@@ -72,7 +72,7 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
7272
if (i == rank) {
7373
continue;
7474
}
75-
packed_size = rcounts[i] * type_size;
75+
packed_size = ompi_count_array_get(rcounts, i) * type_size;
7676
max_size = opal_max(packed_size, max_size);
7777
}
7878

@@ -111,11 +111,11 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
111111
right = (rank + i) % size;
112112
left = (rank + size - i) % size;
113113

114-
if( 0 != rcounts[right] ) { /* nothing to exchange with the peer on the right */
114+
if( 0 != ompi_count_array_get(rcounts, right) ) { /* nothing to exchange with the peer on the right */
115115
ompi_proc_t *right_proc = ompi_comm_peer_lookup(comm, right);
116116
opal_convertor_clone(right_proc->super.proc_convertor, &convertor, 0);
117-
opal_convertor_prepare_for_send(&convertor, &rdtype->super, rcounts[right],
118-
(char *) rbuf + rdisps[right] * extent);
117+
opal_convertor_prepare_for_send(&convertor, &rdtype->super, ompi_count_array_get(rcounts, right),
118+
(char *) rbuf + ompi_disp_array_get(rdisps, right) * extent);
119119
packed_size = max_size;
120120
err = opal_convertor_pack(&convertor, &iov, &iov_count, &packed_size);
121121
if (1 != err) {
@@ -124,17 +124,19 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
124124
}
125125

126126
/* Receive data from the right */
127-
err = MCA_PML_CALL(irecv ((char *) rbuf + rdisps[right] * extent, rcounts[right], rdtype,
127+
err = MCA_PML_CALL(irecv ((char *) rbuf + ompi_disp_array_get(rdisps, right) * extent,
128+
ompi_count_array_get(rcounts, right), rdtype,
128129
right, MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
129130
if (MPI_SUCCESS != err) {
130131
line = __LINE__;
131132
goto error_hndl;
132133
}
133134
}
134135

135-
if( (left != right) && (0 != rcounts[left]) ) {
136+
if( (left != right) && (0 != ompi_count_array_get(rcounts, left)) ) {
136137
/* Send data to the left */
137-
err = MCA_PML_CALL(send ((char *) rbuf + rdisps[left] * extent, rcounts[left], rdtype,
138+
err = MCA_PML_CALL(send ((char *) rbuf + ompi_disp_array_get(rdisps, left) * extent,
139+
ompi_count_array_get(rcounts, left), rdtype,
138140
left, MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD,
139141
comm));
140142
if (MPI_SUCCESS != err) {
@@ -149,15 +151,16 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
149151
}
150152

151153
/* Receive data from the left */
152-
err = MCA_PML_CALL(irecv ((char *) rbuf + rdisps[left] * extent, rcounts[left], rdtype,
154+
err = MCA_PML_CALL(irecv ((char *) rbuf + ompi_disp_array_get(rdisps, left) * extent,
155+
ompi_count_array_get(rcounts, left), rdtype,
153156
left, MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
154157
if (MPI_SUCCESS != err) {
155158
line = __LINE__;
156159
goto error_hndl;
157160
}
158161
}
159162

160-
if( 0 != rcounts[right] ) { /* nothing to exchange with the peer on the right */
163+
if( 0 != ompi_count_array_get(rcounts, right) ) { /* nothing to exchange with the peer on the right */
161164
/* Send data to the right */
162165
err = MCA_PML_CALL(send ((char *) tmp_buffer, packed_size, MPI_PACKED,
163166
right, MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD,
@@ -191,9 +194,9 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
191194
}
192195

193196
int
194-
ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, const int *sdisps,
197+
ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t sdisps,
195198
struct ompi_datatype_t *sdtype,
196-
void* rbuf, const int *rcounts, const int *rdisps,
199+
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
197200
struct ompi_datatype_t *rdtype,
198201
struct ompi_communicator_t *comm,
199202
mca_coll_base_module_t *module)
@@ -230,21 +233,21 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
230233
recvfrom = (rank + size - step) % size;
231234

232235
/* Determine sending and receiving locations */
233-
psnd = (char*)sbuf + (ptrdiff_t)sdisps[sendto] * sext;
234-
prcv = (char*)rbuf + (ptrdiff_t)rdisps[recvfrom] * rext;
236+
psnd = (char*)sbuf + ompi_disp_array_get(sdisps, sendto) * sext;
237+
prcv = (char*)rbuf + ompi_disp_array_get(rdisps, recvfrom) * rext;
235238

236239
/* send and receive */
237-
if (0 < rcounts[recvfrom] && 0 < rdtype_size) {
238-
err = MCA_PML_CALL(irecv(prcv, rcounts[recvfrom], rdtype, recvfrom,
240+
if (0 < ompi_count_array_get(rcounts, recvfrom) && 0 < rdtype_size) {
241+
err = MCA_PML_CALL(irecv(prcv, ompi_count_array_get(rcounts, recvfrom), rdtype, recvfrom,
239242
MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
240243
if (MPI_SUCCESS != err) {
241244
line = __LINE__;
242245
goto err_hndl;
243246
}
244247
}
245248

246-
if (0 < scounts[sendto] && 0 < sdtype_size) {
247-
err = MCA_PML_CALL(send(psnd, scounts[sendto], sdtype, sendto,
249+
if (0 < ompi_count_array_get(scounts, sendto) && 0 < sdtype_size) {
250+
err = MCA_PML_CALL(send(psnd, ompi_count_array_get(scounts, sendto), sdtype, sendto,
248251
MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD, comm));
249252
if (MPI_SUCCESS != err) {
250253
line = __LINE__;
@@ -280,9 +283,9 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
280283
* differently and so will not have to duplicate code.
281284
*/
282285
int
283-
ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts, const int *sdisps,
286+
ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t sdisps,
284287
struct ompi_datatype_t *sdtype,
285-
void *rbuf, const int *rcounts, const int *rdisps,
288+
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
286289
struct ompi_datatype_t *rdtype,
287290
struct ompi_communicator_t *comm,
288291
mca_coll_base_module_t *module)
@@ -313,11 +316,11 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
313316
ompi_datatype_type_extent(rdtype, &rext);
314317

315318
/* Simple optimization - handle send to self first */
316-
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[rank] * sext;
317-
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[rank] * rext;
318-
if (0 < scounts[rank] && 0 < sdtype_size) {
319-
err = ompi_datatype_sndrcv(psnd, scounts[rank], sdtype,
320-
prcv, rcounts[rank], rdtype);
319+
psnd = ((char *) sbuf) + ompi_disp_array_get(sdisps, rank) * sext;
320+
prcv = ((char *) rbuf) + ompi_disp_array_get(rdisps, rank) * rext;
321+
if (0 < ompi_count_array_get(scounts, rank) && 0 < sdtype_size) {
322+
err = ompi_datatype_sndrcv(psnd, ompi_count_array_get(scounts, rank), sdtype,
323+
prcv, ompi_count_array_get(rcounts, rank), rdtype);
321324
if (MPI_SUCCESS != err) {
322325
return err;
323326
}
@@ -339,10 +342,10 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
339342
continue;
340343
}
341344

342-
if (0 < rcounts[i] && 0 < rdtype_size) {
345+
if (0 < ompi_count_array_get(rcounts, i) && 0 < rdtype_size) {
343346
++nreqs;
344-
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[i] * rext;
345-
err = MCA_PML_CALL(irecv_init(prcv, rcounts[i], rdtype,
347+
prcv = ((char *) rbuf) + ompi_disp_array_get(rdisps, i) * rext;
348+
err = MCA_PML_CALL(irecv_init(prcv, ompi_count_array_get(rcounts, i), rdtype,
346349
i, MCA_COLL_BASE_TAG_ALLTOALLV, comm,
347350
preq++));
348351
if (MPI_SUCCESS != err) { goto err_hndl; }
@@ -355,10 +358,10 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
355358
continue;
356359
}
357360

358-
if (0 < scounts[i] && 0 < sdtype_size) {
361+
if (0 < ompi_count_array_get(scounts, i) && 0 < sdtype_size) {
359362
++nreqs;
360-
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[i] * sext;
361-
err = MCA_PML_CALL(isend_init(psnd, scounts[i], sdtype,
363+
psnd = ((char *) sbuf) + ompi_disp_array_get(sdisps, i) * sext;
364+
err = MCA_PML_CALL(isend_init(psnd, ompi_count_array_get(scounts, i), sdtype,
362365
i, MCA_COLL_BASE_TAG_ALLTOALLV,
363366
MCA_PML_BASE_SEND_STANDARD, comm,
364367
preq++));

ompi/mca/coll/base/coll_base_functions.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,27 @@ typedef enum COLLTYPE {
6969

7070
/* defined arg lists to simply auto inclusion of user overriding decision functions */
7171
#define ALLGATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
72-
#define ALLGATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
72+
#define ALLGATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t displs, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
7373
#define ALLREDUCE_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
7474
#define ALLTOALL_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
75-
#define ALLTOALLV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
76-
#define ALLTOALLW_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t * const sendtypes[], void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
75+
#define ALLTOALLV_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
76+
#define ALLTOALLW_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t * const sendtypes[], void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
7777
#define BARRIER_BASE_ARGS struct ompi_communicator_t *comm
7878
#define BCAST_BASE_ARGS void *buffer, size_t count, struct ompi_datatype_t *datatype, int root, struct ompi_communicator_t *comm
7979
#define EXSCAN_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
8080
#define GATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
81-
#define GATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
81+
#define GATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t displs, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
8282
#define REDUCE_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm
83-
#define REDUCESCATTER_BASE_ARGS const void *sendbuf, void *recvbuf, const int recvcounts[], struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
83+
#define REDUCESCATTER_BASE_ARGS const void *sendbuf, void *recvbuf, ompi_count_array_t recvcounts, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
8484
#define REDUCESCATTERBLOCK_BASE_ARGS const void *sendbuf, void *recvbuf, size_t recvcount, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
8585
#define SCAN_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
8686
#define SCATTER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
87-
#define SCATTERV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int displs[], struct ompi_datatype_t *sendtype, void *recvbuf, int recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
87+
#define SCATTERV_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t displs, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
8888
#define NEIGHBOR_ALLGATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
89-
#define NEIGHBOR_ALLGATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
89+
#define NEIGHBOR_ALLGATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t displs, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
9090
#define NEIGHBOR_ALLTOALL_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
91-
#define NEIGHBOR_ALLTOALLV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
92-
#define NEIGHBOR_ALLTOALLW_BASE_ARGS const void *sendbuf, const int sendcounts[], const MPI_Aint sdispls[], struct ompi_datatype_t * const sendtypes[], void *recvbuf, const int recvcounts[], const MPI_Aint rdispls[], struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
91+
#define NEIGHBOR_ALLTOALLV_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
92+
#define NEIGHBOR_ALLTOALLW_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t * const sendtypes[], void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
9393

9494
#define ALLGATHER_ARGS ALLGATHER_BASE_ARGS, mca_coll_base_module_t *module
9595
#define ALLGATHERV_ARGS ALLGATHERV_BASE_ARGS, mca_coll_base_module_t *module
@@ -227,7 +227,7 @@ int mca_coll_base_alltoall_intra_basic_inplace(const void *rbuf, size_t rcount,
227227
/* AlltoAllV */
228228
int ompi_coll_base_alltoallv_intra_pairwise(ALLTOALLV_ARGS);
229229
int ompi_coll_base_alltoallv_intra_basic_linear(ALLTOALLV_ARGS);
230-
int mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts, const int *rdisps,
230+
int mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
231231
struct ompi_datatype_t *rdtype,
232232
struct ompi_communicator_t *comm,
233233
mca_coll_base_module_t *module); /* special version for INPLACE */

0 commit comments

Comments
 (0)