From f6256407ea1534271d59fcb8fb2bf979da3c621c Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Wed, 6 Nov 2024 16:37:00 -0600 Subject: [PATCH 01/10] ch4: refactor can_do_tag query Move the wrapper into mpidig.h so we can use it in other paths. Rename the interface to MPIDIG_can_do_tag(bool is_local). --- src/mpid/ch4/src/mpidig.h | 9 +++++++++ src/mpid/ch4/src/mpidig_pt2pt_callbacks.c | 20 ++++++++------------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/mpid/ch4/src/mpidig.h b/src/mpid/ch4/src/mpidig.h index f253873c87c..6f292064ee7 100644 --- a/src/mpid/ch4/src/mpidig.h +++ b/src/mpid/ch4/src/mpidig.h @@ -135,6 +135,15 @@ typedef struct MPIDIG_global_t { } MPIDIG_global_t; extern MPIDIG_global_t MPIDIG_global; +MPL_STATIC_INLINE_PREFIX int MPIDIG_can_do_tag(bool is_local) +{ +#ifdef MPIDI_CH4_DIRECT_NETMOD + return MPIDI_NM_am_can_do_tag(); +#else + return is_local ? MPIDI_SHM_am_can_do_tag() : MPIDI_NM_am_can_do_tag(); +#endif +} + MPL_STATIC_INLINE_PREFIX int MPIDIG_get_next_am_tag(MPIR_Comm * comm) { int tag = comm->next_am_tag++; diff --git a/src/mpid/ch4/src/mpidig_pt2pt_callbacks.c b/src/mpid/ch4/src/mpidig_pt2pt_callbacks.c index a9c5d86c5ad..f417d87e931 100644 --- a/src/mpid/ch4/src/mpidig_pt2pt_callbacks.c +++ b/src/mpid/ch4/src/mpidig_pt2pt_callbacks.c @@ -10,15 +10,6 @@ static int handle_unexp_cmpl(MPIR_Request * rreq); static int recv_target_cmpl_cb(MPIR_Request * rreq); -static int can_do_tag(MPIR_Request * rreq) -{ -#ifdef MPIDI_CH4_DIRECT_NETMOD - return MPIDI_NM_am_can_do_tag(); -#else - return MPIDI_REQUEST(rreq, is_local) ? MPIDI_SHM_am_can_do_tag() : MPIDI_NM_am_can_do_tag(); -#endif -} - int MPIDIG_do_cts(MPIR_Request * rreq) { int mpi_errno = MPI_SUCCESS; @@ -30,13 +21,18 @@ int MPIDIG_do_cts(MPIR_Request * rreq) MPIDIG_send_cts_msg_t am_hdr; am_hdr.sreq_ptr = (MPIDIG_REQUEST(rreq, req->rreq.peer_req_ptr)); am_hdr.rreq_ptr = rreq; - if (can_do_tag(rreq)) { +#ifndef MPIDI_CH4_DIRECT_NETMOD + int is_local = MPIDI_REQUEST(rreq, is_local); +#else + int is_local = 0; +#endif + if (MPIDIG_can_do_tag(is_local)) { am_hdr.tag = MPIDIG_get_next_am_tag(rreq->comm); CH4_CALL(am_tag_recv(source_rank, rreq->comm, MPIDIG_TAG_RECV_COMPLETE, am_hdr.tag, MPIDIG_REQUEST(rreq, buffer), MPIDIG_REQUEST(rreq, count), MPIDIG_REQUEST(rreq, datatype), remote_vci, local_vci, rreq), - MPIDI_REQUEST(rreq, is_local), mpi_errno); + is_local, mpi_errno); MPIR_ERR_CHECK(mpi_errno); } else { am_hdr.tag = -1; @@ -48,7 +44,7 @@ int MPIDIG_do_cts(MPIR_Request * rreq) CH4_CALL(am_send_hdr_reply(rreq->comm, source_rank, MPIDIG_SEND_CTS, &am_hdr, sizeof(am_hdr), local_vci, remote_vci), - MPIDI_REQUEST(rreq, is_local), mpi_errno); + is_local, mpi_errno); MPIR_ERR_CHECK(mpi_errno); fn_exit: From 97e5995df6c92c7d1595cda5c2656eb9c96e757a Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 7 Nov 2024 15:51:27 -0600 Subject: [PATCH 02/10] ch4/rma: minor refactor get_target_cmpl_cb Rearrange the branches for cleaner code and prepare for the next patch. --- src/mpid/ch4/src/mpidig_rma_callbacks.c | 45 +++++++++++-------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.c b/src/mpid/ch4/src/mpidig_rma_callbacks.c index 5b7014a107f..5741a3a824c 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.c +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.c @@ -896,42 +896,37 @@ static int get_target_cmpl_cb(MPIR_Request * rreq) get_ack.greq_ptr = MPIDIG_REQUEST(rreq, req->greq.greq_ptr); win = rreq->u.rma.win; - int local_vci = MPIDIG_REQUEST(rreq, req->local_vci); - int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci); - if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt) == NULL) { + if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt)) { + /* FIXME: MPIR_Typerep_unflatten should allocate the new object */ + MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem); + if (!dt) { + MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", + "MPIR_Datatype_mem"); + } + MPIR_Object_set_ref(dt, 1); + MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt)); + MPIDIG_REQUEST(rreq, datatype) = dt->handle; + /* count is still target_data_sz now, use it for reply */ + get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count); + MPIDIG_REQUEST(rreq, count) /= dt->size; + } else { MPIDI_Datatype_check_size(MPIDIG_REQUEST(rreq, datatype), MPIDIG_REQUEST(rreq, count), get_ack.target_data_sz); + } + + int local_vci = MPIDIG_REQUEST(rreq, req->local_vci); + int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci); + if (true) { CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank), MPIDIG_GET_ACK, &get_ack, sizeof(get_ack), MPIDIG_REQUEST(rreq, buffer), MPIDIG_REQUEST(rreq, count), MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci, rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno); - MPID_Request_complete(rreq); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; } - - /* FIXME: MPIR_Typerep_unflatten should allocate the new object */ - MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem); - if (!dt) { - MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", - "MPIR_Datatype_mem"); - } - MPIR_Object_set_ref(dt, 1); - MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt)); - MPIDIG_REQUEST(rreq, datatype) = dt->handle; - /* count is still target_data_sz now, use it for reply */ - get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count); - MPIDIG_REQUEST(rreq, count) /= dt->size; - - CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank), - MPIDIG_GET_ACK, &get_ack, sizeof(get_ack), - MPIDIG_REQUEST(rreq, buffer), - MPIDIG_REQUEST(rreq, count), dt->handle, local_vci, - remote_vci, rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno); MPID_Request_complete(rreq); MPIR_ERR_CHECK(mpi_errno); + fn_exit: MPIR_FUNC_EXIT; return mpi_errno; From 1f97f7be9e209fc0bd9dfebf1d86aa9489ef24fa Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 7 Nov 2024 11:10:05 -0600 Subject: [PATCH 03/10] ch4: use am_tag_{send,recv} in MPIDIG get When target reply data to origin get, use am_tag_send if available. --- src/mpid/ch4/include/mpidpre.h | 1 + src/mpid/ch4/src/ch4_types.h | 1 + src/mpid/ch4/src/mpidig.h | 1 + src/mpid/ch4/src/mpidig_init.c | 1 + src/mpid/ch4/src/mpidig_rma.h | 17 ++++++++++++++--- src/mpid/ch4/src/mpidig_rma_callbacks.c | 20 +++++++++++++++++++- src/mpid/ch4/src/mpidig_rma_callbacks.h | 1 + 7 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/mpid/ch4/include/mpidpre.h b/src/mpid/ch4/include/mpidpre.h index ea2ceb78eaa..0281a5c2d72 100644 --- a/src/mpid/ch4/include/mpidpre.h +++ b/src/mpid/ch4/include/mpidpre.h @@ -118,6 +118,7 @@ typedef struct MPIDIG_put_req_t { typedef struct MPIDIG_get_req_t { MPIR_Request *greq_ptr; void *flattened_dt; + int am_tag; } MPIDIG_get_req_t; typedef struct MPIDIG_cswap_req_t { diff --git a/src/mpid/ch4/src/ch4_types.h b/src/mpid/ch4/src/ch4_types.h index 7b2b4f0afc6..910e6ebaf4c 100644 --- a/src/mpid/ch4/src/ch4_types.h +++ b/src/mpid/ch4/src/ch4_types.h @@ -142,6 +142,7 @@ typedef struct MPIDIG_get_msg_t { MPI_Aint target_datatype; MPI_Aint target_true_lb; int flattened_sz; + int am_tag; } MPIDIG_get_msg_t; typedef struct MPIDIG_get_ack_msg_t { diff --git a/src/mpid/ch4/src/mpidig.h b/src/mpid/ch4/src/mpidig.h index 6f292064ee7..3b7d934467b 100644 --- a/src/mpid/ch4/src/mpidig.h +++ b/src/mpid/ch4/src/mpidig.h @@ -77,6 +77,7 @@ enum { enum { MPIDIG_TAG_RECV_COMPLETE = 0, + MPIDIG_TAG_GET_COMPLETE, MPIDIG_TAG_RECV_STATIC_MAX }; diff --git a/src/mpid/ch4/src/mpidig_init.c b/src/mpid/ch4/src/mpidig_init.c index 83683d63e4b..c09a0cd8c3e 100644 --- a/src/mpid/ch4/src/mpidig_init.c +++ b/src/mpid/ch4/src/mpidig_init.c @@ -158,6 +158,7 @@ int MPIDIG_am_init(void) MPIDIG_am_rndv_reg_cb(MPIDIG_RNDV_GENERIC, &MPIDIG_do_cts); MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_RECV_COMPLETE, &MPIDIG_tag_recv_complete); + MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_GET_COMPLETE, &MPIDIG_tag_get_complete); MPIDIG_am_comm_abort_init(); diff --git a/src/mpid/ch4/src/mpidig_rma.h b/src/mpid/ch4/src/mpidig_rma.h index 63545a456d5..62da94f0141 100644 --- a/src/mpid/ch4/src/mpidig_rma.h +++ b/src/mpid/ch4/src/mpidig_rma.h @@ -220,6 +220,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co * counter in request, thus it can be decreased at request completion. */ MPIDIG_win_cmpl_cnts_incr(win, target_rank, &sreq->dev.completion_notification); + bool is_local; + is_local = MPIDI_rank_is_local(target_rank, win->comm_ptr); + if (MPIDIG_can_do_tag(is_local)) { + am_hdr.am_tag = MPIDIG_get_next_am_tag(win->comm_ptr); + CH4_CALL(am_tag_recv(target_rank, win->comm_ptr, MPIDIG_TAG_GET_COMPLETE, am_hdr.am_tag, + origin_addr, origin_count, origin_datatype, vci_target, vci, sreq), + is_local, mpi_errno); + MPIR_ERR_CHECK(mpi_errno); + } else { + am_hdr.am_tag = -1; + } + int is_contig; MPIR_Datatype_is_contig(target_datatype, &is_contig); if (MPIR_DATATYPE_IS_PREDEFINED(target_datatype) || is_contig) { @@ -228,8 +240,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co MPIR_T_PVAR_TIMER_END(RMA, rma_amhdr_set); CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr), - NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq), - MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno); + NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq), is_local, mpi_errno); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -242,7 +253,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr), flattened_dt, flattened_sz, MPI_BYTE, vci, vci_target, sreq), - MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno); + is_local, mpi_errno); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.c b/src/mpid/ch4/src/mpidig_rma_callbacks.c index 5741a3a824c..98bbb7f1598 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.c +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.c @@ -916,7 +916,15 @@ static int get_target_cmpl_cb(MPIR_Request * rreq) int local_vci = MPIDIG_REQUEST(rreq, req->local_vci); int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci); - if (true) { + if (MPIDIG_REQUEST(rreq, req->greq.am_tag) >= 0) { + int src_rank = MPIDIG_REQUEST(rreq, u.target.origin_rank); + CH4_CALL(am_tag_send(src_rank, win->comm_ptr, MPIDIG_GET_ACK, + MPIDIG_REQUEST(rreq, req->greq.am_tag), + MPIDIG_REQUEST(rreq, buffer), + MPIDIG_REQUEST(rreq, count), + MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci, rreq), + MPIDI_REQUEST(rreq, is_local), mpi_errno); + } else { CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank), MPIDIG_GET_ACK, &get_ack, sizeof(get_ack), MPIDIG_REQUEST(rreq, buffer), @@ -2099,6 +2107,7 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, MPIDIG_REQUEST(rreq, req->greq.flattened_dt) = NULL; MPIDIG_REQUEST(rreq, req->greq.greq_ptr) = msg_hdr->greq_ptr; MPIDIG_REQUEST(rreq, u.target.origin_rank) = msg_hdr->src_rank; + MPIDIG_REQUEST(rreq, req->greq.am_tag) = msg_hdr->am_tag; if (msg_hdr->flattened_sz) { void *flattened_dt = MPL_malloc(msg_hdr->flattened_sz, MPL_MEM_BUFFER); @@ -2159,3 +2168,12 @@ int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, MPIR_FUNC_EXIT; return mpi_errno; } + +int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status) +{ + int mpi_errno = MPI_SUCCESS; + + mpi_errno = get_ack_target_cmpl_cb(req); + + return mpi_errno; +} diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.h b/src/mpid/ch4/src/mpidig_rma_callbacks.h index 3fe440cdc7d..ac8d7c6b887 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.h +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.h @@ -112,5 +112,6 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req); int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req); +int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status); #endif /* MPIDIG_RMA_CALLBACKS_H_INCLUDED */ From b0a5b5dc1bef9d8532536c34dd543238254089ec Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 7 Nov 2024 23:02:48 -0600 Subject: [PATCH 04/10] ch4/rma: minor refactor in put_dt protocol Move the code that unflatens the datatype from MPIDIG_put_data_target_msg_cb to put_dt_target_cmpl_cb. This allows us to post tag_am_recv in put_dt_target_cmpl_cb. --- src/mpid/ch4/src/mpidig_rma_callbacks.c | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.c b/src/mpid/ch4/src/mpidig_rma_callbacks.c index 98bbb7f1598..fabdb70ef36 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.c +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.c @@ -971,6 +971,17 @@ static int put_dt_target_cmpl_cb(MPIR_Request * rreq) MPIR_FUNC_ENTER; + /* FIXME: MPIR_Typerep_unflatten should allocate the new object */ + MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem); + if (!dt) { + MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", + "MPIR_Datatype_mem"); + } + /* Note: handle is filled in by MPIR_Handle_obj_alloc() */ + MPIR_Object_set_ref(dt, 1); + MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->preq.flattened_dt)); + MPIDIG_REQUEST(rreq, datatype) = dt->handle; + ack_msg.src_rank = MPIDIG_REQUEST(rreq, u.target.origin_rank); ack_msg.origin_preq_ptr = MPIDIG_REQUEST(rreq, req->preq.preq_ptr); ack_msg.target_preq_ptr = rreq; @@ -1718,17 +1729,6 @@ int MPIDIG_put_data_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, rreq = (MPIR_Request *) msg_hdr->preq_ptr; - /* FIXME: MPIR_Typerep_unflatten should allocate the new object */ - MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem); - if (!dt) { - MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", - "MPIR_Datatype_mem"); - } - /* Note: handle is filled in by MPIR_Handle_obj_alloc() */ - MPIR_Object_set_ref(dt, 1); - MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->preq.flattened_dt)); - MPIDIG_REQUEST(rreq, datatype) = dt->handle; - MPIDIG_REQUEST(rreq, req->target_cmpl_cb) = put_target_cmpl_cb; MPIDIG_recv_type_init(MPIDIG_REQUEST(rreq, req->preq.origin_data_sz), rreq); From 53704b2e7a42e3133c17edbae75c4a05dc060084 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 7 Nov 2024 22:58:58 -0600 Subject: [PATCH 05/10] ch4: use am_tag_{send,recv} in MPIDIG put In the MPIDIG_PUT_DT_REQ protocol, use am_tag_{send,recv} when available. --- src/mpid/ch4/src/ch4_types.h | 1 + src/mpid/ch4/src/mpidig.h | 1 + src/mpid/ch4/src/mpidig_init.c | 1 + src/mpid/ch4/src/mpidig_rma_callbacks.c | 55 +++++++++++++++++++++---- src/mpid/ch4/src/mpidig_rma_callbacks.h | 1 + 5 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/mpid/ch4/src/ch4_types.h b/src/mpid/ch4/src/ch4_types.h index 910e6ebaf4c..914b89c11a6 100644 --- a/src/mpid/ch4/src/ch4_types.h +++ b/src/mpid/ch4/src/ch4_types.h @@ -117,6 +117,7 @@ typedef struct MPIDIG_put_msg_t { typedef struct MPIDIG_put_dt_ack_msg_t { int src_rank; + int am_tag; MPIR_Request *target_preq_ptr; MPIR_Request *origin_preq_ptr; } MPIDIG_put_dt_ack_msg_t; diff --git a/src/mpid/ch4/src/mpidig.h b/src/mpid/ch4/src/mpidig.h index 3b7d934467b..18e846a9e27 100644 --- a/src/mpid/ch4/src/mpidig.h +++ b/src/mpid/ch4/src/mpidig.h @@ -78,6 +78,7 @@ enum { enum { MPIDIG_TAG_RECV_COMPLETE = 0, MPIDIG_TAG_GET_COMPLETE, + MPIDIG_TAG_PUT_COMPLETE, MPIDIG_TAG_RECV_STATIC_MAX }; diff --git a/src/mpid/ch4/src/mpidig_init.c b/src/mpid/ch4/src/mpidig_init.c index c09a0cd8c3e..a3b74cc3b8e 100644 --- a/src/mpid/ch4/src/mpidig_init.c +++ b/src/mpid/ch4/src/mpidig_init.c @@ -159,6 +159,7 @@ int MPIDIG_am_init(void) MPIDIG_am_rndv_reg_cb(MPIDIG_RNDV_GENERIC, &MPIDIG_do_cts); MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_RECV_COMPLETE, &MPIDIG_tag_recv_complete); MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_GET_COMPLETE, &MPIDIG_tag_get_complete); + MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_PUT_COMPLETE, &MPIDIG_tag_put_complete); MPIDIG_am_comm_abort_init(); diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.c b/src/mpid/ch4/src/mpidig_rma_callbacks.c index fabdb70ef36..3ce80bf0ef0 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.c +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.c @@ -988,6 +988,26 @@ static int put_dt_target_cmpl_cb(MPIR_Request * rreq) int local_vci = MPIDIG_REQUEST(rreq, req->local_vci); int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci); + MPIR_Comm *comm = rreq->u.rma.win->comm_ptr; + + bool is_local; +#ifndef MPIDI_CH4_DIRECT_NETMOD + is_local = MPIDI_REQUEST(rreq, is_local); +#else + is_local = 0; +#endif + if (MPIDIG_can_do_tag(is_local)) { + ack_msg.am_tag = MPIDIG_get_next_am_tag(comm); + CH4_CALL(am_tag_recv(ack_msg.src_rank, comm, MPIDIG_TAG_PUT_COMPLETE, ack_msg.am_tag, + MPIDIG_REQUEST(rreq, buffer), + MPIDIG_REQUEST(rreq, count), + MPIDIG_REQUEST(rreq, datatype), + local_vci, remote_vci, rreq), is_local, mpi_errno); + MPIR_ERR_CHECK(mpi_errno); + } else { + ack_msg.am_tag = -1; + } + CH4_CALL(am_send_hdr_reply (rreq->u.rma.win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank), MPIDIG_PUT_DT_ACK, &ack_msg, sizeof(ack_msg), local_vci, remote_vci), @@ -1605,13 +1625,25 @@ int MPIDIG_put_dt_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_s /* origin datatype to be released in MPIDIG_put_data_origin_cb */ MPIDIG_REQUEST(rreq, datatype) = MPIDIG_REQUEST(origin_req, datatype); - CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(origin_req, u.origin.target_rank), - MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg), - MPIDIG_REQUEST(origin_req, buffer), - MPIDIG_REQUEST(origin_req, count), - MPIDIG_REQUEST(origin_req, datatype), - local_vci, remote_vci, rreq), - (attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno); + int target_rank = MPIDIG_REQUEST(origin_req, u.origin.target_rank); + if (msg_hdr->am_tag >= 0) { + CH4_CALL(am_tag_send(target_rank, win->comm_ptr, MPIDIG_PUT_DAT_REQ, + msg_hdr->am_tag, + MPIDIG_REQUEST(origin_req, buffer), + MPIDIG_REQUEST(origin_req, count), + MPIDIG_REQUEST(origin_req, datatype), + local_vci, remote_vci, rreq), + (attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno); + + } else { + CH4_CALL(am_isend_reply(win->comm_ptr, target_rank, + MPIDIG_PUT_DAT_REQ, &dat_msg, sizeof(dat_msg), + MPIDIG_REQUEST(origin_req, buffer), + MPIDIG_REQUEST(origin_req, count), + MPIDIG_REQUEST(origin_req, datatype), + local_vci, remote_vci, rreq), + (attr & MPIDIG_AM_ATTR__IS_LOCAL), mpi_errno); + } MPIR_ERR_CHECK(mpi_errno); if (attr & MPIDIG_AM_ATTR__IS_ASYNC) { @@ -2177,3 +2209,12 @@ int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status) return mpi_errno; } + +int MPIDIG_tag_put_complete(MPIR_Request * req, MPI_Status * status) +{ + int mpi_errno = MPI_SUCCESS; + + mpi_errno = put_target_cmpl_cb(req); + + return mpi_errno; +} diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.h b/src/mpid/ch4/src/mpidig_rma_callbacks.h index ac8d7c6b887..87dc729f83d 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.h +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.h @@ -113,5 +113,6 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req); int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status); +int MPIDIG_tag_put_complete(MPIR_Request * req, MPI_Status * status); #endif /* MPIDIG_RMA_CALLBACKS_H_INCLUDED */ From 59ed7658567861e45df68487aa1fe1c55138e01d Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 9 Nov 2024 17:40:14 -0600 Subject: [PATCH 06/10] ch4/rma: warning fix due to missing error check Potentially I should squash this... --- src/mpid/ch4/src/mpidig_rma_callbacks.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.c b/src/mpid/ch4/src/mpidig_rma_callbacks.c index 3ce80bf0ef0..6e8f1f5a8d0 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.c +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.c @@ -1762,7 +1762,8 @@ int MPIDIG_put_data_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, rreq = (MPIR_Request *) msg_hdr->preq_ptr; MPIDIG_REQUEST(rreq, req->target_cmpl_cb) = put_target_cmpl_cb; - MPIDIG_recv_type_init(MPIDIG_REQUEST(rreq, req->preq.origin_data_sz), rreq); + mpi_errno = MPIDIG_recv_type_init(MPIDIG_REQUEST(rreq, req->preq.origin_data_sz), rreq); + MPIR_ERR_CHECK(mpi_errno); if (attr & MPIDIG_AM_ATTR__IS_ASYNC) { *req = rreq; From 755d7f626b4c77b64b09c2f87c1ca06d8b8fe32d Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 9 Nov 2024 22:26:44 -0600 Subject: [PATCH 07/10] datatype: avoid divide by 0 in get_elements_x When the datatype is a struct with mixed elements, we set its basic_type to MPI_DATATYPE_NULL. This certainly is not ideal. But for now, lets avoid divide by 0 error in MPIR_Type_get_basic_type_elements and simply return 0. MPIR_Type_get_basic_type_elements is used in the OFI receive complete event to check whether we received partial elements. We'll skip the error checking for struct types for now. --- src/mpi/datatype/get_elements_x.c | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/mpi/datatype/get_elements_x.c b/src/mpi/datatype/get_elements_x.c index e85a14bb6ac..eea75423a9f 100644 --- a/src/mpi/datatype/get_elements_x.c +++ b/src/mpi/datatype/get_elements_x.c @@ -96,6 +96,12 @@ static MPI_Count MPIR_Type_get_basic_type_elements(MPI_Count * bytes_p, break; } + if (type1_sz + type2_sz == 0) { + /* this is likely a struct type with mixed basic elements. Let's just bail for now */ + *bytes_p = 0; + return 0; + } + /* determine the number of elements in the region */ elements = 2 * (usable_bytes / (type1_sz + type2_sz)); if (usable_bytes % (type1_sz + type2_sz) >= type1_sz) From 32a9f0b45ee0b933633b7ba00f379180343a82f5 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 9 Nov 2024 23:37:06 -0600 Subject: [PATCH 08/10] datatype: missing MPID_Type_commit_hook in unflatten We need call MPID_Type_commit_hook for unflattened datatypes, or UCX won't able to send or receive an unflattened datatypes. Previously, unflattened datatypes are only used in RMA AM and they are handled in MPIDIG. Now with am_tag_{send,recv}, we are directly using UCX to send/recv such unflattened types. --- src/mpi/datatype/typerep/src/typerep_flatten.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mpi/datatype/typerep/src/typerep_flatten.c b/src/mpi/datatype/typerep/src/typerep_flatten.c index f4f36b2bb86..1caf81afc0d 100644 --- a/src/mpi/datatype/typerep/src/typerep_flatten.c +++ b/src/mpi/datatype/typerep/src/typerep_flatten.c @@ -132,6 +132,8 @@ int MPIR_Typerep_unflatten(MPIR_Datatype * datatype_ptr, void *flattened_type) MPIR_ERR_CHECK(mpi_errno); #endif + MPID_Type_commit_hook(datatype_ptr); + fn_exit: return mpi_errno; From c3890a3bbe630b29d8e1b8d2751ac95f57076383 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 9 Nov 2024 23:41:41 -0600 Subject: [PATCH 09/10] ch4/ucx: fix MPIDI_UCX_TAG_AM Without explicit 1ULL, C will default it to int, resulting overflows. --- src/mpid/ch4/netmod/ucx/ucx_types.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mpid/ch4/netmod/ucx/ucx_types.h b/src/mpid/ch4/netmod/ucx/ucx_types.h index 37369641605..ef53b3d14d4 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_types.h +++ b/src/mpid/ch4/netmod/ucx/ucx_types.h @@ -52,7 +52,7 @@ extern ucp_generic_dt_ops_t MPIDI_UCX_datatype_ops; #define MPIDI_UCX_TAG_BITS (64 - MPIDI_UCX_CONTEXT_ID_BITS - MPIDI_UCX_RANK_BITS - MPIDI_UCX_PROTOCOL_BITS) /* protocol bits */ -#define MPIDI_UCX_TAG_AM (1 << MPIDI_UCX_TAG_BITS) +#define MPIDI_UCX_TAG_AM (1ULL << MPIDI_UCX_TAG_BITS) #define MPIDI_UCX_RANK_SHIFT (MPIDI_UCX_TAG_BITS + MPIDI_UCX_PROTOCOL_BITS) #define MPIDI_UCX_CONTEXT_ID_SHIFT (MPIDI_UCX_TAG_BITS + MPIDI_UCX_PROTOCOL_BITS + MPIDI_UCX_RANK_BITS) From e860410da7d18ab853576f8b28d18c169f02eee1 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 10 Nov 2024 09:01:12 -0600 Subject: [PATCH 10/10] ch4: fix race condition in setting is_local in send requests We need set is_local in a requests inside the vci critical section or race condition may happen. Only the send requests that may go into RNDV active messages need it set. This fixes the occasional failures in am-only threads/pt2pt/multisend2 test. --- src/mpid/ch4/netmod/ofi/ofi_send.h | 4 ++++ src/mpid/ch4/src/ch4_send.h | 2 -- src/mpid/ch4/src/mpidig_send.h | 3 +++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_send.h b/src/mpid/ch4/netmod/ofi/ofi_send.h index 897f05b42a7..3d409ef5fa1 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_send.h +++ b/src/mpid/ch4/netmod/ofi/ofi_send.h @@ -104,6 +104,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_issue_ack_recv(MPIR_Request * sreq, MPIR_ ackreq->remote_addr = MPIDI_OFI_av_to_phys(addr, nic, vci_remote); ackreq->match_bits = match_bits; +#ifndef MPIDI_CH4_DIRECT_NETMOD + /* set is_local in case we go into active messages later */ + MPIDI_REQUEST(sreq, is_local) = 0; +#endif MPIDI_OFI_CALL_RETRY(fi_trecv(MPIDI_OFI_global.ctx[ackreq->ctx_idx].rx, ackreq->ack_hdr, ackreq->ack_hdr_sz, NULL, ackreq->remote_addr, ackreq->match_bits, 0ULL, (void *) &(ackreq->context)), diff --git a/src/mpid/ch4/src/ch4_send.h b/src/mpid/ch4/src/ch4_send.h index 1e78055f081..a8be81b4b2b 100644 --- a/src/mpid/ch4/src/ch4_send.h +++ b/src/mpid/ch4/src/ch4_send.h @@ -29,8 +29,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_isend(const void *buf, mpi_errno = MPIDI_SHM_mpi_isend(buf, count, datatype, rank, tag, comm, attr, av, req); else mpi_errno = MPIDI_NM_mpi_isend(buf, count, datatype, rank, tag, comm, attr, av, req); - if (mpi_errno == MPI_SUCCESS) - MPIDI_REQUEST(*req, is_local) = r; #endif MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/src/mpidig_send.h b/src/mpid/ch4/src/mpidig_send.h index 6c2a0143cc6..ba64f0831f7 100644 --- a/src/mpid/ch4/src/mpidig_send.h +++ b/src/mpid/ch4/src/mpidig_send.h @@ -99,6 +99,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_isend_impl(const void *buf, MPI_Aint count, src_vci, dst_vci, sreq), is_local, mpi_errno); } else { /* RNDV send */ +#ifndef MPIDI_CH4_DIRECT_NETMOD + MPIDI_REQUEST(sreq, is_local) = is_local; +#endif MPIDIG_REQUEST(sreq, buffer) = (void *) buf; MPIDIG_REQUEST(sreq, count) = count; MPIDIG_REQUEST(sreq, datatype) = datatype;