From cbd5778789ac96e8a46afe7561c7453ab0ed3eee Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 27 Dec 2024 11:39:43 -0600 Subject: [PATCH 01/25] comm: refactor comm->is_tainted to comm->vcis_enabled Replace the rather ambiguous field is_tainted with vcis_enabled, thus allowing vci activation on a per-comm basis. It is inherited if the new comm is created from within an parent comm that has vcis_enabled. If the parent comm vcis_enabled=false, then all its descendents will have vcis_enabled until they are turned via separate APIs (to be added in the future). Intercomm and intercomm_merge may include processes outside originating comm, thus vcis_enabled=false by default. MPI_Comm_create may create an intercomm that inherits vcis_enabled=true. This is an exception because both local processes and remote processes are from within originating comm that has vcis_enabled. For now, we switch on vcis_enabled in comm_wrold after post_init. With future extension, it is possible to allow user explicitly set up multi-vics on a smaller comm than comm world. --- src/include/mpir_comm.h | 7 ++----- src/mpi/comm/comm_impl.c | 10 +++------- src/mpi/comm/comm_split.c | 2 +- src/mpi/comm/commutil.c | 22 +++++++++------------ src/mpid/ch4/src/ch4_init.c | 1 + src/mpid/ch4/src/ch4_vci.h | 38 +++++++++++++++---------------------- 6 files changed, 31 insertions(+), 49 deletions(-) diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 8af43abc6d7..c3405a7b2b2 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -212,11 +212,8 @@ struct MPIR_Comm { * because context_id is non-sequential and can't be used to identify user-level * communicators (due to sub-comms). */ int seq; - /* Certain comm and its offsprings should be restricted to sequence 0 due to - * various restrictions. E.g. multiple-vci doesn't support dynamic process, - * nor intercomms (even after its merge). - */ - int tainted; + /* Whether multiple-vci is enabled. This is ONLY inherited in Comm_dup and Comm_split */ + bool vcis_enabled; int hints[MPIR_COMM_HINT_MAX]; /* Hints to the communicator diff --git a/src/mpi/comm/comm_impl.c b/src/mpi/comm/comm_impl.c index 9dbba6d703f..e4af6be45fe 100644 --- a/src/mpi/comm/comm_impl.c +++ b/src/mpi/comm/comm_impl.c @@ -369,7 +369,7 @@ int MPIR_Comm_create_intra(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co mpi_errno = MPII_Comm_create_map(n, 0, mapping, NULL, mapping_comm, *newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); - (*newcomm_ptr)->tainted = comm_ptr->tainted; + (*newcomm_ptr)->vcis_enabled = comm_ptr->vcis_enabled; mpi_errno = MPIR_Comm_commit(*newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); } else { @@ -525,7 +525,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co mapping, remote_mapping, mapping_comm, *newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); - (*newcomm_ptr)->tainted = comm_ptr->tainted; + (*newcomm_ptr)->vcis_enabled = comm_ptr->vcis_enabled; mpi_errno = MPIR_Comm_commit(*newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); @@ -637,7 +637,7 @@ int MPIR_Comm_create_group_impl(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, in mpi_errno = MPII_Comm_create_map(n, 0, mapping, NULL, mapping_comm, *newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); - (*newcomm_ptr)->tainted = comm_ptr->tainted; + (*newcomm_ptr)->vcis_enabled = comm_ptr->vcis_enabled; mpi_errno = MPIR_Comm_commit(*newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); } else { @@ -1086,7 +1086,6 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader, } MPID_THREAD_CS_EXIT(VCI, local_comm_ptr->mutex); - (*new_intercomm_ptr)->tainted = 1; mpi_errno = MPIR_Comm_commit(*new_intercomm_ptr); MPIR_ERR_CHECK(mpi_errno); @@ -1134,7 +1133,6 @@ int MPIR_peer_intercomm_create(int context_id, int recvcontext_id, } MPID_THREAD_CS_EXIT(VCI, comm_self->mutex); - (*newcomm)->tainted = 1; mpi_errno = MPIR_Comm_commit(*newcomm); MPIR_ERR_CHECK(mpi_errno); @@ -1260,7 +1258,6 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i * operations within the context id algorithm, since we already * have a valid (almost - see comm_create_hook) communicator. */ - (*new_intracomm_ptr)->tainted = 1; mpi_errno = MPIR_Comm_commit((*new_intracomm_ptr)); MPIR_ERR_CHECK(mpi_errno); @@ -1292,7 +1289,6 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i mpi_errno = create_and_map(comm_ptr, local_high, (*new_intracomm_ptr)); MPIR_ERR_CHECK(mpi_errno); - (*new_intracomm_ptr)->tainted = 1; mpi_errno = MPIR_Comm_commit((*new_intracomm_ptr)); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpi/comm/comm_split.c b/src/mpi/comm/comm_split.c index 7c5519278e4..67795989220 100644 --- a/src/mpi/comm/comm_split.c +++ b/src/mpi/comm/comm_split.c @@ -341,7 +341,7 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm ** } MPID_THREAD_CS_EXIT(VCI, comm_ptr->mutex); - (*newcomm_ptr)->tainted = comm_ptr->tainted; + (*newcomm_ptr)->vcis_enabled = comm_ptr->vcis_enabled; mpi_errno = MPIR_Comm_commit(*newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); } diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index 9a51e8565ee..874a199c2a9 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -289,7 +289,7 @@ int MPII_Comm_init(MPIR_Comm * comm_p) comm_p->bsendbuffer = NULL; comm_p->name[0] = '\0'; comm_p->seq = 0; /* default to 0, to be updated at Comm_commit */ - comm_p->tainted = 0; + comm_p->vcis_enabled = false; memset(comm_p->hints, 0, sizeof(comm_p->hints)); for (int i = 0; i < next_comm_hint_index; i++) { if (MPIR_comm_hint_list[i].key) { @@ -412,9 +412,6 @@ int MPII_Setup_intercomm_localcomm(MPIR_Comm * intercomm_ptr) intercomm_ptr->local_comm = localcomm_ptr; /* sets up the SMP-aware sub-communicators and tables */ - /* This routine maybe used inside MPI_Comm_idup, so we can't synchronize - * seq using blocking collectives, thus mark as tainted. */ - localcomm_ptr->tainted = 1; mpi_errno = MPIR_Comm_commit(localcomm_ptr); MPIR_ERR_CHECK(mpi_errno); @@ -599,13 +596,13 @@ static void propagate_hints_to_subcomm(MPIR_Comm * comm, MPIR_Comm * subcomm) subcomm->hints[MPIR_COMM_HINT_VCI] = comm->hints[MPIR_COMM_HINT_VCI]; } -static void propagate_tainted_to_subcomms(MPIR_Comm * comm) +static void propagate_vcis_enabled(MPIR_Comm * comm) { if (comm->node_comm != NULL) - comm->node_comm->tainted = comm->tainted; + comm->node_comm->vcis_enabled = comm->vcis_enabled; if (comm->node_roots_comm != NULL) - comm->node_roots_comm->tainted = comm->tainted; + comm->node_roots_comm->vcis_enabled = comm->vcis_enabled; } int MPIR_Comm_create_subcomms(MPIR_Comm * comm) @@ -720,7 +717,7 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm) MPIR_ERR_CHECK(mpi_errno); } - propagate_tainted_to_subcomms(comm); + propagate_vcis_enabled(comm); comm->hierarchy_kind = MPIR_COMM_HIERARCHY_KIND__PARENT; @@ -826,7 +823,7 @@ int MPIR_Comm_commit(MPIR_Comm * comm) MPIR_ERR_CHECK(mpi_errno); } - if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && !comm->tainted) { + if (comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && comm->vcis_enabled) { mpi_errno = init_comm_seq(comm); MPIR_ERR_CHECK(mpi_errno); } @@ -1017,7 +1014,7 @@ int MPII_Comm_copy(MPIR_Comm * comm_ptr, int size, MPIR_Info * info, MPIR_Comm * MPII_Comm_set_hints(newcomm_ptr, info, true); } - newcomm_ptr->tainted = comm_ptr->tainted; + newcomm_ptr->vcis_enabled = comm_ptr->vcis_enabled; mpi_errno = MPIR_Comm_commit(newcomm_ptr); MPIR_ERR_CHECK(mpi_errno); @@ -1093,9 +1090,8 @@ int MPII_Comm_copy_data(MPIR_Comm * comm_ptr, MPIR_Info * info, MPIR_Comm ** out newcomm_ptr->attributes = 0; *outcomm_ptr = newcomm_ptr; - /* inherit tainted flag */ - newcomm_ptr->tainted = comm_ptr->tainted; - propagate_tainted_to_subcomms(newcomm_ptr); + newcomm_ptr->vcis_enabled = comm_ptr->vcis_enabled; + propagate_vcis_enabled(newcomm_ptr); fn_fail: MPIR_FUNC_EXIT; diff --git a/src/mpid/ch4/src/ch4_init.c b/src/mpid/ch4/src/ch4_init.c index 365a12b37ad..dd707fc592a 100644 --- a/src/mpid/ch4/src/ch4_init.c +++ b/src/mpid/ch4/src/ch4_init.c @@ -715,6 +715,7 @@ int MPIDI_world_post_init(void) mpi_errno = MPIDI_NM_post_init(); MPIR_ERR_CHECK(mpi_errno); + MPIR_Process.comm_world->vcis_enabled = true; MPIDI_global.is_initialized = 1; fn_exit: diff --git a/src/mpid/ch4/src/ch4_vci.h b/src/mpid/ch4/src/ch4_vci.h index e60adac148d..c8a520961b3 100644 --- a/src/mpid/ch4/src/ch4_vci.h +++ b/src/mpid/ch4/src/ch4_vci.h @@ -47,7 +47,7 @@ /* VCI hashing function (fast path) */ /* For consistent hashing, we may need differentiate between src and dst vci and whether - * it is being called from sender side or receiver side (consdier intercomm). We use an + * it is being called from sender side or receiver side (consider intercomm). We use an * integer flag to encode the information. * * The flag constants are designed as bit fields, so different hashing algorithm can easily @@ -108,7 +108,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr, MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr, int src_rank, int dst_rank, int tag) { - return MPIDI_hash_vci(comm_ptr->seq, flag, comm_ptr, src_rank, dst_rank); + if (!comm_ptr->vcis_enabled) { + return 0; + } else { + return MPIDI_hash_vci(comm_ptr->seq, flag, comm_ptr, src_rank, dst_rank); + } } #elif MPIDI_CH4_VCI_METHOD == MPICH_VCI__TAG @@ -121,7 +125,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr, int src_rank, int dst_rank, int tag) { int vci; - if (!(flag & 0x1)) { + if (!comm_ptr->vcis_enabled) { + return 0; + } else if (!(flag & 0x1)) { /* src */ vci = (tag == MPI_ANY_TAG) ? 0 : ((tag >> 10) & 0x1f); return MPIDI_hash_vci(vci, flag, comm_ptr, src_rank, dst_rank); @@ -158,20 +164,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_map_contextid_rank_tag_to_vci(int context_id, return MPIR_CONTEXT_READ_FIELD(PREFIX, context_id) + rank + tag; } -static bool is_vci_restricted_to_zero(MPIR_Comm * comm) -{ - bool vci_restricted = false; - if (!(comm->comm_kind == MPIR_COMM_KIND__INTRACOMM && !comm->tainted)) { - vci_restricted |= true; - } - if (!MPIDI_global.is_initialized) { - vci_restricted |= true; - } - - return vci_restricted; -} - - /* Return VCI index of a send transmit context. * Used for two purposes: * 1. For the sender side to determine which VCI index of a transmit context @@ -198,9 +190,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_sender_vci(MPIR_Comm * comm, bool use_user_defined_vci = (comm->hints[MPIR_COMM_HINT_SENDER_VCI] != MPIDI_VCI_INVALID); bool use_tag = comm->hints[MPIR_COMM_HINT_NO_ANY_TAG]; - if (is_vci_restricted_to_zero(comm)) { - vci_idx = 0; - } else if (use_user_defined_vci) { + if (use_user_defined_vci) { vci_idx = comm->hints[MPIR_COMM_HINT_SENDER_VCI]; } else { if (use_tag) { @@ -241,9 +231,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_receiver_vci(MPIR_Comm * comm, bool use_tag = comm->hints[MPIR_COMM_HINT_NO_ANY_TAG]; bool use_source = comm->hints[MPIR_COMM_HINT_NO_ANY_SOURCE]; - if (is_vci_restricted_to_zero(comm)) { - vci_idx = 0; - } else if (use_user_defined_vci) { + if (use_user_defined_vci) { vci_idx = comm->hints[MPIR_COMM_HINT_RECEIVER_VCI] % MPIDI_global.n_vcis; } else { /* If mpi_any_tag and mpi_any_source can be used for recv, all messages @@ -279,6 +267,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_receiver_vci(MPIR_Comm * comm, MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr, int src_rank, int dst_rank, int tag) { + if (!comm_ptr->vcis_enabled) { + return 0; + } + int ctxid_in_effect; if (!(flag & 0x2)) { /* called from sender */ From 31893369f0c315097bdef0289b7bbe6085d4d979 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 27 Dec 2024 17:44:04 -0600 Subject: [PATCH 02/25] vci: directly use MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES Rather than initialize per-vci mutexes in ch4 and register with request pools, directly use MPIR-layer request pool mutexes. --- src/include/mpir_objects.h | 6 -- src/include/mpir_request.h | 25 +++----- src/include/mpir_thread.h | 1 + src/mpi/attr/attrutil.c | 4 +- src/mpi/coll/op/op_impl.c | 2 +- src/mpi/comm/commutil.c | 2 +- src/mpi/datatype/typeutil.c | 2 +- src/mpi/errhan/errutil.c | 2 +- src/mpi/group/grouputil.c | 2 +- src/mpi/info/infoutil.c | 2 +- src/mpi/init/mutex.c | 11 ++++ src/mpi/request/mpir_greq.c | 2 +- src/mpi/request/mpir_request.c | 9 +-- src/mpi/rma/winutil.c | 2 +- src/mpi/session/session_util.c | 2 +- src/mpi/stream/stream_impl.c | 2 +- src/mpid/ch4/include/mpidpost.h | 4 +- .../netmod/include/netmod_am_fallback_probe.h | 8 +-- .../netmod/include/netmod_am_fallback_recv.h | 10 +-- .../netmod/include/netmod_am_fallback_send.h | 4 +- src/mpid/ch4/netmod/ofi/ofi_impl.h | 18 +++--- src/mpid/ch4/netmod/ofi/ofi_recv.h | 2 +- src/mpid/ch4/netmod/ofi/ofi_rma.h | 60 +++++++++--------- src/mpid/ch4/netmod/ofi/ofi_spawn.c | 8 +-- src/mpid/ch4/netmod/ucx/ucx_impl.h | 4 +- src/mpid/ch4/netmod/ucx/ucx_spawn.c | 8 +-- src/mpid/ch4/shm/ipc/gpu/gpu_post.c | 4 +- src/mpid/ch4/shm/posix/posix_impl.h | 4 +- src/mpid/ch4/shm/posix/posix_rma.h | 32 +++++----- src/mpid/ch4/shm/src/shm_am_fallback_probe.h | 8 +-- src/mpid/ch4/shm/src/shm_am_fallback_recv.h | 8 +-- src/mpid/ch4/shm/src/shm_am_fallback_send.h | 4 +- src/mpid/ch4/src/ch4_impl.h | 4 +- src/mpid/ch4/src/ch4_init.c | 15 ----- src/mpid/ch4/src/ch4_persist.c | 8 +-- src/mpid/ch4/src/ch4_progress.h | 12 ++-- src/mpid/ch4/src/ch4_recv.h | 8 +-- src/mpid/ch4/src/ch4_types.h | 1 + src/mpid/ch4/src/mpidig_part.c | 8 +-- src/mpid/ch4/src/mpidig_part.h | 16 ++--- src/mpid/ch4/src/mpidig_rma.h | 36 +++++------ src/mpid/ch4/src/mpidig_win.h | 62 +++++++++---------- 42 files changed, 205 insertions(+), 227 deletions(-) diff --git a/src/include/mpir_objects.h b/src/include/mpir_objects.h index 89e7aea8d35..7f57f8c25c1 100644 --- a/src/include/mpir_objects.h +++ b/src/include/mpir_objects.h @@ -526,12 +526,6 @@ typedef struct MPIR_Object_alloc_t { void *direct; /* Pointer to direct block, used * for allocation */ int direct_size; /* Size of direct block */ - void *lock; /* lower-layer may register a lock to use. This is - * mostly for multipool requests. For other objects - * or not per-vci thread granularity, this lock - * pointer is ignored. Ref. mpir_request.h. - * NOTE: it is `void *` because mutex type not defined yet. - */ /* The following padding is to avoid cache line sharing with other MPIR_Object_alloc_t. This * padding is particularly important for an array of per-vci MPI_Request pools. */ char pad[MPL_CACHELINE_SIZE]; diff --git a/src/include/mpir_request.h b/src/include/mpir_request.h index e1f68804df5..f9af8b44952 100644 --- a/src/include/mpir_request.h +++ b/src/include/mpir_request.h @@ -356,15 +356,6 @@ extern MPIR_Request MPIR_Request_direct[MPIR_REQUEST_PREALLOC]; void MPII_init_request(void); -/* To get the benefit of multiple request pool, device layer need register their per-vci lock - * with each pool that they are going to use, typically a 1-1 vci-pool mapping. - * NOTE: currently, only per-vci thread granularity utilizes multiple request pool. - */ -static inline void MPIR_Request_register_pool_lock(int pool, MPID_Thread_mutex_t * lock) -{ - MPIR_Request_mem[pool].lock = lock; -} - static inline int MPIR_Request_is_persistent(MPIR_Request * req_ptr) { return (req_ptr->kind == MPIR_REQUEST_KIND__PREQUEST_SEND || @@ -429,7 +420,7 @@ static inline MPIR_Request *MPIR_Request_create_from_pool(MPIR_Request_kind_t ki MPIR_Request *req; #ifdef MPICH_DEBUG_MUTEX - MPID_THREAD_ASSERT_IN_CS(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[pool].lock)); + MPID_THREAD_ASSERT_IN_CS(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[pool]); #endif int max_blocks = (pool == 0) ? REQUEST_NUM_BLOCKS0 : REQUEST_NUM_BLOCKS; req = MPIR_Handle_obj_alloc_unsafe(&MPIR_Request_mem[pool], max_blocks, REQUEST_NUM_INDICES); @@ -504,9 +495,9 @@ static inline MPIR_Request *MPIR_Request_create_from_pool_safe(MPIR_Request_kind { MPIR_Request *req; - MPID_THREAD_CS_ENTER(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[pool].lock)); + MPID_THREAD_CS_ENTER(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[pool]); req = MPIR_Request_create_from_pool(kind, pool, ref_count); - MPID_THREAD_CS_EXIT(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[pool].lock)); + MPID_THREAD_CS_EXIT(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[pool]); return req; } @@ -514,9 +505,9 @@ static inline MPIR_Request *MPIR_Request_create_from_pool_safe(MPIR_Request_kind static inline MPIR_Request *MPIR_Request_create(MPIR_Request_kind_t kind) { MPIR_Request *req; - MPID_THREAD_CS_ENTER(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[0].lock)); + MPID_THREAD_CS_ENTER(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[0]); req = MPIR_Request_create_from_pool(kind, 0, 1); - MPID_THREAD_CS_EXIT(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[0].lock)); + MPID_THREAD_CS_EXIT(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[0]); return req; } @@ -556,10 +547,10 @@ static inline void MPIR_Request_free_with_safety(MPIR_Request * req, } if (need_safety) { - MPID_THREAD_CS_ENTER(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[pool].lock)); + MPID_THREAD_CS_ENTER(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[pool]); } #ifdef MPICH_DEBUG_MUTEX - MPID_THREAD_ASSERT_IN_CS(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[pool].lock)); + MPID_THREAD_ASSERT_IN_CS(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[pool]); #endif /* inform the device that we are decrementing the ref-count on * this request */ @@ -638,7 +629,7 @@ static inline void MPIR_Request_free_with_safety(MPIR_Request * req, MPIR_Handle_obj_free_unsafe(&MPIR_Request_mem[pool], req, /* not info */ FALSE); } if (need_safety) { - MPID_THREAD_CS_EXIT(VCI, (*(MPID_Thread_mutex_t *) MPIR_Request_mem[pool].lock)); + MPID_THREAD_CS_EXIT(VCI, MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[pool]); } } diff --git a/src/include/mpir_thread.h b/src/include/mpir_thread.h index 5be56d42f1f..49511342660 100644 --- a/src/include/mpir_thread.h +++ b/src/include/mpir_thread.h @@ -84,6 +84,7 @@ extern MPID_Thread_mutex_t MPIR_THREAD_VCI_HANDLE_MUTEX; extern MPID_Thread_mutex_t MPIR_THREAD_VCI_CTX_MUTEX; extern MPID_Thread_mutex_t MPIR_THREAD_VCI_PMI_MUTEX; extern MPID_Thread_mutex_t MPIR_THREAD_VCI_BSEND_MUTEX; +extern MPID_Thread_mutex_t MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[]; #endif /* MPICH_THREAD_GRANULARITY */ #endif /* MPICH_IS_THREADED */ diff --git a/src/mpi/attr/attrutil.c b/src/mpi/attr/attrutil.c index 0a97095a117..77b99d12d4a 100644 --- a/src/mpi/attr/attrutil.c +++ b/src/mpi/attr/attrutil.c @@ -25,7 +25,7 @@ MPIR_Object_alloc_t MPII_Keyval_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_KEYVAL, sizeof(MPII_Keyval), MPII_Keyval_direct, MPID_KEYVAL_PREALLOC, - NULL, {0} + {0} }; /* Preallocated keyval objects */ @@ -35,7 +35,7 @@ MPIR_Object_alloc_t MPID_Attr_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_ATTR, sizeof(MPIR_Attribute), MPID_Attr_direct, MPIR_ATTR_PREALLOC, - NULL, {0} + {0} }; /* Provides a way to trap all attribute allocations when debugging leaks. */ diff --git a/src/mpi/coll/op/op_impl.c b/src/mpi/coll/op/op_impl.c index 52330cdfcc0..ed33eafa0d8 100644 --- a/src/mpi/coll/op/op_impl.c +++ b/src/mpi/coll/op/op_impl.c @@ -13,7 +13,7 @@ MPIR_Object_alloc_t MPIR_Op_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_OP, sizeof(MPIR_Op), MPIR_Op_direct, MPIR_OP_PREALLOC, - NULL, {0} + {0} }; int MPIR_Op_create_impl(MPI_User_function * user_fn, int commute, MPIR_Op ** p_op_ptr) diff --git a/src/mpi/comm/commutil.c b/src/mpi/comm/commutil.c index 874a199c2a9..046c3842d16 100644 --- a/src/mpi/comm/commutil.c +++ b/src/mpi/comm/commutil.c @@ -22,7 +22,7 @@ MPIR_Object_alloc_t MPIR_Comm_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_COMM, sizeof(MPIR_Comm), MPIR_Comm_direct, MPIR_COMM_PREALLOC, - NULL, {0} + {0} }; /* Communicator creation functions */ diff --git a/src/mpi/datatype/typeutil.c b/src/mpi/datatype/typeutil.c index d6089ea67ab..8b7d625358f 100644 --- a/src/mpi/datatype/typeutil.c +++ b/src/mpi/datatype/typeutil.c @@ -16,7 +16,7 @@ MPIR_Datatype MPIR_Datatype_direct[MPIR_DATATYPE_PREALLOC]; MPIR_Object_alloc_t MPIR_Datatype_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_DATATYPE, sizeof(MPIR_Datatype), MPIR_Datatype_direct, MPIR_DATATYPE_PREALLOC, - NULL, {0} + {0} }; MPI_Datatype MPIR_Datatype_index_to_predefined[MPIR_DATATYPE_N_PREDEFINED]; diff --git a/src/mpi/errhan/errutil.c b/src/mpi/errhan/errutil.c index f6ff0332e1d..ae209cb3b74 100644 --- a/src/mpi/errhan/errutil.c +++ b/src/mpi/errhan/errutil.c @@ -120,7 +120,7 @@ MPIR_Object_alloc_t MPIR_Errhandler_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_ERRHANDLER sizeof(MPIR_Errhandler), MPIR_Errhandler_direct, MPIR_ERRHANDLER_PREALLOC, - NULL, {0} + {0} }; static void init_builtins(void) diff --git a/src/mpi/group/grouputil.c b/src/mpi/group/grouputil.c index ac777e50305..27b1184c124 100644 --- a/src/mpi/group/grouputil.c +++ b/src/mpi/group/grouputil.c @@ -13,7 +13,7 @@ MPIR_Group MPIR_Group_direct[MPIR_GROUP_PREALLOC]; MPIR_Object_alloc_t MPIR_Group_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_GROUP, sizeof(MPIR_Group), MPIR_Group_direct, MPIR_GROUP_PREALLOC, - NULL, {0} + {0} }; MPIR_Group *const MPIR_Group_empty = &MPIR_Group_builtin[0]; diff --git a/src/mpi/info/infoutil.c b/src/mpi/info/infoutil.c index bbce8f46846..2b06e568f2c 100644 --- a/src/mpi/info/infoutil.c +++ b/src/mpi/info/infoutil.c @@ -17,7 +17,7 @@ MPIR_Info MPIR_Info_direct[MPIR_INFO_PREALLOC]; MPIR_Object_alloc_t MPIR_Info_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_INFO, sizeof(MPIR_Info), MPIR_Info_direct, MPIR_INFO_PREALLOC, - NULL, {0} + {0} }; /* Free an info structure. In the multithreaded case, this routine diff --git a/src/mpi/init/mutex.c b/src/mpi/init/mutex.c index ec1131c77ae..559145b467e 100644 --- a/src/mpi/init/mutex.c +++ b/src/mpi/init/mutex.c @@ -18,6 +18,7 @@ MPID_Thread_mutex_t MPIR_THREAD_VCI_HANDLE_MUTEX; MPID_Thread_mutex_t MPIR_THREAD_VCI_CTX_MUTEX; MPID_Thread_mutex_t MPIR_THREAD_VCI_PMI_MUTEX; MPID_Thread_mutex_t MPIR_THREAD_VCI_BSEND_MUTEX; +MPID_Thread_mutex_t MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[MPIR_REQUEST_NUM_POOLS]; #endif /* MPICH_THREAD_GRANULARITY */ /* called the first thing in init so it can enter critical section immediately */ @@ -42,6 +43,11 @@ void MPII_thread_mutex_create(void) MPID_Thread_mutex_create(&MPIR_THREAD_VCI_BSEND_MUTEX, &err); MPIR_Assert(err == 0); + for (int i = 0; i < MPIR_REQUEST_NUM_POOLS; i++) { + MPID_Thread_mutex_create(&MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[i], &err); + MPIR_Assert(err == 0); + } + #elif MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__LOCKFREE /* Updates to shared data and access to shared services is handled * without locks where ever possible. */ @@ -79,6 +85,11 @@ void MPII_thread_mutex_destroy(void) MPID_Thread_mutex_destroy(&MPIR_THREAD_VCI_BSEND_MUTEX, &err); MPIR_Assert(err == 0); + for (int i = 0; i < MPIR_REQUEST_NUM_POOLS; i++) { + MPID_Thread_mutex_destroy(&MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[i], &err); + MPIR_Assert(err == 0); + } + #elif MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__LOCKFREE /* Updates to shared data and access to shared services is handled * without locks where ever possible. */ diff --git a/src/mpi/request/mpir_greq.c b/src/mpi/request/mpir_greq.c index 5e88705c1fd..8e07a76353b 100644 --- a/src/mpi/request/mpir_greq.c +++ b/src/mpi/request/mpir_greq.c @@ -12,7 +12,7 @@ MPIR_Object_alloc_t MPIR_Grequest_class_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_GREQ_C sizeof(MPIR_Grequest_class), MPIR_Grequest_class_direct, MPIR_GREQ_CLASS_PREALLOC, - NULL, {0} + {0} }; /* We jump through some minor hoops to manage the list of classes ourselves and diff --git a/src/mpi/request/mpir_request.c b/src/mpi/request/mpir_request.c index 4c6d836e250..a04682496c7 100644 --- a/src/mpi/request/mpir_request.c +++ b/src/mpi/request/mpir_request.c @@ -31,15 +31,10 @@ static void init_builtin_request(MPIR_Request * req, int handle, MPIR_Request_ki void MPII_init_request(void) { - MPID_Thread_mutex_t *lock_ptr = NULL; -#if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI - lock_ptr = &MPIR_THREAD_VCI_HANDLE_MUTEX; -#endif - /* *INDENT-OFF* */ - MPIR_Request_mem[0] = (MPIR_Object_alloc_t) { 0, 0, 0, 0, 0, 0, 0, MPIR_REQUEST, sizeof(MPIR_Request), MPIR_Request_direct, MPIR_REQUEST_PREALLOC, lock_ptr, {0}}; + MPIR_Request_mem[0] = (MPIR_Object_alloc_t) { 0, 0, 0, 0, 0, 0, 0, MPIR_REQUEST, sizeof(MPIR_Request), MPIR_Request_direct, MPIR_REQUEST_PREALLOC, {0}}; for (int i = 1; i < MPIR_REQUEST_NUM_POOLS; i++) { - MPIR_Request_mem[i] = (MPIR_Object_alloc_t) { 0, 0, 0, 0, 0, 0, 0, MPIR_REQUEST, sizeof(MPIR_Request), NULL, 0, lock_ptr, {0}}; + MPIR_Request_mem[i] = (MPIR_Object_alloc_t) { 0, 0, 0, 0, 0, 0, 0, MPIR_REQUEST, sizeof(MPIR_Request), NULL, 0, {0}}; } /* *INDENT-ON* */ diff --git a/src/mpi/rma/winutil.c b/src/mpi/rma/winutil.c index d3bc52393a0..c6cfacaf2f3 100644 --- a/src/mpi/rma/winutil.c +++ b/src/mpi/rma/winutil.c @@ -14,5 +14,5 @@ MPIR_Win MPIR_Win_direct[MPIR_WIN_PREALLOC]; MPIR_Object_alloc_t MPIR_Win_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_WIN, sizeof(MPIR_Win), MPIR_Win_direct, MPIR_WIN_PREALLOC, - NULL, {0} + {0} }; diff --git a/src/mpi/session/session_util.c b/src/mpi/session/session_util.c index fb2680b84e5..5d6226f56fe 100644 --- a/src/mpi/session/session_util.c +++ b/src/mpi/session/session_util.c @@ -16,7 +16,7 @@ MPIR_Session MPIR_Session_direct[MPIR_SESSION_PREALLOC]; MPIR_Object_alloc_t MPIR_Session_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_SESSION, sizeof(MPIR_Session), MPIR_Session_direct, MPIR_SESSION_PREALLOC, - NULL, {0} + {0} }; int MPIR_Session_create(MPIR_Session ** p_session_ptr, int thread_level) diff --git a/src/mpi/stream/stream_impl.c b/src/mpi/stream/stream_impl.c index 6c71de24046..65d62306f9a 100644 --- a/src/mpi/stream/stream_impl.c +++ b/src/mpi/stream/stream_impl.c @@ -60,7 +60,7 @@ MPIR_Stream MPIR_Stream_direct[MPIR_STREAM_PREALLOC]; MPIR_Object_alloc_t MPIR_Stream_mem = { 0, 0, 0, 0, 0, 0, 0, MPIR_STREAM, sizeof(MPIR_Stream), MPIR_Stream_direct, MPIR_STREAM_PREALLOC, - NULL, {0} + {0} }; /* utilities for managing streams in a communicator */ diff --git a/src/mpid/ch4/include/mpidpost.h b/src/mpid/ch4/include/mpidpost.h index a1ae19f21ce..0776e979d9e 100644 --- a/src/mpid/ch4/include/mpidpost.h +++ b/src/mpid/ch4/include/mpidpost.h @@ -14,9 +14,9 @@ MPL_STATIC_INLINE_PREFIX MPIR_Request *MPID_Request_create_from_comm(MPIR_Reques { MPIR_Request *req; int vci = MPIDI_get_comm_vci(comm); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); req = MPIR_Request_create_from_pool(kind, vci, 1); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return req; } diff --git a/src/mpid/ch4/netmod/include/netmod_am_fallback_probe.h b/src/mpid/ch4/netmod/include/netmod_am_fallback_probe.h index 9a5b100f076..55d187e4bb5 100644 --- a/src/mpid/ch4/netmod/include/netmod_am_fallback_probe.h +++ b/src/mpid/ch4/netmod/include/netmod_am_fallback_probe.h @@ -16,10 +16,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_improbe(int source, int mpi_errno = MPI_SUCCESS; int context_offset = MPIR_PT2PT_ATTR_CONTEXT_OFFSET(attr); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); mpi_errno = MPIDIG_mpi_improbe(source, tag, comm, context_offset, 0, flag, false /* is_local */ , message, status); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); return mpi_errno; } @@ -33,10 +33,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_iprobe(int source, int mpi_errno = MPI_SUCCESS; int context_offset = MPIR_PT2PT_ATTR_CONTEXT_OFFSET(attr); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); mpi_errno = MPIDIG_mpi_iprobe(source, tag, comm, context_offset, 0, flag, false /* is_local */ , status); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); return mpi_errno; } diff --git a/src/mpid/ch4/netmod/include/netmod_am_fallback_recv.h b/src/mpid/ch4/netmod/include/netmod_am_fallback_recv.h index 8cef735f35c..049c7a5cd42 100644 --- a/src/mpid/ch4/netmod/include/netmod_am_fallback_recv.h +++ b/src/mpid/ch4/netmod/include/netmod_am_fallback_recv.h @@ -15,9 +15,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_imrecv(void *buf, #if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI int vci = MPIDI_Request_get_vci(message); #endif - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_mpi_imrecv(buf, count, datatype, message); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return mpi_errno; } @@ -43,17 +43,17 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_irecv(void *buf, need_cs = (rank != MPI_ANY_SOURCE); #endif if (need_cs) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); } else { #ifdef MPICH_DEBUG_MUTEX - MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(0)); #endif } mpi_errno = MPIDIG_mpi_irecv(buf, count, datatype, rank, tag, comm, context_offset, 0, request, 0, partner); if (need_cs) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); } return mpi_errno; diff --git a/src/mpid/ch4/netmod/include/netmod_am_fallback_send.h b/src/mpid/ch4/netmod/include/netmod_am_fallback_send.h index 7c22e547e5c..4324f7d2305 100644 --- a/src/mpid/ch4/netmod/include/netmod_am_fallback_send.h +++ b/src/mpid/ch4/netmod/include/netmod_am_fallback_send.h @@ -24,10 +24,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_isend(const void *buf, vci_src = 0; vci_dst = 0; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci_src).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci_src)); mpi_errno = MPIDIG_mpi_isend(buf, count, datatype, rank, tag, comm, context_offset, addr, vci_src, vci_dst, request, syncflag, errflag); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci_src).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_src)); return mpi_errno; } diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index a3ad267eee4..f2a1088624c 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -139,9 +139,9 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret); #define MPIDI_OFI_VCI_PROGRESS(vci_) \ do { \ int made_progress = 0; \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci_)); \ mpi_errno = MPIDI_NM_progress(vci_, &made_progress); \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_)); \ MPIR_ERR_CHECK(mpi_errno); \ MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX); \ } while (0) @@ -149,23 +149,23 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret); #define MPIDI_OFI_VCI_PROGRESS_WHILE(vci_, cond) \ do { \ int made_progress = 0; \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci_)); \ while (cond) { \ mpi_errno = MPIDI_NM_progress(vci_, &made_progress); \ if (mpi_errno) { \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_)); \ MPIR_ERR_POP(mpi_errno); \ } \ MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX); \ } \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_)); \ } while (0) #define MPIDI_OFI_VCI_CALL(FUNC,vci_,STR) \ do { \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci_)); \ ssize_t _ret = FUNC; \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_)); \ MPIDI_OFI_ERR(_ret<0, \ mpi_errno, \ MPI_ERR_OTHER, \ @@ -178,14 +178,14 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret); #define MPIDI_OFI_THREAD_CS_ENTER_VCI_OPTIONAL(vci_) \ do { \ if (!MPIDI_VCI_IS_EXPLICIT(vci_) && MPIDI_CH4_MT_MODEL != MPIDI_CH4_MT_LOCKLESS) { \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci_)); \ } \ } while (0) #define MPIDI_OFI_THREAD_CS_EXIT_VCI_OPTIONAL(vci_) \ do { \ if (!MPIDI_VCI_IS_EXPLICIT(vci_) && MPIDI_CH4_MT_MODEL != MPIDI_CH4_MT_LOCKLESS) { \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci_).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_)); \ } \ } while (0) diff --git a/src/mpid/ch4/netmod/ofi/ofi_recv.h b/src/mpid/ch4/netmod/ofi/ofi_recv.h index bd5d3482fb8..186ef692161 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_recv.h +++ b/src/mpid/ch4/netmod/ofi/ofi_recv.h @@ -413,7 +413,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_irecv(void *buf, MPIDI_OFI_THREAD_CS_ENTER_VCI_OPTIONAL(vci_dst); } else { #ifdef MPICH_DEBUG_MUTEX - MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI(vci_dst).lock); + MPID_THREAD_ASSERT_IN_CS(VCI, MPIDI_VCI_LOCK(vci_dst)); #endif } if (!MPIDI_OFI_ENABLE_TAGGED) { diff --git a/src/mpid/ch4/netmod/ofi/ofi_rma.h b/src/mpid/ch4/netmod/ofi/ofi_rma.h index c05319f4fd8..3ba70d68eae 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_rma.h +++ b/src/mpid/ch4/netmod/ofi/ofi_rma.h @@ -243,7 +243,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_put(const void *origin_addr, * very slow */ if (origin_contig && target_contig && (origin_bytes <= MPIDI_OFI_global.max_buffered_write && !MPL_gpu_attr_is_dev(&attr))) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_cntr_incr(win); MPIDI_OFI_CALL_RETRY(fi_inject_write(MPIDI_OFI_WIN(win).ep, MPIR_get_contig_ptr(origin_addr, origin_true_lb), @@ -251,13 +251,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_put(const void *origin_addr, MPIDI_OFI_av_to_phys(addr, nic_target, vci_target), target_mr.addr + target_true_lb, target_mr.mr_key), vci, rdma_inject_write); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto null_op_exit; } /* large contiguous messages */ if (origin_contig && target_contig) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); if (sigreq) { MPIDI_OFI_REQUEST_CREATE(*sigreq, MPIR_REQUEST_KIND__RMA, 0); flags = FI_COMPLETION | FI_DELIVERY_COMPLETE; @@ -286,7 +286,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_put(const void *origin_addr, MPIDI_OFI_CALL_RETRY(fi_writemsg(MPIDI_OFI_WIN(win).ep, &msg, flags), vci, rdma_write); /* Complete signal request to inform completion to user. */ MPIDI_OFI_sigreq_complete(sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } @@ -297,22 +297,22 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_put(const void *origin_addr, if (origin_density >= MPIR_CVAR_CH4_IOV_DENSITY_MIN && target_density >= MPIR_CVAR_CH4_IOV_DENSITY_MIN) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_OFI_nopack_putget(origin_addr, origin_count, origin_datatype, target_rank, target_count, target_datatype, target_mr, win, addr, MPIDI_OFI_PUT, sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } if (origin_density < MPIR_CVAR_CH4_IOV_DENSITY_MIN && target_density >= MPIR_CVAR_CH4_IOV_DENSITY_MIN) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_OFI_pack_put(origin_addr, origin_count, origin_datatype, target_rank, target_count, target_datatype, target_mr, win, addr, sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } @@ -432,7 +432,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_get(void *origin_addr, /* contiguous messages */ if (origin_contig && target_contig) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); if (sigreq) { MPIDI_OFI_REQUEST_CREATE(*sigreq, MPIR_REQUEST_KIND__RMA, 0); flags = FI_COMPLETION | FI_DELIVERY_COMPLETE; @@ -463,7 +463,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_get(void *origin_addr, MPIDI_OFI_CALL_RETRY(fi_readmsg(MPIDI_OFI_WIN(win).ep, &msg, flags), vci, rdma_write); /* Complete signal request to inform completion to user. */ MPIDI_OFI_sigreq_complete(sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } @@ -474,22 +474,22 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_get(void *origin_addr, if (origin_density >= MPIR_CVAR_CH4_IOV_DENSITY_MIN && target_density >= MPIR_CVAR_CH4_IOV_DENSITY_MIN) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_OFI_nopack_putget(origin_addr, origin_count, origin_datatype, target_rank, target_count, target_datatype, target_mr, win, addr, MPIDI_OFI_GET, sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } if (origin_density < MPIR_CVAR_CH4_IOV_DENSITY_MIN && target_density >= MPIR_CVAR_CH4_IOV_DENSITY_MIN) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_OFI_pack_get(origin_addr, origin_count, origin_datatype, target_rank, target_count, target_datatype, target_mr, win, addr, sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } @@ -687,12 +687,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_compare_and_swap(const void *origin_ad msg.context = NULL; msg.data = 0; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_cntr_incr(win); MPIDI_OFI_CALL_RETRY(fi_compare_atomicmsg(MPIDI_OFI_WIN(win).ep, &msg, &comparev, compare_desc, 1, &resultv, result_desc, 1, 0), vci, atomicto); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); fn_exit: MPIR_FUNC_EXIT; return mpi_errno; @@ -701,9 +701,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_compare_and_swap(const void *origin_ad am_fallback: /* Wait for OFI case to complete for atomicity. * For now, there is no FI flag to track atomic only ops, we use RMA level cntr. */ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_do_progress(win, vci); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return MPIDIG_mpi_compare_and_swap(origin_addr, compare_addr, result_addr, datatype, target_rank, target_disp, win); } @@ -779,7 +779,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_accumulate(const void *origin_addr, /* Ensure completion of outstanding AMs for atomicity. */ MPIDIG_wait_am_acc(win, target_rank); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); uint64_t flags; if (sigreq) { MPIDI_OFI_REQUEST_CREATE(*sigreq, MPIR_REQUEST_KIND__RMA, 0); @@ -816,16 +816,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_accumulate(const void *origin_addr, MPIDI_OFI_CALL_RETRY(fi_atomicmsg(MPIDI_OFI_WIN(win).ep, &msg, flags), vci, rdma_atomicto); /* Complete signal request to inform completion to user. */ MPIDI_OFI_sigreq_complete(sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } am_fallback: /* Wait for OFI acc to complete for atomicity. * For now, there is no FI flag to track atomic only ops, we use RMA level cntr. */ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_do_progress(win, vci); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); if (sigreq) mpi_errno = MPIDIG_mpi_raccumulate(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, op, win, @@ -920,7 +920,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_get_accumulate(const void *origin_addr /* Ensure completion of outstanding AMs for atomicity. */ MPIDIG_wait_am_acc(win, target_rank); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); uint64_t flags; if (sigreq) { MPIDI_OFI_REQUEST_CREATE(*sigreq, MPIR_REQUEST_KIND__RMA, 0); @@ -963,16 +963,16 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_get_accumulate(const void *origin_addr result_desc, 1, flags), vci, rdma_readfrom); /* Complete signal request to inform completion to user. */ MPIDI_OFI_sigreq_complete(sigreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); goto fn_exit; } am_fallback: /* Wait for OFI getacc to complete for atomicity. * For now, there is no FI flag to track atomic only ops, we use RMA level cntr. */ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_do_progress(win, vci); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); if (sigreq) mpi_errno = MPIDIG_mpi_rget_accumulate(origin_addr, origin_count, origin_datatype, result_addr, @@ -1190,11 +1190,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_fetch_and_op(const void *origin_addr, msg.context = NULL; msg.data = 0; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_cntr_incr(win); MPIDI_OFI_CALL_RETRY(fi_fetch_atomicmsg(MPIDI_OFI_WIN(win).ep, &msg, &resultv, result_desc, 1, 0), vci, rdma_readfrom); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); fn_exit: MPIR_FUNC_EXIT; @@ -1204,9 +1204,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_fetch_and_op(const void *origin_addr, am_fallback: /* Wait for OFI fetch_and_op to complete for atomicity. * For now, there is no FI flag to track atomic only ops, we use RMA level cntr. */ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_do_progress(win, vci); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return MPIDIG_mpi_fetch_and_op(origin_addr, result_addr, datatype, target_rank, target_disp, op, win); } diff --git a/src/mpid/ch4/netmod/ofi/ofi_spawn.c b/src/mpid/ch4/netmod/ofi/ofi_spawn.c index 20adc54b3b1..f4039fdeedd 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_spawn.c +++ b/src/mpid/ch4/netmod/ofi/ofi_spawn.c @@ -20,7 +20,7 @@ int MPIDI_OFI_dynamic_send(uint64_t remote_gpid, int tag, const void *buf, int s int lpid = MPIDIU_GPID_GET_LPID(remote_gpid); fi_addr_t remote_addr = MPIDI_OFI_av_to_phys(&MPIDIU_get_av(avtid, lpid), nic, vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_dynamic_process_request_t req; req.done = 0; @@ -69,7 +69,7 @@ int MPIDI_OFI_dynamic_send(uint64_t remote_gpid, int tag, const void *buf, int s } fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return mpi_errno; fn_fail: goto fn_exit; @@ -91,7 +91,7 @@ int MPIDI_OFI_dynamic_recv(int tag, void *buf, int size, int timeout) match_bits = MPIDI_OFI_init_recvtag(&mask_bits, 0, MPI_ANY_SOURCE, tag); match_bits |= MPIDI_OFI_DYNPROC_SEND; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPL_time_t time_start, time_now; double time_gap; @@ -126,7 +126,7 @@ int MPIDI_OFI_dynamic_recv(int tag, void *buf, int size, int timeout) } fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return mpi_errno; fn_fail: goto fn_exit; diff --git a/src/mpid/ch4/netmod/ucx/ucx_impl.h b/src/mpid/ch4/netmod/ucx/ucx_impl.h index d204383ac5b..c07a54ab1bc 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_impl.h +++ b/src/mpid/ch4/netmod/ucx/ucx_impl.h @@ -30,14 +30,14 @@ #define MPIDI_UCX_THREAD_CS_ENTER_VCI(vci) \ do { \ if (!MPIDI_VCI_IS_EXPLICIT(vci)) { \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); \ } \ } while (0) #define MPIDI_UCX_THREAD_CS_EXIT_VCI(vci) \ do { \ if (!MPIDI_VCI_IS_EXPLICIT(vci)) { \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); \ } \ } while (0) diff --git a/src/mpid/ch4/netmod/ucx/ucx_spawn.c b/src/mpid/ch4/netmod/ucx/ucx_spawn.c index 05e888d5639..a2d15f4af83 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_spawn.c +++ b/src/mpid/ch4/netmod/ucx/ucx_spawn.c @@ -27,7 +27,7 @@ int MPIDI_UCX_dynamic_send(uint64_t remote_gpid, int tag, const void *buf, int s uint64_t ucx_tag = MPIDI_UCX_DYNPROC_MASK + tag; int vci = 0; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); int avtid = MPIDIU_GPID_GET_AVTID(remote_gpid); int lpid = MPIDIU_GPID_GET_LPID(remote_gpid); @@ -68,7 +68,7 @@ int MPIDI_UCX_dynamic_send(uint64_t remote_gpid, int tag, const void *buf, int s } fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return mpi_errno; } @@ -80,7 +80,7 @@ int MPIDI_UCX_dynamic_recv(int tag, void *buf, int size, int timeout) uint64_t tag_mask = 0xffffffffffffffff; int vci = 0; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); bool done = false; ucp_request_param_t param = { @@ -117,7 +117,7 @@ int MPIDI_UCX_dynamic_recv(int tag, void *buf, int size, int timeout) } fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return mpi_errno; } diff --git a/src/mpid/ch4/shm/ipc/gpu/gpu_post.c b/src/mpid/ch4/shm/ipc/gpu/gpu_post.c index 55bfe3ca03e..8d8a7fd18fe 100644 --- a/src/mpid/ch4/shm/ipc/gpu/gpu_post.c +++ b/src/mpid/ch4/shm/ipc/gpu/gpu_post.c @@ -643,12 +643,12 @@ static int gpu_ipc_async_poll(MPIX_Async_thing thing) if (is_done) { int vci = MPIDIG_REQUEST(p->rreq, req->local_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); err = MPIDI_GPU_ipc_handle_unmap(p->src_buf, p->gpu_handle, 0); MPIR_Assertp(err == MPI_SUCCESS); err = MPIDI_IPC_complete(p->rreq, MPIDI_IPCI_TYPE__GPU); MPIR_Assertp(err == MPI_SUCCESS); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPL_free(p); return MPIX_ASYNC_DONE; diff --git a/src/mpid/ch4/shm/posix/posix_impl.h b/src/mpid/ch4/shm/posix/posix_impl.h index 2d84d60350d..ad440781ee2 100644 --- a/src/mpid/ch4/shm/posix/posix_impl.h +++ b/src/mpid/ch4/shm/posix/posix_impl.h @@ -17,14 +17,14 @@ #define MPIDI_POSIX_THREAD_CS_ENTER_VCI(vci) \ do { \ if (!MPIDI_VCI_IS_EXPLICIT(vci)) { \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); \ } \ } while (0) #define MPIDI_POSIX_THREAD_CS_EXIT_VCI(vci) \ do { \ if (!MPIDI_VCI_IS_EXPLICIT(vci)) { \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); \ } \ } while (0) diff --git a/src/mpid/ch4/shm/posix/posix_rma.h b/src/mpid/ch4/shm/posix/posix_rma.h index 360bf3e0860..db7ea1f9056 100644 --- a/src/mpid/ch4/shm/posix/posix_rma.h +++ b/src/mpid/ch4/shm/posix/posix_rma.h @@ -336,7 +336,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_do_get_accumulate(const void *origin_ad if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_LOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); } mpi_errno = MPIR_Localcopy((char *) base + disp_unit * target_disp, target_count, @@ -353,7 +353,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_do_get_accumulate(const void *origin_ad if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_UNLOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } fn_exit: @@ -404,7 +404,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_do_accumulate(const void *origin_addr, if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_LOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); } mpi_errno = MPIDI_POSIX_compute_accumulate((void *) origin_addr, origin_count, origin_datatype, @@ -413,7 +413,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_do_accumulate(const void *origin_addr, if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_UNLOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } fn_exit: @@ -445,10 +445,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_put(const void *origin_addr, #if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI int vci = MPIDI_WIN(win, am_vci); #endif - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_POSIX_do_put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, winattr); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; @@ -477,10 +477,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_get(void *origin_addr, #if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI int vci = MPIDI_WIN(win, am_vci); #endif - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_POSIX_do_get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, winattr); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; @@ -511,14 +511,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_rput(const void *origin_addr, } int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_POSIX_do_put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, winattr); if (mpi_errno == MPI_SUCCESS) { *request = MPIR_Request_create_complete(MPIR_REQUEST_KIND__RMA); } - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); fn_exit: MPIR_FUNC_EXIT; @@ -575,7 +575,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_compare_and_swap(const void *origin if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_LOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_Typerep_copy(result_addr, target_addr, dtype_sz, MPIR_TYPEREP_FLAG_NONE); @@ -586,7 +586,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_compare_and_swap(const void *origin if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_UNLOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } fn_exit: @@ -726,7 +726,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_fetch_and_op(const void *origin_add if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_LOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_Typerep_copy(result_addr, target_addr, dtype_sz, MPIR_TYPEREP_FLAG_NONE); @@ -745,7 +745,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_fetch_and_op(const void *origin_add if (winattr & MPIDI_WINATTR_SHM_ALLOCATED) { MPIDI_POSIX_RMA_MUTEX_UNLOCK(posix_win->shm_mutex_ptr); } else { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } fn_exit: @@ -779,14 +779,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_mpi_rget(void *origin_addr, } int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_POSIX_do_get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, winattr); if (mpi_errno == MPI_SUCCESS) { *request = MPIR_Request_create_complete(MPIR_REQUEST_KIND__RMA); } - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); fn_exit: MPIR_FUNC_EXIT; diff --git a/src/mpid/ch4/shm/src/shm_am_fallback_probe.h b/src/mpid/ch4/shm/src/shm_am_fallback_probe.h index c977ab9aa0d..67b8b20d097 100644 --- a/src/mpid/ch4/shm/src/shm_am_fallback_probe.h +++ b/src/mpid/ch4/shm/src/shm_am_fallback_probe.h @@ -15,10 +15,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_improbe(int source, { int mpi_errno = MPI_SUCCESS; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); mpi_errno = MPIDIG_mpi_improbe(source, tag, comm, context_offset, 0, flag, true /* is_local */ , message, status); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); return mpi_errno; } @@ -31,10 +31,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_iprobe(int source, { int mpi_errno = MPI_SUCCESS; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); mpi_errno = MPIDIG_mpi_iprobe(source, tag, comm, context_offset, 0, flag, true /* is_local */ , status); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); return mpi_errno; } diff --git a/src/mpid/ch4/shm/src/shm_am_fallback_recv.h b/src/mpid/ch4/shm/src/shm_am_fallback_recv.h index 178080e0903..01b37337e0c 100644 --- a/src/mpid/ch4/shm/src/shm_am_fallback_recv.h +++ b/src/mpid/ch4/shm/src/shm_am_fallback_recv.h @@ -15,9 +15,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_imrecv(void *buf, #if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI int vci = MPIDI_Request_get_vci(message); #endif - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_mpi_imrecv(buf, count, datatype, message); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); return mpi_errno; } @@ -40,13 +40,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_irecv(void *buf, } else { need_lock = true; vci = 0; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); } mpi_errno = MPIDIG_mpi_irecv(buf, count, datatype, rank, tag, comm, context_offset, vci, request, 1, NULL); if (need_lock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } return mpi_errno; diff --git a/src/mpid/ch4/shm/src/shm_am_fallback_send.h b/src/mpid/ch4/shm/src/shm_am_fallback_send.h index d53ee2ac64e..208c9d044ab 100644 --- a/src/mpid/ch4/shm/src/shm_am_fallback_send.h +++ b/src/mpid/ch4/shm/src/shm_am_fallback_send.h @@ -24,10 +24,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_SHM_mpi_isend(const void *buf, vci_src = 0; vci_dst = 0; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci_src).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci_src)); mpi_errno = MPIDIG_mpi_isend(buf, count, datatype, rank, tag, comm, context_offset, addr, vci_src, vci_dst, request, syncflag, errflag); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci_src).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci_src)); return mpi_errno; } diff --git a/src/mpid/ch4/src/ch4_impl.h b/src/mpid/ch4/src/ch4_impl.h index 8991052f1a5..e630e27be05 100644 --- a/src/mpid/ch4/src/ch4_impl.h +++ b/src/mpid/ch4/src/ch4_impl.h @@ -416,7 +416,7 @@ do { \ mpi_errno = MPIDI_progress_test_vci(vci); \ MPIR_ERR_CHECK(mpi_errno); \ MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX); \ - MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI_LOCK(vci)); \ DEBUG_PROGRESS_CHECK; \ } \ } while (0) @@ -428,7 +428,7 @@ do { \ mpi_errno = MPIDI_progress_test_vci(vci); \ MPIR_ERR_CHECK(mpi_errno); \ MPID_THREAD_CS_YIELD(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX); \ - MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_YIELD(VCI, MPIDI_VCI_LOCK(vci)); \ DEBUG_PROGRESS_CHECK; \ } while (cond); \ } while (0) diff --git a/src/mpid/ch4/src/ch4_init.c b/src/mpid/ch4/src/ch4_init.c index dd707fc592a..4144a8223b4 100644 --- a/src/mpid/ch4/src/ch4_init.c +++ b/src/mpid/ch4/src/ch4_init.c @@ -523,17 +523,6 @@ int MPID_Init(int requested, int *provided) MPIR_Assert(MPIDI_global.n_total_vcis <= MPIR_REQUEST_NUM_POOLS); for (int i = 0; i < MPIDI_global.n_total_vcis; i++) { - int err; - MPID_Thread_mutex_create(&MPIDI_VCI(i).lock, &err); - MPIR_Assert(err == 0); - - /* NOTE: 1-1 vci-pool mapping */ - /* For lockless, use a separate set of mutexes */ - if (MPIDI_CH4_MT_MODEL == MPIDI_CH4_MT_LOCKLESS) - MPIR_Request_register_pool_lock(i, &MPIR_THREAD_VCI_HANDLE_POOL_MUTEXES[i]); - else - MPIR_Request_register_pool_lock(i, &MPIDI_VCI(i).lock); - /* Initialize registered host buffer pool to be used as temporary unpack buffers */ mpi_errno = MPIDU_genq_private_pool_create(MPIR_CVAR_CH4_PACK_BUFFER_SIZE, MPIR_CVAR_CH4_NUM_PACK_BUFFERS_PER_CHUNK, @@ -837,10 +826,6 @@ int MPID_Finalize(void) for (int i = 0; i < MPIDI_global.n_total_vcis; i++) { MPIDU_genq_private_pool_destroy(MPIDI_global.per_vci[i].pack_buf_pool); - - int err; - MPID_Thread_mutex_destroy(&MPIDI_VCI(i).lock, &err); - MPIR_Assert(err == 0); } MPL_free(MPIDI_global.all_num_vcis); diff --git a/src/mpid/ch4/src/ch4_persist.c b/src/mpid/ch4/src/ch4_persist.c index 20ffa647d43..8b82e260d52 100644 --- a/src/mpid/ch4/src/ch4_persist.c +++ b/src/mpid/ch4/src/ch4_persist.c @@ -20,9 +20,9 @@ static int psend_init(MPIDI_ptype ptype, int context_offset = MPIR_PT2PT_ATTR_CONTEXT_OFFSET(attr); int vci = MPIDI_get_vci(SRC_VCI_FROM_SENDER, comm, comm->rank, rank, tag); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_CH4_REQUEST_CREATE(sreq, MPIR_REQUEST_KIND__PREQUEST_SEND, vci, 1); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHKANDSTMT(sreq == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq"); *request = sreq; @@ -137,9 +137,9 @@ int MPID_Recv_init(void *buf, int context_offset = MPIR_PT2PT_ATTR_CONTEXT_OFFSET(attr); int vci = MPIDI_get_vci(DST_VCI_FROM_RECVER, comm, rank, comm->rank, tag); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_CH4_REQUEST_CREATE(rreq, MPIR_REQUEST_KIND__PREQUEST_RECV, vci, 1); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHKANDSTMT(rreq == NULL, mpi_errno, MPIX_ERR_NOREQ, goto fn_fail, "**nomemreq"); *request = rreq; diff --git a/src/mpid/ch4/src/ch4_progress.h b/src/mpid/ch4/src/ch4_progress.h index be35a9bafb0..786ddcf6cda 100644 --- a/src/mpid/ch4/src/ch4_progress.h +++ b/src/mpid/ch4/src/ch4_progress.h @@ -45,12 +45,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_do_global_progress(void) #define MPIDI_THREAD_CS_ENTER_VCI_OPTIONAL(vci) \ if (!MPIDI_VCI_IS_EXPLICIT(vci) && !(state->flag & MPIDI_PROGRESS_NM_LOCKLESS)) { \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); \ } #define MPIDI_THREAD_CS_EXIT_VCI_OPTIONAL(vci) \ if (!MPIDI_VCI_IS_EXPLICIT(vci) && !(state->flag & MPIDI_PROGRESS_NM_LOCKLESS)) { \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); \ } while (0) @@ -70,9 +70,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_do_global_progress(void) #define MPIDI_PROGRESS(vci) \ do { \ if (state->flag & MPIDI_PROGRESS_SHM && !made_progress) { \ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); \ mpi_errno = MPIDI_SHM_progress(vci, &made_progress); \ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); \ + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); \ MPIR_ERR_CHECK(mpi_errno); \ } \ if (state->flag & MPIDI_PROGRESS_NM && !made_progress) { \ @@ -189,9 +189,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_progress_test_vci(int vci) int mpi_errno = MPI_SUCCESS; if (!MPIDI_VCI_IS_EXPLICIT(vci) && MPIDI_do_global_progress()) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPID_Progress_test(NULL); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); } else { int made_progress = 0; mpi_errno = MPIDI_NM_progress(vci, &made_progress); diff --git a/src/mpid/ch4/src/ch4_recv.h b/src/mpid/ch4/src/ch4_recv.h index e382f283195..21839e189ed 100644 --- a/src/mpid/ch4/src/ch4_recv.h +++ b/src/mpid/ch4/src/ch4_recv.h @@ -24,7 +24,7 @@ MPL_STATIC_INLINE_PREFIX int anysource_irecv(void *buf, MPI_Aint count, MPI_Data #if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI int vci; MPIDI_POSIX_RECV_VSI(vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_CH4_REQUEST_CREATE(*request, MPIR_REQUEST_KIND__RECV, vci, 1); MPIR_Assert(*request); @@ -45,7 +45,7 @@ MPL_STATIC_INLINE_PREFIX int anysource_irecv(void *buf, MPI_Aint count, MPI_Data } fn_exit: #if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); #endif return mpi_errno; fn_fail: @@ -156,9 +156,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_cancel_recv_safe(MPIR_Request * rreq) * usage it's often used inside a critical section (e.g. progress and anysource * receive). Therefore, we allow recursive lock usage here. */ - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDI_cancel_recv_unsafe(rreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpid/ch4/src/ch4_types.h b/src/mpid/ch4/src/ch4_types.h index 88fbb5ac4ab..5adab8ebe74 100644 --- a/src/mpid/ch4/src/ch4_types.h +++ b/src/mpid/ch4/src/ch4_types.h @@ -254,6 +254,7 @@ typedef struct MPIDI_per_vci { } MPIDI_per_vci_t; #define MPIDI_VCI(i) MPIDI_global.per_vci[i] +#define MPIDI_VCI_LOCK(i) MPIR_THREAD_VCI_REQUEST_POOL_MUTEXES[i] typedef struct MPIDI_CH4_Global_t { int pname_set; diff --git a/src/mpid/ch4/src/mpidig_part.c b/src/mpid/ch4/src/mpidig_part.c index 94f1c227087..6ae2cb6c7e6 100644 --- a/src/mpid/ch4/src/mpidig_part.c +++ b/src/mpid/ch4/src/mpidig_part.c @@ -98,7 +98,7 @@ int MPIDIG_mpi_psend_init(const void *buf, int partitions, MPI_Aint count, int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); /* Create and initialize device-layer partitioned request */ mpi_errno = part_req_create((void *) buf, partitions, count, datatype, dest, tag, comm, @@ -123,7 +123,7 @@ int MPIDIG_mpi_psend_init(const void *buf, int partitions, MPI_Aint count, MPIDI_REQUEST(*request, is_local), mpi_errno); fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -138,7 +138,7 @@ int MPIDIG_mpi_precv_init(void *buf, int partitions, MPI_Aint count, int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); /* Create and initialize device-layer partitioned request */ mpi_errno = part_req_create(buf, partitions, count, datatype, source, tag, comm, @@ -168,7 +168,7 @@ int MPIDIG_mpi_precv_init(void *buf, int partitions, MPI_Aint count, } fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: diff --git a/src/mpid/ch4/src/mpidig_part.h b/src/mpid/ch4/src/mpidig_part.h index 35377b58b36..bf755776f59 100644 --- a/src/mpid/ch4/src/mpidig_part.h +++ b/src/mpid/ch4/src/mpidig_part.h @@ -23,7 +23,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_part_start(MPIR_Request * request) int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); /* Indicate data transfer starts. * Decrease when am request completes on sender (via completion_notification), @@ -39,7 +39,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_part_start(MPIR_Request * request) MPIR_Part_request_activate(request); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); MPIR_FUNC_EXIT; return mpi_errno; } @@ -55,9 +55,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_pready_range(int partition_low, int part MPIR_cc_decr(&MPIDIG_PART_REQUEST(part_sreq, u.send).ready_cntr, &incomplete); if (!incomplete) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); mpi_errno = MPIDIG_part_issue_data(part_sreq, MPIDIG_PART_REGULAR); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); } MPIR_FUNC_EXIT; @@ -75,9 +75,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_pready_list(int length, const int array_ MPIR_cc_decr(&MPIDIG_PART_REQUEST(part_sreq, u.send).ready_cntr, &incomplete); if (!incomplete) { - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); mpi_errno = MPIDIG_part_issue_data(part_sreq, MPIDIG_PART_REGULAR); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); } MPIR_FUNC_EXIT; @@ -89,7 +89,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_parrived(MPIR_Request * request, int par int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(0)); /* Do not maintain per-partition completion. Arrived when full data transfer is done. * An inactive request returns TRUE (same for NULL req, handled at MPIR layer). */ @@ -104,7 +104,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_parrived(MPIR_Request * request, int par } fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(0).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(0)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: diff --git a/src/mpid/ch4/src/mpidig_rma.h b/src/mpid/ch4/src/mpidig_rma.h index 63545a456d5..850e57f9edf 100644 --- a/src/mpid/ch4/src/mpidig_rma.h +++ b/src/mpid/ch4/src/mpidig_rma.h @@ -557,11 +557,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_put(const void *origin_addr, MPI_Aint or MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, vci, NULL); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -582,10 +582,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_rput(const void *origin_addr, MPI_Aint o MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, vci, &sreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -606,11 +606,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_get(void *origin_addr, MPI_Aint origin_c MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, vci, NULL); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -631,10 +631,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_rget(void *origin_addr, MPI_Aint origin_ MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, win, vci, &sreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -657,11 +657,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_raccumulate(const void *origin_addr, MPI MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_accumulate(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, op, win, vci, &sreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -682,11 +682,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_accumulate(const void *origin_addr, MPI_ MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_accumulate(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count, target_datatype, op, win, vci, NULL); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -712,11 +712,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_rget_accumulate(const void *origin_addr, MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_get_accumulate(origin_addr, origin_count, origin_datatype, result_addr, result_count, result_datatype, target_rank, target_disp, target_count, target_datatype, op, win, vci, &sreq); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -741,12 +741,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_get_accumulate(const void *origin_addr, MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPIDIG_do_get_accumulate(origin_addr, origin_count, origin_datatype, result_addr, result_count, result_datatype, target_rank, target_disp, target_count, target_datatype, op, win, vci, NULL); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_ERR_CHECK(mpi_errno); fn_exit: @@ -770,7 +770,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_compare_and_swap(const void *origin_addr MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDIG_RMA_OP_CHECK_SYNC(target_rank, win); @@ -812,7 +812,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_compare_and_swap(const void *origin_addr MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno); MPIR_ERR_CHECK(mpi_errno); fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: diff --git a/src/mpid/ch4/src/mpidig_win.h b/src/mpid/ch4/src/mpidig_win.h index 6353bf7def3..29a628a0fef 100644 --- a/src/mpid/ch4/src/mpidig_win.h +++ b/src/mpid/ch4/src/mpidig_win.h @@ -134,7 +134,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_start(MPIR_Group * group, int assert MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDIG_ACCESS_EPOCH_CHECK_NONE(win, mpi_errno, goto fn_fail); @@ -153,7 +153,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_start(MPIR_Group * group, int assert MPIDIG_WIN(win, sync).access_epoch_type = MPIDIG_EPOTYPE_START; fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -177,7 +177,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_complete(MPIR_Win * win) MPIR_Assert(group != NULL); int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op completion in netmod and shmmod */ @@ -225,7 +225,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_complete(MPIR_Win * win) MPL_free(ranks_in_win_grp); if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -242,7 +242,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_post(MPIR_Group * group, int assert, MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDIG_EXPOSURE_EPOCH_CHECK_NONE(win, mpi_errno, goto fn_fail); @@ -280,7 +280,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_post(MPIR_Group * group, int assert, fn_exit: MPL_free(ranks_in_win_grp); - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -294,7 +294,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_wait(MPIR_Win * win) MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDIG_EXPOSURE_EPOCH_CHECK(win, MPIDIG_EPOTYPE_POST, mpi_errno, goto fn_fail); group = MPIDIG_WIN(win, sync).pw.group; @@ -306,7 +306,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_wait(MPIR_Win * win) MPIDIG_WIN(win, sync).exposure_epoch_type = MPIDIG_EPOTYPE_NONE; fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -321,7 +321,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_test(MPIR_Win * win, int *flag) #if MPICH_THREAD_GRANULARITY == MPICH_THREAD_GRANULARITY__VCI int vci = MPIDI_WIN(win, am_vci); #endif - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDIG_EXPOSURE_EPOCH_CHECK(win, MPIDIG_EPOTYPE_POST, mpi_errno, goto fn_fail); @@ -335,14 +335,14 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_test(MPIR_Win * win, int *flag) MPIR_Group_release(group); MPIDIG_WIN(win, sync).exposure_epoch_type = MPIDIG_EPOTYPE_NONE; } else { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); mpi_errno = MPID_Progress_test(NULL); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); *flag = 0; } fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -358,7 +358,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_lock(int lock_type, int rank, int as MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); int vci_target = MPIDI_WIN_TARGET_VCI(win, rank); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDIG_LOCK_EPOCH_CHECK_NONE(win, rank, mpi_errno, goto fn_fail); @@ -392,7 +392,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_lock(int lock_type, int rank, int as MPIDIG_WIN(win, sync).lock.count++; fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -423,7 +423,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_unlock(int rank, MPIR_Win * win) int vci = MPIDI_WIN(win, am_vci); int vci_target = MPIDI_WIN_TARGET_VCI(win, rank); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op completion in netmod and shmmod */ @@ -470,7 +470,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_unlock(int rank, MPIR_Win * win) fn_exit: if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -488,7 +488,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_fence(int massert, MPIR_Win * win) MPIDIG_FENCE_EPOCH_CHECK(win, mpi_errno, goto fn_fail); int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op completion in netmod and shmmod */ @@ -520,13 +520,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_fence(int massert, MPIR_Win * win) /* MPIR_Barrier's state is protected by ALLFUNC_MUTEX. * In VCI granularity, individual send/recv/wait operations will take * the VCI lock internally. */ - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 0; mpi_errno = MPIR_Barrier(win->comm_ptr, MPIR_ERR_NONE); fn_exit: if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -641,7 +641,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush(int rank, MPIR_Win * win) MPIDIG_EPOCH_CHECK_PASSIVE(win, mpi_errno, return mpi_errno); int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op completion in netmod and shmmod */ @@ -670,7 +670,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush(int rank, MPIR_Win * win) fn_exit: if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -687,7 +687,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush_local_all(MPIR_Win * win) MPIDIG_EPOCH_CHECK_PASSIVE(win, mpi_errno, goto fn_fail); int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op local completion in netmod and shmmod */ @@ -710,7 +710,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush_local_all(MPIR_Win * win) fn_exit: if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -730,7 +730,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_unlock_all(MPIR_Win * win) MPIR_Assert(MPIDIG_WIN(win, sync).lockall.allLocked == win->comm_ptr->local_size); int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op completion in netmod and shmmod */ @@ -775,7 +775,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_unlock_all(MPIR_Win * win) fn_exit: if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -793,7 +793,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush_local(int rank, MPIR_Win * win MPIDIG_EPOCH_CHECK_PASSIVE(win, mpi_errno, return mpi_errno); int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op local completion in netmod and shmmod */ @@ -821,7 +821,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush_local(int rank, MPIR_Win * win fn_exit: if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -853,7 +853,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush_all(MPIR_Win * win) MPIDIG_EPOCH_CHECK_PASSIVE(win, mpi_errno, goto fn_fail); int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); need_unlock = 1; /* Ensure op completion in netmod and shmmod */ @@ -876,7 +876,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_flush_all(MPIR_Win * win) fn_exit: if (need_unlock) { - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); } MPIR_FUNC_EXIT; return mpi_errno; @@ -890,7 +890,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_lock_all(int assert, MPIR_Win * win) MPIR_FUNC_ENTER; int vci = MPIDI_WIN(win, am_vci); - MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDIG_ACCESS_EPOCH_CHECK_NONE(win, mpi_errno, goto fn_fail); @@ -924,7 +924,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_mpi_win_lock_all(int assert, MPIR_Win * win) MPIDIG_WIN(win, sync).access_epoch_type = MPIDIG_EPOTYPE_LOCK_ALL; fn_exit: - MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI(vci).lock); + MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); MPIR_FUNC_EXIT; return mpi_errno; fn_fail: From 2405285b2fcb43984b61c878e191765d69b824f7 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 27 Dec 2024 20:21:42 -0600 Subject: [PATCH 03/25] ch4: cleanup MPIDI_global vci settings --- src/mpid/ch4/src/ch4_init.c | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/src/mpid/ch4/src/ch4_init.c b/src/mpid/ch4/src/ch4_init.c index 4144a8223b4..3553e48710d 100644 --- a/src/mpid/ch4/src/ch4_init.c +++ b/src/mpid/ch4/src/ch4_init.c @@ -508,19 +508,13 @@ int MPID_Init(int requested, int *provided) MPIR_Assert(MPIR_CVAR_CH4_NUM_VCIS >= 1); /* number of vcis used in implicit vci hashing */ MPIR_Assert(MPIR_CVAR_CH4_RESERVE_VCIS >= 0); /* maximum number of vcis can be reserved */ - MPIDI_global.n_vcis = MPIR_CVAR_CH4_NUM_VCIS; - MPIDI_global.n_total_vcis = MPIDI_global.n_vcis + MPIR_CVAR_CH4_RESERVE_VCIS; + MPIDI_global.n_vcis = 1; + MPIDI_global.n_total_vcis = 1; MPIDI_global.n_reserved_vcis = 0; MPIDI_global.share_reserved_vcis = false; - MPIDI_global.all_num_vcis = MPL_malloc(sizeof(int) * MPIR_Process.size, MPL_MEM_OTHER); + MPIDI_global.all_num_vcis = MPL_calloc(MPIR_Process.size, size(int), MPL_MEM_OTHER); MPIR_Assert(MPIDI_global.all_num_vcis); - for (int i = 0; i < MPIR_Process.size; i++) { - MPIDI_global.all_num_vcis[i] = MPIDI_global.n_vcis; - } - - MPIR_Assert(MPIDI_global.n_total_vcis <= MPIDI_CH4_MAX_VCIS); - MPIR_Assert(MPIDI_global.n_total_vcis <= MPIR_REQUEST_NUM_POOLS); for (int i = 0; i < MPIDI_global.n_total_vcis; i++) { /* Initialize registered host buffer pool to be used as temporary unpack buffers */ @@ -672,24 +666,21 @@ int MPIDI_world_post_init(void) * this restriction, then we can move MPIDI_NM_init_vcis to * MPIDI_world_pre_init. */ + int n_total_vcis = MPIR_CVAR_CH4_NUM_VCIS + MPIR_CVAR_CH4_RESERVE_VCIS; + MPIR_Assert(n_total_vcis <= MPIDI_CH4_MAX_VCIS); + MPIR_Assert(n_total_vcis <= MPIR_REQUEST_NUM_POOLS); + int num_vcis_actual; - mpi_errno = MPIDI_NM_init_vcis(MPIDI_global.n_total_vcis, &num_vcis_actual); + mpi_errno = MPIDI_NM_init_vcis(n_total_vcis, &num_vcis_actual); MPIR_ERR_CHECK(mpi_errno); #if MPIDI_CH4_MAX_VCIS == 1 MPIR_Assert(num_vcis_actual == 1); #else MPIR_Assert(num_vcis_actual > 0 && num_vcis_actual <= MPIDI_global.n_total_vcis); - int diff = MPIDI_global.n_total_vcis - num_vcis_actual; - /* we can shrink implicit vcis down to 1, then n_reserved_vcis down to 0 */ - MPIDI_global.n_total_vcis -= diff; - if (MPIDI_global.n_vcis > diff + 1) { - MPIDI_global.n_vcis -= diff; - } else { - diff -= (MPIDI_global.n_vcis - 1); - MPIDI_global.n_vcis = 1; - MPIDI_global.n_reserved_vcis -= diff; - } + + MPIDI_global.n_total_vcis = num_vcis_actual; + MPIDI_global.n_vcis = MPL_MIN(MPIR_CVAR_CH4_NUM_VCIS, MPIDI_global.n_total_vcis); mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT, MPIDI_global.all_num_vcis, 1, MPI_INT, From 63998a3d192adc10370c7a08e97a9d3f3eaa12d5 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 26 Dec 2024 00:08:57 -0600 Subject: [PATCH 04/25] ch4: add API comm_set_vcis --- src/mpid/ch4/ch4_api.txt | 4 ++++ src/mpid/ch4/include/mpidimpl.h | 1 + src/mpid/ch4/netmod/ofi/Makefile.mk | 1 + src/mpid/ch4/netmod/ofi/ofi_vci.c | 17 +++++++++++++++++ src/mpid/ch4/netmod/ucx/Makefile.mk | 1 + src/mpid/ch4/netmod/ucx/ucx_vci.c | 13 +++++++++++++ src/mpid/ch4/shm/posix/Makefile.mk | 1 + src/mpid/ch4/shm/posix/posix_vci.c | 13 +++++++++++++ src/mpid/ch4/shm/src/shm_hooks.c | 5 +++++ src/mpid/ch4/src/Makefile.mk | 1 + src/mpid/ch4/src/ch4_vci.c | 13 +++++++++++++ 11 files changed, 70 insertions(+) create mode 100644 src/mpid/ch4/netmod/ofi/ofi_vci.c create mode 100644 src/mpid/ch4/netmod/ucx/ucx_vci.c create mode 100644 src/mpid/ch4/shm/posix/posix_vci.c create mode 100644 src/mpid/ch4/src/ch4_vci.c diff --git a/src/mpid/ch4/ch4_api.txt b/src/mpid/ch4/ch4_api.txt index c1778e546ff..24026e01748 100644 --- a/src/mpid/ch4/ch4_api.txt +++ b/src/mpid/ch4/ch4_api.txt @@ -87,6 +87,9 @@ Non Native API: am_tag_recv : int NM*: rank, comm, handler_id, tag, buf-2, count, datatype, src_vci, dst_vci, rreq SHM*: rank, comm, handler_id, tag, buf-2, count, datatype, src_vci, dst_vci, rreq + comm_set_vcis : int + NM : comm, num_vcis, all_num_vcis + SHM : comm, num_vcis comm_get_gpid : int NM*: comm_ptr, idx, gpid_ptr, is_remote get_local_upids : int @@ -485,6 +488,7 @@ PARAM: newcomm_ptr: MPIR_Comm ** num_vcis: int num_vcis_actual: int * + all_num_vcis: int * op: MPI_Op op_p: MPIR_Op * origin_addr: const void * diff --git a/src/mpid/ch4/include/mpidimpl.h b/src/mpid/ch4/include/mpidimpl.h index 93c6dccd9e8..6bd230b21f7 100644 --- a/src/mpid/ch4/include/mpidimpl.h +++ b/src/mpid/ch4/include/mpidimpl.h @@ -20,5 +20,6 @@ int MPIDI_world_pre_init(void); int MPIDI_world_post_init(void); +int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis); #endif /* MPIDIMPL_H_INCLUDED */ diff --git a/src/mpid/ch4/netmod/ofi/Makefile.mk b/src/mpid/ch4/netmod/ofi/Makefile.mk index 3903ebcd22c..0ee8c26e384 100644 --- a/src/mpid/ch4/netmod/ofi/Makefile.mk +++ b/src/mpid/ch4/netmod/ofi/Makefile.mk @@ -22,6 +22,7 @@ mpi_core_sources += src/mpid/ch4/netmod/ofi/func_table.c \ src/mpid/ch4/netmod/ofi/ofi_progress.c \ src/mpid/ch4/netmod/ofi/ofi_am_events.c \ src/mpid/ch4/netmod/ofi/ofi_nic.c \ + src/mpid/ch4/netmod/ofi/ofi_vci.c \ src/mpid/ch4/netmod/ofi/globals.c \ src/mpid/ch4/netmod/ofi/init_provider.c \ src/mpid/ch4/netmod/ofi/init_settings.c \ diff --git a/src/mpid/ch4/netmod/ofi/ofi_vci.c b/src/mpid/ch4/netmod/ofi/ofi_vci.c new file mode 100644 index 00000000000..b525b90da58 --- /dev/null +++ b/src/mpid/ch4/netmod/ofi/ofi_vci.c @@ -0,0 +1,17 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#include "mpidimpl.h" +#include "ofi_impl.h" + +int MPIDI_OFI_comm_set_vcis(MPIR_Comm * comm, int num_vcis) +{ + int mpi_errno = MPI_SUCCESS; + /* 0. get num_nics from CVARs */ + /* 1. check that MPIDI_OFI_global.n_total_vcis = 0 */ + /* 2. allocate and initialize local vcis */ + /* 3. exchange addresses */ + return mpi_errno; +} diff --git a/src/mpid/ch4/netmod/ucx/Makefile.mk b/src/mpid/ch4/netmod/ucx/Makefile.mk index fb61b0628cc..3765549f5ad 100644 --- a/src/mpid/ch4/netmod/ucx/Makefile.mk +++ b/src/mpid/ch4/netmod/ucx/Makefile.mk @@ -15,6 +15,7 @@ mpi_core_sources += src/mpid/ch4/netmod/ucx/func_table.c\ src/mpid/ch4/netmod/ucx/ucx_win.c \ src/mpid/ch4/netmod/ucx/ucx_part.c \ src/mpid/ch4/netmod/ucx/ucx_am.c \ + src/mpid/ch4/netmod/ucx/ucx_vci.c \ src/mpid/ch4/netmod/ucx/globals.c errnames_txt_files += src/mpid/ch4/netmod/ucx/errnames.txt diff --git a/src/mpid/ch4/netmod/ucx/ucx_vci.c b/src/mpid/ch4/netmod/ucx/ucx_vci.c new file mode 100644 index 00000000000..2153d73c4f1 --- /dev/null +++ b/src/mpid/ch4/netmod/ucx/ucx_vci.c @@ -0,0 +1,13 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#include "mpidimpl.h" +#include "ucx_impl.h" + +int MPIDI_UCX_comm_set_vcis(MPIR_Comm * comm, int num_vcis) +{ + int mpi_errno = MPI_SUCCESS; + return mpi_errno; +} diff --git a/src/mpid/ch4/shm/posix/Makefile.mk b/src/mpid/ch4/shm/posix/Makefile.mk index cb0faedecaa..6b59f8e87be 100644 --- a/src/mpid/ch4/shm/posix/Makefile.mk +++ b/src/mpid/ch4/shm/posix/Makefile.mk @@ -34,6 +34,7 @@ mpi_core_sources += src/mpid/ch4/shm/posix/globals.c \ src/mpid/ch4/shm/posix/posix_datatype.c \ src/mpid/ch4/shm/posix/posix_win.c \ src/mpid/ch4/shm/posix/posix_part.c \ + src/mpid/ch4/shm/posix/posix_vci.c \ src/mpid/ch4/shm/posix/posix_eager_array.c include $(top_srcdir)/src/mpid/ch4/shm/posix/eager/Makefile.mk diff --git a/src/mpid/ch4/shm/posix/posix_vci.c b/src/mpid/ch4/shm/posix/posix_vci.c new file mode 100644 index 00000000000..9b5534be040 --- /dev/null +++ b/src/mpid/ch4/shm/posix/posix_vci.c @@ -0,0 +1,13 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#include "mpidimpl.h" +#include "posix_types.h" + +int MPIDI_POSIX_comm_set_vcis(MPIR_Comm * comm, int num_vcis) +{ + int mpi_errno = MPI_SUCCESS; + return mpi_errno; +} diff --git a/src/mpid/ch4/shm/src/shm_hooks.c b/src/mpid/ch4/shm/src/shm_hooks.c index 8cf495902c2..9f781175b98 100644 --- a/src/mpid/ch4/shm/src/shm_hooks.c +++ b/src/mpid/ch4/shm/src/shm_hooks.c @@ -198,3 +198,8 @@ int MPIDI_SHM_mpi_win_free_hook(MPIR_Win * win) fn_fail: goto fn_exit; } + +int MPIDI_SHM_comm_set_vcis(MPIR_Comm * comm, int num_vcis) +{ + return MPIDI_POSIX_comm_set_vcis(comm, num_vcis); +} diff --git a/src/mpid/ch4/src/Makefile.mk b/src/mpid/ch4/src/Makefile.mk index 0e499fd924c..13480f05915 100644 --- a/src/mpid/ch4/src/Makefile.mk +++ b/src/mpid/ch4/src/Makefile.mk @@ -51,6 +51,7 @@ mpi_core_sources += src/mpid/ch4/src/ch4_globals.c \ src/mpid/ch4/src/ch4_proc.c \ src/mpid/ch4/src/ch4_stream_enqueue.c \ src/mpid/ch4/src/ch4_persist.c \ + src/mpid/ch4/src/ch4_vci.c \ src/mpid/ch4/src/mpidig_init.c \ src/mpid/ch4/src/mpidig_recvq.c \ src/mpid/ch4/src/mpidig_pt2pt_callbacks.c \ diff --git a/src/mpid/ch4/src/ch4_vci.c b/src/mpid/ch4/src/ch4_vci.c new file mode 100644 index 00000000000..f958e501320 --- /dev/null +++ b/src/mpid/ch4/src/ch4_vci.c @@ -0,0 +1,13 @@ +/* + * Copyright (C) by Argonne National Laboratory + * See COPYRIGHT in top-level directory + */ + +#include "mpidimpl.h" + +int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis) +{ + int mpi_errno = MPI_SUCCESS; + + return mpi_errno; +} From acbc2bd90a686c1641b9925bd17d48da55dd41fc Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 28 Dec 2024 10:04:05 -0600 Subject: [PATCH 05/25] ch4: refactor vci init/finalize Move multiple vci related init/finalize code into ch4_vci.c. Wrap per-vci code into a function and only deal with vci 0 in ch4_init.c and additional vcis in ch4_vci.c. --- src/mpid/ch4/src/ch4_impl.h | 4 + src/mpid/ch4/src/ch4_init.c | 156 ++++++++++----------------------- src/mpid/ch4/src/ch4_vci.c | 132 ++++++++++++++++++++++++++++ src/mpid/ch4/src/mpidig.h | 2 + src/mpid/ch4/src/mpidig_init.c | 54 ++++++++---- 5 files changed, 221 insertions(+), 127 deletions(-) diff --git a/src/mpid/ch4/src/ch4_impl.h b/src/mpid/ch4/src/ch4_impl.h index e630e27be05..063f0667f9d 100644 --- a/src/mpid/ch4/src/ch4_impl.h +++ b/src/mpid/ch4/src/ch4_impl.h @@ -13,6 +13,10 @@ #include "ch4_self.h" #include "ch4_vci.h" +int MPIDI_vci_init(void); +int MPIDI_vci_finalize(void); +int MPIDI_init_per_vci(int vci); +int MPIDI_destroy_per_vci(int vci); int MPIDIU_Intercomm_map_bcast_intra(MPIR_Comm * local_comm, int local_leader, int *remote_size, int *is_low_group, int pure_intracomm, int *remote_upid_size, char *remote_upids, diff --git a/src/mpid/ch4/src/ch4_init.c b/src/mpid/ch4/src/ch4_init.c index 3553e48710d..7fb1ab6986a 100644 --- a/src/mpid/ch4/src/ch4_init.c +++ b/src/mpid/ch4/src/ch4_init.c @@ -75,26 +75,6 @@ direct (default) lockless - - name : MPIR_CVAR_CH4_NUM_VCIS - category : CH4 - type : int - default : 1 - class : none - verbosity : MPI_T_VERBOSITY_USER_BASIC - scope : MPI_T_SCOPE_LOCAL - description : >- - Sets the number of VCIs to be implicitly used (should be a subset of MPIDI_CH4_MAX_VCIS). - - - name : MPIR_CVAR_CH4_RESERVE_VCIS - category : CH4 - type : int - default : 0 - class : none - verbosity : MPI_T_VERBOSITY_USER_BASIC - scope : MPI_T_SCOPE_LOCAL - description : >- - Sets the number of VCIs that user can explicitly allocate (should be a subset of MPIDI_CH4_MAX_VCIS). - - name : MPIR_CVAR_CH4_COLL_SELECTION_TUNING_JSON_FILE category : COLLECTIVE type : string @@ -420,6 +400,43 @@ static void register_comm_hints(void) MPIR_COMM_HINT_TYPE_INT, 0, MPIDI_VCI_INVALID); } +int MPIDI_init_per_vci(int vci) +{ + int mpi_errno = MPI_SUCCESS; + /* Initialize registered host buffer pool to be used as temporary unpack buffers */ + mpi_errno = MPIDU_genq_private_pool_create(MPIR_CVAR_CH4_PACK_BUFFER_SIZE, + MPIR_CVAR_CH4_NUM_PACK_BUFFERS_PER_CHUNK, + MPIR_CVAR_CH4_MAX_NUM_PACK_BUFFERS, + host_alloc_registered, + host_free_registered, + &MPIDI_global.per_vci[vci].pack_buf_pool); + MPIR_ERR_CHECK(mpi_errno); + + mpi_errno = MPIDIG_init_per_vci(vci); + MPIR_ERR_CHECK(mpi_errno); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +int MPIDI_destroy_per_vci(int vci) +{ + int mpi_errno = MPI_SUCCESS; + + mpi_errno = MPIDU_genq_private_pool_destroy(MPIDI_global.per_vci[vci].pack_buf_pool); + MPIR_ERR_CHECK(mpi_errno); + + mpi_errno = MPIDIG_destroy_per_vci(vci); + MPIR_ERR_CHECK(mpi_errno); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + int MPID_Init(int requested, int *provided) { int mpi_errno = MPI_SUCCESS; @@ -503,31 +520,11 @@ int MPID_Init(int requested, int *provided) MPIDI_global.csel_root = NULL; MPIDI_global.csel_root_gpu = NULL; - /* Initialize multiple VCIs */ - /* TODO: add checks to ensure MPIDI_vci_t is padded or aligned to MPL_CACHELINE_SIZE */ - MPIR_Assert(MPIR_CVAR_CH4_NUM_VCIS >= 1); /* number of vcis used in implicit vci hashing */ - MPIR_Assert(MPIR_CVAR_CH4_RESERVE_VCIS >= 0); /* maximum number of vcis can be reserved */ - - MPIDI_global.n_vcis = 1; - MPIDI_global.n_total_vcis = 1; - MPIDI_global.n_reserved_vcis = 0; - MPIDI_global.share_reserved_vcis = false; - - MPIDI_global.all_num_vcis = MPL_calloc(MPIR_Process.size, size(int), MPL_MEM_OTHER); - MPIR_Assert(MPIDI_global.all_num_vcis); - - for (int i = 0; i < MPIDI_global.n_total_vcis; i++) { - /* Initialize registered host buffer pool to be used as temporary unpack buffers */ - mpi_errno = MPIDU_genq_private_pool_create(MPIR_CVAR_CH4_PACK_BUFFER_SIZE, - MPIR_CVAR_CH4_NUM_PACK_BUFFERS_PER_CHUNK, - MPIR_CVAR_CH4_MAX_NUM_PACK_BUFFERS, - host_alloc_registered, - host_free_registered, - &MPIDI_global.per_vci[i].pack_buf_pool); - MPIR_ERR_CHECK(mpi_errno); - - } + mpi_errno = MPIDI_vci_init(); + MPIR_ERR_CHECK(mpi_errno); + mpi_errno = MPIDI_init_per_vci(0); + MPIR_ERR_CHECK(mpi_errno); /* internally does per-vci am initialization */ MPIDIG_am_init(); @@ -661,33 +658,10 @@ int MPIDI_world_post_init(void) { int mpi_errno = MPI_SUCCESS; - /* FIXME: currently ofi require each process to have the same number of nics, - * thus need access to world_comm for collectives. We should remove - * this restriction, then we can move MPIDI_NM_init_vcis to - * MPIDI_world_pre_init. - */ int n_total_vcis = MPIR_CVAR_CH4_NUM_VCIS + MPIR_CVAR_CH4_RESERVE_VCIS; - MPIR_Assert(n_total_vcis <= MPIDI_CH4_MAX_VCIS); - MPIR_Assert(n_total_vcis <= MPIR_REQUEST_NUM_POOLS); - - int num_vcis_actual; - mpi_errno = MPIDI_NM_init_vcis(n_total_vcis, &num_vcis_actual); + mpi_errno = MPIDI_Comm_set_vcis(MPIR_Process.comm_world, n_total_vcis); MPIR_ERR_CHECK(mpi_errno); -#if MPIDI_CH4_MAX_VCIS == 1 - MPIR_Assert(num_vcis_actual == 1); -#else - MPIR_Assert(num_vcis_actual > 0 && num_vcis_actual <= MPIDI_global.n_total_vcis); - - MPIDI_global.n_total_vcis = num_vcis_actual; - MPIDI_global.n_vcis = MPL_MIN(MPIR_CVAR_CH4_NUM_VCIS, MPIDI_global.n_total_vcis); - - mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT, - MPIDI_global.all_num_vcis, 1, MPI_INT, - MPIR_Process.comm_world, MPIR_ERR_NONE); - MPIR_ERR_CHECK(mpi_errno); -#endif - #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_post_init(); MPIR_ERR_CHECK(mpi_errno); @@ -704,43 +678,6 @@ int MPIDI_world_post_init(void) goto fn_exit; } -int MPID_Allocate_vci(int *vci, bool is_shared) -{ - int mpi_errno = MPI_SUCCESS; - - *vci = 0; -#if MPIDI_CH4_MAX_VCIS == 1 - MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**ch4nostream"); -#else - - if (MPIDI_global.n_vcis + MPIDI_global.n_reserved_vcis >= MPIDI_global.n_total_vcis) { - MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**outofstream"); - } else { - MPIDI_global.n_reserved_vcis++; - for (int i = MPIDI_global.n_vcis; i < MPIDI_global.n_total_vcis; i++) { - if (!MPIDI_VCI(i).allocated) { - MPIDI_VCI(i).allocated = true; - *vci = i; - break; - } - } - } -#endif - if (is_shared) { - MPIDI_global.share_reserved_vcis = true; - } - return mpi_errno; -} - -int MPID_Deallocate_vci(int vci) -{ - MPIR_Assert(vci < MPIDI_global.n_total_vcis && vci >= MPIDI_global.n_vcis); - MPIR_Assert(MPIDI_VCI(vci).allocated); - MPIDI_VCI(vci).allocated = false; - MPIDI_global.n_reserved_vcis--; - return MPI_SUCCESS; -} - int MPID_Stream_create_hook(MPIR_Stream * stream) { int mpi_errno = MPI_SUCCESS; @@ -815,11 +752,12 @@ int MPID_Finalize(void) MPIR_Assert(err == 0); } - for (int i = 0; i < MPIDI_global.n_total_vcis; i++) { - MPIDU_genq_private_pool_destroy(MPIDI_global.per_vci[i].pack_buf_pool); - } - MPL_free(MPIDI_global.all_num_vcis); + mpi_errno = MPIDI_destroy_per_vci(0); + MPIR_ERR_CHECK(mpi_errno); + + mpi_errno = MPIDI_vci_finalize(); + MPIR_ERR_CHECK(mpi_errno); memset(&MPIDI_global, 0, sizeof(MPIDI_global)); diff --git a/src/mpid/ch4/src/ch4_vci.c b/src/mpid/ch4/src/ch4_vci.c index f958e501320..cf29f00e777 100644 --- a/src/mpid/ch4/src/ch4_vci.c +++ b/src/mpid/ch4/src/ch4_vci.c @@ -5,9 +5,141 @@ #include "mpidimpl.h" +/* +=== BEGIN_MPI_T_CVAR_INFO_BLOCK === + +cvars: + - name : MPIR_CVAR_CH4_NUM_VCIS + category : CH4 + type : int + default : 1 + class : none + verbosity : MPI_T_VERBOSITY_USER_BASIC + scope : MPI_T_SCOPE_LOCAL + description : >- + Sets the number of VCIs to be implicitly used (should be a subset of MPIDI_CH4_MAX_VCIS). + + - name : MPIR_CVAR_CH4_RESERVE_VCIS + category : CH4 + type : int + default : 0 + class : none + verbosity : MPI_T_VERBOSITY_USER_BASIC + scope : MPI_T_SCOPE_LOCAL + description : >- + Sets the number of VCIs that user can explicitly allocate (should be a subset of MPIDI_CH4_MAX_VCIS). + +=== END_MPI_T_CVAR_INFO_BLOCK === +*/ + +int MPIDI_vci_init(void) +{ + /* Initialize multiple VCIs */ + /* TODO: add checks to ensure MPIDI_vci_t is padded or aligned to MPL_CACHELINE_SIZE */ + MPIR_Assert(MPIR_CVAR_CH4_NUM_VCIS >= 1); /* number of vcis used in implicit vci hashing */ + MPIR_Assert(MPIR_CVAR_CH4_RESERVE_VCIS >= 0); /* maximum number of vcis can be reserved */ + + MPIDI_global.n_vcis = 1; + MPIDI_global.n_total_vcis = 1; + MPIDI_global.n_reserved_vcis = 0; + MPIDI_global.share_reserved_vcis = false; + + MPIDI_global.all_num_vcis = MPL_calloc(MPIR_Process.size, sizeof(int), MPL_MEM_OTHER); + MPIR_Assert(MPIDI_global.all_num_vcis); + + return MPI_SUCCESS; +} + +int MPIDI_vci_finalize(void) +{ + int mpi_errno = MPI_SUCCESS; + + for (int vci = 1; vci < MPIDI_global.n_total_vcis; vci++) { + mpi_errno = MPIDI_destroy_per_vci(vci); + MPIR_ERR_CHECK(mpi_errno); + } + MPL_free(MPIDI_global.all_num_vcis); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis) { int mpi_errno = MPI_SUCCESS; + /* FIXME: currently ofi require each process to have the same number of nics, + * thus need access to world_comm for collectives. We should remove + * this restriction, then we can move MPIDI_NM_init_vcis to + * MPIDI_world_pre_init. + */ + MPIR_Assert(n_total_vcis <= MPIDI_CH4_MAX_VCIS); + MPIR_Assert(n_total_vcis <= MPIR_REQUEST_NUM_POOLS); + + int num_vcis_actual; + mpi_errno = MPIDI_NM_init_vcis(n_total_vcis, &num_vcis_actual); + MPIR_ERR_CHECK(mpi_errno); + +#if MPIDI_CH4_MAX_VCIS == 1 + MPIR_Assert(num_vcis_actual == 1); +#else + MPIR_Assert(num_vcis_actual > 0 && num_vcis_actual <= MPIDI_global.n_total_vcis); + + MPIDI_global.n_total_vcis = num_vcis_actual; + MPIDI_global.n_vcis = MPL_MIN(MPIR_CVAR_CH4_NUM_VCIS, MPIDI_global.n_total_vcis); + + mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT, + MPIDI_global.all_num_vcis, 1, MPI_INT, + MPIR_Process.comm_world, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); +#endif + + for (int vci = 1; vci < MPIDI_global.n_total_vcis; vci++) { + mpi_errno = MPIDI_init_per_vci(vci); + MPIR_ERR_CHECK(mpi_errno); + } + + fn_exit: return mpi_errno; + fn_fail: + goto fn_exit; +} + +int MPID_Allocate_vci(int *vci, bool is_shared) +{ + int mpi_errno = MPI_SUCCESS; + + *vci = 0; +#if MPIDI_CH4_MAX_VCIS == 1 + MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**ch4nostream"); +#else + + if (MPIDI_global.n_vcis + MPIDI_global.n_reserved_vcis >= MPIDI_global.n_total_vcis) { + MPIR_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**outofstream"); + } else { + MPIDI_global.n_reserved_vcis++; + for (int i = MPIDI_global.n_vcis; i < MPIDI_global.n_total_vcis; i++) { + if (!MPIDI_VCI(i).allocated) { + MPIDI_VCI(i).allocated = true; + *vci = i; + break; + } + } + } +#endif + if (is_shared) { + MPIDI_global.share_reserved_vcis = true; + } + return mpi_errno; +} + +int MPID_Deallocate_vci(int vci) +{ + MPIR_Assert(vci < MPIDI_global.n_total_vcis && vci >= MPIDI_global.n_vcis); + MPIR_Assert(MPIDI_VCI(vci).allocated); + MPIDI_VCI(vci).allocated = false; + MPIDI_global.n_reserved_vcis--; + return MPI_SUCCESS; } diff --git a/src/mpid/ch4/src/mpidig.h b/src/mpid/ch4/src/mpidig.h index f253873c87c..5f59c3eace7 100644 --- a/src/mpid/ch4/src/mpidig.h +++ b/src/mpid/ch4/src/mpidig.h @@ -152,6 +152,8 @@ void MPIDIG_am_tag_recv_reg_cb(int tag_recv_id, MPIDIG_am_tag_recv_cb tag_recv_c int MPIDIG_am_init(void); void MPIDIG_am_finalize(void); +int MPIDIG_init_per_vci(int vci); +int MPIDIG_destroy_per_vci(int vci); /* am protocol prototypes */ diff --git a/src/mpid/ch4/src/mpidig_init.c b/src/mpid/ch4/src/mpidig_init.c index 83683d63e4b..34ad19706ed 100644 --- a/src/mpid/ch4/src/mpidig_init.c +++ b/src/mpid/ch4/src/mpidig_init.c @@ -79,26 +79,47 @@ void MPIDIG_am_tag_recv_reg_cb(int tag_recv_id, MPIDIG_am_tag_recv_cb tag_recv_c MPIR_FUNC_EXIT; } -int MPIDIG_am_init(void) +int MPIDIG_init_per_vci(int vci) { int mpi_errno = MPI_SUCCESS; - MPIR_FUNC_ENTER; - for (int vci = 0; vci < MPIDI_global.n_total_vcis; vci++) { - MPIDI_global.per_vci[vci].posted_list = NULL; - MPIDI_global.per_vci[vci].unexp_list = NULL; + MPIDI_global.per_vci[vci].posted_list = NULL; + MPIDI_global.per_vci[vci].unexp_list = NULL; - mpi_errno = MPIDU_genq_private_pool_create(MPIDIU_REQUEST_POOL_CELL_SIZE, - MPIDIU_REQUEST_POOL_NUM_CELLS_PER_CHUNK, - 0 /* unlimited */ , - host_alloc, host_free, - &MPIDI_global.per_vci[vci].request_pool); - MPIR_ERR_CHECK(mpi_errno); + mpi_errno = MPIDU_genq_private_pool_create(MPIDIU_REQUEST_POOL_CELL_SIZE, + MPIDIU_REQUEST_POOL_NUM_CELLS_PER_CHUNK, + 0 /* unlimited */ , + host_alloc, host_free, + &MPIDI_global.per_vci[vci].request_pool); + MPIR_ERR_CHECK(mpi_errno); - MPIDI_global.per_vci[vci].cmpl_list = NULL; - MPL_atomic_store_uint64(&MPIDI_global.per_vci[vci].exp_seq_no, 0); - MPL_atomic_store_uint64(&MPIDI_global.per_vci[vci].nxt_seq_no, 0); - } + MPIDI_global.per_vci[vci].cmpl_list = NULL; + MPL_atomic_store_uint64(&MPIDI_global.per_vci[vci].exp_seq_no, 0); + MPL_atomic_store_uint64(&MPIDI_global.per_vci[vci].nxt_seq_no, 0); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +int MPIDIG_destroy_per_vci(int vci) +{ + int mpi_errno = MPI_SUCCESS; + + mpi_errno = MPIDU_genq_private_pool_destroy(MPIDI_global.per_vci[vci].request_pool); + MPIR_ERR_CHECK(mpi_errno); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +int MPIDIG_am_init(void) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; MPIDI_global.part_posted_list = NULL; MPIDI_global.part_unexp_list = NULL; @@ -180,9 +201,6 @@ void MPIDIG_am_finalize(void) MPIR_FUNC_ENTER; MPIDIU_map_destroy(MPIDI_global.win_map); - for (int vci = 0; vci < MPIDI_global.n_total_vcis; vci++) { - MPIDU_genq_private_pool_destroy(MPIDI_global.per_vci[vci].request_pool); - } MPIR_FUNC_EXIT; } From 03c0384b20124d0cfc6b78ae6701592c5734a4df Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 27 Dec 2024 20:08:16 -0600 Subject: [PATCH 06/25] ch4: refactor to enable vcis in MPIDI_Comm_set_vcis Gather all multiple vci init code in MPIDI_Comm_set_vcis. So - 1. The rest of the init code only deal with root vci. 2. Prepare for future dynamic and per-comm vci. --- src/mpid/ch4/src/ch4_comm.c | 13 ++++++++ src/mpid/ch4/src/ch4_init.c | 5 --- src/mpid/ch4/src/ch4_vci.c | 64 ++++++++++++++++++++++++++----------- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/src/mpid/ch4/src/ch4_comm.c b/src/mpid/ch4/src/ch4_comm.c index 808d6f6e21b..c0d27293230 100644 --- a/src/mpid/ch4/src/ch4_comm.c +++ b/src/mpid/ch4/src/ch4_comm.c @@ -229,6 +229,19 @@ int MPID_Comm_commit_post_hook(MPIR_Comm * comm) MPIR_ERR_CHECK(mpi_errno); #endif + /* set_vcis for comm_world. + * TODO: expose MPIX_Comm_set_vcis to allow vcis for arbitrary comm + */ + if (comm == MPIR_Process.comm_world) { + int n_total_vcis = MPIR_CVAR_CH4_NUM_VCIS + MPIR_CVAR_CH4_RESERVE_VCIS; + /* we always need call set_vcis even when n_total_vcis is 1 because - + * 1. in case netmod need support multi-nics. + * 2. remote processes may have multiple vcis. + */ + mpi_errno = MPIDI_Comm_set_vcis(comm, n_total_vcis); + MPIR_ERR_CHECK(mpi_errno); + } + /* prune selection tree */ if (MPIDI_global.csel_root) { mpi_errno = MPIR_Csel_prune(MPIDI_global.csel_root, comm, &MPIDI_COMM(comm, csel_comm)); diff --git a/src/mpid/ch4/src/ch4_init.c b/src/mpid/ch4/src/ch4_init.c index 7fb1ab6986a..11f65f5eab7 100644 --- a/src/mpid/ch4/src/ch4_init.c +++ b/src/mpid/ch4/src/ch4_init.c @@ -658,10 +658,6 @@ int MPIDI_world_post_init(void) { int mpi_errno = MPI_SUCCESS; - int n_total_vcis = MPIR_CVAR_CH4_NUM_VCIS + MPIR_CVAR_CH4_RESERVE_VCIS; - mpi_errno = MPIDI_Comm_set_vcis(MPIR_Process.comm_world, n_total_vcis); - MPIR_ERR_CHECK(mpi_errno); - #ifndef MPIDI_CH4_DIRECT_NETMOD mpi_errno = MPIDI_SHM_post_init(); MPIR_ERR_CHECK(mpi_errno); @@ -669,7 +665,6 @@ int MPIDI_world_post_init(void) mpi_errno = MPIDI_NM_post_init(); MPIR_ERR_CHECK(mpi_errno); - MPIR_Process.comm_world->vcis_enabled = true; MPIDI_global.is_initialized = 1; fn_exit: diff --git a/src/mpid/ch4/src/ch4_vci.c b/src/mpid/ch4/src/ch4_vci.c index cf29f00e777..628733f3896 100644 --- a/src/mpid/ch4/src/ch4_vci.c +++ b/src/mpid/ch4/src/ch4_vci.c @@ -66,35 +66,62 @@ int MPIDI_vci_finalize(void) goto fn_exit; } +/* enable multiple vcis for this comm. + * The number of vcis below MPIR_CVAR_CH4_NUM_VCIS will be used for implicit vcis. + * The number of vcis above MPIR_CVAR_CH4_NUM_VCIS will be used as explicit (reserved) vcis. + * The netmod may create less than the requested number of vcis. + */ int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis) { int mpi_errno = MPI_SUCCESS; + MPIR_CHKLMEM_DECL(3); + + /* for now, intracomm only. I believe we can enable it for intercomm in the future */ + MPIR_Assert(comm->comm_kind == MPIR_COMM_KIND__INTRACOMM); + + /* make sure multiple vcis are not previously enabled. Or it will mess up the + * internal communication during setting up vcis. */ + MPIR_Assert(!comm->vcis_enabled); + /* actually, only do it once for now */ + MPIR_Assert(MPIDI_global.n_total_vcis == 1); + + /* get global ranks */ + bool same_world = true; + int nprocs = comm->local_size; + int *granks; + MPIR_CHKLMEM_MALLOC(granks, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER, + MPL_MEM_OTHER); + for (int i = 0; i < nprocs; i++) { + int avtid; + MPIDIU_comm_rank_to_pid(comm, i, &granks[i], &avtid); + MPIR_Assert(avtid == 0); + } - /* FIXME: currently ofi require each process to have the same number of nics, - * thus need access to world_comm for collectives. We should remove - * this restriction, then we can move MPIDI_NM_init_vcis to - * MPIDI_world_pre_init. - */ - MPIR_Assert(n_total_vcis <= MPIDI_CH4_MAX_VCIS); - MPIR_Assert(n_total_vcis <= MPIR_REQUEST_NUM_POOLS); + /* for now, we only allow setup setup vcis for each remote once */ + for (int i = 0; i < nprocs; i++) { + MPIR_Assert(MPIDI_global.all_num_vcis[granks[i]] == 0); + } + /* set up local vcis */ int num_vcis_actual; - mpi_errno = MPIDI_NM_init_vcis(n_total_vcis, &num_vcis_actual); + mpi_errno = MPIDI_NM_init_vcis(MPIDI_global.n_total_vcis, &num_vcis_actual); MPIR_ERR_CHECK(mpi_errno); -#if MPIDI_CH4_MAX_VCIS == 1 - MPIR_Assert(num_vcis_actual == 1); -#else - MPIR_Assert(num_vcis_actual > 0 && num_vcis_actual <= MPIDI_global.n_total_vcis); - MPIDI_global.n_total_vcis = num_vcis_actual; - MPIDI_global.n_vcis = MPL_MIN(MPIR_CVAR_CH4_NUM_VCIS, MPIDI_global.n_total_vcis); - mpi_errno = MPIR_Allgather_fallback(&MPIDI_global.n_vcis, 1, MPI_INT, - MPIDI_global.all_num_vcis, 1, MPI_INT, - MPIR_Process.comm_world, MPIR_ERR_NONE); + /* gather the number of remote vcis */ + int *all_num_vcis; + MPIR_CHKLMEM_MALLOC(all_num_vcis, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER, + MPL_MEM_OTHER); + mpi_errno = MPIR_Allgather_impl(num_vcis_actual, 1, MPI_INT, + all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); -#endif + + for (int i = 0; i < nprocs; i++) { + MPIDI_global.all_num_vcis[granks[i]] = all_num_vcis[i]; + } + + comm->vcis_enabled = true; for (int vci = 1; vci < MPIDI_global.n_total_vcis; vci++) { mpi_errno = MPIDI_init_per_vci(vci); @@ -102,6 +129,7 @@ int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis) } fn_exit: + MPIR_CHKLMEM_FREEALL(); return mpi_errno; fn_fail: goto fn_exit; From f4228f3f009b311eb94a45a806ba0d5f945a13ac Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 7 Jan 2025 22:39:44 -0600 Subject: [PATCH 07/25] ch4/ofi: use both local and remote vci/nic in av_to_phys We may relax the av insertion order which may require the full (vci_local, nic_local, vci_remote, nic_remote) to look up an actual destination address. Add MPIDI_OFI_av_to_phys_root for convenience and quick survey on where we restrict in only root vci (such as the init and spawn paths). Remove MPIDI_OFI_comm_to_phys and prefer an explicit MPIDIU_comm_rank_to_av and then MPIDI_OFI_av_to_phys. Refactor MPIDI_OFI_SET_AM_HDR_COMMON in ofi_am_impl.h to directly use dst_addr (as remote_id) rather than to recalculate it. --- src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h | 57 ++++++++++---------- src/mpid/ch4/netmod/ofi/ofi_am_events.h | 3 +- src/mpid/ch4/netmod/ofi/ofi_am_impl.h | 29 +++++----- src/mpid/ch4/netmod/ofi/ofi_events.c | 3 +- src/mpid/ch4/netmod/ofi/ofi_huge.c | 3 +- src/mpid/ch4/netmod/ofi/ofi_impl.h | 11 ++-- src/mpid/ch4/netmod/ofi/ofi_init.c | 4 +- src/mpid/ch4/netmod/ofi/ofi_probe.h | 2 +- src/mpid/ch4/netmod/ofi/ofi_recv.h | 8 +-- src/mpid/ch4/netmod/ofi/ofi_rma.c | 13 +++-- src/mpid/ch4/netmod/ofi/ofi_rma.h | 16 +++--- src/mpid/ch4/netmod/ofi/ofi_send.h | 19 +++---- src/mpid/ch4/netmod/ofi/ofi_spawn.c | 3 +- 13 files changed, 91 insertions(+), 80 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h index a9de6f689d6..591c1491d15 100644 --- a/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h +++ b/src/mpid/ch4/netmod/ofi/coll/ofi_coll_util.h @@ -96,9 +96,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf /* Post recv for RTR from children */ for (p = (int *) utarray_front(my_tree->children); p != NULL; p = (int *) utarray_next(my_tree->children, p)) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, *p); ret = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].rx, &((*works)[i + j]), - MPIDI_OFI_comm_to_phys(comm_ptr, *p, 0, 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_RECV, NULL, 0, rtr_tag, comm_ptr, *p, 0, *rcv_cntr, *rcv_cntr, false); @@ -114,10 +115,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf MPIR_ERR_POP(mpi_errno); if (!is_root) { /* non-root nodes post recv for data from parents */ + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, parent); mpi_errno = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].rx, &((*works)[i]), - MPIDI_OFI_comm_to_phys(comm_ptr, parent, - 0, 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_RECV, buffer, count * data_sz, tag, comm_ptr, parent, 0, *rcv_cntr, *rcv_cntr, false); @@ -132,10 +133,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf if (!is_root) { /* Non-root nodes send RTR to parents */ uint64_t match_bits = MPIDI_OFI_init_sendtag(comm_ptr->context_id + context_offset, comm_ptr->rank, rtr_tag); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, parent); MPIDI_OFI_CALL_RETRY(fi_tinject (MPIDI_OFI_global.ctx[0].tx, NULL, 0, - MPIDI_OFI_comm_to_phys(comm_ptr, parent, 0, 0), match_bits), 0, - tinject); + MPIDI_OFI_av_to_phys_root(av), match_bits), 0, tinject); } if (is_root) { @@ -148,10 +149,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_tagged(void *buf /* Root and intermediate nodes send data to children */ for (p = (int *) utarray_front(my_tree->children); p != NULL; p = (int *) utarray_next(my_tree->children, p)) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, *p); mpi_errno = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].tx, &((*works)[index + k]), - MPIDI_OFI_comm_to_phys(comm_ptr, *p, 0, 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_SEND, buffer, count * data_sz, tag, comm_ptr, *p, threshold, *rcv_cntr, *snd_cntr, false); @@ -243,12 +245,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer /* Post recv for RTR from children */ for (j = 0; j < num_children; j++) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, first_child + j); mpi_errno = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].rx, &((*works)[i + j]), - MPIDI_OFI_comm_to_phys(comm_ptr, - first_child + j, 0, - 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_RECV, NULL, 0, rtr_tag, comm_ptr, first_child + j, 0, *rcv_cntr, *rcv_cntr, false); @@ -263,10 +264,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer MPIR_ERR_POP(mpi_errno); if (!is_root) { /* Non-root nodes post recv for data */ + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, parent); mpi_errno = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].rx, &((*works)[i]), - MPIDI_OFI_comm_to_phys(comm_ptr, parent, 0, - 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_RECV, buffer, count * data_sz, tag, comm_ptr, parent, 0, *rcv_cntr, *rcv_cntr, false); @@ -281,10 +282,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer if (!is_root) { /* Non-root nodes send RTR to parents */ s_match_bits = MPIDI_OFI_init_sendtag(comm_ptr->context_id + context_offset, comm_ptr->rank, rtr_tag); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, parent); MPIDI_OFI_CALL_RETRY(fi_tinject (MPIDI_OFI_global.ctx[0].tx, NULL, 0, - MPIDI_OFI_comm_to_phys(comm_ptr, parent, 0, 0), s_match_bits), 0, - tinject); + MPIDI_OFI_av_to_phys_root(av), s_match_bits), 0, tinject); } if (is_root) { @@ -296,12 +297,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_tagged(void *buffer /* Root and intremediate nodes send data to children */ for (k = 0; k < num_children; k++) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, first_child + k); mpi_errno = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].tx, &((*works)[index + k]), - MPIDI_OFI_comm_to_phys(comm_ptr, - first_child + k, 0, - 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_SEND, buffer, count * data_sz, tag, comm_ptr, first_child + k, threshold, *rcv_cntr, @@ -424,10 +424,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer /* Post recv for RTR from children; this is needed to avoid unexpected messages */ for (p = (int *) utarray_front(my_tree->children); p != NULL; p = (int *) utarray_next(my_tree->children, p)) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, *p); mpi_errno = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].rx, &((*works)[i + j]), - MPIDI_OFI_comm_to_phys(comm_ptr, *p, 0, 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_RECV, NULL, 0, rtr_tag, comm_ptr, *p, 0, *rcv_cntr, *rcv_cntr, false); @@ -443,10 +444,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer if (!is_root) { /* Non-root nodes send RTR to parents */ s_match_bits = MPIDI_OFI_init_sendtag(comm_ptr->context_id + context_offset, comm_ptr->rank, rtr_tag); - MPIDI_OFI_CALL_RETRY(fi_tinject - (MPIDI_OFI_global.ctx[0].tx, NULL, 0, - MPIDI_OFI_comm_to_phys(comm_ptr, parent, 0, 0), s_match_bits), 0, - tinject); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, parent); + MPIDI_OFI_CALL_RETRY(fi_tinject(MPIDI_OFI_global.ctx[0].tx, NULL, 0, + MPIDI_OFI_av_to_phys_root(av), s_match_bits), 0, tinject); } if (is_root) { @@ -458,9 +458,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_knomial_triggered_rma(void *buffer /* Root and intermediate nodes send data to children */ for (p = (int *) utarray_front(my_tree->children); p != NULL; p = (int *) utarray_next(my_tree->children, p)) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, *p); mpi_errno = MPIDI_OFI_prepare_rma_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].tx, &((*works)[i++]), - MPIDI_OFI_comm_to_phys(comm_ptr, *p, 0, 0), + MPIDI_OFI_av_to_phys_root(av), buffer, fi_mr_key(*r_mr), count * data_sz, threshold, *rcv_cntr, *snd_cntr, MPIDI_OFI_TRIGGERED_RMA_WRITE, @@ -567,12 +568,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_rma(void *buffer, i /* Post recv for RTR from children */ for (j = 0; j < num_children; j++) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, first_child + j); mpi_errno = MPIDI_OFI_prepare_tagged_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].rx, &((*works)[i + j]), - MPIDI_OFI_comm_to_phys(comm_ptr, - first_child + j, 0, - 0), + MPIDI_OFI_av_to_phys_root(av), MPIDI_OFI_TRIGGERED_TAGGED_RECV, NULL, 0, rtr_tag, comm_ptr, first_child + j, 0, *rcv_cntr, *rcv_cntr, false); @@ -587,10 +587,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_rma(void *buffer, i if (!is_root) { /* Non-root nodes send RTR to parents; this is needed to avoid unexpected messages */ s_match_bits = MPIDI_OFI_init_sendtag(comm_ptr->context_id + context_offset, comm_ptr->rank, rtr_tag); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, parent); MPIDI_OFI_CALL_RETRY(fi_tinject (MPIDI_OFI_global.ctx[0].tx, NULL, 0, - MPIDI_OFI_comm_to_phys(comm_ptr, parent, 0, 0), s_match_bits), 0, - tinject); + MPIDI_OFI_av_to_phys_root(av), s_match_bits), 0, tinject); } if (is_root) { @@ -602,9 +602,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_Ibcast_kary_triggered_rma(void *buffer, i /* Root and intremediate nodes send data to children */ for (k = 0; k < num_children; k++) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm_ptr, first_child + k); MPIDI_OFI_prepare_rma_control_cmd(MPIDI_OFI_global.ctx[0].domain, MPIDI_OFI_global.ctx[0].tx, &((*works)[i + k]), - MPIDI_OFI_comm_to_phys(comm_ptr, first_child + k, 0, 0), + MPIDI_OFI_av_to_phys_root(av), buffer, fi_mr_key(*r_mr), count * data_sz, threshold, *rcv_cntr, *snd_cntr, MPIDI_OFI_TRIGGERED_RMA_WRITE, false); diff --git a/src/mpid/ch4/netmod/ofi/ofi_am_events.h b/src/mpid/ch4/netmod/ofi/ofi_am_events.h index 351eba3beb3..df3667ad09f 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_am_events.h +++ b/src/mpid/ch4/netmod/ofi/ofi_am_events.h @@ -129,11 +129,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_rdma_read(void *dst, .len = curr_len, .key = rma_key }; + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, src_rank); struct fi_msg_rma msg = { .msg_iov = &iov, .desc = NULL, .iov_count = 1, - .addr = MPIDI_OFI_comm_to_phys(comm, src_rank, nic, vci_remote), + .addr = MPIDI_OFI_av_to_phys(av, vci_local, nic, vci_remote, nic), .rma_iov = &rma_iov, .rma_iov_count = 1, .context = &am_req->context, diff --git a/src/mpid/ch4/netmod/ofi/ofi_am_impl.h b/src/mpid/ch4/netmod/ofi/ofi_am_impl.h index fddadc97ef3..84ffa57c40f 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_am_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_am_impl.h @@ -38,15 +38,12 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_progress_do_queue(int vci_idx); (((uint64_t) MPIR_Process.world_id << 32) + \ ((uint64_t) MPIR_Process.rank << 16) + ((nic) << 8) + (vci)) -#define MPIDI_OFI_REMOTE_ID(comm, rank, nic, vci) \ - MPIDI_OFI_comm_to_phys(comm, rank, nic, vci) - -#define MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, comm, rank, nic_src, vci_src, nic_dst, vci_dst) \ +#define MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, vci_src, vci_dst, dst_addr) \ do { \ (msg_hdr)->vci_src = vci_src; \ (msg_hdr)->vci_dst = vci_dst; \ - (msg_hdr)->src_id = MPIDI_OFI_LOCAL_ID(nic_src, vci_src); \ - uint64_t remote_id = MPIDI_OFI_REMOTE_ID(comm, rank, nic_dst, vci_dst); \ + (msg_hdr)->src_id = MPIDI_OFI_LOCAL_ID(0, vci_src); \ + uint64_t remote_id = (uint64_t) dst_addr; \ (msg_hdr)->seqno = MPIDI_OFI_am_fetch_incr_send_seqno(vci_src, remote_id); \ } while (0) @@ -238,7 +235,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_am_isend_long(int rank, MPIR_Comm * comm, MPIDI_OFI_lmt_msg_payload_t *lmt_info; int nic = 0; int ctx_idx = MPIDI_OFI_get_ctx_index(vci_src, nic); - fi_addr_t dst_addr = MPIDI_OFI_comm_to_phys(comm, rank, nic, vci_dst); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, rank); + fi_addr_t dst_addr = MPIDI_OFI_av_to_phys(av, vci_src, nic, vci_dst, nic); MPIR_FUNC_ENTER; @@ -254,7 +252,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_am_isend_long(int rank, MPIR_Comm * comm, msg_hdr->am_hdr_sz = am_hdr_sz; msg_hdr->payload_sz = 0; /* LMT info sent as header */ msg_hdr->am_type = MPIDI_AMTYPE_RDMA_READ; - MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, comm, rank, nic, vci_src, nic, vci_dst); + MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, vci_src, vci_dst, dst_addr); lmt_info = (void *) ((char *) msg_hdr + sizeof(MPIDI_OFI_am_header_t) + am_hdr_sz); lmt_info->context_id = comm->context_id; @@ -312,7 +310,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_am_isend_short(int rank, MPIR_Comm * comm int mpi_errno = MPI_SUCCESS; int nic = 0; int ctx_idx = MPIDI_OFI_get_ctx_index(vci_src, nic); - fi_addr_t dst_addr = MPIDI_OFI_comm_to_phys(comm, rank, nic, vci_dst); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, rank); + fi_addr_t dst_addr = MPIDI_OFI_av_to_phys(av, vci_src, nic, vci_dst, nic); MPIR_FUNC_ENTER; @@ -336,7 +335,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_am_isend_short(int rank, MPIR_Comm * comm msg_hdr->am_hdr_sz = MPIDI_OFI_AM_SREQ_HDR(sreq, am_hdr_sz); msg_hdr->payload_sz = data_sz; msg_hdr->am_type = MPIDI_AMTYPE_SHORT; - MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, comm, rank, nic, vci_src, nic, vci_dst); + MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, vci_src, vci_dst, dst_addr); MPIR_cc_inc(sreq->cc_ptr); MPIDI_OFI_AMREQUEST(sreq, event_id) = MPIDI_OFI_EVENT_AM_SEND; @@ -378,7 +377,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_am_isend_pipeline(int rank, MPIR_Comm * c MPIDI_OFI_am_header_t *msg_hdr; int nic = 0; int ctx_idx = MPIDI_OFI_get_ctx_index(vci_src, nic); - fi_addr_t dst_addr = MPIDI_OFI_comm_to_phys(comm, rank, nic, vci_dst); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, rank); + fi_addr_t dst_addr = MPIDI_OFI_av_to_phys(av, vci_src, nic, vci_dst, nic); MPIR_FUNC_ENTER; @@ -399,7 +399,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_am_isend_pipeline(int rank, MPIR_Comm * c msg_hdr->am_hdr_sz = am_hdr_sz; msg_hdr->payload_sz = seg_sz; msg_hdr->am_type = MPIDI_AMTYPE_PIPELINE; - MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, comm, rank, nic, vci_src, nic, vci_dst); + MPIDI_OFI_SET_AM_HDR_COMMON(msg_hdr, vci_src, vci_dst, dst_addr); MPIR_cc_inc(sreq->cc_ptr); send_req->event_id = MPIDI_OFI_EVENT_AM_SEND_PIPELINE; @@ -572,7 +572,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_inject(int rank, size_t buff_len; int nic = 0; int ctx_idx = MPIDI_OFI_get_ctx_index(vci_src, nic); - fi_addr_t dst_addr = MPIDI_OFI_comm_to_phys(comm, rank, nic, vci_dst); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, rank); + fi_addr_t dst_addr = MPIDI_OFI_av_to_phys(av, vci_src, nic, vci_dst, nic); MPIR_CHKLMEM_DECL(1); MPIR_FUNC_ENTER; @@ -584,7 +585,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_inject(int rank, msg_hdr.am_hdr_sz = am_hdr_sz; msg_hdr.payload_sz = 0; msg_hdr.am_type = MPIDI_AMTYPE_SHORT_HDR; - MPIDI_OFI_SET_AM_HDR_COMMON((&msg_hdr), comm, rank, nic, vci_src, nic, vci_dst); + MPIDI_OFI_SET_AM_HDR_COMMON(&msg_hdr, vci_src, vci_dst, dst_addr); MPIR_Assert((uint64_t) comm->rank < (1ULL << MPIDI_OFI_AM_RANK_BITS)); diff --git a/src/mpid/ch4/netmod/ofi/ofi_events.c b/src/mpid/ch4/netmod/ofi/ofi_events.c index 178cbb80ca7..5896c56d54d 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_events.c +++ b/src/mpid/ch4/netmod/ofi/ofi_events.c @@ -821,7 +821,8 @@ int MPIDI_OFI_send_ack(MPIR_Request * rreq, int context_id, void *hdr, int hdr_s int vci_dst = MPIDI_get_vci(DST_VCI_FROM_RECVER, comm, src_rank, dst_rank, tag); int nic = 0; int ctx_idx = MPIDI_OFI_get_ctx_index(vci_dst, nic); - fi_addr_t dest_addr = MPIDI_OFI_comm_to_phys(comm, src_rank, nic, vci_src); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, src_rank); + fi_addr_t dest_addr = MPIDI_OFI_av_to_phys(av, vci_dst, nic, vci_src, nic); MPIDI_OFI_CALL_RETRY(fi_tinject(MPIDI_OFI_global.ctx[ctx_idx].tx, hdr, hdr_sz, dest_addr, match_bits), vci_dst, tinject); fn_exit: diff --git a/src/mpid/ch4/netmod/ofi/ofi_huge.c b/src/mpid/ch4/netmod/ofi/ofi_huge.c index cdd47b61ff5..160ea7b29ab 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_huge.c +++ b/src/mpid/ch4/netmod/ofi/ofi_huge.c @@ -100,7 +100,8 @@ static int get_huge_issue_read(MPIR_Request * rreq) int nic = 0; while (bytesLeft > 0) { int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic); - fi_addr_t addr = MPIDI_OFI_comm_to_phys(comm, info->origin_rank, nic, vci_remote); + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, info->origin_rank); + fi_addr_t addr = MPIDI_OFI_av_to_phys(av, vci_local, nic, vci_remote, nic); uint64_t remote_key = info->rma_keys[nic]; MPI_Aint bytesToGet = MPL_MIN(chunk_size, bytesLeft); diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index f2a1088624c..368be20e672 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -439,7 +439,9 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_OFI_win_request_complete(MPIDI_OFI_win_reque * on any local endpoints, as long as we are careful in the insertion order). Thus, * we get away with simplified interface using just (nic, vci) pair. */ -MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_av_to_phys(MPIDI_av_entry_t * av, int nic, int vci) +MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_av_to_phys(MPIDI_av_entry_t * av, + int local_vci, int local_nic, + int vci, int nic) { #ifdef MPIDI_OFI_VNI_USE_DOMAIN if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) { @@ -457,11 +459,10 @@ MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_av_to_phys(MPIDI_av_entry_t * av, i #endif } -MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_comm_to_phys(MPIR_Comm * comm, int rank, - int nic, int vci) +/* a simpler version used where vci is not enabled, e.g. init and spawn */ +MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_av_to_phys_root(MPIDI_av_entry_t * av) { - MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, rank); - return MPIDI_OFI_av_to_phys(av, nic, vci); + return MPIDI_OFI_av_to_phys(av, 0, 0, 0, 0); } MPL_STATIC_INLINE_PREFIX bool MPIDI_OFI_is_tag_sync(uint64_t match_bits) diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 634d3b7facb..9fd7820e485 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -943,7 +943,7 @@ static int flush_send(int dst, int nic, int vci, MPIDI_OFI_dynamic_process_reque { int mpi_errno = MPI_SUCCESS; - fi_addr_t addr = MPIDI_OFI_av_to_phys(&MPIDIU_get_av(0, dst), nic, vci); + fi_addr_t addr = MPIDI_OFI_av_to_phys(&MPIDIU_get_av(0, dst), vci, nic, vci, nic); static int data = 0; uint64_t match_bits = MPIDI_OFI_init_sendtag(MPIDI_OFI_FLUSH_CONTEXT_ID, 0, MPIDI_OFI_FLUSH_TAG); @@ -974,7 +974,7 @@ static int flush_recv(int src, int nic, int vci, MPIDI_OFI_dynamic_process_reque { int mpi_errno = MPI_SUCCESS; - fi_addr_t addr = MPIDI_OFI_av_to_phys(&MPIDIU_get_av(0, src), nic, vci); + fi_addr_t addr = MPIDI_OFI_av_to_phys(&MPIDIU_get_av(0, src), vci, nic, vci, nic); uint64_t mask_bits = 0; uint64_t match_bits = MPIDI_OFI_init_sendtag(MPIDI_OFI_FLUSH_CONTEXT_ID, 0, MPIDI_OFI_FLUSH_TAG); diff --git a/src/mpid/ch4/netmod/ofi/ofi_probe.h b/src/mpid/ch4/netmod/ofi/ofi_probe.h index 162376eb397..8a0d7f98da9 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_probe.h +++ b/src/mpid/ch4/netmod/ofi/ofi_probe.h @@ -35,7 +35,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_iprobe(int source, } else { int sender_nic = MPIDI_OFI_multx_sender_nic_index(comm, comm->recvcontext_id, source, comm->rank, tag); - remote_proc = MPIDI_OFI_av_to_phys(addr, sender_nic, vci_src); + remote_proc = MPIDI_OFI_av_to_phys(addr, vci_dst, receiver_nic, vci_src, sender_nic); } if (message) { diff --git a/src/mpid/ch4/netmod/ofi/ofi_recv.h b/src/mpid/ch4/netmod/ofi/ofi_recv.h index 186ef692161..de2637bdc62 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_recv.h +++ b/src/mpid/ch4/netmod/ofi/ofi_recv.h @@ -79,7 +79,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_recv_iov(void *buf, MPI_Aint count, MPI_D int sender_nic = MPIDI_OFI_multx_sender_nic_index(comm, comm->recvcontext_id, rank, comm->rank, MPIDI_OFI_init_get_tag(match_bits)); - msg.addr = MPIDI_OFI_av_to_phys(addr, sender_nic, vci_remote); + msg.addr = MPIDI_OFI_av_to_phys(addr, vci_local, receiver_nic, vci_remote, sender_nic); } MPIDI_OFI_CALL_RETRY(fi_trecvmsg(MPIDI_OFI_global.ctx[ctx_idx].rx, &msg, flags), vci_local, @@ -245,7 +245,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf, int sender_nic = MPIDI_OFI_multx_sender_nic_index(comm, comm->recvcontext_id, rank, comm->rank, MPIDI_OFI_init_get_tag(match_bits)); - remote_addr = MPIDI_OFI_av_to_phys(addr, sender_nic, vci_remote); + remote_addr = + MPIDI_OFI_av_to_phys(addr, vci_local, receiver_nic, vci_remote, sender_nic); } /* Save pipeline information. */ @@ -311,7 +312,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf, int sender_nic = MPIDI_OFI_multx_sender_nic_index(comm, comm->recvcontext_id, rank, comm->rank, MPIDI_OFI_init_get_tag(match_bits)); - sender_addr = MPIDI_OFI_av_to_phys(addr, sender_nic, vci_remote); + sender_addr = + MPIDI_OFI_av_to_phys(addr, vci_local, receiver_nic, vci_remote, sender_nic); } MPIDI_OFI_CALL_RETRY(fi_trecv(MPIDI_OFI_global.ctx[ctx_idx].rx, recv_buf, data_sz, desc, sender_addr, match_bits, mask_bits, diff --git a/src/mpid/ch4/netmod/ofi/ofi_rma.c b/src/mpid/ch4/netmod/ofi/ofi_rma.c index 9a477f4e89d..47a1f86d3ad 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_rma.c +++ b/src/mpid/ch4/netmod/ofi/ofi_rma.c @@ -94,9 +94,9 @@ int MPIDI_OFI_nopack_putget(const void *origin_addr, MPI_Aint origin_count, } void *desc = NULL; - int nic = MPIDI_OFI_get_pref_nic(win->comm_ptr, target_rank);; + int nic_target = MPIDI_OFI_get_pref_nic(win->comm_ptr, target_rank);; - MPIDI_OFI_gpu_rma_register(origin_addr, origin_bytes, NULL, win, nic, &desc); + MPIDI_OFI_gpu_rma_register(origin_addr, origin_bytes, NULL, win, nic_target, &desc); int i = 0, j = 0; size_t msg_len; @@ -113,8 +113,9 @@ int MPIDI_OFI_nopack_putget(const void *origin_addr, MPI_Aint origin_count, msg_len = MPL_MIN(origin_iov[origin_cur].iov_len, target_iov[target_cur].iov_len); int vci = MPIDI_WIN(win, am_vci); + int vci_target = MPIDI_WIN_TARGET_VCI(win, target_rank); msg.desc = desc; - msg.addr = MPIDI_OFI_av_to_phys(addr, nic, vci); + msg.addr = MPIDI_OFI_av_to_phys(addr, vci, 0, vci_target, nic_target); msg.context = NULL; msg.data = 0; msg.msg_iov = &iov; @@ -207,7 +208,8 @@ static int issue_packed_put(MPIR_Win * win, MPIDI_OFI_win_request_t * req) MPIR_ERR_CHKANDSTMT(chunk == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); msg.desc = NULL; - msg.addr = MPIDI_OFI_av_to_phys(req->noncontig.put.target.addr, nic_target, vci_target); + msg.addr = + MPIDI_OFI_av_to_phys(req->noncontig.put.target.addr, vci, 0, vci_target, nic_target); msg.context = NULL; msg.data = 0; msg.msg_iov = &iov; @@ -292,7 +294,8 @@ static int issue_packed_get(MPIR_Win * win, MPIDI_OFI_win_request_t * req) MPIR_ERR_CHKANDSTMT(chunk == NULL, mpi_errno, MPI_ERR_NO_MEM, goto fn_fail, "**nomem"); msg.desc = NULL; - msg.addr = MPIDI_OFI_av_to_phys(req->noncontig.get.target.addr, nic_target, vci_target); + msg.addr = + MPIDI_OFI_av_to_phys(req->noncontig.get.target.addr, vci, 0, vci_target, nic_target); msg.context = NULL; msg.data = 0; msg.msg_iov = &iov; diff --git a/src/mpid/ch4/netmod/ofi/ofi_rma.h b/src/mpid/ch4/netmod/ofi/ofi_rma.h index 3ba70d68eae..b2cffff14f0 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_rma.h +++ b/src/mpid/ch4/netmod/ofi/ofi_rma.h @@ -245,10 +245,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_put(const void *origin_addr, (origin_bytes <= MPIDI_OFI_global.max_buffered_write && !MPL_gpu_attr_is_dev(&attr))) { MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); MPIDI_OFI_win_cntr_incr(win); + fi_addr_t dest = MPIDI_OFI_av_to_phys(addr, vci, 0, vci_target, nic_target); MPIDI_OFI_CALL_RETRY(fi_inject_write(MPIDI_OFI_WIN(win).ep, MPIR_get_contig_ptr(origin_addr, origin_true_lb), - target_bytes, - MPIDI_OFI_av_to_phys(addr, nic_target, vci_target), + target_bytes, dest, target_mr.addr + target_true_lb, target_mr.mr_key), vci, rdma_inject_write); MPID_THREAD_CS_EXIT(VCI, MPIDI_VCI_LOCK(vci)); @@ -272,7 +272,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_put(const void *origin_addr, MPIDI_OFI_gpu_rma_register(iov.iov_base, iov.iov_len, &attr, win, nic_target, &desc); msg.desc = desc; - msg.addr = MPIDI_OFI_av_to_phys(addr, nic_target, vci_target); + msg.addr = MPIDI_OFI_av_to_phys(addr, vci, 0, vci_target, nic_target); msg.context = NULL; msg.data = 0; msg.msg_iov = &iov; @@ -451,7 +451,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_get(void *origin_addr, msg.desc = desc; msg.msg_iov = &iov; msg.iov_count = 1; - msg.addr = MPIDI_OFI_av_to_phys(addr, nic_target, vci_target); + msg.addr = MPIDI_OFI_av_to_phys(addr, vci, 0, vci_target, nic_target); msg.rma_iov = &riov; msg.rma_iov_count = 1; msg.context = NULL; @@ -679,7 +679,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_compare_and_swap(const void *origin_ad msg.msg_iov = &originv; msg.desc = desc; msg.iov_count = 1; - msg.addr = MPIDI_OFI_av_to_phys(av, nic_target, vci_target); + msg.addr = MPIDI_OFI_av_to_phys(av, vci, 0, vci_target, nic_target); msg.rma_iov = &targetv; msg.rma_iov_count = 1; msg.datatype = fi_dt; @@ -805,7 +805,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_accumulate(const void *origin_addr, msg.msg_iov = &originv; msg.desc = desc; msg.iov_count = 1; - msg.addr = MPIDI_OFI_av_to_phys(addr, nic_target, vci_target); + msg.addr = MPIDI_OFI_av_to_phys(addr, vci, 0, vci_target, nic_target); msg.rma_iov = &targetv; msg.rma_iov_count = 1; msg.datatype = fi_dt; @@ -951,7 +951,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_get_accumulate(const void *origin_addr msg.msg_iov = &originv; msg.desc = desc; msg.iov_count = 1; - msg.addr = MPIDI_OFI_av_to_phys(addr, nic_target, vci_target); + msg.addr = MPIDI_OFI_av_to_phys(addr, vci, 0, vci_target, nic_target); msg.rma_iov = &targetv; msg.rma_iov_count = 1; msg.datatype = fi_dt; @@ -1182,7 +1182,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_mpi_fetch_and_op(const void *origin_addr, msg.msg_iov = &originv; msg.desc = desc; msg.iov_count = 1; - msg.addr = MPIDI_OFI_av_to_phys(av, nic_target, vci_target); + msg.addr = MPIDI_OFI_av_to_phys(av, vci, 0, vci_target, nic_target); msg.rma_iov = &targetv; msg.rma_iov_count = 1; msg.datatype = fi_dt; diff --git a/src/mpid/ch4/netmod/ofi/ofi_send.h b/src/mpid/ch4/netmod/ofi/ofi_send.h index 3d409ef5fa1..29f684527ea 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_send.h +++ b/src/mpid/ch4/netmod/ofi/ofi_send.h @@ -101,7 +101,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_issue_ack_recv(MPIR_Request * sreq, MPIR_ ackreq->ack_hdr_sz = hdr_sz; ackreq->ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic); ackreq->vci_local = vci_local; - ackreq->remote_addr = MPIDI_OFI_av_to_phys(addr, nic, vci_remote); + ackreq->remote_addr = MPIDI_OFI_av_to_phys(addr, vci_local, nic, vci_remote, nic); ackreq->match_bits = match_bits; #ifndef MPIDI_CH4_DIRECT_NETMOD @@ -131,7 +131,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_lightweight(const void *buf, size_t int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, sender_nic); - fi_addr_t dest_addr = MPIDI_OFI_av_to_phys(addr, receiver_nic, vci_remote); + fi_addr_t dest_addr = + MPIDI_OFI_av_to_phys(addr, vci_local, sender_nic, vci_remote, receiver_nic); if (MPIDI_OFI_ENABLE_DATA) { MPIDI_OFI_CALL_RETRY(fi_tinjectdata(MPIDI_OFI_global.ctx[ctx_idx].tx, buf, data_sz, cq_data, dest_addr, match_bits), @@ -187,7 +188,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_iov(const void *buf, MPI_Aint count, msg.ignore = 0ULL; msg.context = (void *) &(MPIDI_OFI_REQUEST(sreq, context)); msg.data = MPIDI_OFI_ENABLE_DATA ? cq_data : 0; - msg.addr = MPIDI_OFI_av_to_phys(addr, receiver_nic, vci_remote); + msg.addr = MPIDI_OFI_av_to_phys(addr, vci_local, sender_nic, vci_remote, receiver_nic); int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, sender_nic); @@ -231,7 +232,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_normal(const void *data, MPI_Aint da } } - fi_addr_t dest_addr = MPIDI_OFI_av_to_phys(addr, receiver_nic, vci_remote); + fi_addr_t dest_addr = + MPIDI_OFI_av_to_phys(addr, vci_local, sender_nic, vci_remote, receiver_nic); if (MPIDI_OFI_ENABLE_DATA) { MPIDI_OFI_CALL_RETRY(fi_tsenddata(MPIDI_OFI_global.ctx[ctx_idx].tx, data, data_sz, desc, cq_data, dest_addr, @@ -341,12 +343,11 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_huge(const void *data, MPI_Aint data MPIR_cc_inc(sreq->cc_ptr); MPIDI_OFI_REQUEST(sreq, event_id) = MPIDI_OFI_EVENT_SEND_HUGE; + fi_addr_t dest = MPIDI_OFI_av_to_phys(addr, vci_local, sender_nic, vci_remote, receiver_nic); match_bits |= MPIDI_OFI_HUGE_SEND; /* Add the bit for a huge message */ MPIDI_OFI_CALL_RETRY(fi_tsenddata(MPIDI_OFI_global.ctx[ctx_idx].tx, data, msg_size, NULL /* desc */ , - cq_data, - MPIDI_OFI_av_to_phys(addr, receiver_nic, vci_remote), - match_bits, + cq_data, dest, match_bits, (void *) &(MPIDI_OFI_REQUEST(sreq, context))), vci_local, tsenddata); /* FIXME: sender_nic may not be the actual nic */ @@ -395,7 +396,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_pipeline(const void *buf, MPI_Aint c MPIDI_OFI_REQUEST(sreq, pipeline_info.chunk_sz) = chunk_size; MPIDI_OFI_REQUEST(sreq, pipeline_info.cq_data) = cq_data; MPIDI_OFI_REQUEST(sreq, pipeline_info.remote_addr) = - MPIDI_OFI_av_to_phys(addr, receiver_nic, vci_remote); + MPIDI_OFI_av_to_phys(addr, vci_local, sender_nic, vci_remote, receiver_nic); MPIDI_OFI_REQUEST(sreq, pipeline_info.vci_local) = vci_local; MPIDI_OFI_REQUEST(sreq, pipeline_info.ctx_idx) = ctx_idx; MPIDI_OFI_REQUEST(sreq, pipeline_info.match_bits) = match_bits; @@ -472,7 +473,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_fallback(const void *buf, MPI_Aint c msg.ignore = 0ULL; msg.context = (void *) &(MPIDI_OFI_REQUEST(sreq, context)); msg.data = 0; - msg.addr = MPIDI_OFI_av_to_phys(addr, 0, 0); + msg.addr = MPIDI_OFI_av_to_phys_root(addr); int flags = FI_COMPLETION | FI_TRANSMIT_COMPLETE; if (MPIDI_OFI_ENABLE_DATA) { diff --git a/src/mpid/ch4/netmod/ofi/ofi_spawn.c b/src/mpid/ch4/netmod/ofi/ofi_spawn.c index f4039fdeedd..d1da16a361f 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_spawn.c +++ b/src/mpid/ch4/netmod/ofi/ofi_spawn.c @@ -13,12 +13,11 @@ int MPIDI_OFI_dynamic_send(uint64_t remote_gpid, int tag, const void *buf, int s MPIR_Assert(MPIDI_OFI_ENABLE_TAGGED); - int nic = 0; /* dynamic process only use nic 0 */ int vci = 0; /* dynamic process only use vci 0 */ int ctx_idx = 0; int avtid = MPIDIU_GPID_GET_AVTID(remote_gpid); int lpid = MPIDIU_GPID_GET_LPID(remote_gpid); - fi_addr_t remote_addr = MPIDI_OFI_av_to_phys(&MPIDIU_get_av(avtid, lpid), nic, vci); + fi_addr_t remote_addr = MPIDI_OFI_av_to_phys_root(&MPIDIU_get_av(avtid, lpid)); MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI_LOCK(vci)); From 016512c0bf5bcf40f1ab2c7ecee34a3c42fb702f Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 28 Dec 2024 17:03:08 -0600 Subject: [PATCH 08/25] ch4/ofi: delay allocate addrs for multiple vci/nic Let MPIDI_OFI_addr_t only contain field for root vci address, and only allocate more space for additional addresses when multiple vci and nic is enabled -- potentially at runtime. This avoids wasting memory for multiple vcis unelss it is actually needed. --- src/mpid/ch4/netmod/ofi/init_addrxchg.c | 36 +++++++++++++++---------- src/mpid/ch4/netmod/ofi/ofi_impl.h | 26 +++++++++++++----- src/mpid/ch4/netmod/ofi/ofi_init.c | 19 +++++++------ src/mpid/ch4/netmod/ofi/ofi_pre.h | 7 ++--- src/mpid/ch4/netmod/ofi/ofi_spawn.c | 16 +++++------ 5 files changed, 63 insertions(+), 41 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/init_addrxchg.c b/src/mpid/ch4/netmod/ofi/init_addrxchg.c index 7a1766df84e..cf27ce9a6de 100644 --- a/src/mpid/ch4/netmod/ofi/init_addrxchg.c +++ b/src/mpid/ch4/netmod/ofi/init_addrxchg.c @@ -133,7 +133,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void) for (int i = 0; i < num_nodes; i++) { MPIR_Assert(mapped_table[i] != FI_ADDR_NOTAVAIL); - MPIDI_OFI_AV(&MPIDIU_get_av(0, node_roots[i])).dest[0][0] = mapped_table[i]; + MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, node_roots[i])) = mapped_table[i]; } MPL_free(mapped_table); /* Then, allgather all address names using init_comm */ @@ -149,7 +149,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void) char *addrname = (char *) table + recv_bc_len * rank_map[i]; MPIDI_OFI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[0].av, addrname, 1, &addr, 0ULL, NULL), avmap); - MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[0][0] = addr; + MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, i)) = addr; } } mpi_errno = MPIDU_bc_table_destroy(); @@ -163,7 +163,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void) for (int i = 0; i < size; i++) { MPIR_Assert(mapped_table[i] != FI_ADDR_NOTAVAIL); - MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[0][0] = mapped_table[i]; + MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, i)) = mapped_table[i]; } MPL_free(mapped_table); mpi_errno = MPIDU_bc_table_destroy(); @@ -173,8 +173,8 @@ int MPIDI_OFI_addr_exchange_root_ctx(void) /* check */ if (MPIDI_OFI_ENABLE_AV_TABLE) { for (int r = 0; r < size; r++) { - MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, r)); - MPIR_Assert(av->dest[0][0] == get_root_av_table_index(r)); + MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, r); + MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == get_root_av_table_index(r)); } } @@ -192,7 +192,7 @@ int MPIDI_OFI_addr_exchange_root_ctx(void) /* Macros to reduce clutter, so we can focus on the ordering logics. * Note: they are not perfectly wrapped, but tolerable since only used here. */ #define GET_AV_AND_ADDRNAMES(rank) \ - MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, rank)); \ + MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, rank); \ char *r_names = all_names + rank * max_vcis * num_nics * name_len; #define DO_AV_INSERT(ctx_idx, nic, vci) \ @@ -245,6 +245,14 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) goto fn_exit; } + /* allocate additional av addrs */ + for (int i = 0; i < size; i++) { + MPIDI_av_entry_t *av = &MPIDIU_get_av(0, i); + MPIDI_OFI_AV(av).all_dest = MPL_malloc(max_vcis * num_nics * sizeof(fi_addr_t), + MPL_MEM_ADDRESS); + MPIR_ERR_CHKANDJUMP(!MPIDI_OFI_AV(av).all_dest, mpi_errno, MPI_ERR_OTHER, "**nomem"); + } + /* libfabric uses uniform name_len within a single provider */ int name_len = MPIDI_OFI_global.addrnamelen; int my_len = max_vcis * num_nics * name_len; @@ -275,7 +283,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { SKIP_ROOT(nic, vci); DO_AV_INSERT(root_ctx_idx, nic, vci); - av->dest[nic][vci] = addr; + MPIDI_OFI_AV_ADDR(av, vci, nic) = addr; } } } @@ -306,7 +314,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) if (is_node_roots[r]) { GET_AV_AND_ADDRNAMES(r); DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(av->dest[0][0] == addr); + MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr); } } /* non-node-root */ @@ -314,7 +322,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) if (!is_node_roots[r]) { GET_AV_AND_ADDRNAMES(r); DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(av->dest[0][0] == addr); + MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr); } } } else { @@ -322,7 +330,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) for (int r = 0; r < size; r++) { GET_AV_AND_ADDRNAMES(r); DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(av->dest[0][0] == addr); + MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr); } } @@ -333,7 +341,7 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { SKIP_ROOT(nic, vci); DO_AV_INSERT(ctx_idx, nic, vci); - MPIR_Assert(av->dest[nic][vci] == addr); + MPIR_Assert(MPIDI_OFI_AV_ADDR(av, vci, nic) == addr); } } } @@ -346,11 +354,11 @@ int MPIDI_OFI_addr_exchange_all_ctx(void) #if MPIDI_CH4_MAX_VCIS > 1 if (MPIDI_OFI_ENABLE_AV_TABLE) { for (int r = 0; r < size; r++) { - MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, r)); + MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, r); for (int nic = 0; nic < num_nics; nic++) { for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { - MPIR_Assert(av->dest[nic][vci] == get_av_table_index(r, nic, vci, - all_num_vcis)); + MPIR_Assert(MPIDI_OFI_AV_ADDR(av, vci, nic) == get_av_table_index(r, nic, vci, + all_num_vcis)); } } } diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index 368be20e672..6264225a74e 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -33,8 +33,22 @@ ATTRIBUTE((unused)); #define MPIDI_OFI_COMM(comm) ((comm)->dev.ch4.netmod.ofi) #define MPIDI_OFI_COMM_TO_INDEX(comm,rank) \ MPIDIU_comm_rank_to_pid(comm, rank, NULL, NULL) -#define MPIDI_OFI_TO_PHYS(avtid, lpid, _nic) \ - MPIDI_OFI_AV(&MPIDIU_get_av((avtid), (lpid))).dest[_nic][0] + +#define MPIDI_OFI_AV_ROOT_ADDR(av) MPIDI_OFI_AV(av).root_dest + +#ifdef MPIDI_OFI_VNI_USE_DOMAIN +#define MPIDI_OFI_AV_ADDR_ROOT(av) \ + MPIDI_OFI_AV(av).root_dest +#define MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic) \ + MPIDI_OFI_AV(av).all_dest[(vci)*MPIDI_OFI_global.num_nics+(nic)] +#else /* scalable endpoints - all vci share the same addr */ +#define MPIDI_OFI_AV_ADDR_ROOT(av, vci, nic) \ + MPIDI_OFI_AV(av).root_dest +#define MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic) \ + MPIDI_OFI_AV(av).all_dest[nic] +#endif +#define MPIDI_OFI_AV_ADDR(av, vci, nic) \ + ((vci==0 && nic==0) ? MPIDI_OFI_AV_ADDR_ROOT(av) : MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic)) #define MPIDI_OFI_WIN(win) ((win)->dev.netmod.ofi) @@ -445,16 +459,16 @@ MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_av_to_phys(MPIDI_av_entry_t * av, { #ifdef MPIDI_OFI_VNI_USE_DOMAIN if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) { - return fi_rx_addr(MPIDI_OFI_AV(av).dest[nic][vci], 0, MPIDI_OFI_MAX_ENDPOINTS_BITS); + return fi_rx_addr(MPIDI_OFI_AV_ADDR(av, vci, nic), 0, MPIDI_OFI_MAX_ENDPOINTS_BITS); } else { - return MPIDI_OFI_AV(av).dest[nic][vci]; + return MPIDI_OFI_AV_ADDR(av, vci, nic); } #else /* MPIDI_OFI_VNI_USE_SEPCTX */ if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) { - return fi_rx_addr(MPIDI_OFI_AV(av).dest[nic][0], vci, MPIDI_OFI_MAX_ENDPOINTS_BITS); + return fi_rx_addr(MPIDI_OFI_AV_ADDR(av, vci, nic), vci, MPIDI_OFI_MAX_ENDPOINTS_BITS); } else { MPIR_Assert(vci == 0); - return MPIDI_OFI_AV(av).dest[nic][0]; + return MPIDI_OFI_AV_ADDR(av, vci, nic); } #endif } diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 9fd7820e485..1df47b66b51 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -1116,6 +1116,13 @@ int MPIDI_OFI_mpi_finalize_hook(void) fi_freeinfo(MPIDI_OFI_global.prov_use[i]); } + /* free av entries for multiple vcis and nics */ + for (i = 0; i < MPIR_Process.size; i++) { + MPIDI_av_entry_t *av = &MPIDIU_get_av(0, i); + MPL_free(MPIDI_OFI_AV(av).all_dest); + MPIDI_OFI_AV(av).all_dest = NULL; + } + MPIDIU_map_destroy(MPIDI_OFI_global.win_map); if (MPIDI_OFI_ENABLE_AM) { @@ -1186,7 +1193,7 @@ static int create_sep_tx(struct fid_ep *ep, int idx, struct fid_ep **p_tx, struct fid_cq *cq, struct fid_cntr *cntr, int nic); static int create_sep_rx(struct fid_ep *ep, int idx, struct fid_ep **p_rx, struct fid_cq *cq, int nic); -static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av, int nic); +static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av); static int open_local_av(struct fid_domain *p_domain, struct fid_av **p_av); /* This function creates a vci context which includes all of the OFI-level objects needed to @@ -1420,7 +1427,7 @@ static int create_vci_domain(struct fid_domain **p_domain, struct fid_av **p_av, * Otherwise, set MPIDI_OFI_global.got_named_av and * copy the map_addr. */ - if (try_open_shared_av(domain, p_av, nic)) { + if (nic == 0 && try_open_shared_av(domain, p_av)) { MPIDI_OFI_global.got_named_av = 1; } else { mpi_errno = open_local_av(domain, p_av); @@ -1525,14 +1532,10 @@ static int create_sep_rx(struct fid_ep *ep, int idx, struct fid_ep **p_rx, struc goto fn_exit; } -static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av, int nic) +static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av) { int ret = 0; - /* It's not possible to use shared address vectors with more than one domain in a single - * process. If we're trying to do that (for example if we are using MPIDI_OFI_VNI_USE_DOMAIN or - * we have multiple VNIs because of multi-nic), attempt to open up the shared AV in one VNI and - * then copy the results to the others later. */ struct fi_av_attr av_attr; memset(&av_attr, 0, sizeof(av_attr)); if (MPIDI_OFI_ENABLE_AV_TABLE) { @@ -1555,7 +1558,7 @@ static int try_open_shared_av(struct fid_domain *domain, struct fid_av **p_av, i /* directly references the mapped fi_addr_t array instead */ fi_addr_t *mapped_table = (fi_addr_t *) av_attr.map_addr; for (int i = 0; i < MPIR_Process.size; i++) { - MPIDI_OFI_AV(&MPIDIU_get_av(0, i)).dest[nic][0] = mapped_table[i]; + MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(0, i)) = mapped_table[i]; MPL_DBG_MSG_FMT(MPIDI_CH4_DBG_MAP, VERBOSE, (MPL_DBG_FDEST, " grank mapped to: rank=%d, av=%p, dest=%" PRIu64, i, (void *) &MPIDIU_get_av(0, i), mapped_table[i])); diff --git a/src/mpid/ch4/netmod/ofi/ofi_pre.h b/src/mpid/ch4/netmod/ofi/ofi_pre.h index abb03a6291e..92166fe33dc 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_pre.h +++ b/src/mpid/ch4/netmod/ofi/ofi_pre.h @@ -312,11 +312,8 @@ typedef struct { #define MPIDI_OFI_MAX_NICS 8 typedef struct { -#ifdef MPIDI_OFI_VNI_USE_DOMAIN - fi_addr_t dest[MPIDI_OFI_MAX_NICS][MPIDI_CH4_MAX_VCIS]; /* [nic][vci] */ -#else - fi_addr_t dest[MPIDI_OFI_MAX_NICS][1]; -#endif + fi_addr_t root_dest; + fi_addr_t *all_dest; /* to be allocated into an array of [nic * vci] */ } MPIDI_OFI_addr_t; #endif /* OFI_PRE_H_INCLUDED */ diff --git a/src/mpid/ch4/netmod/ofi/ofi_spawn.c b/src/mpid/ch4/netmod/ofi/ofi_spawn.c index d1da16a361f..a7ad13a9cba 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_spawn.c +++ b/src/mpid/ch4/netmod/ofi/ofi_spawn.c @@ -143,8 +143,7 @@ int MPIDI_OFI_upids_to_gpids(int size, int *remote_upid_size, char *remote_upids int n_new_procs = 0; int n_avts; char *curr_upid; - int nic = 0; - int ctx_idx = MPIDI_OFI_get_ctx_index(0, nic); + int ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); MPIR_CHKLMEM_DECL(2); @@ -173,8 +172,9 @@ int MPIDI_OFI_upids_to_gpids(int size, int *remote_upid_size, char *remote_upids } for (j = 0; j < MPIDIU_get_av_table(k)->size; j++) { sz = MPIDI_OFI_global.addrnamelen; + MPIDI_av_entry_t *av = &MPIDIU_get_av(k, j); MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av, - MPIDI_OFI_TO_PHYS(k, j, nic), &tbladdr, &sz), 0, + MPIDI_OFI_AV_ROOT_ADDR(av), &tbladdr, &sz), 0, avlookup); if (sz == addrname_len && !memcmp(tbladdr, addrname, addrname_len)) { remote_gpids[i] = MPIDIU_GPID_CREATE(k, j); @@ -209,7 +209,7 @@ int MPIDI_OFI_upids_to_gpids(int size, int *remote_upid_size, char *remote_upids MPIDI_OFI_VCI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[ctx_idx].av, addrname, 1, &addr, 0ULL, NULL), 0, avmap); MPIR_Assert(addr != FI_ADDR_NOTAVAIL); - MPIDI_OFI_AV(&MPIDIU_get_av(avtid, i)).dest[nic][0] = addr; + MPIDI_OFI_AV_ROOT_ADDR(&MPIDIU_get_av(avtid, i)) = addr; int node_id; mpi_errno = MPIR_nodeid_lookup(hostname, &node_id); @@ -232,8 +232,7 @@ int MPIDI_OFI_get_local_upids(MPIR_Comm * comm, int **local_upid_size, char **lo int mpi_errno = MPI_SUCCESS; int i; char *temp_buf = NULL; - int nic = 0; - int ctx_idx = MPIDI_OFI_get_ctx_index(0, nic); + int ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); MPIR_CHKPMEM_DECL(2); @@ -264,8 +263,9 @@ int MPIDI_OFI_get_local_upids(MPIR_Comm * comm, int **local_upid_size, char **lo idx += hostname_len + 1; size_t sz = MPIDI_OFI_global.addrnamelen;; - MPIDI_OFI_addr_t *av = &MPIDI_OFI_AV(MPIDIU_comm_rank_to_av(comm, i)); - MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av, av->dest[nic][0], + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, i); + MPIDI_OFI_VCI_CALL(fi_av_lookup(MPIDI_OFI_global.ctx[ctx_idx].av, + MPIDI_OFI_AV_ROOT_ADDR(av), temp_buf + idx, &sz), 0, avlookup); idx += (int) sz; From ecf6bfa7bc287e7ed74f3bd3fe14627fd2e1b81f Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 28 Dec 2024 11:24:14 -0600 Subject: [PATCH 09/25] ch4/ofi: refactor vci-related initialization code Move code that are related to multiple-vci setup to ofi_vci.c. --- src/mpid/ch4/netmod/ofi/init_addrxchg.c | 244 ---------------- src/mpid/ch4/netmod/ofi/ofi_init.c | 134 +-------- src/mpid/ch4/netmod/ofi/ofi_init.h | 3 + src/mpid/ch4/netmod/ofi/ofi_vci.c | 364 +++++++++++++++++++++++- 4 files changed, 369 insertions(+), 376 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/init_addrxchg.c b/src/mpid/ch4/netmod/ofi/init_addrxchg.c index cf27ce9a6de..5b33ff7cb7f 100644 --- a/src/mpid/ch4/netmod/ofi/init_addrxchg.c +++ b/src/mpid/ch4/netmod/ofi/init_addrxchg.c @@ -38,58 +38,6 @@ * isolates multi-nic/vci complications from bootstrapping phase. */ -/* with MPIDI_OFI_ENABLE_AV_TABLE, we potentially can omit storing av tables. - * The following routines ensures we can do that. It is static now, but we can - * easily export to global when we need to. - */ - -#if !defined(MPIDI_OFI_VNI_USE_DOMAIN) || MPIDI_CH4_MAX_VCIS == 1 -/* NOTE: with scalable endpoint as context, all vcis share the same address. */ -#define NUM_VCIS_FOR_RANK(r) 1 -#else -#define NUM_VCIS_FOR_RANK(r) all_num_vcis[r] -#endif - -ATTRIBUTE((unused)) -static int get_root_av_table_index(int rank) -{ - if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { - /* node roots with greater ranks are inserted before this rank if it a non-node-root */ - int num_extra = 0; - - /* check node roots */ - for (int i = 0; i < MPIR_Process.num_nodes; i++) { - if (MPIR_Process.node_root_map[i] == rank) { - return i; - } else if (MPIR_Process.node_root_map[i] > rank) { - num_extra++; - } - } - - /* must be non-node-root */ - return rank + num_extra; - } else { - return rank; - } -} - -ATTRIBUTE((unused)) -static int get_av_table_index(int rank, int nic, int vci, int *all_num_vcis) -{ - if (nic == 0 && vci == 0) { - return get_root_av_table_index(rank); - } else { - int num_nics = MPIDI_OFI_global.num_nics; - int idx = 0; - idx += MPIR_Process.size; /* root entries */ - for (int i = 0; i < rank; i++) { - idx += num_nics * NUM_VCIS_FOR_RANK(i) - 1; - } - idx += nic * NUM_VCIS_FOR_RANK(rank) + vci - 1; - return idx; - } -} - /* Step 1: exchange root contexts */ int MPIDI_OFI_addr_exchange_root_ctx(void) { @@ -170,14 +118,6 @@ int MPIDI_OFI_addr_exchange_root_ctx(void) MPIR_ERR_CHECK(mpi_errno); } - /* check */ - if (MPIDI_OFI_ENABLE_AV_TABLE) { - for (int r = 0; r < size; r++) { - MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, r); - MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == get_root_av_table_index(r)); - } - } - fn_exit: if (init_comm && !mpi_errno) { MPIDI_destroy_init_comm(&init_comm); @@ -186,187 +126,3 @@ int MPIDI_OFI_addr_exchange_root_ctx(void) fn_fail: goto fn_exit; } - -/* Step 2 & 3: exchange non-root contexts */ - -/* Macros to reduce clutter, so we can focus on the ordering logics. - * Note: they are not perfectly wrapped, but tolerable since only used here. */ -#define GET_AV_AND_ADDRNAMES(rank) \ - MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, rank); \ - char *r_names = all_names + rank * max_vcis * num_nics * name_len; - -#define DO_AV_INSERT(ctx_idx, nic, vci) \ - fi_addr_t addr; \ - MPIDI_OFI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[ctx_idx].av, \ - r_names + (vci * num_nics + nic) * name_len, 1, \ - &addr, 0ULL, NULL), avmap); - -#define SKIP_ROOT(nic, vci) \ - if (nic == 0 && vci == 0) { \ - continue; \ - } - -int MPIDI_OFI_addr_exchange_all_ctx(void) -{ - int mpi_errno = MPI_SUCCESS; - - MPIR_Comm *comm = MPIR_Process.comm_world; - int size = MPIR_Process.size; - int rank = MPIR_Process.rank; - MPIR_CHKLMEM_DECL(3); - - int max_vcis; - int *all_num_vcis; - -#if !defined(MPIDI_OFI_VNI_USE_DOMAIN) || MPIDI_CH4_MAX_VCIS == 1 - max_vcis = 1; - all_num_vcis = NULL; -#else - /* Allgather num_vcis */ - MPIR_CHKLMEM_MALLOC(all_num_vcis, void *, sizeof(int) * size, - mpi_errno, "all_num_vcis", MPL_MEM_ADDRESS); - mpi_errno = MPIR_Allgather_fallback(&MPIDI_OFI_global.num_vcis, 1, MPI_INT, - all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); - MPIR_ERR_CHECK(mpi_errno); - - max_vcis = 0; - for (int i = 0; i < size; i++) { - if (max_vcis < NUM_VCIS_FOR_RANK(i)) { - max_vcis = NUM_VCIS_FOR_RANK(i); - } - } -#endif - - int num_vcis = NUM_VCIS_FOR_RANK(rank); - int num_nics = MPIDI_OFI_global.num_nics; - - /* Assume num_nics are all equal */ - if (max_vcis * num_nics == 1) { - goto fn_exit; - } - - /* allocate additional av addrs */ - for (int i = 0; i < size; i++) { - MPIDI_av_entry_t *av = &MPIDIU_get_av(0, i); - MPIDI_OFI_AV(av).all_dest = MPL_malloc(max_vcis * num_nics * sizeof(fi_addr_t), - MPL_MEM_ADDRESS); - MPIR_ERR_CHKANDJUMP(!MPIDI_OFI_AV(av).all_dest, mpi_errno, MPI_ERR_OTHER, "**nomem"); - } - - /* libfabric uses uniform name_len within a single provider */ - int name_len = MPIDI_OFI_global.addrnamelen; - int my_len = max_vcis * num_nics * name_len; - char *all_names; - MPIR_CHKLMEM_MALLOC(all_names, char *, size * my_len, mpi_errno, "all_names", MPL_MEM_ADDRESS); - char *my_names = all_names + rank * my_len; - - /* put in my addrnames */ - for (int nic = 0; nic < num_nics; nic++) { - for (int vci = 0; vci < num_vcis; vci++) { - size_t actual_name_len = name_len; - char *vci_addrname = my_names + (vci * num_nics + nic) * name_len; - int ctx_idx = MPIDI_OFI_get_ctx_index(vci, nic); - MPIDI_OFI_CALL(fi_getname((fid_t) MPIDI_OFI_global.ctx[ctx_idx].ep, vci_addrname, - &actual_name_len), getname); - MPIR_Assert(actual_name_len == name_len); - } - } - /* Allgather */ - mpi_errno = MPIR_Allgather_fallback(MPI_IN_PLACE, 0, MPI_BYTE, - all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); - - /* Step 2: insert and store non-root nic/vci on the root context */ - int root_ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); - for (int r = 0; r < size; r++) { - GET_AV_AND_ADDRNAMES(r); - for (int nic = 0; nic < num_nics; nic++) { - for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { - SKIP_ROOT(nic, vci); - DO_AV_INSERT(root_ctx_idx, nic, vci); - MPIDI_OFI_AV_ADDR(av, vci, nic) = addr; - } - } - } - - /* Step 3: insert all nic/vci on non-root context, following exact order as step 1 and 2 */ - - int *is_node_roots = NULL; - if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { - MPIR_CHKLMEM_MALLOC(is_node_roots, int *, size * sizeof(int), - mpi_errno, "is_node_roots", MPL_MEM_ADDRESS); - for (int r = 0; r < size; r++) { - is_node_roots[r] = 0; - } - for (int i = 0; i < MPIR_Process.num_nodes; i++) { - is_node_roots[MPIR_Process.node_root_map[i]] = 1; - } - } - - for (int nic_local = 0; nic_local < num_nics; nic_local++) { - for (int vci_local = 0; vci_local < num_vcis; vci_local++) { - SKIP_ROOT(nic_local, vci_local); - int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic_local); - - /* -- same order as step 1 -- */ - if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { - /* node roots */ - for (int r = 0; r < size; r++) { - if (is_node_roots[r]) { - GET_AV_AND_ADDRNAMES(r); - DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr); - } - } - /* non-node-root */ - for (int r = 0; r < size; r++) { - if (!is_node_roots[r]) { - GET_AV_AND_ADDRNAMES(r); - DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr); - } - } - } else { - /* !MPIR_CVAR_CH4_ROOTS_ONLY_PMI */ - for (int r = 0; r < size; r++) { - GET_AV_AND_ADDRNAMES(r); - DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(MPIDI_OFI_AV_ROOT_ADDR(av) == addr); - } - } - - /* -- same order as step 2 -- */ - for (int r = 0; r < size; r++) { - GET_AV_AND_ADDRNAMES(r); - for (int nic = 0; nic < num_nics; nic++) { - for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { - SKIP_ROOT(nic, vci); - DO_AV_INSERT(ctx_idx, nic, vci); - MPIR_Assert(MPIDI_OFI_AV_ADDR(av, vci, nic) == addr); - } - } - } - } - } - mpi_errno = MPIR_Barrier_fallback(comm, MPIR_ERR_NONE); - MPIR_ERR_CHECK(mpi_errno); - - /* check */ -#if MPIDI_CH4_MAX_VCIS > 1 - if (MPIDI_OFI_ENABLE_AV_TABLE) { - for (int r = 0; r < size; r++) { - MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, r); - for (int nic = 0; nic < num_nics; nic++) { - for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { - MPIR_Assert(MPIDI_OFI_AV_ADDR(av, vci, nic) == get_av_table_index(r, nic, vci, - all_num_vcis)); - } - } - } - } -#endif - fn_exit: - MPIR_CHKLMEM_FREEALL(); - return mpi_errno; - fn_fail: - goto fn_exit; -} diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 1df47b66b51..8db1ae86ef6 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -542,13 +542,9 @@ categories : static int update_global_limits(struct fi_info *prov); static void dump_global_settings(void); static void dump_dynamic_settings(void); -static int create_vci_context(int vci, int nic); static int destroy_vci_context(int vci, int nic); static int ofi_pvar_init(void); -static int ofi_am_init(int vci); -static int ofi_am_post_recv(int vci, int nic); - static void *host_alloc(uintptr_t size); static void host_free(void *ptr); @@ -765,7 +761,7 @@ int MPIDI_OFI_init_local(int *tag_bits) /* Creating the context for vci 0 and nic 0. * This code maybe moved to a later stage */ - mpi_errno = create_vci_context(0, 0); + mpi_errno = MPIDI_OFI_create_vci_context(0, 0); MPIR_ERR_CHECK(mpi_errno); /* index datatypes for RMA atomics. */ @@ -775,8 +771,8 @@ int MPIDI_OFI_init_local(int *tag_bits) MPIR_Assert(MPIDI_OFI_DEFAULT_SHORT_SEND_SIZE <= MPIR_CVAR_CH4_PACK_BUFFER_SIZE); MPIDI_OFI_global.num_vcis = 1; - ofi_am_init(0); - ofi_am_post_recv(0, 0); + MPIDI_OFI_am_init(0); + MPIDI_OFI_am_post_recv(0, 0); fn_exit: *tag_bits = MPIDI_OFI_TAG_BITS; @@ -801,36 +797,6 @@ int MPIDI_OFI_init_world(void) goto fn_exit; } -static int check_num_nics(void); -static int setup_additional_vcis(void); - -int MPIDI_OFI_init_vcis(int num_vcis, int *num_vcis_actual) -{ - int mpi_errno = MPI_SUCCESS; - - /* Multiple vci without using domain require MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS */ -#ifndef MPIDI_OFI_VNI_USE_DOMAIN - MPIR_Assert(num_vcis == 1 || MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS); -#endif - - MPIDI_OFI_global.num_vcis = num_vcis; - - /* All processes must have the same number of NICs */ - mpi_errno = check_num_nics(); - MPIR_ERR_CHECK(mpi_errno); - - /* may update MPIDI_OFI_global.num_vcis */ - mpi_errno = setup_additional_vcis(); - MPIR_ERR_CHECK(mpi_errno); - - *num_vcis_actual = MPIDI_OFI_global.num_vcis; - - fn_exit: - return mpi_errno; - fn_fail: - goto fn_exit; -} - int MPIDI_OFI_post_init(void) { int mpi_errno = MPI_SUCCESS; @@ -839,93 +805,7 @@ int MPIDI_OFI_post_init(void) dump_dynamic_settings(); } - /* Since we allow different process to have different num_vcis, we always need run exchange. */ - mpi_errno = MPIDI_OFI_addr_exchange_all_ctx(); - MPIR_ERR_CHECK(mpi_errno); - - for (int vci = 1; vci < MPIDI_OFI_global.num_vcis; vci++) { - ofi_am_init(vci); - ofi_am_post_recv(vci, 0); - } - - fn_exit: - return mpi_errno; - fn_fail: - goto fn_exit; -} - -static int check_num_nics(void) -{ - int mpi_errno = MPI_SUCCESS; - - int num_nics = MPIDI_OFI_global.num_nics; - int tmp_num_vcis = MPIDI_OFI_global.num_vcis; - int tmp_num_nics = MPIDI_OFI_global.num_nics; - - /* Set the number of NICs and VNIs to 1 temporarily to avoid problems during the collective */ - MPIDI_OFI_global.num_vcis = MPIDI_OFI_global.num_nics = 1; - - /* Confirm that all processes have the same number of NICs */ - mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, &num_nics, 1, MPI_INT, - MPI_MIN, MPIR_Process.comm_world, MPIR_ERR_NONE); - MPIDI_OFI_global.num_vcis = tmp_num_vcis; - MPIDI_OFI_global.num_nics = tmp_num_nics; - MPIR_ERR_CHECK(mpi_errno); - - /* If the user did not ask to fallback to fewer NICs, throw an error if someone is missing a - * NIC. */ - if (tmp_num_nics != num_nics) { - if (MPIR_CVAR_OFI_USE_MIN_NICS) { - MPIDI_OFI_global.num_nics = num_nics; - - /* If we fall down to 1 nic, turn off multi-nic optimizations. */ - if (num_nics == 1) { - MPIDI_OFI_COMM(MPIR_Process.comm_world).enable_striping = 0; - } - } else { - MPIR_ERR_CHKANDJUMP(num_nics != MPIDI_OFI_global.num_nics, mpi_errno, MPI_ERR_OTHER, - "**ofi_num_nics"); - } - } - - /* FIXME: It would also be helpful to check that all of the NICs can communicate so we can fall - * back to other options if that is not the case (e.g., verbs are often configured with a - * different subnet for each "set" of nics). It's unknown how to write a good check for that. */ - - /* set rx_ctx_cnt and tx_ctx_cnt for the remaining (non-0) nics */ - for (int nic = 1; nic < MPIDI_OFI_global.num_nics; nic++) { - set_sep_counters(nic); - } - - fn_exit: - return mpi_errno; - fn_fail: - goto fn_exit; -} - -static int setup_additional_vcis(void) -{ - int mpi_errno = MPI_SUCCESS; - - for (int vci = 0; vci < MPIDI_OFI_global.num_vcis; vci++) { - for (int nic = 0; nic < MPIDI_OFI_global.num_nics; nic++) { - /* vci 0 nic 0 already created */ - if (vci > 0 || nic > 0) { - mpi_errno = create_vci_context(vci, nic); - if (mpi_errno != MPI_SUCCESS) { - /* running out of vcis, reduce MPIDI_OFI_global.num_vcis */ - if (vci > 0) { - MPIDI_OFI_global.num_vcis = vci; - /* FIXME: destroy already created vci_context */ - mpi_errno = MPI_SUCCESS; - goto fn_exit; - } else { - MPIR_ERR_CHECK(mpi_errno); - } - } - } - } - } + mpi_errno = MPIDI_OFI_comm_set_vcis(MPIR_Process.comm_world); fn_exit: return mpi_errno; @@ -1204,7 +1084,7 @@ static int open_local_av(struct fid_domain *p_domain, struct fid_av **p_av); * * Each nic will restart its vci indexing. This allows each VNI to use any nic if desired. */ -static int create_vci_context(int vci, int nic) +int MPIDI_OFI_create_vci_context(int vci, int nic) { int mpi_errno = MPI_SUCCESS; @@ -1732,7 +1612,7 @@ static void dump_dynamic_settings(void) /* static functions for AM */ -int ofi_am_init(int vci) +int MPIDI_OFI_am_init(int vci) { int mpi_errno = MPI_SUCCESS; @@ -1781,7 +1661,7 @@ int ofi_am_init(int vci) goto fn_exit; } -int ofi_am_post_recv(int vci, int nic) +int MPIDI_OFI_am_post_recv(int vci, int nic) { int mpi_errno = MPI_SUCCESS; diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.h b/src/mpid/ch4/netmod/ofi/ofi_init.h index 3c7312b3d2e..ab06b92fba7 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.h +++ b/src/mpid/ch4/netmod/ofi/ofi_init.h @@ -32,8 +32,11 @@ void MPIDI_OFI_update_global_settings(struct fi_info *prov); /* Determine if NIC has already been included in others */ bool MPIDI_OFI_nic_already_used(const struct fi_info *prov, struct fi_info **others, int nic_count); +int MPIDI_OFI_create_vci_context(int vci, int nic); int MPIDI_OFI_addr_exchange_root_ctx(void); int MPIDI_OFI_addr_exchange_all_ctx(void); +int MPIDI_OFI_am_init(int vci); +int MPIDI_OFI_am_post_recv(int vci, int nic); bool MPIDI_OFI_nic_is_up(struct fi_info *prov); diff --git a/src/mpid/ch4/netmod/ofi/ofi_vci.c b/src/mpid/ch4/netmod/ofi/ofi_vci.c index b525b90da58..e0403482d2f 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_vci.c +++ b/src/mpid/ch4/netmod/ofi/ofi_vci.c @@ -5,13 +5,367 @@ #include "mpidimpl.h" #include "ofi_impl.h" +#include "ofi_init.h" -int MPIDI_OFI_comm_set_vcis(MPIR_Comm * comm, int num_vcis) +/* Address exchange within comm and setup multiple vcis */ + +int MPIDI_OFI_comm_set_vcis(MPIR_Comm * comm) +{ + int mpi_errno = MPI_SUCCESS; + + MPIR_Assert(comm == MPIR_Process.comm_world); + + /* Since we allow different process to have different num_vcis, we always need run exchange. */ + mpi_errno = MPIDI_OFI_addr_exchange_all_ctx(); + MPIR_ERR_CHECK(mpi_errno); + + for (int vci = 1; vci < MPIDI_OFI_global.num_vcis; vci++) { + MPIDI_OFI_am_init(vci); + MPIDI_OFI_am_post_recv(vci, 0); + } + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +/* MPIDI_OFI_init_vcis: locally create multiple vcis */ + +static int check_num_nics(void); +static int setup_additional_vcis(void); + +int MPIDI_OFI_init_vcis(int num_vcis, int *num_vcis_actual) +{ + int mpi_errno = MPI_SUCCESS; + + /* Multiple vci without using domain require MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS */ +#ifndef MPIDI_OFI_VNI_USE_DOMAIN + MPIR_Assert(num_vcis == 1 || MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS); +#endif + + MPIDI_OFI_global.num_vcis = num_vcis; + + /* All processes must have the same number of NICs */ + mpi_errno = check_num_nics(); + MPIR_ERR_CHECK(mpi_errno); + + /* may update MPIDI_OFI_global.num_vcis */ + mpi_errno = setup_additional_vcis(); + MPIR_ERR_CHECK(mpi_errno); + + *num_vcis_actual = MPIDI_OFI_global.num_vcis; + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +static int check_num_nics(void) +{ + int mpi_errno = MPI_SUCCESS; + + int num_nics = MPIDI_OFI_global.num_nics; + int tmp_num_vcis = MPIDI_OFI_global.num_vcis; + int tmp_num_nics = MPIDI_OFI_global.num_nics; + + /* Set the number of NICs and VNIs to 1 temporarily to avoid problems during the collective */ + MPIDI_OFI_global.num_vcis = MPIDI_OFI_global.num_nics = 1; + + /* Confirm that all processes have the same number of NICs */ + mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, &num_nics, 1, MPI_INT, + MPI_MIN, MPIR_Process.comm_world, MPIR_ERR_NONE); + MPIDI_OFI_global.num_vcis = tmp_num_vcis; + MPIDI_OFI_global.num_nics = tmp_num_nics; + MPIR_ERR_CHECK(mpi_errno); + + /* If the user did not ask to fallback to fewer NICs, throw an error if someone is missing a + * NIC. */ + if (tmp_num_nics != num_nics) { + if (MPIR_CVAR_OFI_USE_MIN_NICS) { + MPIDI_OFI_global.num_nics = num_nics; + + /* If we fall down to 1 nic, turn off multi-nic optimizations. */ + if (num_nics == 1) { + MPIDI_OFI_COMM(MPIR_Process.comm_world).enable_striping = 0; + } + } else { + MPIR_ERR_CHKANDJUMP(num_nics != MPIDI_OFI_global.num_nics, mpi_errno, MPI_ERR_OTHER, + "**ofi_num_nics"); + } + } + + /* FIXME: It would also be helpful to check that all of the NICs can communicate so we can fall + * back to other options if that is not the case (e.g., verbs are often configured with a + * different subnet for each "set" of nics). It's unknown how to write a good check for that. */ + + /* set rx_ctx_cnt and tx_ctx_cnt for the remaining (non-0) nics */ + for (int nic = 1; nic < MPIDI_OFI_global.num_nics; nic++) { + set_sep_counters(nic); + } + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +static int setup_additional_vcis(void) +{ + int mpi_errno = MPI_SUCCESS; + + for (int vci = 0; vci < MPIDI_OFI_global.num_vcis; vci++) { + for (int nic = 0; nic < MPIDI_OFI_global.num_nics; nic++) { + /* vci 0 nic 0 already created */ + if (vci > 0 || nic > 0) { + mpi_errno = MPIDI_OFI_create_vci_context(vci, nic); + if (mpi_errno != MPI_SUCCESS) { + /* running out of vcis, reduce MPIDI_OFI_global.num_vcis */ + if (vci > 0) { + MPIDI_OFI_global.num_vcis = vci; + /* FIXME: destroy already created vci_context */ + mpi_errno = MPI_SUCCESS; + goto fn_exit; + } else { + MPIR_ERR_CHECK(mpi_errno); + } + } + } + } + } + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +/* MPIDI_OFI_addr_exchange_all_ctx: exchange addresses for multiple vcis */ + +/* Macros to reduce clutter, so we can focus on the ordering logics. + * Note: they are not perfectly wrapped, but tolerable since only used here. */ + +#if !defined(MPIDI_OFI_VNI_USE_DOMAIN) || MPIDI_CH4_MAX_VCIS == 1 +/* NOTE: with scalable endpoint as context, all vcis share the same address. */ +#define NUM_VCIS_FOR_RANK(r) 1 +#else +#define NUM_VCIS_FOR_RANK(r) all_num_vcis[r] +#endif + +#define GET_AV_AND_ADDRNAMES(rank) \ + MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, rank)); \ + char *r_names = all_names + rank * max_vcis * num_nics * name_len; + +#define DO_AV_INSERT(ctx_idx, nic, vci) \ + fi_addr_t addr; \ + MPIDI_OFI_CALL(fi_av_insert(MPIDI_OFI_global.ctx[ctx_idx].av, \ + r_names + (vci * num_nics + nic) * name_len, 1, \ + &addr, 0ULL, NULL), avmap); + +#define SKIP_ROOT(nic, vci) \ + if (nic == 0 && vci == 0) { \ + continue; \ + } + +/* with MPIDI_OFI_ENABLE_AV_TABLE, we potentially can omit storing av tables. + * The following routines ensures we can do that. It is static now, but we can + * easily export to global when we need to. + */ + +ATTRIBUTE((unused)) +static int get_root_av_table_index(int rank) +{ + if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { + /* node roots with greater ranks are inserted before this rank if it a non-node-root */ + int num_extra = 0; + + /* check node roots */ + for (int i = 0; i < MPIR_Process.num_nodes; i++) { + if (MPIR_Process.node_root_map[i] == rank) { + return i; + } else if (MPIR_Process.node_root_map[i] > rank) { + num_extra++; + } + } + + /* must be non-node-root */ + return rank + num_extra; + } else { + return rank; + } +} + +ATTRIBUTE((unused)) +static int get_av_table_index(int rank, int nic, int vci, int *all_num_vcis) +{ + if (nic == 0 && vci == 0) { + return get_root_av_table_index(rank); + } else { + int num_nics = MPIDI_OFI_global.num_nics; + int idx = 0; + idx += MPIR_Process.size; /* root entries */ + for (int i = 0; i < rank; i++) { + idx += num_nics * NUM_VCIS_FOR_RANK(i) - 1; + } + idx += nic * NUM_VCIS_FOR_RANK(rank) + vci - 1; + return idx; + } +} + +static int addr_exchange_all_ctx(MPIR_Comm * comm) { int mpi_errno = MPI_SUCCESS; - /* 0. get num_nics from CVARs */ - /* 1. check that MPIDI_OFI_global.n_total_vcis = 0 */ - /* 2. allocate and initialize local vcis */ - /* 3. exchange addresses */ + + MPIR_Comm *comm = MPIR_Process.comm_world; + int size = MPIR_Process.size; + int rank = MPIR_Process.rank; + MPIR_CHKLMEM_DECL(3); + + int max_vcis; + int *all_num_vcis; + +#if !defined(MPIDI_OFI_VNI_USE_DOMAIN) || MPIDI_CH4_MAX_VCIS == 1 + max_vcis = 1; + all_num_vcis = NULL; +#else + /* Allgather num_vcis */ + MPIR_CHKLMEM_MALLOC(all_num_vcis, void *, sizeof(int) * size, + mpi_errno, "all_num_vcis", MPL_MEM_ADDRESS); + mpi_errno = MPIR_Allgather_fallback(&MPIDI_OFI_global.num_vcis, 1, MPI_INT, + all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + + max_vcis = 0; + for (int i = 0; i < size; i++) { + if (max_vcis < NUM_VCIS_FOR_RANK(i)) { + max_vcis = NUM_VCIS_FOR_RANK(i); + } + } +#endif + + int num_vcis = NUM_VCIS_FOR_RANK(rank); + int num_nics = MPIDI_OFI_global.num_nics; + + /* Assume num_nics are all equal */ + if (max_vcis * num_nics == 1) { + goto fn_exit; + } + + /* libfabric uses uniform name_len within a single provider */ + int name_len = MPIDI_OFI_global.addrnamelen; + int my_len = max_vcis * num_nics * name_len; + char *all_names; + MPIR_CHKLMEM_MALLOC(all_names, char *, size * my_len, mpi_errno, "all_names", MPL_MEM_ADDRESS); + char *my_names = all_names + rank * my_len; + + /* put in my addrnames */ + for (int nic = 0; nic < num_nics; nic++) { + for (int vci = 0; vci < num_vcis; vci++) { + size_t actual_name_len = name_len; + char *vci_addrname = my_names + (vci * num_nics + nic) * name_len; + int ctx_idx = MPIDI_OFI_get_ctx_index(vci, nic); + MPIDI_OFI_CALL(fi_getname((fid_t) MPIDI_OFI_global.ctx[ctx_idx].ep, vci_addrname, + &actual_name_len), getname); + MPIR_Assert(actual_name_len == name_len); + } + } + /* Allgather */ + mpi_errno = MPIR_Allgather_fallback(MPI_IN_PLACE, 0, MPI_BYTE, + all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + + /* Step 2: insert and store non-root nic/vci on the root context */ + int root_ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); + for (int r = 0; r < size; r++) { + GET_AV_AND_ADDRNAMES(r); + for (int nic = 0; nic < num_nics; nic++) { + for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { + SKIP_ROOT(nic, vci); + DO_AV_INSERT(root_ctx_idx, nic, vci); + av->dest[nic][vci] = addr; + } + } + } + + /* Step 3: insert all nic/vci on non-root context, following exact order as step 1 and 2 */ + + int *is_node_roots = NULL; + if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { + MPIR_CHKLMEM_MALLOC(is_node_roots, int *, size * sizeof(int), + mpi_errno, "is_node_roots", MPL_MEM_ADDRESS); + for (int r = 0; r < size; r++) { + is_node_roots[r] = 0; + } + for (int i = 0; i < MPIR_Process.num_nodes; i++) { + is_node_roots[MPIR_Process.node_root_map[i]] = 1; + } + } + + for (int nic_local = 0; nic_local < num_nics; nic_local++) { + for (int vci_local = 0; vci_local < num_vcis; vci_local++) { + SKIP_ROOT(nic_local, vci_local); + int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic_local); + + /* -- same order as step 1 -- */ + if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { + /* node roots */ + for (int r = 0; r < size; r++) { + if (is_node_roots[r]) { + GET_AV_AND_ADDRNAMES(r); + DO_AV_INSERT(ctx_idx, 0, 0); + MPIR_Assert(av->dest[0][0] == addr); + } + } + /* non-node-root */ + for (int r = 0; r < size; r++) { + if (!is_node_roots[r]) { + GET_AV_AND_ADDRNAMES(r); + DO_AV_INSERT(ctx_idx, 0, 0); + MPIR_Assert(av->dest[0][0] == addr); + } + } + } else { + /* !MPIR_CVAR_CH4_ROOTS_ONLY_PMI */ + for (int r = 0; r < size; r++) { + GET_AV_AND_ADDRNAMES(r); + DO_AV_INSERT(ctx_idx, 0, 0); + MPIR_Assert(av->dest[0][0] == addr); + } + } + + /* -- same order as step 2 -- */ + for (int r = 0; r < size; r++) { + GET_AV_AND_ADDRNAMES(r); + for (int nic = 0; nic < num_nics; nic++) { + for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { + SKIP_ROOT(nic, vci); + DO_AV_INSERT(ctx_idx, nic, vci); + MPIR_Assert(av->dest[nic][vci] == addr); + } + } + } + } + } + mpi_errno = MPIR_Barrier_fallback(comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + + /* check */ +#if MPIDI_CH4_MAX_VCIS > 1 + if (MPIDI_OFI_ENABLE_AV_TABLE) { + for (int r = 0; r < size; r++) { + MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, r)); + for (int nic = 0; nic < num_nics; nic++) { + for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { + MPIR_Assert(av->dest[nic][vci] == get_av_table_index(r, nic, vci, + all_num_vcis)); + } + } + } + } +#endif + fn_exit: + MPIR_CHKLMEM_FREEALL(); return mpi_errno; + fn_fail: + goto fn_exit; } From 2115cbceba765b8d84272bc6bf2e5d5f06937669 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 28 Dec 2024 11:26:15 -0600 Subject: [PATCH 10/25] ch4: call SHM/NM comm_set_vcis in MPIDI_Comm_set_vcis --- src/mpid/ch4/netmod/ofi/ofi_init.c | 2 -- src/mpid/ch4/src/ch4_vci.c | 11 ++++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 8db1ae86ef6..33c3d3bf554 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -805,8 +805,6 @@ int MPIDI_OFI_post_init(void) dump_dynamic_settings(); } - mpi_errno = MPIDI_OFI_comm_set_vcis(MPIR_Process.comm_world); - fn_exit: return mpi_errno; fn_fail: diff --git a/src/mpid/ch4/src/ch4_vci.c b/src/mpid/ch4/src/ch4_vci.c index 628733f3896..f10c700b2e9 100644 --- a/src/mpid/ch4/src/ch4_vci.c +++ b/src/mpid/ch4/src/ch4_vci.c @@ -121,13 +121,22 @@ int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis) MPIDI_global.all_num_vcis[granks[i]] = all_num_vcis[i]; } - comm->vcis_enabled = true; + /* setup vcis in netmod and shm */ + mpi_errno = MPIDI_NM_comm_set_vcis(comm); + MPIR_ERR_CHECK(mpi_errno); +#ifndef MPIDI_CH4_DIRECT_NETMOD + mpi_errno = MPIDI_SHM_comm_set_vcis(comm, MPIDI_global.n_total_vcis); + MPIR_ERR_CHECK(mpi_errno); +#endif for (int vci = 1; vci < MPIDI_global.n_total_vcis; vci++) { mpi_errno = MPIDI_init_per_vci(vci); MPIR_ERR_CHECK(mpi_errno); } + /* enable multiple vcis */ + comm->vcis_enabled = true; + fn_exit: MPIR_CHKLMEM_FREEALL(); return mpi_errno; From fbab8f63048f1d5ad28b40df3ccec83998bebdf7 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 28 Dec 2024 14:10:49 -0600 Subject: [PATCH 11/25] ch4/ofi: do vci address exchange within a comm Add the flexibility of perform multiple vci address exchange within a comm other than the comm world. For one, this is important in a session where the comm_world may not exist. For two, this provides a mechanism to save resource when applications don't need multiple vcis for the entire comm world. --- src/mpid/ch4/netmod/ofi/init_addrxchg.c | 32 +--- src/mpid/ch4/netmod/ofi/ofi_init.h | 1 - src/mpid/ch4/netmod/ofi/ofi_vci.c | 225 ++++++++---------------- 3 files changed, 79 insertions(+), 179 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/init_addrxchg.c b/src/mpid/ch4/netmod/ofi/init_addrxchg.c index 5b33ff7cb7f..7c12c2cd012 100644 --- a/src/mpid/ch4/netmod/ofi/init_addrxchg.c +++ b/src/mpid/ch4/netmod/ofi/init_addrxchg.c @@ -8,37 +8,7 @@ #include "ofi_init.h" #include "mpidu_bc.h" -/* NOTE on av insertion order: - * - * Each nic-vci is an endpoint with a unique address, and inside libfabric maintains - * one av table. Thus to fully store the address mapping, we'll need a multi-dim table as - * av_table[src_nic][src_vci][dest_rank][dest_nic][dest_vci] - * Note, this table is for illustration, and different from MPIDI_OFI_addr_t. - * - * However, if we insert the address carefully, we can manage to make the av table inside - * each endpoint *identical*. Then, we can omit the dimension of [src_nic][src_vci]. - * - * To achieve that, we use the following 3-step process (described with above illustrative av_table). - * - * Step 1. insert and store av_table[ 0 ][ 0 ][rank][ 0 ][ 0 ] - * - * Step 2. insert and store av_table[ 0 ][ 0 ][rank][nic][vci] - * - * Step 3. insert (but not store) av_table[nic][vci][rank][nic][vci] - * - * The step 1 is done in addr_exchange_root_vci. Step 2 and 3 are done in addr_exchange_all_vcis. - * Step 3 populates av tables inside libfabric for all non-zero endpoints, but they should be - * identical to the table in root endpoint, thus no need to store them in mpich. Thus the table is - * reduced to - * av_table[rank] -> dest[nic][vci] - * - * With single-nic and single-vci, only step 1 is needed. - * - * We do step 1 during world-init, and step 2 & 3 during post-init. The separation - * isolates multi-nic/vci complications from bootstrapping phase. - */ - -/* Step 1: exchange root contexts */ +/* exchange root contexts */ int MPIDI_OFI_addr_exchange_root_ctx(void) { int mpi_errno = MPI_SUCCESS; diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.h b/src/mpid/ch4/netmod/ofi/ofi_init.h index ab06b92fba7..d22f7e99d89 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.h +++ b/src/mpid/ch4/netmod/ofi/ofi_init.h @@ -34,7 +34,6 @@ bool MPIDI_OFI_nic_already_used(const struct fi_info *prov, struct fi_info **oth int MPIDI_OFI_create_vci_context(int vci, int nic); int MPIDI_OFI_addr_exchange_root_ctx(void); -int MPIDI_OFI_addr_exchange_all_ctx(void); int MPIDI_OFI_am_init(int vci); int MPIDI_OFI_am_post_recv(int vci, int nic); diff --git a/src/mpid/ch4/netmod/ofi/ofi_vci.c b/src/mpid/ch4/netmod/ofi/ofi_vci.c index e0403482d2f..f55b02914de 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_vci.c +++ b/src/mpid/ch4/netmod/ofi/ofi_vci.c @@ -7,8 +7,28 @@ #include "ofi_impl.h" #include "ofi_init.h" +/* NOTE on av insertion order: + * + * Each nic-vci is an endpoint with a unique address, and inside libfabric maintains + * one av table. Thus to fully store the address mapping, we'll need a multi-dim table as + * av_table[src_vci][src_nic][dest_rank][dest_vci][dest_nic] + * Note, this table is for illustration, and different from MPIDI_OFI_addr_t. + * + * However, if we insert the addresses in the same order between local endpoints, then the + * av table indexes will be *identical*. Then, we can omit the dimension of [src_vci][src_nic]. + * I.e. we only need av_table[rank][vci][nic], saving two dimensions of local vcis and local nics. + * + * To achieve that, we need always insert each remote address on *all* local endpoints together. + * Because we separate root addr (av_table[0][0][rank][0][0]) separately, we allow the root + * address to be inserted separately from the rest. The rest of the addresses are only + * needed when multiple vcis/nics are enabled. But we require for each remote rank, all remote + * endpoints to be inserted all at once. + */ + /* Address exchange within comm and setup multiple vcis */ +static int addr_exchange_all_ctx(MPIR_Comm * comm); + int MPIDI_OFI_comm_set_vcis(MPIR_Comm * comm) { int mpi_errno = MPI_SUCCESS; @@ -16,7 +36,7 @@ int MPIDI_OFI_comm_set_vcis(MPIR_Comm * comm) MPIR_Assert(comm == MPIR_Process.comm_world); /* Since we allow different process to have different num_vcis, we always need run exchange. */ - mpi_errno = MPIDI_OFI_addr_exchange_all_ctx(); + mpi_errno = addr_exchange_all_ctx(comm); MPIR_ERR_CHECK(mpi_errno); for (int vci = 1; vci < MPIDI_OFI_global.num_vcis; vci++) { @@ -39,6 +59,9 @@ int MPIDI_OFI_init_vcis(int num_vcis, int *num_vcis_actual) { int mpi_errno = MPI_SUCCESS; + /* NOTE: we only can create contexts for additional vcis and nics once (see notes at top) */ + MPIR_Assert(MPIDI_OFI_global.num_vcis == 1 && MPIDI_OFI_global.num_nics == 1); + /* Multiple vci without using domain require MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS */ #ifndef MPIDI_OFI_VNI_USE_DOMAIN MPIR_Assert(num_vcis == 1 || MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS); @@ -141,7 +164,7 @@ static int setup_additional_vcis(void) goto fn_exit; } -/* MPIDI_OFI_addr_exchange_all_ctx: exchange addresses for multiple vcis */ +/* addr_exchange_all_ctx: exchange addresses for multiple vcis */ /* Macros to reduce clutter, so we can focus on the ordering logics. * Note: they are not perfectly wrapped, but tolerable since only used here. */ @@ -154,7 +177,7 @@ static int setup_additional_vcis(void) #endif #define GET_AV_AND_ADDRNAMES(rank) \ - MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, rank)); \ + MPIDI_av_entry_t *av ATTRIBUTE((unused)) = &MPIDIU_get_av(0, rank); \ char *r_names = all_names + rank * max_vcis * num_nics * name_len; #define DO_AV_INSERT(ctx_idx, nic, vci) \ @@ -168,83 +191,35 @@ static int setup_additional_vcis(void) continue; \ } -/* with MPIDI_OFI_ENABLE_AV_TABLE, we potentially can omit storing av tables. - * The following routines ensures we can do that. It is static now, but we can - * easily export to global when we need to. - */ - -ATTRIBUTE((unused)) -static int get_root_av_table_index(int rank) -{ - if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { - /* node roots with greater ranks are inserted before this rank if it a non-node-root */ - int num_extra = 0; - - /* check node roots */ - for (int i = 0; i < MPIR_Process.num_nodes; i++) { - if (MPIR_Process.node_root_map[i] == rank) { - return i; - } else if (MPIR_Process.node_root_map[i] > rank) { - num_extra++; - } - } - - /* must be non-node-root */ - return rank + num_extra; - } else { - return rank; - } -} - -ATTRIBUTE((unused)) -static int get_av_table_index(int rank, int nic, int vci, int *all_num_vcis) -{ - if (nic == 0 && vci == 0) { - return get_root_av_table_index(rank); - } else { - int num_nics = MPIDI_OFI_global.num_nics; - int idx = 0; - idx += MPIR_Process.size; /* root entries */ - for (int i = 0; i < rank; i++) { - idx += num_nics * NUM_VCIS_FOR_RANK(i) - 1; - } - idx += nic * NUM_VCIS_FOR_RANK(rank) + vci - 1; - return idx; - } -} - static int addr_exchange_all_ctx(MPIR_Comm * comm) { int mpi_errno = MPI_SUCCESS; - - MPIR_Comm *comm = MPIR_Process.comm_world; - int size = MPIR_Process.size; - int rank = MPIR_Process.rank; MPIR_CHKLMEM_DECL(3); - int max_vcis; - int *all_num_vcis; + int nprocs = comm->local_size; + int myrank = comm->rank; -#if !defined(MPIDI_OFI_VNI_USE_DOMAIN) || MPIDI_CH4_MAX_VCIS == 1 - max_vcis = 1; - all_num_vcis = NULL; -#else - /* Allgather num_vcis */ - MPIR_CHKLMEM_MALLOC(all_num_vcis, void *, sizeof(int) * size, - mpi_errno, "all_num_vcis", MPL_MEM_ADDRESS); - mpi_errno = MPIR_Allgather_fallback(&MPIDI_OFI_global.num_vcis, 1, MPI_INT, - all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); - MPIR_ERR_CHECK(mpi_errno); + /* get global ranks and number of remote vcis */ + int *grank, *all_num_vcis; + MPIR_CHKLMEM_MALLOC(granks, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER); + MPIR_CHKLMEM_MALLOC(all_num_vcis, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER); + for (int i = 0; i < nprocs; i++) { + int avtid; + MPIDIU_comm_rank_to_pid(comm, i, &granks[i], &avtid); + MPIR_Assert(avtid == 0); - max_vcis = 0; - for (int i = 0; i < size; i++) { + all_num_vcis[i] = MPIDI_global.all_num_vcis[granks[i]]; + } + + /* get the max_vcis (so we can do Allgather rather than Allgatherv) */ + int max_vcis = 0; + for (int i = 0; i < nprocs; i++) { if (max_vcis < NUM_VCIS_FOR_RANK(i)) { max_vcis = NUM_VCIS_FOR_RANK(i); } } -#endif - int num_vcis = NUM_VCIS_FOR_RANK(rank); + int my_num_vcis = NUM_VCIS_FOR_RANK(rank); int num_nics = MPIDI_OFI_global.num_nics; /* Assume num_nics are all equal */ @@ -252,16 +227,26 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) goto fn_exit; } + /* allocate all_dest[] in av entry */ + for (int i = 0; i < nprocs; i++) { + MPIDI_av_entry_t *av = MPIDIU_comm_rank_to_av(comm, i); + MPIR_Assert(MPIDI_OFI_AV(av).all_dest == NULL); + MPIDI_OFI_AV(av).all_dest = MPL_malloc(max_vcis * num_nics * sizeof(fi_addr_t), + MPL_MEM_ADDRESS); + MPIR_ERR_CHKANDJUMP(!MPIDI_OFI_AV(av).all_dest, mpi_errno, MPI_ERR_OTHER, "**nomem"); + } + /* libfabric uses uniform name_len within a single provider */ int name_len = MPIDI_OFI_global.addrnamelen; int my_len = max_vcis * num_nics * name_len; char *all_names; - MPIR_CHKLMEM_MALLOC(all_names, char *, size * my_len, mpi_errno, "all_names", MPL_MEM_ADDRESS); + MPIR_CHKLMEM_MALLOC(all_names, char *, nprocs * my_len, mpi_errno, "all_names", + MPL_MEM_ADDRESS); char *my_names = all_names + rank * my_len; /* put in my addrnames */ for (int nic = 0; nic < num_nics; nic++) { - for (int vci = 0; vci < num_vcis; vci++) { + for (int vci = 0; vci < my_num_vcis; vci++) { size_t actual_name_len = name_len; char *vci_addrname = my_names + (vci * num_nics + nic) * name_len; int ctx_idx = MPIDI_OFI_get_ctx_index(vci, nic); @@ -271,98 +256,44 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) } } /* Allgather */ - mpi_errno = MPIR_Allgather_fallback(MPI_IN_PLACE, 0, MPI_BYTE, - all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + mpi_errno = MPIR_Allgather_impl(MPI_IN_PLACE, 0, MPI_BYTE, + all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); /* Step 2: insert and store non-root nic/vci on the root context */ int root_ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); - for (int r = 0; r < size; r++) { + for (int r = 0; r < nprocs; r++) { GET_AV_AND_ADDRNAMES(r); + /* for each remote endpoints */ for (int nic = 0; nic < num_nics; nic++) { for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { - SKIP_ROOT(nic, vci); - DO_AV_INSERT(root_ctx_idx, nic, vci); - av->dest[nic][vci] = addr; - } - } - } - - /* Step 3: insert all nic/vci on non-root context, following exact order as step 1 and 2 */ - - int *is_node_roots = NULL; - if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { - MPIR_CHKLMEM_MALLOC(is_node_roots, int *, size * sizeof(int), - mpi_errno, "is_node_roots", MPL_MEM_ADDRESS); - for (int r = 0; r < size; r++) { - is_node_roots[r] = 0; - } - for (int i = 0; i < MPIR_Process.num_nodes; i++) { - is_node_roots[MPIR_Process.node_root_map[i]] = 1; - } - } - - for (int nic_local = 0; nic_local < num_nics; nic_local++) { - for (int vci_local = 0; vci_local < num_vcis; vci_local++) { - SKIP_ROOT(nic_local, vci_local); - int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic_local); - - /* -- same order as step 1 -- */ - if (MPIR_CVAR_CH4_ROOTS_ONLY_PMI) { - /* node roots */ - for (int r = 0; r < size; r++) { - if (is_node_roots[r]) { - GET_AV_AND_ADDRNAMES(r); - DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(av->dest[0][0] == addr); - } - } - /* non-node-root */ - for (int r = 0; r < size; r++) { - if (!is_node_roots[r]) { - GET_AV_AND_ADDRNAMES(r); - DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(av->dest[0][0] == addr); - } - } - } else { - /* !MPIR_CVAR_CH4_ROOTS_ONLY_PMI */ - for (int r = 0; r < size; r++) { - GET_AV_AND_ADDRNAMES(r); - DO_AV_INSERT(ctx_idx, 0, 0); - MPIR_Assert(av->dest[0][0] == addr); - } - } - - /* -- same order as step 2 -- */ - for (int r = 0; r < size; r++) { - GET_AV_AND_ADDRNAMES(r); - for (int nic = 0; nic < num_nics; nic++) { - for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { - SKIP_ROOT(nic, vci); + /* for each local endpoints */ + fi_addr_t expect_addr = FI_ADDR_NOTAVAIL; + for (int nic_local = 0; nic_local < num_nics; nic_local++) { + for (int vci_local = 0; vci_local < my_num_vcis; vci_local++) { + /* skip root */ + if (nic == 0 && vci == 0 && nic_local == 0 && vci_local == 0) { + next; + } + int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic_local); DO_AV_INSERT(ctx_idx, nic, vci); - MPIR_Assert(av->dest[nic][vci] == addr); + /* we expect all resulting addr to be the same */ + if (expect_addr == FI_ADDR_NOTAVAIL) { + expect_addr = addr; + } else { + MPIR_Assert(expect_addr == addr); + } } } + MPIR_Assert(expect_addr != FI_ADDR_NOTAVAIL); + MPIDI_OFI_AV_ADDR(av, vci, nic) = addr; } } } + mpi_errno = MPIR_Barrier_fallback(comm, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); - /* check */ -#if MPIDI_CH4_MAX_VCIS > 1 - if (MPIDI_OFI_ENABLE_AV_TABLE) { - for (int r = 0; r < size; r++) { - MPIDI_OFI_addr_t *av ATTRIBUTE((unused)) = &MPIDI_OFI_AV(&MPIDIU_get_av(0, r)); - for (int nic = 0; nic < num_nics; nic++) { - for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { - MPIR_Assert(av->dest[nic][vci] == get_av_table_index(r, nic, vci, - all_num_vcis)); - } - } - } - } -#endif fn_exit: MPIR_CHKLMEM_FREEALL(); return mpi_errno; From 081264288197f0f0cbfe026ffaf4f6279f000e6a Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sat, 28 Dec 2024 19:05:59 -0600 Subject: [PATCH 12/25] ch4/ofi: only activate multiple nics when vcis are set We exchange non-root endpoints in comm_set_vcis. Because we can't use multiple nics before the address exchange, we only can activate multiple nics in comm_set_vcis. --- src/mpid/ch4/netmod/ofi/ofi_impl.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index 6264225a74e..beaf34f768a 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -579,7 +579,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_multx_sender_nic_index(MPIR_Comm * comm, { int nic_idx = 0; - if (MPIDI_OFI_COMM(comm).pref_nic) { + if (!comm->vcis_enabled) { + /* address exchange may not have performed on this comm */ + nic_idx = 0; + } else if (MPIDI_OFI_COMM(comm).pref_nic) { nic_idx = MPIDI_OFI_COMM(comm).pref_nic[sender_rank]; } else if (MPIDI_OFI_COMM(comm).enable_hashing) { /* TODO - We should use the per-communicator value for the maximum number of NICs in this @@ -607,7 +610,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_multx_receiver_nic_index(MPIR_Comm * comm { int nic_idx = 0; - if (MPIDI_OFI_COMM(comm).pref_nic) { + if (!comm->vcis_enabled) { + /* address exchange may not have performed on this comm */ + nic_idx = 0; + } else if (MPIDI_OFI_COMM(comm).pref_nic) { nic_idx = MPIDI_OFI_COMM(comm).pref_nic[receiver_rank]; } else if (MPIDI_OFI_COMM(comm).enable_hashing) { /* TODO - We should use the per-communicator value for the maximum number of NICs in this From 370ff1a70ef7df1c89437a5ba043f57ebbda5157 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 29 Dec 2024 09:04:39 -0600 Subject: [PATCH 13/25] ch4/ofi: delay setting MPIDI_OFI_global.num_nics MPIDI_OFI_global.num_nics affects runtime paths such as ofi progress and large message striping. Only set it in MPIDI_OFI_init_vcis so we won't have complications when multi-nics is not ready. --- src/mpid/ch4/netmod/ofi/ofi_nic.c | 53 +++++++++++++++-------------- src/mpid/ch4/netmod/ofi/ofi_types.h | 1 + src/mpid/ch4/netmod/ofi/ofi_vci.c | 1 + 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_nic.c b/src/mpid/ch4/netmod/ofi/ofi_nic.c index 7ee60b1f60f..6e9abcc19c7 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_nic.c +++ b/src/mpid/ch4/netmod/ofi/ofi_nic.c @@ -203,7 +203,7 @@ int MPIDI_OFI_init_multi_nic(struct fi_info *prov) mpi_errno = setup_single_nic(); MPIR_ERR_CHECK(mpi_errno); } - MPIR_Assert(MPIDI_OFI_global.num_nics > 0); + MPIR_Assert(MPIDI_OFI_global.num_nics_available > 0); fn_exit: return mpi_errno; @@ -213,7 +213,7 @@ int MPIDI_OFI_init_multi_nic(struct fi_info *prov) static int setup_single_nic(void) { - MPIDI_OFI_global.num_nics = 1; + MPIDI_OFI_global.num_nics_available = 1; MPIDI_OFI_global.num_close_nics = 1; MPIDI_OFI_global.nic_info[0].nic = MPIDI_OFI_global.prov_use[0]; MPIDI_OFI_global.nic_info[0].id = 0; @@ -226,7 +226,7 @@ static int setup_single_nic(void) MPIR_Info *info_ptr = NULL; MPIR_Info_get_ptr(MPI_INFO_ENV, info_ptr); snprintf(nics_str, 32, "%d", 1); - MPIR_Info_set_impl(info_ptr, "num_nics", nics_str); + MPIR_Info_set_impl(info_ptr, "num_nics_available", nics_str); snprintf(nics_str, 32, "%d", 1); MPIR_Info_set_impl(info_ptr, "num_close_nics", nics_str); @@ -259,39 +259,41 @@ static int setup_multi_nic(int nic_count) MPIR_CHKLMEM_DECL(1); bool pref_nic_set = false; - MPIDI_OFI_global.num_nics = nic_count; + MPIDI_OFI_global.num_nics_available = nic_count; - if (MPIR_CVAR_CH4_OFI_PREF_NIC > -1 && MPIR_CVAR_CH4_OFI_PREF_NIC < MPIDI_OFI_global.num_nics) { + if (MPIR_CVAR_CH4_OFI_PREF_NIC > -1 && + MPIR_CVAR_CH4_OFI_PREF_NIC < MPIDI_OFI_global.num_nics_available) { pref_nic_set = true; } /* Initially sort the NICs by name. This way all intranode ranks have a consistent view. */ - qsort(MPIDI_OFI_global.prov_use, MPIDI_OFI_global.num_nics, sizeof(struct fi_info *), + qsort(MPIDI_OFI_global.prov_use, MPIDI_OFI_global.num_nics_available, sizeof(struct fi_info *), compare_nic_names); /* Limit the number of physical NICs depending on the CVAR */ - if (MPIR_CVAR_CH4_OFI_MAX_NICS > 0 && MPIDI_OFI_global.num_nics > MPIR_CVAR_CH4_OFI_MAX_NICS) { - for (int i = MPIR_CVAR_CH4_OFI_MAX_NICS; i < MPIDI_OFI_global.num_nics; ++i) { + if (MPIR_CVAR_CH4_OFI_MAX_NICS > 0 && + MPIDI_OFI_global.num_nics_available > MPIR_CVAR_CH4_OFI_MAX_NICS) { + for (int i = MPIR_CVAR_CH4_OFI_MAX_NICS; i < MPIDI_OFI_global.num_nics_available; ++i) { fi_freeinfo(MPIDI_OFI_global.prov_use[i]); } - MPIDI_OFI_global.num_nics = MPIR_CVAR_CH4_OFI_MAX_NICS; + MPIDI_OFI_global.num_nics_available = MPIR_CVAR_CH4_OFI_MAX_NICS; } int num_numa_nodes = MPIR_hwtopo_get_num_numa_nodes(); bool is_snc4_with_cxi_nics = false; if ((num_numa_nodes == 8 || num_numa_nodes == 16)) - if (MPIDI_OFI_global.num_nics > 1) + if (MPIDI_OFI_global.num_nics_available > 1) if (strstr(MPIDI_OFI_global.prov_use[0]->domain_attr->name, "cxi")) is_snc4_with_cxi_nics = true; /* Special case of nic assignment for SPR in SNC4 mode */ if (is_snc4_with_cxi_nics && !pref_nic_set) { - for (int i = 0; i < MPIDI_OFI_global.num_nics; ++i) { + for (int i = 0; i < MPIDI_OFI_global.num_nics_available; ++i) { nics[i].nic = MPIDI_OFI_global.prov_use[i]; nics[i].id = i; /* Set the preference of all NICs to least preferable (lower is more preferable) */ - nics[i].prefer = MPIDI_OFI_global.num_nics + 1; + nics[i].prefer = MPIDI_OFI_global.num_nics_available + 1; nics[i].count = 0; nics[i].num_close_ranks = 0; @@ -310,7 +312,7 @@ static int setup_multi_nic(int nic_count) } } /* Use num_parents to determine nic closeness */ - for (int i = 0; i < MPIDI_OFI_global.num_nics; ++i) { + for (int i = 0; i < MPIDI_OFI_global.num_nics_available; ++i) { nics[i].close = is_nic_close_snc4(&nics[i], num_parents); if (nics[i].close) MPIDI_OFI_global.num_close_nics++; @@ -321,7 +323,7 @@ static int setup_multi_nic(int nic_count) /* Now go through every NIC and set initial information * from current process's perspective */ - for (int i = 0; i < MPIDI_OFI_global.num_nics; ++i) { + for (int i = 0; i < MPIDI_OFI_global.num_nics_available; ++i) { nics[i].nic = MPIDI_OFI_global.prov_use[i]; nics[i].id = i; /* Determine NIC's "closeness" to current process */ @@ -329,7 +331,7 @@ static int setup_multi_nic(int nic_count) if (nics[i].close) MPIDI_OFI_global.num_close_nics++; /* Set the preference of all NICs to least preferable (lower is more preferable) */ - nics[i].prefer = MPIDI_OFI_global.num_nics + 1; + nics[i].prefer = MPIDI_OFI_global.num_nics_available + 1; nics[i].count = 0; nics[i].num_close_ranks = 0; /* Determine NIC's first normal parent topology @@ -355,9 +357,9 @@ static int setup_multi_nic(int nic_count) /* If there were zero NICs on my socket, then just consider every NIC close * and share them among all ranks with a similar view */ if (MPIDI_OFI_global.num_close_nics == 0) { - for (int i = 0; i < MPIDI_OFI_global.num_nics; ++i) + for (int i = 0; i < MPIDI_OFI_global.num_nics_available; ++i) nics[i].close = 1; - MPIDI_OFI_global.num_close_nics = MPIDI_OFI_global.num_nics; + MPIDI_OFI_global.num_close_nics = MPIDI_OFI_global.num_nics_available; } if (pref_nic_set) { @@ -377,9 +379,9 @@ static int setup_multi_nic(int nic_count) if (is_snc4_with_cxi_nics) { /* Use a separate sorting function for snc4 nics in order to just compare * closeness followed by nic name */ - qsort(nics, MPIDI_OFI_global.num_nics, sizeof(nics[0]), compare_nics_snc4); + qsort(nics, MPIDI_OFI_global.num_nics_available, sizeof(nics[0]), compare_nics_snc4); } else { - qsort(nics, MPIDI_OFI_global.num_nics, sizeof(nics[0]), compare_nics); + qsort(nics, MPIDI_OFI_global.num_nics_available, sizeof(nics[0]), compare_nics); } /* Because we cannot communicate with the other local processes to avoid collisions with the @@ -391,9 +393,10 @@ static int setup_multi_nic(int nic_count) if (old_idx != 0) { MPIDI_OFI_nic_info_t *old_nics; MPIR_CHKLMEM_MALLOC(old_nics, MPIDI_OFI_nic_info_t *, sizeof(MPIDI_OFI_nic_info_t) * - MPIDI_OFI_global.num_nics, mpi_errno, "temporary nic info", - MPL_MEM_ADDRESS); - memcpy(old_nics, nics, sizeof(MPIDI_OFI_nic_info_t) * MPIDI_OFI_global.num_nics); + MPIDI_OFI_global.num_nics_available, mpi_errno, + "temporary nic info", MPL_MEM_ADDRESS); + memcpy(old_nics, nics, + sizeof(MPIDI_OFI_nic_info_t) * MPIDI_OFI_global.num_nics_available); /* Rotate the preferred NIC for each process starting at old_idx. */ for (int new_idx = 0; new_idx < MPIDI_OFI_global.num_close_nics; new_idx++) { @@ -407,7 +410,7 @@ static int setup_multi_nic(int nic_count) } /* Reorder the prov_use array based on nic_info array */ - for (int i = 0; i < MPIDI_OFI_global.num_nics; ++i) { + for (int i = 0; i < MPIDI_OFI_global.num_nics_available; ++i) { MPIDI_OFI_global.prov_use[i] = nics[i].nic; } @@ -415,8 +418,8 @@ static int setup_multi_nic(int nic_count) char nics_str[32]; MPIR_Info *info_ptr = NULL; MPIR_Info_get_ptr(MPI_INFO_ENV, info_ptr); - snprintf(nics_str, 32, "%d", MPIDI_OFI_global.num_nics); - MPIR_Info_set_impl(info_ptr, "num_nics", nics_str); + snprintf(nics_str, 32, "%d", MPIDI_OFI_global.num_nics_available); + MPIR_Info_set_impl(info_ptr, "num_nics_available", nics_str); snprintf(nics_str, 32, "%d", MPIDI_OFI_global.num_close_nics); MPIR_Info_set_impl(info_ptr, "num_close_nics", nics_str); diff --git a/src/mpid/ch4/netmod/ofi/ofi_types.h b/src/mpid/ch4/netmod/ofi/ofi_types.h index 9b1309fd0e5..bc4073d802d 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_types.h +++ b/src/mpid/ch4/netmod/ofi/ofi_types.h @@ -458,6 +458,7 @@ typedef struct MPIDI_GPU_RDMA_queue_t { typedef struct { /* OFI objects */ int avtid; + int num_nics_available; struct fi_info *prov_use[MPIDI_OFI_MAX_NICS]; MPIDI_OFI_nic_info_t nic_info[MPIDI_OFI_MAX_NICS]; struct fid_fabric *fabric; diff --git a/src/mpid/ch4/netmod/ofi/ofi_vci.c b/src/mpid/ch4/netmod/ofi/ofi_vci.c index f55b02914de..f8205009f27 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_vci.c +++ b/src/mpid/ch4/netmod/ofi/ofi_vci.c @@ -67,6 +67,7 @@ int MPIDI_OFI_init_vcis(int num_vcis, int *num_vcis_actual) MPIR_Assert(num_vcis == 1 || MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS); #endif + MPIDI_OFI_global.num_nics = MPIDI_OFI_global.num_nics_available; MPIDI_OFI_global.num_vcis = num_vcis; /* All processes must have the same number of NICs */ From 1c0128e9134661678d8405b2cbbcc2ae7013a61f Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 29 Dec 2024 09:23:33 -0600 Subject: [PATCH 14/25] ch4/ofi: gather vci init in MPIDI_OFI_vci_init --- src/mpid/ch4/netmod/ofi/ofi_init.c | 5 +++-- src/mpid/ch4/netmod/ofi/ofi_init.h | 1 + src/mpid/ch4/netmod/ofi/ofi_vci.c | 7 +++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index 33c3d3bf554..be9e6ae3937 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -715,6 +715,9 @@ int MPIDI_OFI_init_local(int *tag_bits) mpi_errno = ofi_pvar_init(); MPIR_ERR_CHECK(mpi_errno); + mpi_errno = MPIDI_OFI_vci_init(); + MPIR_ERR_CHECK(mpi_errno); + /* -------------------------------- */ /* Set up the libfabric provider(s) */ /* -------------------------------- */ @@ -722,8 +725,6 @@ int MPIDI_OFI_init_local(int *tag_bits) /* WB TODO - I assume that after this function is done, there will be an array of providers in * MPIDI_OFI_global.prov_use that will map to the VNI contexts below. We can also use it to * generate the addresses in the business card exchange. */ - MPIDI_OFI_global.num_nics = 1; - struct fi_info *prov = NULL; mpi_errno = MPIDI_OFI_find_provider(&prov); MPIR_ERR_CHECK(mpi_errno); diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.h b/src/mpid/ch4/netmod/ofi/ofi_init.h index d22f7e99d89..907aa321de7 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.h +++ b/src/mpid/ch4/netmod/ofi/ofi_init.h @@ -11,6 +11,7 @@ int MPIDI_OFI_get_required_version(void); int MPIDI_OFI_find_provider(struct fi_info **prov_out); void MPIDI_OFI_find_provider_cleanup(void); int MPIDI_OFI_init_multi_nic(struct fi_info *prov); +int MPIDI_OFI_vci_init(void); /* set hints based on MPIDI_OFI_global.settings */ int MPIDI_OFI_init_hints(struct fi_info *hints); diff --git a/src/mpid/ch4/netmod/ofi/ofi_vci.c b/src/mpid/ch4/netmod/ofi/ofi_vci.c index f8205009f27..fe73c21c374 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_vci.c +++ b/src/mpid/ch4/netmod/ofi/ofi_vci.c @@ -25,6 +25,13 @@ * endpoints to be inserted all at once. */ +int MPIDI_OFI_vci_init(void) +{ + MPIDI_OFI_global.num_nics = 1; + MPIDI_OFI_global.num_vcis = 1; + return MPI_SUCCESS; +} + /* Address exchange within comm and setup multiple vcis */ static int addr_exchange_all_ctx(MPIR_Comm * comm); From b5f72016f4cccb6f2b7ce68b0b04d40553ad802b Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 29 Dec 2024 10:54:49 -0600 Subject: [PATCH 15/25] ch4: do local vci creations in MPIDI_NM_comm_set_vcis --- src/mpid/ch4/netmod/ofi/ofi_init.c | 23 ++--- src/mpid/ch4/netmod/ofi/ofi_vci.c | 144 +++++++++++------------------ src/mpid/ch4/src/ch4_vci.c | 37 ++++---- 3 files changed, 75 insertions(+), 129 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_init.c b/src/mpid/ch4/netmod/ofi/ofi_init.c index be9e6ae3937..2ee52a01101 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_init.c +++ b/src/mpid/ch4/netmod/ofi/ofi_init.c @@ -541,7 +541,6 @@ categories : static int update_global_limits(struct fi_info *prov); static void dump_global_settings(void); -static void dump_dynamic_settings(void); static int destroy_vci_context(int vci, int nic); static int ofi_pvar_init(void); @@ -638,7 +637,8 @@ static void set_sep_counters(int nic) /* Note: currently we request a single tx and rx ctx under MPIDI_OFI_VNI_USE_DOMAIN */ int num_ctx_per_nic = 1; #else - int num_ctx_per_nic = MPIDI_OFI_global.num_vcis; + /* the actual needed number of vcis is not known yet. Use the CVAR. */ + int num_ctx_per_nic = MPIR_CVAR_CH4_NUM_VCIS + MPIR_CVAR_CH4_RESERVE_VCIS; #endif int max_by_prov = MPL_MIN(MPIDI_OFI_global.prov_use[nic]->domain_attr->tx_ctx_cnt, MPIDI_OFI_global.prov_use[nic]->domain_attr->rx_ctx_cnt); @@ -733,6 +733,11 @@ int MPIDI_OFI_init_local(int *tag_bits) mpi_errno = MPIDI_OFI_init_multi_nic(prov); MPIR_ERR_CHECK(mpi_errno); + for (int i = 0; i < MPIDI_OFI_global.num_nics_available; i++) { + /* if MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS, set rx_ctx_cnt and tx_ctx_cnt */ + set_sep_counters(i); + } + mpi_errno = update_global_limits(MPIDI_OFI_global.prov_use[0]); MPIR_ERR_CHECK(mpi_errno); @@ -802,14 +807,7 @@ int MPIDI_OFI_post_init(void) { int mpi_errno = MPI_SUCCESS; - if (MPIR_CVAR_DEBUG_SUMMARY && MPIR_Process.rank == 0) { - dump_dynamic_settings(); - } - - fn_exit: return mpi_errno; - fn_fail: - goto fn_exit; } /* static functions needed by finalize */ @@ -1599,13 +1597,6 @@ static void dump_global_settings(void) fprintf(stdout, "MPIDI_OFI_AM_HDR_POOL_CELL_SIZE: %d\n", (int) MPIDI_OFI_AM_HDR_POOL_CELL_SIZE); fprintf(stdout, "MPIDI_OFI_DEFAULT_SHORT_SEND_SIZE: %d\n", (int) MPIDI_OFI_DEFAULT_SHORT_SEND_SIZE); -} - -static void dump_dynamic_settings(void) -{ - fprintf(stdout, "==== OFI dynamic settings ====\n"); - fprintf(stdout, "num_vcis: %d\n", MPIDI_OFI_global.num_vcis); - fprintf(stdout, "num_nics: %d\n", MPIDI_OFI_global.num_nics); fprintf(stdout, "======================================\n"); } diff --git a/src/mpid/ch4/netmod/ofi/ofi_vci.c b/src/mpid/ch4/netmod/ofi/ofi_vci.c index fe73c21c374..9f52d12ae7a 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_vci.c +++ b/src/mpid/ch4/netmod/ofi/ofi_vci.c @@ -32,41 +32,18 @@ int MPIDI_OFI_vci_init(void) return MPI_SUCCESS; } -/* Address exchange within comm and setup multiple vcis */ +/* Enable multiple VCIs and perform address exchange within comm */ -static int addr_exchange_all_ctx(MPIR_Comm * comm); +static int check_num_nics(MPIR_Comm * comm, int *num_nics_out); +static int setup_additional_vcis(int num_vcis, int num_nics, int *actual_num_vcis); +static int addr_exchange_all_ctx(MPIR_Comm * comm, int *all_num_vcis); -int MPIDI_OFI_comm_set_vcis(MPIR_Comm * comm) -{ - int mpi_errno = MPI_SUCCESS; - - MPIR_Assert(comm == MPIR_Process.comm_world); - - /* Since we allow different process to have different num_vcis, we always need run exchange. */ - mpi_errno = addr_exchange_all_ctx(comm); - MPIR_ERR_CHECK(mpi_errno); - - for (int vci = 1; vci < MPIDI_OFI_global.num_vcis; vci++) { - MPIDI_OFI_am_init(vci); - MPIDI_OFI_am_post_recv(vci, 0); - } - - fn_exit: - return mpi_errno; - fn_fail: - goto fn_exit; -} - -/* MPIDI_OFI_init_vcis: locally create multiple vcis */ - -static int check_num_nics(void); -static int setup_additional_vcis(void); - -int MPIDI_OFI_init_vcis(int num_vcis, int *num_vcis_actual) +int MPIDI_OFI_comm_set_vcis(MPIR_Comm * comm, int num_vcis, int *all_num_vcis) { int mpi_errno = MPI_SUCCESS; /* NOTE: we only can create contexts for additional vcis and nics once (see notes at top) */ + /* TODO: relax it as long as num_{vcis,nics} <= MPIDI_OFI_global.num_{vcis,nics}, which we skip the creation part */ MPIR_Assert(MPIDI_OFI_global.num_vcis == 1 && MPIDI_OFI_global.num_nics == 1); /* Multiple vci without using domain require MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS */ @@ -74,18 +51,38 @@ int MPIDI_OFI_init_vcis(int num_vcis, int *num_vcis_actual) MPIR_Assert(num_vcis == 1 || MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS); #endif - MPIDI_OFI_global.num_nics = MPIDI_OFI_global.num_nics_available; - MPIDI_OFI_global.num_vcis = num_vcis; - /* All processes must have the same number of NICs */ - mpi_errno = check_num_nics(); + int num_nics; + mpi_errno = check_num_nics(comm, &num_nics); + MPIR_ERR_CHECK(mpi_errno); + + int actual_num_vcis; + /* NOOP unless num_vcis > 1 or num_nics > 1 */ + mpi_errno = setup_additional_vcis(num_vcis, num_nics, &actual_num_vcis); MPIR_ERR_CHECK(mpi_errno); - /* may update MPIDI_OFI_global.num_vcis */ - mpi_errno = setup_additional_vcis(); + MPIDI_OFI_global.num_nics = num_nics; + MPIDI_OFI_global.num_vcis = actual_num_vcis; + + if (MPIR_CVAR_DEBUG_SUMMARY && comm->rank == 0) { + printf("==== MPIDI_OFI_comm_set_vcis ====\n"); + printf("num_vcis: %d\n", MPIDI_OFI_global.num_vcis); + printf("num_nics: %d\n", MPIDI_OFI_global.num_nics); + printf("======================================\n"); + } + + mpi_errno = MPIR_Allgather_impl(&MPIDI_OFI_global.num_vcis, 1, MPI_INT, + all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); - *num_vcis_actual = MPIDI_OFI_global.num_vcis; + /* NOOP unless one of the all_num_vcis[i] > 1 or num_nics > 1 */ + mpi_errno = addr_exchange_all_ctx(comm, all_num_vcis); + MPIR_ERR_CHECK(mpi_errno); + + for (int vci = 1; vci < MPIDI_OFI_global.num_vcis; vci++) { + MPIDI_OFI_am_init(vci); + MPIDI_OFI_am_post_recv(vci, 0); + } fn_exit: return mpi_errno; @@ -93,72 +90,48 @@ int MPIDI_OFI_init_vcis(int num_vcis, int *num_vcis_actual) goto fn_exit; } -static int check_num_nics(void) +static int check_num_nics(MPIR_Comm * comm, int *num_nics_out) { int mpi_errno = MPI_SUCCESS; - int num_nics = MPIDI_OFI_global.num_nics; - int tmp_num_vcis = MPIDI_OFI_global.num_vcis; - int tmp_num_nics = MPIDI_OFI_global.num_nics; - - /* Set the number of NICs and VNIs to 1 temporarily to avoid problems during the collective */ - MPIDI_OFI_global.num_vcis = MPIDI_OFI_global.num_nics = 1; - - /* Confirm that all processes have the same number of NICs */ - mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, &num_nics, 1, MPI_INT, - MPI_MIN, MPIR_Process.comm_world, MPIR_ERR_NONE); - MPIDI_OFI_global.num_vcis = tmp_num_vcis; - MPIDI_OFI_global.num_nics = tmp_num_nics; + int tmp_num_nics = MPIDI_OFI_global.num_nics_available; + mpi_errno = MPIR_Allreduce_allcomm_auto(&tmp_num_nics, num_nics_out, 1, MPI_INT, + MPI_MIN, comm, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); /* If the user did not ask to fallback to fewer NICs, throw an error if someone is missing a * NIC. */ - if (tmp_num_nics != num_nics) { - if (MPIR_CVAR_OFI_USE_MIN_NICS) { - MPIDI_OFI_global.num_nics = num_nics; - - /* If we fall down to 1 nic, turn off multi-nic optimizations. */ - if (num_nics == 1) { - MPIDI_OFI_COMM(MPIR_Process.comm_world).enable_striping = 0; - } - } else { - MPIR_ERR_CHKANDJUMP(num_nics != MPIDI_OFI_global.num_nics, mpi_errno, MPI_ERR_OTHER, - "**ofi_num_nics"); + if (tmp_num_nics != *num_nics_out) { + if (!MPIR_CVAR_OFI_USE_MIN_NICS) { + MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**ofi_num_nics"); } } - /* FIXME: It would also be helpful to check that all of the NICs can communicate so we can fall - * back to other options if that is not the case (e.g., verbs are often configured with a - * different subnet for each "set" of nics). It's unknown how to write a good check for that. */ - - /* set rx_ctx_cnt and tx_ctx_cnt for the remaining (non-0) nics */ - for (int nic = 1; nic < MPIDI_OFI_global.num_nics; nic++) { - set_sep_counters(nic); - } - fn_exit: return mpi_errno; fn_fail: goto fn_exit; } -static int setup_additional_vcis(void) +static int setup_additional_vcis(int num_vcis, int num_nics, int *actual_num_vcis) { int mpi_errno = MPI_SUCCESS; - for (int vci = 0; vci < MPIDI_OFI_global.num_vcis; vci++) { - for (int nic = 0; nic < MPIDI_OFI_global.num_nics; nic++) { + *actual_num_vcis = num_vcis; + for (int vci = 0; vci < num_vcis; vci++) { + for (int nic = 0; nic < num_nics; nic++) { /* vci 0 nic 0 already created */ if (vci > 0 || nic > 0) { mpi_errno = MPIDI_OFI_create_vci_context(vci, nic); if (mpi_errno != MPI_SUCCESS) { /* running out of vcis, reduce MPIDI_OFI_global.num_vcis */ if (vci > 0) { - MPIDI_OFI_global.num_vcis = vci; + *actual_num_vcis = vci; /* FIXME: destroy already created vci_context */ mpi_errno = MPI_SUCCESS; goto fn_exit; } else { + /* fatal error if we can not enable all nics on vci 0 */ MPIR_ERR_CHECK(mpi_errno); } } @@ -199,7 +172,7 @@ static int setup_additional_vcis(void) continue; \ } -static int addr_exchange_all_ctx(MPIR_Comm * comm) +static int addr_exchange_all_ctx(MPIR_Comm * comm, int *all_num_vcis) { int mpi_errno = MPI_SUCCESS; MPIR_CHKLMEM_DECL(3); @@ -207,18 +180,6 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) int nprocs = comm->local_size; int myrank = comm->rank; - /* get global ranks and number of remote vcis */ - int *grank, *all_num_vcis; - MPIR_CHKLMEM_MALLOC(granks, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER); - MPIR_CHKLMEM_MALLOC(all_num_vcis, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER); - for (int i = 0; i < nprocs; i++) { - int avtid; - MPIDIU_comm_rank_to_pid(comm, i, &granks[i], &avtid); - MPIR_Assert(avtid == 0); - - all_num_vcis[i] = MPIDI_global.all_num_vcis[granks[i]]; - } - /* get the max_vcis (so we can do Allgather rather than Allgatherv) */ int max_vcis = 0; for (int i = 0; i < nprocs; i++) { @@ -227,7 +188,7 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) } } - int my_num_vcis = NUM_VCIS_FOR_RANK(rank); + int my_num_vcis = NUM_VCIS_FOR_RANK(comm->rank); int num_nics = MPIDI_OFI_global.num_nics; /* Assume num_nics are all equal */ @@ -250,7 +211,7 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) char *all_names; MPIR_CHKLMEM_MALLOC(all_names, char *, nprocs * my_len, mpi_errno, "all_names", MPL_MEM_ADDRESS); - char *my_names = all_names + rank * my_len; + char *my_names = all_names + myrank * my_len; /* put in my addrnames */ for (int nic = 0; nic < num_nics; nic++) { @@ -268,8 +229,7 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); MPIR_ERR_CHECK(mpi_errno); - /* Step 2: insert and store non-root nic/vci on the root context */ - int root_ctx_idx = MPIDI_OFI_get_ctx_index(0, 0); + /* insert and store non-root nic/vci on the root context */ for (int r = 0; r < nprocs; r++) { GET_AV_AND_ADDRNAMES(r); /* for each remote endpoints */ @@ -281,7 +241,7 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) for (int vci_local = 0; vci_local < my_num_vcis; vci_local++) { /* skip root */ if (nic == 0 && vci == 0 && nic_local == 0 && vci_local == 0) { - next; + continue; } int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic_local); DO_AV_INSERT(ctx_idx, nic, vci); @@ -294,7 +254,7 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm) } } MPIR_Assert(expect_addr != FI_ADDR_NOTAVAIL); - MPIDI_OFI_AV_ADDR(av, vci, nic) = addr; + MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic) = expect_addr; } } } diff --git a/src/mpid/ch4/src/ch4_vci.c b/src/mpid/ch4/src/ch4_vci.c index f10c700b2e9..01b8c7d6705 100644 --- a/src/mpid/ch4/src/ch4_vci.c +++ b/src/mpid/ch4/src/ch4_vci.c @@ -85,8 +85,8 @@ int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis) /* actually, only do it once for now */ MPIR_Assert(MPIDI_global.n_total_vcis == 1); - /* get global ranks */ - bool same_world = true; + /* TODO: check and assert that all processes are inside the comm world */ + int nprocs = comm->local_size; int *granks; MPIR_CHKLMEM_MALLOC(granks, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER, @@ -102,38 +102,33 @@ int MPIDI_Comm_set_vcis(MPIR_Comm * comm, int num_vcis) MPIR_Assert(MPIDI_global.all_num_vcis[granks[i]] == 0); } - /* set up local vcis */ - int num_vcis_actual; - mpi_errno = MPIDI_NM_init_vcis(MPIDI_global.n_total_vcis, &num_vcis_actual); - MPIR_ERR_CHECK(mpi_errno); - - MPIDI_global.n_total_vcis = num_vcis_actual; - - /* gather the number of remote vcis */ + /* setup vcis in netmod and shm */ + /* Netmod will decide that actual number of vcis (and nics) and gather from all ranks in all_num_vcis */ int *all_num_vcis; MPIR_CHKLMEM_MALLOC(all_num_vcis, int *, nprocs * sizeof(int), mpi_errno, MPL_MEM_OTHER, MPL_MEM_OTHER); - mpi_errno = MPIR_Allgather_impl(num_vcis_actual, 1, MPI_INT, - all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); + mpi_errno = MPIDI_NM_comm_set_vcis(comm, num_vcis, all_num_vcis); MPIR_ERR_CHECK(mpi_errno); - for (int i = 0; i < nprocs; i++) { - MPIDI_global.all_num_vcis[granks[i]] = all_num_vcis[i]; - } - - /* setup vcis in netmod and shm */ - mpi_errno = MPIDI_NM_comm_set_vcis(comm); - MPIR_ERR_CHECK(mpi_errno); + int n_total_vcis = all_num_vcis[comm->rank]; #ifndef MPIDI_CH4_DIRECT_NETMOD - mpi_errno = MPIDI_SHM_comm_set_vcis(comm, MPIDI_global.n_total_vcis); + mpi_errno = MPIDI_SHM_comm_set_vcis(comm, n_total_vcis); MPIR_ERR_CHECK(mpi_errno); #endif - for (int vci = 1; vci < MPIDI_global.n_total_vcis; vci++) { + for (int vci = 1; vci < n_total_vcis; vci++) { mpi_errno = MPIDI_init_per_vci(vci); MPIR_ERR_CHECK(mpi_errno); } + /* update global vci settings */ + MPIDI_global.n_total_vcis = all_num_vcis[comm->rank]; + MPIDI_global.n_vcis = MPL_MIN(MPIR_CVAR_CH4_NUM_VCIS, MPIDI_global.n_total_vcis); + MPIDI_global.n_reserved_vcis = MPIDI_global.n_total_vcis - MPIDI_global.n_vcis; + for (int i = 0; i < nprocs; i++) { + MPIDI_global.all_num_vcis[granks[i]] = all_num_vcis[i]; + } + /* enable multiple vcis */ comm->vcis_enabled = true; From 8d2e1a9cdbac437d4e904d01974bcaf2492db7ce Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 5 Jan 2025 16:20:00 -0600 Subject: [PATCH 16/25] ch4/ucx: move multivci code to ucx_vci.c --- src/mpid/ch4/netmod/ucx/ucx_impl.h | 3 + src/mpid/ch4/netmod/ucx/ucx_init.c | 124 +---------------------------- src/mpid/ch4/netmod/ucx/ucx_vci.c | 112 +++++++++++++++++++++++++- 3 files changed, 118 insertions(+), 121 deletions(-) diff --git a/src/mpid/ch4/netmod/ucx/ucx_impl.h b/src/mpid/ch4/netmod/ucx/ucx_impl.h index c07a54ab1bc..b6d9253ee81 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_impl.h +++ b/src/mpid/ch4/netmod/ucx/ucx_impl.h @@ -129,6 +129,9 @@ MPL_STATIC_INLINE_PREFIX bool MPIDI_UCX_is_reachable_target(int rank, MPIR_Win * #define MPIDI_UCX_WIN_AV_TO_EP(av, vci, vci_target) MPIDI_UCX_AV((av)).dest[vci][vci_target] +int MPIDI_UCX_init_world(void); +int MPIDI_UCX_init_worker(int vci); + /* am handler for message sent by ucp_am_send_nb */ ucs_status_t MPIDI_UCX_am_handler(void *arg, void *data, size_t length, ucp_ep_h reply_ep, unsigned flags); diff --git a/src/mpid/ch4/netmod/ucx/ucx_init.c b/src/mpid/ch4/netmod/ucx/ucx_init.c index fd7698bbebf..e43ae3af4e1 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_init.c +++ b/src/mpid/ch4/netmod/ucx/ucx_init.c @@ -28,13 +28,7 @@ static void request_init_callback(void *request) } -static void init_num_vcis(void) -{ - /* TODO: check capabilities, abort if we can't support the requested number of vcis. */ - MPIDI_UCX_global.num_vcis = MPIDI_global.n_total_vcis; -} - -static int init_worker(int vci) +int MPIDI_UCX_init_worker(int vci) { int mpi_errno = MPI_SUCCESS; @@ -147,63 +141,6 @@ static int initial_address_exchange(void) goto fn_exit; } -static int all_vcis_address_exchange(void) -{ - int mpi_errno = MPI_SUCCESS; - - int size = MPIR_Process.size; - int rank = MPIR_Process.rank; - int num_vcis = MPIDI_UCX_global.num_vcis; - - /* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */ - size_t name_len = MPID_MAX_BC_SIZE; - - int my_len = num_vcis * name_len; - char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS); - MPIR_Assert(all_names); - - char *my_names = all_names + rank * my_len; - - /* put in my addrnames */ - for (int i = 0; i < num_vcis; i++) { - char *vci_addrname = my_names + i * name_len; - memcpy(vci_addrname, MPIDI_UCX_global.ctx[i].if_address, - MPIDI_UCX_global.ctx[i].addrname_len); - } - /* Allgather */ - MPIR_Comm *comm = MPIR_Process.comm_world; - mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_BYTE, - all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); - MPIR_ERR_CHECK(mpi_errno); - - /* insert the addresses */ - ucp_ep_params_t ep_params; - for (int vci_local = 0; vci_local < num_vcis; vci_local++) { - for (int r = 0; r < size; r++) { - MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r)); - for (int vci_remote = 0; vci_remote < num_vcis; vci_remote++) { - if (vci_local == 0 && vci_remote == 0) { - /* don't overwrite existing addr, or bad things will happen */ - continue; - } - int idx = r * num_vcis + vci_remote; - ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = (ucp_address_t *) (all_names + idx * name_len); - - ucs_status_t ucx_status; - ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vci_local].worker, - &ep_params, &av->dest[vci_local][vci_remote]); - MPIDI_UCX_CHK_STATUS(ucx_status); - } - } - } - fn_exit: - MPL_free(all_names); - return mpi_errno; - fn_fail: - goto fn_exit; -} - int MPIDI_UCX_init_local(int *tag_bits) { int mpi_errno = MPI_SUCCESS; @@ -213,7 +150,7 @@ int MPIDI_UCX_init_local(int *tag_bits) uint64_t features = 0; ucp_params_t ucp_params; - init_num_vcis(); + MPIDI_UCX_global.num_vcis = 1; /* unable to support extended context id in current match bit configuration */ MPL_COMPILE_TIME_ASSERT(MPIR_CONTEXT_ID_BITS <= MPIDI_UCX_CONTEXT_ID_BITS); @@ -233,7 +170,7 @@ int MPIDI_UCX_init_local(int *tag_bits) UCP_PARAM_FIELD_REQUEST_SIZE | UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT; - if (MPIDI_UCX_global.num_vcis > 1) { + if (MPICH_IS_THREADED) { ucp_params.mt_workers_shared = 1; ucp_params.field_mask |= UCP_PARAM_FIELD_MT_WORKERS_SHARED; } @@ -271,7 +208,7 @@ int MPIDI_UCX_init_world(void) int mpi_errno = MPI_SUCCESS; /* initialize worker for vci 0 */ - mpi_errno = init_worker(0); + mpi_errno = MPIDI_UCX_init_worker(0); MPIR_ERR_CHECK(mpi_errno); mpi_errno = initial_address_exchange(); @@ -286,64 +223,11 @@ int MPIDI_UCX_init_world(void) goto fn_exit; } -int MPIDI_UCX_init_vcis(int num_vcis, int *num_vcis_actual) -{ - int mpi_errno = MPI_SUCCESS; - *num_vcis_actual = num_vcis; - return mpi_errno; -} - -/* static functions for MPIDI_UCX_post_init */ -static void flush_cb(void *request, ucs_status_t status) -{ -} - -static void flush_all(void) -{ - void *reqs[MPIDI_CH4_MAX_VCIS]; - for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) { - reqs[vci] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vci].worker, 0, &flush_cb); - } - for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) { - if (reqs[vci] == NULL) { - continue; - } else if (UCS_PTR_IS_ERR(reqs[vci])) { - continue; - } else { - ucs_status_t status; - do { - MPID_Progress_test(NULL); - status = ucp_request_check_status(reqs[vci]); - } while (status == UCS_INPROGRESS); - ucp_request_release(reqs[vci]); - } - } -} - int MPIDI_UCX_post_init(void) { int mpi_errno = MPI_SUCCESS; - if (MPIDI_UCX_global.num_vcis == 1) { - goto fn_exit; - } - - for (int i = 1; i < MPIDI_UCX_global.num_vcis; i++) { - mpi_errno = init_worker(i); - MPIR_ERR_CHECK(mpi_errno); - } - mpi_errno = all_vcis_address_exchange(); - MPIR_ERR_CHECK(mpi_errno); - - /* Flush all pending wireup operations or it may interfere with RMA flush_ops count. - * Since this require progress in non-zero vcis, we need switch on is_initialized. */ - MPIDI_global.is_initialized = 1; - flush_all(); - - fn_exit: return mpi_errno; - fn_fail: - goto fn_exit; } int MPIDI_UCX_mpi_finalize_hook(void) diff --git a/src/mpid/ch4/netmod/ucx/ucx_vci.c b/src/mpid/ch4/netmod/ucx/ucx_vci.c index 2153d73c4f1..6427af60487 100644 --- a/src/mpid/ch4/netmod/ucx/ucx_vci.c +++ b/src/mpid/ch4/netmod/ucx/ucx_vci.c @@ -5,9 +5,119 @@ #include "mpidimpl.h" #include "ucx_impl.h" +#include "mpidu_bc.h" -int MPIDI_UCX_comm_set_vcis(MPIR_Comm * comm, int num_vcis) +static int all_vcis_address_exchange(void); +static void flush_all(void); + +int MPIDI_UCX_comm_set_vcis(MPIR_Comm * comm, int num_vcis, int *all_num_vcis) { int mpi_errno = MPI_SUCCESS; + + MPIR_Assert(MPIDI_UCX_global.num_vcis == 1); + + MPIDI_UCX_global.num_vcis = num_vcis; + + mpi_errno = MPIR_Allgather_impl(&MPIDI_UCX_global.num_vcis, 1, MPI_INT, + all_num_vcis, 1, MPI_INT, comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + + for (int i = 1; i < MPIDI_UCX_global.num_vcis; i++) { + mpi_errno = MPIDI_UCX_init_worker(i); + MPIR_ERR_CHECK(mpi_errno); + } + mpi_errno = all_vcis_address_exchange(); + MPIR_ERR_CHECK(mpi_errno); + + /* Flush all pending wireup operations or it may interfere with RMA flush_ops count. + * Since this require progress in non-zero vcis, we need switch on is_initialized. */ + flush_all(); + + fn_exit: return mpi_errno; + fn_fail: + goto fn_exit; +} + +static int all_vcis_address_exchange(void) +{ + int mpi_errno = MPI_SUCCESS; + + int size = MPIR_Process.size; + int rank = MPIR_Process.rank; + int num_vcis = MPIDI_UCX_global.num_vcis; + + /* ucx address lengths are non-uniform, use MPID_MAX_BC_SIZE */ + size_t name_len = MPID_MAX_BC_SIZE; + + int my_len = num_vcis * name_len; + char *all_names = MPL_malloc(size * my_len, MPL_MEM_ADDRESS); + MPIR_Assert(all_names); + + char *my_names = all_names + rank * my_len; + + /* put in my addrnames */ + for (int i = 0; i < num_vcis; i++) { + char *vci_addrname = my_names + i * name_len; + memcpy(vci_addrname, MPIDI_UCX_global.ctx[i].if_address, + MPIDI_UCX_global.ctx[i].addrname_len); + } + /* Allgather */ + MPIR_Comm *comm = MPIR_Process.comm_world; + mpi_errno = MPIR_Allgather_allcomm_auto(MPI_IN_PLACE, 0, MPI_BYTE, + all_names, my_len, MPI_BYTE, comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + + /* insert the addresses */ + ucp_ep_params_t ep_params; + for (int vci_local = 0; vci_local < num_vcis; vci_local++) { + for (int r = 0; r < size; r++) { + MPIDI_UCX_addr_t *av = &MPIDI_UCX_AV(&MPIDIU_get_av(0, r)); + for (int vci_remote = 0; vci_remote < num_vcis; vci_remote++) { + if (vci_local == 0 && vci_remote == 0) { + /* don't overwrite existing addr, or bad things will happen */ + continue; + } + int idx = r * num_vcis + vci_remote; + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + ep_params.address = (ucp_address_t *) (all_names + idx * name_len); + + ucs_status_t ucx_status; + ucx_status = ucp_ep_create(MPIDI_UCX_global.ctx[vci_local].worker, + &ep_params, &av->dest[vci_local][vci_remote]); + MPIDI_UCX_CHK_STATUS(ucx_status); + } + } + } + fn_exit: + MPL_free(all_names); + return mpi_errno; + fn_fail: + goto fn_exit; +} + +static void flush_cb(void *request, ucs_status_t status) +{ +} + +static void flush_all(void) +{ + void *reqs[MPIDI_CH4_MAX_VCIS]; + for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) { + reqs[vci] = ucp_worker_flush_nb(MPIDI_UCX_global.ctx[vci].worker, 0, &flush_cb); + } + for (int vci = 0; vci < MPIDI_UCX_global.num_vcis; vci++) { + if (reqs[vci] == NULL) { + continue; + } else if (UCS_PTR_IS_ERR(reqs[vci])) { + continue; + } else { + ucs_status_t status; + do { + MPID_Progress_test(NULL); + status = ucp_request_check_status(reqs[vci]); + } while (status == UCS_INPROGRESS); + ucp_request_release(reqs[vci]); + } + } } From 53f1baa392a7620d22ea570c79aa055fc2ee8641 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 6 Jan 2025 16:07:04 -0600 Subject: [PATCH 17/25] ch4/api: remove netmod api MPIDI_NM_init_vcis This has been superseded by MPIDI_NM_comm_set_vcis. --- src/mpid/ch4/ch4_api.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mpid/ch4/ch4_api.txt b/src/mpid/ch4/ch4_api.txt index 24026e01748..9965f0831e3 100644 --- a/src/mpid/ch4/ch4_api.txt +++ b/src/mpid/ch4/ch4_api.txt @@ -40,8 +40,6 @@ Non Native API: mpi_finalize_hook : int NM : void SHM : void - init_vcis: int - NM : num_vcis, num_vcis_actual post_init : int NM : void SHM : void From e68757fcfc351b27c404d7f825251ffb6ebd7409 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 29 Dec 2024 19:50:43 -0600 Subject: [PATCH 18/25] ch4/shm: refactor iqueue shmem allocation Consolidate the shmem allocations in iqueue to 2 slabs. One root slab that is initialized at world_init. The other all_slab for per-vci transport, initialized at the time of init vcis. The goal is to eventually allow more flexible shm creation, potentially allow init within a non-world communicator. --- .../ch4/shm/posix/eager/iqueue/iqueue_init.c | 75 ++++++++++++++----- .../ch4/shm/posix/eager/iqueue/iqueue_types.h | 6 ++ src/mpid/common/genq/mpidu_genq_shmem_pool.c | 27 ++++--- src/mpid/common/genq/mpidu_genq_shmem_pool.h | 7 +- 4 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c index b4ec1e546e4..df7e489d0b6 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c +++ b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c @@ -36,7 +36,7 @@ MPIDI_POSIX_eager_iqueue_global_t MPIDI_POSIX_eager_iqueue_global; -static int init_transport(int vci_src, int vci_dst) +static int init_transport(void *slab, int vci_src, int vci_dst) { int mpi_errno = MPI_SUCCESS; @@ -51,28 +51,24 @@ static int init_transport(int vci_src, int vci_dst) MPIDU_GENQ_SHMEM_QUEUE_TYPE__MPSC, MPIDU_GENQ_SHMEM_QUEUE_TYPE__MPMC }; - mpi_errno = MPIDU_genq_shmem_pool_create(transport->size_of_cell, transport->num_cells, + mpi_errno = MPIDU_genq_shmem_pool_create(slab, MPIDI_POSIX_eager_iqueue_global.slab_size, + transport->size_of_cell, transport->num_cells, MPIR_Process.local_size, MPIR_Process.local_rank, 2, queue_types, &transport->cell_pool); MPIR_ERR_CHECK(mpi_errno); } else { int queue_type = MPIDU_GENQ_SHMEM_QUEUE_TYPE__MPSC; - mpi_errno = MPIDU_genq_shmem_pool_create(transport->size_of_cell, transport->num_cells, + mpi_errno = MPIDU_genq_shmem_pool_create(slab, MPIDI_POSIX_eager_iqueue_global.slab_size, + transport->size_of_cell, transport->num_cells, MPIR_Process.local_size, MPIR_Process.local_rank, 1, &queue_type, &transport->cell_pool); MPIR_ERR_CHECK(mpi_errno); } - size_t size_of_terminals; - /* Create one terminal for each process with which we will be able to communicate. */ - size_of_terminals = (size_t) MPIR_Process.local_size * sizeof(MPIDU_genq_shmem_queue_u); - - /* Create the shared memory regions that will be used for the iqueue cells and terminals. */ - mpi_errno = MPIDU_Init_shm_alloc(size_of_terminals, (void *) &transport->terminals); - MPIR_ERR_CHECK(mpi_errno); - + transport->terminals = (void *) ((char *) slab + + MPIDI_POSIX_eager_iqueue_global.terminal_offset); transport->my_terminal = &transport->terminals[MPIR_Process.local_rank]; mpi_errno = MPIDU_genq_shmem_queue_init(transport->my_terminal, @@ -98,7 +94,27 @@ int MPIDI_POSIX_iqueue_init(int rank, int size) /* Init vci 0. Communication on vci 0 is enabled afterwards. */ MPIDI_POSIX_eager_iqueue_global.max_vcis = 1; - mpi_errno = init_transport(0, 0); + /* calculate needed shmem size per (vci_src, vci_dst) */ + int num_free_queue = MPIR_CVAR_CH4_SHM_POSIX_TOPO_ENABLE ? 2 : 1; + int cell_size = MPIR_CVAR_CH4_SHM_POSIX_IQUEUE_CELL_SIZE; + int num_cells = MPIR_CVAR_CH4_SHM_POSIX_IQUEUE_NUM_CELLS; + int nprocs = MPIR_Process.local_size; + + int pool_size = MPIDU_genq_shmem_pool_size(cell_size, num_cells, nprocs, num_free_queue); + int terminal_size = num_proc * sizeof(MPIDU_genq_shmem_queue_u); + + int slab_size = pool_size + terminal_size; + + /* Create the shared memory regions that will be used for the iqueue cells and terminals. */ + void *slab; + mpi_errno = MPIDU_Init_shm_alloc(slab_size, (void *) &slab); + MPIR_ERR_CHECK(mpi_errno); + + MPIDI_POSIX_eager_iqueue_global.slab_size = slab_size; + MPIDI_POSIX_eager_iqueue_global.terminal_offset = pool_size; + MPIDI_POSIX_eager_iqueue_global.root_slab = slab; + + mpi_errno = init_transport(slab, 0, 0); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIDU_Init_shm_barrier(); @@ -127,18 +143,27 @@ int MPIDI_POSIX_iqueue_post_init(void) max_vcis = num; } } + MPIDI_POSIX_eager_iqueue_global.max_vcis = max_vcis; MPIDU_Init_shm_barrier(); - MPIDI_POSIX_eager_iqueue_global.max_vcis = max_vcis; + int slab_size = MPIDI_POSIX_eager_iqueue_global.slab_size * max_vcis * max_vcis; + /* Create the shared memory regions for all vcis */ + /* TODO: do shm alloc in a comm */ + void *slab; + mpi_errno = MPIDU_Init_shm_alloc(slab_size, (void *) &slab); + MPIR_ERR_CHECK(mpi_errno); + + MPIDI_POSIX_eager_iqueue_global.all_slab = slab; for (int vci_src = 0; vci_src < max_vcis; vci_src++) { for (int vci_dst = 0; vci_dst < max_vcis; vci_dst++) { if (vci_src == 0 && vci_dst == 0) { continue; } - mpi_errno = init_transport(vci_src, vci_dst); + void *p = (char *) slab + (vci_src * max_vcis + vci_dst) * + MPIDI_POSIX_eager_iqueue_global.slab_size; + mpi_errno = init_transport(p, vci_src, vci_dst); MPIR_ERR_CHECK(mpi_errno); - } } @@ -157,18 +182,34 @@ int MPIDI_POSIX_iqueue_finalize(void) MPIR_FUNC_ENTER; + if (MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab) { + MPIDI_POSIX_eager_iqueue_transport_t *transport; + transport = MPIDI_POSIX_eager_iqueue_get_transport(vci_src, vci_dst); + + mpi_errno = MPIDU_genq_shmem_pool_destroy(transport->cell_pool); + MPIR_ERR_CHECK(mpi_errno); + + mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab); + MPIR_ERR_CHECK(mpi_errno); + MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab = NULL; + } + + if (!MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab) { + goto fn_exit; + } int max_vcis = MPIDI_POSIX_eager_iqueue_global.max_vcis; for (int vci_src = 0; vci_src < max_vcis; vci_src++) { for (int vci_dst = 0; vci_dst < max_vcis; vci_dst++) { MPIDI_POSIX_eager_iqueue_transport_t *transport; transport = MPIDI_POSIX_eager_iqueue_get_transport(vci_src, vci_dst); - mpi_errno = MPIDU_Init_shm_free(transport->terminals); - MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIDU_genq_shmem_pool_destroy(transport->cell_pool); MPIR_ERR_CHECK(mpi_errno); } } + mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab); + MPIR_ERR_CHECK(mpi_errno); + MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab = NULL; fn_exit: MPIR_FUNC_EXIT; diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h index 9af9ef29812..105eb82a4da 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h +++ b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h @@ -35,6 +35,12 @@ typedef struct MPIDI_POSIX_eager_iqueue_transport { typedef struct MPIDI_POSIX_eager_iqueue_global { int max_vcis; + /* sizes for shmem slabs */ + int slab_size; + int terminal_offset; + /* shmem slabs */ + void *root_slab; + void *all_slab; /* 2d array indexed with [src_vci][dst_vci] */ MPIDI_POSIX_eager_iqueue_transport_t transports[MPIDI_CH4_MAX_VCIS][MPIDI_CH4_MAX_VCIS]; } MPIDI_POSIX_eager_iqueue_global_t; diff --git a/src/mpid/common/genq/mpidu_genq_shmem_pool.c b/src/mpid/common/genq/mpidu_genq_shmem_pool.c index ee7adfdb500..3ee2c9fad9e 100644 --- a/src/mpid/common/genq/mpidu_genq_shmem_pool.c +++ b/src/mpid/common/genq/mpidu_genq_shmem_pool.c @@ -95,13 +95,23 @@ static int cell_block_alloc(MPIDU_genqi_shmem_pool_s * pool, int rank) goto fn_exit; } -int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_queue, - uintptr_t num_proc, int rank, uintptr_t num_free_queue, +int MPIDU_genq_shmem_pool_size(int cell_size, int cells_per_free_queue, + int num_proc, int num_free_queue) +{ + int aligned_cell_size = RESIZE_TO_MAX_ALIGN(cell_size); + int cell_alloc_size = sizeof(MPIDU_genqi_shmem_cell_header_s) + aligned_cell_size; + int total_cells_size = num_proc * num_free_queue * cells_per_free_queue * cell_alloc_size; + int free_queue_size = num_proc * num_free_queue * sizeof(MPIDU_genq_shmem_queue_u); + return total_cells_size + free_queue_size; +} + +int MPIDU_genq_shmem_pool_create(void *slab, int slab_size, + int cell_size, int cells_per_free_queue, + int num_proc, int rank, int num_free_queue, int *queue_types, MPIDU_genq_shmem_pool_t * pool) { int rc = MPI_SUCCESS; MPIDU_genqi_shmem_pool_s *pool_obj; - uintptr_t slab_size = 0; uintptr_t aligned_cell_size = 0; MPIR_FUNC_ENTER; @@ -117,15 +127,13 @@ int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_q pool_obj->num_free_queue = num_free_queue; pool_obj->rank = rank; pool_obj->gpu_registered = false; + pool_obj->slab = slab; /* the global_block_index is at the end of the slab to avoid extra need of alignment */ int total_cells_size = num_proc * num_free_queue * cells_per_free_queue * pool_obj->cell_alloc_size; int free_queue_size = num_proc * num_free_queue * sizeof(MPIDU_genq_shmem_queue_u); - slab_size = total_cells_size + free_queue_size; - - rc = MPIDU_Init_shm_alloc(slab_size, &pool_obj->slab); - MPIR_ERR_CHECK(rc); + MPIR_Assertp(slab_size >= total_cells_size + free_queue_size); pool_obj->cell_header_base = (MPIDU_genqi_shmem_cell_header_s *) pool_obj->slab; pool_obj->free_queues = @@ -140,16 +148,12 @@ int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_q rc = cell_block_alloc(pool_obj, rank); MPIR_ERR_CHECK(rc); - rc = MPIDU_Init_shm_barrier(); - MPIR_ERR_CHECK(rc); - *pool = (MPIDU_genq_shmem_pool_t) pool_obj; fn_exit: MPIR_FUNC_EXIT; return rc; fn_fail: - MPIDU_Init_shm_free(pool_obj->slab); MPL_free(pool_obj); goto fn_exit; } @@ -166,7 +170,6 @@ int MPIDU_genq_shmem_pool_destroy(MPIDU_genq_shmem_pool_t pool) if (pool_obj->gpu_registered) { MPIR_gpu_unregister_host(pool_obj->slab); } - MPIDU_Init_shm_free(pool_obj->slab); /* free self */ MPL_free(pool_obj); diff --git a/src/mpid/common/genq/mpidu_genq_shmem_pool.h b/src/mpid/common/genq/mpidu_genq_shmem_pool.h index cb43ac634e7..e3de456ba1c 100644 --- a/src/mpid/common/genq/mpidu_genq_shmem_pool.h +++ b/src/mpid/common/genq/mpidu_genq_shmem_pool.h @@ -14,8 +14,11 @@ #include #include -int MPIDU_genq_shmem_pool_create(uintptr_t cell_size, uintptr_t cells_per_free_queue, - uintptr_t num_proc, int rank, uintptr_t num_free_queue, +int MPIDU_genq_shmem_pool_size(int cell_size, int cells_per_free_queue, + int num_proc, int num_free_queue); +int MPIDU_genq_shmem_pool_create(void *slab, int slab_size, + int cell_size, int cells_per_free_queue, + int num_proc, int rank, int num_free_queue, int *queue_types, MPIDU_genq_shmem_pool_t * pool); int MPIDU_genq_shmem_pool_destroy(MPIDU_genq_shmem_pool_t pool); int MPIDU_genqi_shmem_pool_register(MPIDU_genqi_shmem_pool_s * pool_obj); From e5c96ec6797976bb929b470c53012e3f151b66d5 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Sun, 29 Dec 2024 20:29:38 -0600 Subject: [PATCH 19/25] ch4/shm: move vci setup from post_init to set_vcis Transition from world init to per-comm vci init. --- .../ch4/shm/posix/eager/include/posix_eager.h | 3 +++ .../ch4/shm/posix/eager/iqueue/func_table.c | 1 + .../ch4/shm/posix/eager/iqueue/iqueue_init.c | 23 ++++++++-------- .../shm/posix/eager/iqueue/iqueue_noinline.h | 2 ++ .../shm/posix/eager/src/posix_eager_impl.c | 5 ++++ src/mpid/ch4/shm/posix/eager/stub/stub_init.c | 6 +++++ .../ch4/shm/posix/eager/stub/stub_noinline.h | 2 ++ src/mpid/ch4/shm/posix/posix_impl.h | 1 + src/mpid/ch4/shm/posix/posix_init.c | 13 ++------- src/mpid/ch4/shm/posix/posix_types.h | 2 +- src/mpid/ch4/shm/posix/posix_vci.c | 27 +++++++++++++++++++ 11 files changed, 61 insertions(+), 24 deletions(-) diff --git a/src/mpid/ch4/shm/posix/eager/include/posix_eager.h b/src/mpid/ch4/shm/posix/eager/include/posix_eager.h index 21d31b3024a..e1970c38b66 100644 --- a/src/mpid/ch4/shm/posix/eager/include/posix_eager.h +++ b/src/mpid/ch4/shm/posix/eager/include/posix_eager.h @@ -12,6 +12,7 @@ typedef int (*MPIDI_POSIX_eager_init_t) (int rank, int size); typedef int (*MPIDI_POSIX_eager_post_init_t) (void); +typedef int (*MPIDI_POSIX_eager_set_vcis_t) (MPIR_Comm * comm); typedef int (*MPIDI_POSIX_eager_finalize_t) (void); typedef int (*MPIDI_POSIX_eager_send_t) (int grank, MPIDI_POSIX_am_header_t * msg_hdr, @@ -37,6 +38,7 @@ typedef size_t(*MPIDI_POSIX_eager_buf_limit_t) (void); typedef struct { MPIDI_POSIX_eager_init_t init; MPIDI_POSIX_eager_post_init_t post_init; + MPIDI_POSIX_eager_set_vcis_t set_vcis; MPIDI_POSIX_eager_finalize_t finalize; MPIDI_POSIX_eager_send_t send; @@ -59,6 +61,7 @@ extern char MPIDI_POSIX_eager_strings[][MPIDI_MAX_POSIX_EAGER_STRING_LEN]; int MPIDI_POSIX_eager_init(int rank, int size); int MPIDI_POSIX_eager_post_init(void); +int MPIDI_POSIX_eager_set_vcis(MPIR_Comm * comm); int MPIDI_POSIX_eager_finalize(void); MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_eager_send(int grank, MPIDI_POSIX_am_header_t * msg_hdr, diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/func_table.c b/src/mpid/ch4/shm/posix/eager/iqueue/func_table.c index a4710f8c0eb..bfaf5326eda 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/func_table.c +++ b/src/mpid/ch4/shm/posix/eager/iqueue/func_table.c @@ -15,6 +15,7 @@ MPIDI_POSIX_eager_funcs_t MPIDI_POSIX_eager_iqueue_funcs = { MPIDI_POSIX_iqueue_init, MPIDI_POSIX_iqueue_post_init, + MPIDI_POSIX_iqueue_set_vcis, MPIDI_POSIX_iqueue_finalize, MPIDI_POSIX_eager_send, diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c index df7e489d0b6..415b7ac970f 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c +++ b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c @@ -130,19 +130,18 @@ int MPIDI_POSIX_iqueue_init(int rank, int size) int MPIDI_POSIX_iqueue_post_init(void) { int mpi_errno = MPI_SUCCESS; + return mpi_errno; +} - /* gather max_vcis */ - int max_vcis = 0; - max_vcis = 0; - MPIDU_Init_shm_put(&MPIDI_POSIX_global.num_vcis, sizeof(int)); - MPIDU_Init_shm_barrier(); - for (int i = 0; i < MPIR_Process.local_size; i++) { - int num; - MPIDU_Init_shm_get(i, sizeof(int), &num); - if (max_vcis < num) { - max_vcis = num; - } - } +int MPIDI_POSIX_iqueue_set_vcis(MPIR_Comm * comm) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + MPIR_Assert(comm == MPIR_Process.comm_world); /* TODO: relax this */ + MPIR_Assert(MPIDI_POSIX_eager_iqueue_global.all_slab == NULL); + + int max_vcis = MPIDI_POSIX_global.num_vcis; MPIDI_POSIX_eager_iqueue_global.max_vcis = max_vcis; MPIDU_Init_shm_barrier(); diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h index e70ee9dd51c..02dd65f4773 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h +++ b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h @@ -11,11 +11,13 @@ int MPIDI_POSIX_iqueue_init(int rank, int size); int MPIDI_POSIX_iqueue_post_init(void); +int MPIDI_POSIX_iqueue_set_vcis(MPIR_Comm * comm); int MPIDI_POSIX_iqueue_finalize(void); #ifdef POSIX_EAGER_INLINE #define MPIDI_POSIX_eager_init MPIDI_POSIX_iqueue_init #define MPIDI_POSIX_eager_post_init MPIDI_POSIX_iqueue_post_init +#define MPIDI_POSIX_eager_set_vcis MPIDI_POSIX_iqueue_set_vcis #define MPIDI_POSIX_eager_finalize MPIDI_POSIX_iqueue_finalize #endif diff --git a/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c b/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c index 62cdc2cbc94..fe1f60efcc5 100644 --- a/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c +++ b/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c @@ -19,6 +19,11 @@ int MPIDI_POSIX_eager_post_init(void) return MPIDI_POSIX_eager_func->post_init(); } +int MPIDI_POSIX_eager_set_vcis(MPIR_Comm * comm) +{ + return MPIDI_POSIX_eager_func->set_vcis(comm); +} + int MPIDI_POSIX_eager_finalize(void) { return MPIDI_POSIX_eager_func->finalize(); diff --git a/src/mpid/ch4/shm/posix/eager/stub/stub_init.c b/src/mpid/ch4/shm/posix/eager/stub/stub_init.c index 7be6d102184..695cae572b2 100644 --- a/src/mpid/ch4/shm/posix/eager/stub/stub_init.c +++ b/src/mpid/ch4/shm/posix/eager/stub/stub_init.c @@ -20,6 +20,12 @@ int MPIDI_POSIX_stub_post_init(void) return MPI_SUCCESS; } +int MPIDI_POSIX_stub_set_vcis(MPIR_Comm * comm) +{ + MPIR_Assert(0); + return MPI_SUCCESS; +} + int MPIDI_POSIX_stub_finalize() { MPIR_Assert(0); diff --git a/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h b/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h index 371bd1078c0..81c51f6d2ee 100644 --- a/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h +++ b/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h @@ -10,11 +10,13 @@ int MPIDI_POSIX_stub_init(int rank, int size); int MPIDI_POSIX_stub_post_init(void); +int MPIDI_POSIX_stub_set_vcis(MPIR_Comm * comm); int MPIDI_POSIX_stub_finalize(void); #ifdef POSIX_EAGER_INLINE #define MPIDI_POSIX_eager_init MPIDI_POSIX_stub_init #define MPIDI_POSIX_eager_post_init MPIDI_POSIX_stub_post_init +#define MPIDI_POSIX_eager_set_vcis MPIDI_POSIX_stub_set_vcis #define MPIDI_POSIX_eager_finalize MPIDI_POSIX_stub_finalize #endif diff --git a/src/mpid/ch4/shm/posix/posix_impl.h b/src/mpid/ch4/shm/posix/posix_impl.h index ad440781ee2..ec8f8d23e4e 100644 --- a/src/mpid/ch4/shm/posix/posix_impl.h +++ b/src/mpid/ch4/shm/posix/posix_impl.h @@ -28,6 +28,7 @@ } \ } while (0) +int MPIDI_POSIX_init_vci(int vci); void MPIDI_POSIX_delay_shm_mutex_destroy(int rank, MPL_proc_mutex_t * shm_mutex_ptr); #endif /* POSIX_IMPL_H_INCLUDED */ diff --git a/src/mpid/ch4/shm/posix/posix_init.c b/src/mpid/ch4/shm/posix/posix_init.c index ecb64c7d2e6..a9550125f9a 100644 --- a/src/mpid/ch4/shm/posix/posix_init.c +++ b/src/mpid/ch4/shm/posix/posix_init.c @@ -149,7 +149,7 @@ static void *create_container(struct json_object *obj) return cnt; } -static int init_vci(int vci) +int MPIDI_POSIX_init_vci(int vci) { int mpi_errno = MPI_SUCCESS; @@ -255,7 +255,7 @@ int MPIDI_POSIX_init_world(void) MPIDI_POSIX_global.num_vcis = 1; - mpi_errno = init_vci(0); + mpi_errno = MPIDI_POSIX_init_vci(0); MPIR_ERR_CHECK(mpi_errno); mpi_errno = MPIDI_POSIX_eager_init(rank, size); @@ -279,15 +279,6 @@ int MPIDI_POSIX_post_init(void) int mpi_errno = MPI_SUCCESS; MPIDI_POSIX_topo_info_t *local_rank_topo = NULL; - MPIDI_POSIX_global.num_vcis = MPIDI_global.n_total_vcis; - for (int i = 1; i < MPIDI_POSIX_global.num_vcis; i++) { - mpi_errno = init_vci(i); - MPIR_ERR_CHECK(mpi_errno); - } - - mpi_errno = MPIDI_POSIX_eager_post_init(); - MPIR_ERR_CHECK(mpi_errno); - /* gather topo info from local procs and calculate distance */ if (MPIR_CVAR_CH4_SHM_POSIX_TOPO_ENABLE && MPIR_Process.local_size > 1) { int topo_info_size = sizeof(MPIDI_POSIX_topo_info_t); diff --git a/src/mpid/ch4/shm/posix/posix_types.h b/src/mpid/ch4/shm/posix/posix_types.h index 1884d776b6c..8740afa77e3 100644 --- a/src/mpid/ch4/shm/posix/posix_types.h +++ b/src/mpid/ch4/shm/posix/posix_types.h @@ -51,7 +51,7 @@ typedef struct { int *local_ranks; int *local_procs; int local_rank_0; - int num_vcis; + int num_vcis; /* num_vcis in POSIX need >= MPIDI_global.n_total_vcis */ int *local_rank_dist; MPIDI_POSIX_topo_info_t topo; } MPIDI_POSIX_global_t; diff --git a/src/mpid/ch4/shm/posix/posix_vci.c b/src/mpid/ch4/shm/posix/posix_vci.c index 9b5534be040..173ac3be5f0 100644 --- a/src/mpid/ch4/shm/posix/posix_vci.c +++ b/src/mpid/ch4/shm/posix/posix_vci.c @@ -9,5 +9,32 @@ int MPIDI_POSIX_comm_set_vcis(MPIR_Comm * comm, int num_vcis) { int mpi_errno = MPI_SUCCESS; + + /* We only set up vcis once */ + MPIR_Assert(MPIDI_POSIX_global.num_vcis == 1); + + MPIR_Comm *node_comm = comm->node_comm; + if (node_comm == NULL) { + /* nothing to do if there is no local domain */ + goto fn_exit; + } + + int max_vcis; + mpi_errno = MPIR_Allreduce_impl(&num_vcis, &max_vcis, 1, MPI_INT, MPI_MAX, node_comm, + MPIR_ERR_NONE); + + MPIDI_POSIX_global.num_vcis = max_vcis; + + for (int i = 1; i < MPIDI_POSIX_global.num_vcis; i++) { + mpi_errno = MPIDI_POSIX_init_vci(i); + MPIR_ERR_CHECK(mpi_errno); + } + + mpi_errno = MPIDI_POSIX_eager_set_vcis(comm); + MPIR_ERR_CHECK(mpi_errno); + + fn_exit: return mpi_errno; + fn_fail: + goto fn_exit; } From 665e0def7f53f1f86b063fda7f2e88629a0c753f Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Wed, 8 Jan 2025 11:24:57 -0600 Subject: [PATCH 20/25] mpid/shm: remove MPIDU_shm_seg_t and refactor MPIDU_shm_seg_t was used by mpidu_shm_alloc.c, mpidu_init_shm.c, and mpidu_init_shm_alloc.c. However, the usages are all slightly different and some fields are only used in one but not the other. It is simpler to locally define it or, in the case of mpidu_init_shm.c, just use static globals. --- src/mpid/common/shm/mpidu_init_shm.c | 40 +++++++++++----------- src/mpid/common/shm/mpidu_init_shm_alloc.c | 31 +++++++++++------ src/mpid/common/shm/mpidu_shm_alloc.c | 11 +++++- src/mpid/common/shm/mpidu_shm_seg.h | 23 ------------- 4 files changed, 50 insertions(+), 55 deletions(-) delete mode 100644 src/mpid/common/shm/mpidu_shm_seg.h diff --git a/src/mpid/common/shm/mpidu_init_shm.c b/src/mpid/common/shm/mpidu_init_shm.c index 899865e2488..357fcfa8412 100644 --- a/src/mpid/common/shm/mpidu_init_shm.c +++ b/src/mpid/common/shm/mpidu_init_shm.c @@ -8,7 +8,6 @@ #include "mpl_shm.h" #include "mpidimpl.h" #include "mpir_pmi.h" -#include "mpidu_shm_seg.h" static int init_shm_initialized; @@ -57,7 +56,11 @@ typedef struct Init_shm_barrier { static int local_size; static int my_local_rank; -static MPIDU_shm_seg_t memory; + +static size_t init_shm_len; +static MPL_shm_hnd_t init_shm_hnd; +static char *init_shm_addr; + static Init_shm_barrier_t *barrier; static void *baseaddr; @@ -69,7 +72,7 @@ static int Init_shm_barrier_init(int is_root) MPIR_FUNC_ENTER; - barrier = (Init_shm_barrier_t *) memory.base_addr; + barrier = (Init_shm_barrier_t *) init_shm_addr; if (is_root) { MPL_atomic_store_int(&barrier->val, 0); MPL_atomic_store_int(&barrier->wait, 0); @@ -125,10 +128,10 @@ int MPIDU_Init_shm_init(void) char *serialized_hnd = NULL; int serialized_hnd_size = 0; - mpl_err = MPL_shm_hnd_init(&(memory.hnd)); + mpl_err = MPL_shm_hnd_init(&init_shm_hnd); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); - memory.segment_len = segment_len; + init_shm_len = segment_len; if (local_size == 1) { char *addr; @@ -136,24 +139,23 @@ int MPIDU_Init_shm_init(void) MPIR_CHKPMEM_MALLOC(addr, char *, segment_len + MPIDU_SHM_CACHE_LINE_LEN, mpi_errno, "segment", MPL_MEM_SHM); - memory.base_addr = addr; + init_shm_addr = addr; baseaddr = (char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) & (~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1))); - memory.symmetrical = 0; mpi_errno = Init_shm_barrier_init(TRUE); MPIR_ERR_CHECK(mpi_errno); } else { if (my_local_rank == 0) { /* root prepare shm segment */ - mpl_err = MPL_shm_seg_create_and_attach(memory.hnd, memory.segment_len, - (void **) &(memory.base_addr), 0); + mpl_err = MPL_shm_seg_create_and_attach(init_shm_hnd, init_shm_len, + (void **) &(init_shm_addr), 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); MPIR_Assert(MPIR_Process.node_local_map[0] == MPIR_Process.rank); - mpl_err = MPL_shm_hnd_get_serialized_by_ref(memory.hnd, &serialized_hnd); + mpl_err = MPL_shm_hnd_get_serialized_by_ref(init_shm_hnd, &serialized_hnd); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); serialized_hnd_size = strlen(serialized_hnd) + 1; MPIR_Assert(serialized_hnd_size < MPIR_pmi_max_val_size()); @@ -176,11 +178,10 @@ int MPIDU_Init_shm_init(void) MPIR_Assert(local_size > 1); if (my_local_rank > 0) { /* non-root attach shm segment */ - mpl_err = MPL_shm_hnd_deserialize(memory.hnd, serialized_hnd, strlen(serialized_hnd)); + mpl_err = MPL_shm_hnd_deserialize(init_shm_hnd, serialized_hnd, strlen(serialized_hnd)); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); - mpl_err = MPL_shm_seg_attach(memory.hnd, memory.segment_len, - (void **) &memory.base_addr, 0); + mpl_err = MPL_shm_seg_attach(init_shm_hnd, init_shm_len, (void **) &init_shm_addr, 0); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); mpi_errno = Init_shm_barrier_init(FALSE); @@ -191,13 +192,12 @@ int MPIDU_Init_shm_init(void) MPIR_ERR_CHECK(mpi_errno); if (my_local_rank == 0) { - /* memory->hnd no longer needed */ - mpl_err = MPL_shm_seg_remove(memory.hnd); + /* init_shm_hnd no longer needed */ + mpl_err = MPL_shm_seg_remove(init_shm_hnd); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem"); } - baseaddr = memory.base_addr + MPIDU_SHM_CACHE_LINE_LEN; - memory.symmetrical = 0; + baseaddr = init_shm_addr + MPIDU_SHM_CACHE_LINE_LEN; } mpi_errno = Init_shm_barrier(); @@ -225,13 +225,13 @@ int MPIDU_Init_shm_finalize(void) } if (local_size == 1) - MPL_free(memory.base_addr); + MPL_free(init_shm_addr); else { - mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len); + mpl_err = MPL_shm_seg_detach(init_shm_hnd, (void **) &(init_shm_addr), init_shm_len); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem"); } - MPL_shm_hnd_finalize(&(memory.hnd)); + MPL_shm_hnd_finalize(&(init_shm_hnd)); init_shm_initialized = 0; diff --git a/src/mpid/common/shm/mpidu_init_shm_alloc.c b/src/mpid/common/shm/mpidu_init_shm_alloc.c index 61b3c7d8943..379c4b1cbaf 100644 --- a/src/mpid/common/shm/mpidu_init_shm_alloc.c +++ b/src/mpid/common/shm/mpidu_init_shm_alloc.c @@ -6,7 +6,6 @@ #include #include "mpl_shm.h" #include "mpidu_init_shm.h" -#include "mpidu_shm_seg.h" #include #ifdef HAVE_UNISTD_H @@ -19,16 +18,24 @@ #include #endif +struct memory_seg { + size_t segment_len; + MPL_shm_hnd_t hnd; + char *base_addr; + bool symmetrical; + bool is_shm; +}; + typedef struct memory_list { void *ptr; - MPIDU_shm_seg_t *memory; + struct memory_seg *memory; struct memory_list *next; } memory_list_t; static memory_list_t *memory_head = NULL; static memory_list_t *memory_tail = NULL; -static int check_alloc(MPIDU_shm_seg_t * memory); +static int check_alloc(struct memory_seg *memory); /* MPIDU_Init_shm_alloc(len, ptr_p) @@ -41,7 +48,7 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) size_t segment_len = len; int local_rank = MPIR_Process.local_rank; int num_local = MPIR_Process.local_size; - MPIDU_shm_seg_t *memory = NULL; + struct memory_seg *memory = NULL; memory_list_t *memory_node = NULL; MPIR_CHKPMEM_DECL(3); @@ -49,7 +56,7 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) MPIR_Assert(segment_len > 0); - MPIR_CHKPMEM_MALLOC(memory, MPIDU_shm_seg_t *, sizeof(*memory), mpi_errno, "memory_handle", + MPIR_CHKPMEM_MALLOC(memory, struct memory_seg *, sizeof(*memory), mpi_errno, "memory_handle", MPL_MEM_OTHER); mpl_err = MPL_shm_hnd_init(&(memory->hnd)); @@ -70,7 +77,8 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) current_addr = (char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) & (~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1))); - memory->symmetrical = 1; + memory->symmetrical = true; + memory->is_shm = false; } else { if (local_rank == 0) { /* root prepare shm segment */ @@ -106,7 +114,8 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem"); } current_addr = memory->base_addr; - memory->symmetrical = 0; + memory->symmetrical = false; + memory->is_shm = true; mpi_errno = check_alloc(memory); MPIR_ERR_CHECK(mpi_errno); @@ -139,7 +148,7 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) int MPIDU_Init_shm_free(void *ptr) { int mpi_errno = MPI_SUCCESS, mpl_err = 0; - MPIDU_shm_seg_t *memory = NULL; + struct memory_seg *memory = NULL; memory_list_t *el = NULL; MPIR_FUNC_ENTER; @@ -192,7 +201,7 @@ int MPIDU_Init_shm_is_symm(void *ptr) /* check_alloc() checks to see whether the shared memory segment is allocated at the same virtual memory address at each process. */ -static int check_alloc(MPIDU_shm_seg_t * memory) +static int check_alloc(struct memory_seg *memory) { int mpi_errno = MPI_SUCCESS; int is_sym; @@ -225,9 +234,9 @@ static int check_alloc(MPIDU_shm_seg_t * memory) } if (is_sym) { - memory->symmetrical = 1; + memory->symmetrical = true; } else { - memory->symmetrical = 0; + memory->symmetrical = false; } MPIR_FUNC_EXIT; diff --git a/src/mpid/common/shm/mpidu_shm_alloc.c b/src/mpid/common/shm/mpidu_shm_alloc.c index 72fdaa0fa55..8e5926d7929 100644 --- a/src/mpid/common/shm/mpidu_shm_alloc.c +++ b/src/mpid/common/shm/mpidu_shm_alloc.c @@ -6,7 +6,6 @@ #include #include "mpl_shm.h" #include "mpidu_shm.h" -#include "mpidu_shm_seg.h" #include #ifdef HAVE_UNISTD_H @@ -58,6 +57,16 @@ enum { SYMSHM_OTHER_FAIL /* other failure reported by MPL shm */ }; +typedef struct MPIDU_shm_seg { + size_t segment_len; + /* Handle to shm seg */ + MPL_shm_hnd_t hnd; + /* Pointers */ + char *base_addr; + /* Misc */ + int symmetrical; +} MPIDU_shm_seg_t; + /* Linked list internally used to keep track * of allocate shared memory segments */ typedef struct seg_list { diff --git a/src/mpid/common/shm/mpidu_shm_seg.h b/src/mpid/common/shm/mpidu_shm_seg.h deleted file mode 100644 index 842ee14a74d..00000000000 --- a/src/mpid/common/shm/mpidu_shm_seg.h +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (C) by Argonne National Laboratory - * See COPYRIGHT in top-level directory - */ - -#ifndef MPIDU_SHM_SEG_H_INCLUDED -#define MPIDU_SHM_SEG_H_INCLUDED - -#include "mpidu_init_shm.h" - -typedef struct MPIDU_shm_seg { - size_t segment_len; - /* Handle to shm seg */ - MPL_shm_hnd_t hnd; - /* Pointers */ - char *base_addr; - /* Misc */ - char file_name[MPIDU_SHM_MAX_FNAME_LEN]; - int base_descs; - int symmetrical; -} MPIDU_shm_seg_t; - -#endif /* MPIDU_SHM_SEG_H_INCLUDED */ From 46226c9c5c7bccb914f696e1f63e1c5fde96d051 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 30 Dec 2024 10:10:41 -0600 Subject: [PATCH 21/25] init_shm: add MPIDU_Init_shm_comm_alloc Add routine to support allocating a shared memory by a comm, which allows - * create shared memory by a smaller comm than a comm_world * attach the shared memory by later processes * potentially allowing shm communication with dynamic processes - we need a way to discover and attach to Init_shm (via intercomm) and the initial shared memory need pre-allocate to account for new processes. For now, we need this to support MPIDI_POSIX_comm_set_vcis. --- src/mpid/common/shm/mpidu_init_shm.c | 49 +++++++-- src/mpid/common/shm/mpidu_init_shm.h | 13 ++- src/mpid/common/shm/mpidu_init_shm_alloc.c | 117 ++++++++++++++++++++- 3 files changed, 167 insertions(+), 12 deletions(-) diff --git a/src/mpid/common/shm/mpidu_init_shm.c b/src/mpid/common/shm/mpidu_init_shm.c index 357fcfa8412..fc0e45463da 100644 --- a/src/mpid/common/shm/mpidu_init_shm.c +++ b/src/mpid/common/shm/mpidu_init_shm.c @@ -52,6 +52,10 @@ int MPIDU_Init_shm_query(int local_rank, void **target_addr) typedef struct Init_shm_barrier { MPL_atomic_int_t val; MPL_atomic_int_t wait; + /* fields that support async shm alloc */ + MPL_atomic_int_t lock; + MPL_atomic_int_t alloc_count; + char serialized_hnd[MPIDU_INIT_SHM_BLOCK_SIZE]; } Init_shm_barrier_t; static int local_size; @@ -123,7 +127,7 @@ int MPIDU_Init_shm_init(void) local_size = MPIR_Process.local_size; my_local_rank = MPIR_Process.local_rank; - size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN + sizeof(MPIDU_Init_shm_block_t) * local_size; + size_t segment_len = sizeof(Init_shm_barrier_t) + MPIDU_INIT_SHM_BLOCK_SIZE * local_size; char *serialized_hnd = NULL; int serialized_hnd_size = 0; @@ -197,7 +201,7 @@ int MPIDU_Init_shm_init(void) MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem"); } - baseaddr = init_shm_addr + MPIDU_SHM_CACHE_LINE_LEN; + baseaddr = init_shm_addr + sizeof(Init_shm_barrier_t); } mpi_errno = Init_shm_barrier(); @@ -261,8 +265,8 @@ int MPIDU_Init_shm_put(void *orig, size_t len) MPIR_FUNC_ENTER; - MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t)); - MPIR_Memcpy((char *) baseaddr + my_local_rank * sizeof(MPIDU_Init_shm_block_t), orig, len); + MPIR_Assert(len <= MPIDU_INIT_SHM_BLOCK_SIZE); + MPIR_Memcpy((char *) baseaddr + my_local_rank * MPIDU_INIT_SHM_BLOCK_SIZE, orig, len); MPIR_FUNC_EXIT; @@ -275,8 +279,8 @@ int MPIDU_Init_shm_get(int local_rank, size_t len, void *target) MPIR_FUNC_ENTER; - MPIR_Assert(local_rank < local_size && len <= sizeof(MPIDU_Init_shm_block_t)); - MPIR_Memcpy(target, (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t), len); + MPIR_Assert(local_rank < local_size && len <= MPIDU_INIT_SHM_BLOCK_SIZE); + MPIR_Memcpy(target, (char *) baseaddr + local_rank * MPIDU_INIT_SHM_BLOCK_SIZE, len); MPIR_FUNC_EXIT; @@ -290,11 +294,42 @@ int MPIDU_Init_shm_query(int local_rank, void **target_addr) MPIR_FUNC_ENTER; MPIR_Assert(local_rank < local_size); - *target_addr = (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t); + *target_addr = (char *) baseaddr + local_rank * MPIDU_INIT_SHM_BLOCK_SIZE; MPIR_FUNC_EXIT; return mpi_errno; } +int MPIDU_Init_shm_atomic_count(void) +{ + return MPL_atomic_load_int(&barrier->alloc_count); +} + +int MPIDU_Init_shm_atomic_put(void *orig, size_t len) +{ + /* get spin lock */ + while (MPL_atomic_cas_int(&barrier->lock, 0, 1)) { + } + /* set the data */ + MPIR_Assert(len <= MPIDU_INIT_SHM_BLOCK_SIZE); + MPIR_Memcpy(barrier->serialized_hnd, orig, len); + MPL_atomic_store_int(&barrier->alloc_count, 1); + /* unlock */ + MPL_atomic_store_int(&barrier->lock, 0); + return MPI_SUCCESS; +} + +int MPIDU_Init_shm_atomic_get(void *target, size_t len) +{ + /* get spin lock */ + while (MPL_atomic_cas_int(&barrier->lock, 0, 1)) { + } + /* copy the data */ + MPIR_Memcpy(target, barrier->serialized_hnd, len); + /* unlock */ + MPL_atomic_store_int(&barrier->lock, 0); + return MPI_SUCCESS; +} + #endif /* ENABLE_NO_LOCAL */ diff --git a/src/mpid/common/shm/mpidu_init_shm.h b/src/mpid/common/shm/mpidu_init_shm.h index 0de1a1aa5af..43ebaf630a1 100644 --- a/src/mpid/common/shm/mpidu_init_shm.h +++ b/src/mpid/common/shm/mpidu_init_shm.h @@ -17,10 +17,6 @@ shared memory will not work properly. Consider disable it with --enable-nolocal. #endif -typedef struct MPIDU_Init_shm_block { - char block[MPIDU_INIT_SHM_BLOCK_SIZE]; -} MPIDU_Init_shm_block_t; - int MPIDU_Init_shm_init(void); int MPIDU_Init_shm_finalize(void); int MPIDU_Init_shm_barrier(void); @@ -32,4 +28,13 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr); int MPIDU_Init_shm_free(void *ptr); int MPIDU_Init_shm_is_symm(void *ptr); +/* support routines for MPIDU_Init_shm_root_alloc */ +int MPIDU_Init_shm_atomic_count(void); +int MPIDU_Init_shm_atomic_put(void *orig, size_t len); +int MPIDU_Init_shm_atomic_get(void *target, size_t len); + +/* comm root allocate a shared memory and put the serialized handle if previously not allocated, + * Otherwise, retrieve the handle and attach. Broadcast the handle to rest of the comm */ +int MPIDU_Init_shm_comm_alloc(MPIR_Comm * comm, size_t len, void **ptr); + #endif /* MPIDU_INIT_SHM_H_INCLUDED */ diff --git a/src/mpid/common/shm/mpidu_init_shm_alloc.c b/src/mpid/common/shm/mpidu_init_shm_alloc.c index 379c4b1cbaf..2915ab1f86a 100644 --- a/src/mpid/common/shm/mpidu_init_shm_alloc.c +++ b/src/mpid/common/shm/mpidu_init_shm_alloc.c @@ -144,6 +144,121 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) /* --END ERROR HANDLING-- */ } +int MPIDU_Init_shm_comm_alloc(MPIR_Comm * comm, size_t len, void **ptr) +{ + int mpi_errno = MPI_SUCCESS, mpl_err = 0; + void *current_addr; + size_t segment_len = len; + struct memory_seg *memory = NULL; + memory_list_t *memory_node = NULL; + MPIR_CHKPMEM_DECL(3); + + MPIR_FUNC_ENTER; + + MPIR_Comm *node_comm = comm->node_comm; + bool is_root; + if (node_comm) { + is_root = (node_comm->rank == 0); + } else { + is_root = true; + } + + MPIR_Assert(segment_len > 0); + MPIR_CHKPMEM_MALLOC(memory, struct memory_seg *, sizeof(*memory), mpi_errno, "memory_handle", + MPL_MEM_OTHER); + mpl_err = MPL_shm_hnd_init(&(memory->hnd)); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); + + memory->segment_len = segment_len; + + char *serialized_hnd = NULL; + int serialized_hnd_size = 0; + char serialized_hnd_buffer[MPIDU_INIT_SHM_BLOCK_SIZE]; + bool need_attach; + bool need_remove; + if (is_root) { + if (MPIDU_Init_shm_atomic_count() == 0) { + /* We need to create the shm segment */ + mpl_err = MPL_shm_seg_create_and_attach(memory->hnd, memory->segment_len, + (void **) &(memory->base_addr), 0); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); + + mpl_err = MPL_shm_hnd_get_serialized_by_ref(memory->hnd, &serialized_hnd); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); + serialized_hnd_size = strlen(serialized_hnd) + 1; /* add 1 for null char */ + + MPIDU_Init_shm_atomic_put(serialized_hnd, serialized_hnd_size); + need_attach = false; + need_remove = true; + } else { + /* Just retrieve the existing serialized handle */ + MPIDU_Init_shm_atomic_get(serialized_hnd_buffer, MPIDU_INIT_SHM_BLOCK_SIZE); + serialized_hnd = serialized_hnd_buffer; + serialized_hnd_size = strlen(serialized_hnd) + 1; /* add 1 for null char */ + need_attach = true; + need_remove = false; + } + MPIR_Assert(serialized_hnd_size <= MPIDU_INIT_SHM_BLOCK_SIZE); + if (node_comm) { + mpi_errno = MPIR_Bcast_impl(serialized_hnd, MPIDU_INIT_SHM_BLOCK_SIZE, MPI_CHAR, + 0, node_comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + } + } else { + mpi_errno = MPIR_Bcast_impl(serialized_hnd_buffer, MPIDU_INIT_SHM_BLOCK_SIZE, MPI_CHAR, + 0, node_comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + serialized_hnd = serialized_hnd_buffer; + serialized_hnd_size = strlen(serialized_hnd) + 1; /* add 1 for null char */ + need_attach = true; + need_remove = false; + } + if (need_attach) { + mpl_err = MPL_shm_hnd_deserialize(memory->hnd, serialized_hnd, strlen(serialized_hnd)); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); + + mpl_err = MPL_shm_seg_attach(memory->hnd, memory->segment_len, + (void **) &memory->base_addr, 0); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**attach_shar_mem"); + } + + if (node_comm) { + mpi_errno = MPIR_Barrier_impl(node_comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + } + if (need_remove) { + /* memory->hnd no longer needed */ + mpl_err = MPL_shm_seg_remove(memory->hnd); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem"); + } + + current_addr = memory->base_addr; + memory->symmetrical = false; + memory->is_shm = true; + + mpi_errno = check_alloc(memory); + MPIR_ERR_CHECK(mpi_errno); + + /* assign sections of the shared memory segment to their pointers */ + *ptr = current_addr; + + MPIR_CHKPMEM_MALLOC(memory_node, memory_list_t *, sizeof(*memory_node), mpi_errno, + "memory_node", MPL_MEM_OTHER); + memory_node->ptr = *ptr; + memory_node->memory = memory; + LL_APPEND(memory_head, memory_tail, memory_node); + + MPIR_CHKPMEM_COMMIT(); + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + MPL_shm_seg_remove(memory->hnd); + MPL_shm_hnd_finalize(&(memory->hnd)); + MPIR_CHKPMEM_REAP(); + goto fn_exit; +} + /* MPIDU_SHM_Seg_free() free the shared memory segment */ int MPIDU_Init_shm_free(void *ptr) { @@ -165,7 +280,7 @@ int MPIDU_Init_shm_free(void *ptr) MPIR_Assert(memory != NULL); - if (MPIR_Process.local_size == 1) + if (!memory->is_shm) MPL_free(memory->base_addr); else { mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr), From 67aec2b82cebb089a69973f75d82940033fbe573 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Mon, 30 Dec 2024 10:25:39 -0600 Subject: [PATCH 22/25] posix/iqueue: support comm_set_vcis --- .../ch4/shm/posix/eager/include/posix_eager.h | 4 +- .../ch4/shm/posix/eager/iqueue/iqueue_init.c | 60 ++++++++++--------- .../shm/posix/eager/iqueue/iqueue_noinline.h | 2 +- .../ch4/shm/posix/eager/iqueue/iqueue_types.h | 2 +- .../shm/posix/eager/src/posix_eager_impl.c | 4 +- .../ch4/shm/posix/eager/stub/stub_noinline.h | 2 +- src/mpid/ch4/shm/posix/posix_vci.c | 24 ++++---- 7 files changed, 51 insertions(+), 47 deletions(-) diff --git a/src/mpid/ch4/shm/posix/eager/include/posix_eager.h b/src/mpid/ch4/shm/posix/eager/include/posix_eager.h index e1970c38b66..028da8cb30a 100644 --- a/src/mpid/ch4/shm/posix/eager/include/posix_eager.h +++ b/src/mpid/ch4/shm/posix/eager/include/posix_eager.h @@ -12,7 +12,7 @@ typedef int (*MPIDI_POSIX_eager_init_t) (int rank, int size); typedef int (*MPIDI_POSIX_eager_post_init_t) (void); -typedef int (*MPIDI_POSIX_eager_set_vcis_t) (MPIR_Comm * comm); +typedef int (*MPIDI_POSIX_eager_set_vcis_t) (MPIR_Comm * comm, int num_vcis); typedef int (*MPIDI_POSIX_eager_finalize_t) (void); typedef int (*MPIDI_POSIX_eager_send_t) (int grank, MPIDI_POSIX_am_header_t * msg_hdr, @@ -61,7 +61,7 @@ extern char MPIDI_POSIX_eager_strings[][MPIDI_MAX_POSIX_EAGER_STRING_LEN]; int MPIDI_POSIX_eager_init(int rank, int size); int MPIDI_POSIX_eager_post_init(void); -int MPIDI_POSIX_eager_set_vcis(MPIR_Comm * comm); +int MPIDI_POSIX_eager_set_vcis(MPIR_Comm * comm, int num_vcis); int MPIDI_POSIX_eager_finalize(void); MPL_STATIC_INLINE_PREFIX int MPIDI_POSIX_eager_send(int grank, MPIDI_POSIX_am_header_t * msg_hdr, diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c index 415b7ac970f..6ee422cb4e3 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c +++ b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c @@ -101,7 +101,7 @@ int MPIDI_POSIX_iqueue_init(int rank, int size) int nprocs = MPIR_Process.local_size; int pool_size = MPIDU_genq_shmem_pool_size(cell_size, num_cells, nprocs, num_free_queue); - int terminal_size = num_proc * sizeof(MPIDU_genq_shmem_queue_u); + int terminal_size = nprocs * sizeof(MPIDU_genq_shmem_queue_u); int slab_size = pool_size + terminal_size; @@ -133,26 +133,22 @@ int MPIDI_POSIX_iqueue_post_init(void) return mpi_errno; } -int MPIDI_POSIX_iqueue_set_vcis(MPIR_Comm * comm) +int MPIDI_POSIX_iqueue_set_vcis(MPIR_Comm * comm, int max_vcis) { int mpi_errno = MPI_SUCCESS; MPIR_FUNC_ENTER; - MPIR_Assert(comm == MPIR_Process.comm_world); /* TODO: relax this */ - MPIR_Assert(MPIDI_POSIX_eager_iqueue_global.all_slab == NULL); + MPIR_Assert(MPIDI_POSIX_eager_iqueue_global.all_vci_slab == NULL); - int max_vcis = MPIDI_POSIX_global.num_vcis; - MPIDI_POSIX_eager_iqueue_global.max_vcis = max_vcis; MPIDU_Init_shm_barrier(); int slab_size = MPIDI_POSIX_eager_iqueue_global.slab_size * max_vcis * max_vcis; /* Create the shared memory regions for all vcis */ - /* TODO: do shm alloc in a comm */ void *slab; - mpi_errno = MPIDU_Init_shm_alloc(slab_size, (void *) &slab); + mpi_errno = MPIDU_Init_shm_comm_alloc(comm, slab_size, (void *) &slab); MPIR_ERR_CHECK(mpi_errno); - MPIDI_POSIX_eager_iqueue_global.all_slab = slab; + MPIDI_POSIX_eager_iqueue_global.all_vci_slab = slab; for (int vci_src = 0; vci_src < max_vcis; vci_src++) { for (int vci_dst = 0; vci_dst < max_vcis; vci_dst++) { @@ -166,8 +162,12 @@ int MPIDI_POSIX_iqueue_set_vcis(MPIR_Comm * comm) } } - mpi_errno = MPIDU_Init_shm_barrier(); - MPIR_ERR_CHECK(mpi_errno); + MPIDI_POSIX_eager_iqueue_global.max_vcis = max_vcis; + + if (comm->node_comm) { + mpi_errno = MPIR_Barrier_impl(comm->node_comm, MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); + } fn_exit: return mpi_errno; @@ -181,34 +181,36 @@ int MPIDI_POSIX_iqueue_finalize(void) MPIR_FUNC_ENTER; - if (MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab) { + if (MPIDI_POSIX_eager_iqueue_global.root_slab) { MPIDI_POSIX_eager_iqueue_transport_t *transport; - transport = MPIDI_POSIX_eager_iqueue_get_transport(vci_src, vci_dst); + transport = MPIDI_POSIX_eager_iqueue_get_transport(0, 0); mpi_errno = MPIDU_genq_shmem_pool_destroy(transport->cell_pool); MPIR_ERR_CHECK(mpi_errno); - mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab); + mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.root_slab); MPIR_ERR_CHECK(mpi_errno); - MPIDI_POSIX_eager_iqueue_global.max_vcis.root_slab = NULL; + MPIDI_POSIX_eager_iqueue_global.root_slab = NULL; } - if (!MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab) { - goto fn_exit; - } - int max_vcis = MPIDI_POSIX_eager_iqueue_global.max_vcis; - for (int vci_src = 0; vci_src < max_vcis; vci_src++) { - for (int vci_dst = 0; vci_dst < max_vcis; vci_dst++) { - MPIDI_POSIX_eager_iqueue_transport_t *transport; - transport = MPIDI_POSIX_eager_iqueue_get_transport(vci_src, vci_dst); - - mpi_errno = MPIDU_genq_shmem_pool_destroy(transport->cell_pool); - MPIR_ERR_CHECK(mpi_errno); + if (MPIDI_POSIX_eager_iqueue_global.all_vci_slab) { + int max_vcis = MPIDI_POSIX_eager_iqueue_global.max_vcis; + for (int vci_src = 0; vci_src < max_vcis; vci_src++) { + for (int vci_dst = 0; vci_dst < max_vcis; vci_dst++) { + if (vci_src == 0 && vci_dst == 0) { + continue; + } + MPIDI_POSIX_eager_iqueue_transport_t *transport; + transport = MPIDI_POSIX_eager_iqueue_get_transport(vci_src, vci_dst); + + mpi_errno = MPIDU_genq_shmem_pool_destroy(transport->cell_pool); + MPIR_ERR_CHECK(mpi_errno); + } } + mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.all_vci_slab); + MPIR_ERR_CHECK(mpi_errno); + MPIDI_POSIX_eager_iqueue_global.all_vci_slab = NULL; } - mpi_errno = MPIDU_Init_shm_free(MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab); - MPIR_ERR_CHECK(mpi_errno); - MPIDI_POSIX_eager_iqueue_global.max_vcis.all_slab = NULL; fn_exit: MPIR_FUNC_EXIT; diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h index 02dd65f4773..9586c04a0a7 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h +++ b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_noinline.h @@ -11,7 +11,7 @@ int MPIDI_POSIX_iqueue_init(int rank, int size); int MPIDI_POSIX_iqueue_post_init(void); -int MPIDI_POSIX_iqueue_set_vcis(MPIR_Comm * comm); +int MPIDI_POSIX_iqueue_set_vcis(MPIR_Comm * comm, int num_vcis); int MPIDI_POSIX_iqueue_finalize(void); #ifdef POSIX_EAGER_INLINE diff --git a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h index 105eb82a4da..f77799a25e8 100644 --- a/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h +++ b/src/mpid/ch4/shm/posix/eager/iqueue/iqueue_types.h @@ -40,7 +40,7 @@ typedef struct MPIDI_POSIX_eager_iqueue_global { int terminal_offset; /* shmem slabs */ void *root_slab; - void *all_slab; + void *all_vci_slab; /* 2d array indexed with [src_vci][dst_vci] */ MPIDI_POSIX_eager_iqueue_transport_t transports[MPIDI_CH4_MAX_VCIS][MPIDI_CH4_MAX_VCIS]; } MPIDI_POSIX_eager_iqueue_global_t; diff --git a/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c b/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c index fe1f60efcc5..5d32a23f95b 100644 --- a/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c +++ b/src/mpid/ch4/shm/posix/eager/src/posix_eager_impl.c @@ -19,9 +19,9 @@ int MPIDI_POSIX_eager_post_init(void) return MPIDI_POSIX_eager_func->post_init(); } -int MPIDI_POSIX_eager_set_vcis(MPIR_Comm * comm) +int MPIDI_POSIX_eager_set_vcis(MPIR_Comm * comm, int num_vcis) { - return MPIDI_POSIX_eager_func->set_vcis(comm); + return MPIDI_POSIX_eager_func->set_vcis(comm, num_vcis); } int MPIDI_POSIX_eager_finalize(void) diff --git a/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h b/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h index 81c51f6d2ee..566b8d5d201 100644 --- a/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h +++ b/src/mpid/ch4/shm/posix/eager/stub/stub_noinline.h @@ -10,7 +10,7 @@ int MPIDI_POSIX_stub_init(int rank, int size); int MPIDI_POSIX_stub_post_init(void); -int MPIDI_POSIX_stub_set_vcis(MPIR_Comm * comm); +int MPIDI_POSIX_stub_set_vcis(MPIR_Comm * comm, int num_vcis); int MPIDI_POSIX_stub_finalize(void); #ifdef POSIX_EAGER_INLINE diff --git a/src/mpid/ch4/shm/posix/posix_vci.c b/src/mpid/ch4/shm/posix/posix_vci.c index 173ac3be5f0..3c3c4a8c7f6 100644 --- a/src/mpid/ch4/shm/posix/posix_vci.c +++ b/src/mpid/ch4/shm/posix/posix_vci.c @@ -13,25 +13,27 @@ int MPIDI_POSIX_comm_set_vcis(MPIR_Comm * comm, int num_vcis) /* We only set up vcis once */ MPIR_Assert(MPIDI_POSIX_global.num_vcis == 1); + int max_vcis; MPIR_Comm *node_comm = comm->node_comm; if (node_comm == NULL) { - /* nothing to do if there is no local domain */ - goto fn_exit; + max_vcis = num_vcis; + } else { + mpi_errno = MPIR_Allreduce_impl(&num_vcis, &max_vcis, 1, MPI_INT, MPI_MAX, node_comm, + MPIR_ERR_NONE); + MPIR_ERR_CHECK(mpi_errno); } - int max_vcis; - mpi_errno = MPIR_Allreduce_impl(&num_vcis, &max_vcis, 1, MPI_INT, MPI_MAX, node_comm, - MPIR_ERR_NONE); - - MPIDI_POSIX_global.num_vcis = max_vcis; + if (max_vcis > 1) { + for (int i = 1; i < max_vcis; i++) { + mpi_errno = MPIDI_POSIX_init_vci(i); + MPIR_ERR_CHECK(mpi_errno); + } - for (int i = 1; i < MPIDI_POSIX_global.num_vcis; i++) { - mpi_errno = MPIDI_POSIX_init_vci(i); + mpi_errno = MPIDI_POSIX_eager_set_vcis(comm, max_vcis); MPIR_ERR_CHECK(mpi_errno); } - mpi_errno = MPIDI_POSIX_eager_set_vcis(comm); - MPIR_ERR_CHECK(mpi_errno); + MPIDI_POSIX_global.num_vcis = max_vcis; fn_exit: return mpi_errno; From 7e23fe8276c6f8881eaad256d27543ea1867b5a3 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 3 Jan 2025 23:13:54 -0600 Subject: [PATCH 23/25] ch4/proc: use calloc to zero MPIDI_global.avt_mgr.av_table0 We need ensure the extra fields, such as MPIDI_OFI_AV(av, all_dest), are initialized to NULL. --- src/mpid/ch4/src/ch4_proc.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mpid/ch4/src/ch4_proc.c b/src/mpid/ch4/src/ch4_proc.c index 56cb70b48c1..dad44ef8a67 100644 --- a/src/mpid/ch4/src/ch4_proc.c +++ b/src/mpid/ch4/src/ch4_proc.c @@ -164,7 +164,7 @@ int MPIDIU_avt_init(void) int size = MPIR_Process.size; int rank = MPIR_Process.rank; size_t table_size = sizeof(MPIDI_av_table_t) + size * sizeof(MPIDI_av_entry_t); - MPIDI_global.avt_mgr.av_table0 = (MPIDI_av_table_t *) MPL_malloc(table_size, MPL_MEM_ADDRESS); + MPIDI_global.avt_mgr.av_table0 = MPL_calloc(1, table_size, MPL_MEM_ADDRESS); MPIR_Assert(MPIDI_global.avt_mgr.av_table0); #if MPIDI_CH4_AVTABLE_USE_DDR From de52554e2e04fec9a660bfabbb6d8b2172cdf130 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 7 Jan 2025 22:34:09 -0600 Subject: [PATCH 24/25] ch4/ofi: MPIDI_OFI_AV_ADDR --- src/mpid/ch4/netmod/ofi/ofi_impl.h | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index beaf34f768a..9bde0db8c00 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -39,16 +39,22 @@ ATTRIBUTE((unused)); #ifdef MPIDI_OFI_VNI_USE_DOMAIN #define MPIDI_OFI_AV_ADDR_ROOT(av) \ MPIDI_OFI_AV(av).root_dest -#define MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic) \ +#define MPIDI_OFI_AV_ADDR_OFFSET(av, vci, nic) \ + (MPIDI_OFI_AV(av).all_dest[(vci)*MPIDI_OFI_global.num_nics+(nic)] + MPIDI_OFI_AV(av).root_offset) +#define MPIDI_OFI_AV_ADDR_NO_OFFSET(av, vci, nic) \ MPIDI_OFI_AV(av).all_dest[(vci)*MPIDI_OFI_global.num_nics+(nic)] #else /* scalable endpoints - all vci share the same addr */ -#define MPIDI_OFI_AV_ADDR_ROOT(av, vci, nic) \ +#define MPIDI_OFI_AV_ADDR_ROOT(av) \ MPIDI_OFI_AV(av).root_dest -#define MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic) \ +#define MPIDI_OFI_AV_ADDR_OFFSET(av, vci, nic) \ + (MPIDI_OFI_AV(av).all_dest[nic] + MPIDI_OFI_AV(av).root_offset) +#define MPIDI_OFI_AV_ADDR_NO_OFFSET(av, vci, nic) \ MPIDI_OFI_AV(av).all_dest[nic] #endif -#define MPIDI_OFI_AV_ADDR(av, vci, nic) \ - ((vci==0 && nic==0) ? MPIDI_OFI_AV_ADDR_ROOT(av) : MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic)) +#define MPIDI_OFI_AV_ADDR(av, local_vci, local_nic, vci, nic) \ + ((local_vci==0 && local_nic==0) ? \ + ((vci == 0 && nic == 0) ? MPIDI_OFI_AV_ADDR_ROOT(av) : MPIDI_OFI_AV_ADDR_OFFSET(av, vci, nic)) : \ + MPIDI_OFI_AV_ADDR_NO_OFFSET(av, vci, nic)) #define MPIDI_OFI_WIN(win) ((win)->dev.netmod.ofi) @@ -457,18 +463,19 @@ MPL_STATIC_INLINE_PREFIX fi_addr_t MPIDI_OFI_av_to_phys(MPIDI_av_entry_t * av, int local_vci, int local_nic, int vci, int nic) { + fi_addr_t dest = MPIDI_OFI_AV_ADDR(av, local_vci, local_nic, vci, nic); #ifdef MPIDI_OFI_VNI_USE_DOMAIN if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) { - return fi_rx_addr(MPIDI_OFI_AV_ADDR(av, vci, nic), 0, MPIDI_OFI_MAX_ENDPOINTS_BITS); + return fi_rx_addr(dest, 0, MPIDI_OFI_MAX_ENDPOINTS_BITS); } else { - return MPIDI_OFI_AV_ADDR(av, vci, nic); + return dest; } #else /* MPIDI_OFI_VNI_USE_SEPCTX */ if (MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS) { - return fi_rx_addr(MPIDI_OFI_AV_ADDR(av, vci, nic), vci, MPIDI_OFI_MAX_ENDPOINTS_BITS); + return fi_rx_addr(dest, vci, MPIDI_OFI_MAX_ENDPOINTS_BITS); } else { MPIR_Assert(vci == 0); - return MPIDI_OFI_AV_ADDR(av, vci, nic); + return dest; } #endif } From 03831440019db047fd9432e13e226a69bc15d07c Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Tue, 7 Jan 2025 22:59:19 -0600 Subject: [PATCH 25/25] ch4/ofi: fix the av table assumptions Because we insert all remote endpoints to all local endpoints at the same time, thus follow the exact same insertion order, they will share the same av table index except for the local root endpoint because it has inserted other remote root endpoints at init time. The local root to remote non-root endpoints will have a fixed offset from that of local non-root. --- src/mpid/ch4/netmod/ofi/ofi_pre.h | 9 +++++++-- src/mpid/ch4/netmod/ofi/ofi_vci.c | 19 +++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_pre.h b/src/mpid/ch4/netmod/ofi/ofi_pre.h index 92166fe33dc..52f205e104f 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_pre.h +++ b/src/mpid/ch4/netmod/ofi/ofi_pre.h @@ -311,9 +311,14 @@ typedef struct { /* Maximum number of network interfaces CH4 can support. */ #define MPIDI_OFI_MAX_NICS 8 +/* Imagine a dimension of [local_vci][local_nic][rank][vci][nic] - + * all local endpoints will share the same remote address due to the same insertion order + * and use of FI_AV_TABLE except the local root endpoint. + */ typedef struct { - fi_addr_t root_dest; - fi_addr_t *all_dest; /* to be allocated into an array of [nic * vci] */ + fi_addr_t root_dest; /* [0][0][r][0][0] */ + fi_addr_t root_offset; /* [0][0][r][vci][nic] - [*][*][r][vci][nic] */ + fi_addr_t *all_dest; /* [*][*][r][vci][nic] */ } MPIDI_OFI_addr_t; #endif /* OFI_PRE_H_INCLUDED */ diff --git a/src/mpid/ch4/netmod/ofi/ofi_vci.c b/src/mpid/ch4/netmod/ofi/ofi_vci.c index 9f52d12ae7a..baa65ccac3c 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_vci.c +++ b/src/mpid/ch4/netmod/ofi/ofi_vci.c @@ -231,12 +231,13 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm, int *all_num_vcis) /* insert and store non-root nic/vci on the root context */ for (int r = 0; r < nprocs; r++) { + fi_addr_t expect_addr = FI_ADDR_NOTAVAIL; + fi_addr_t root_offset = 0; GET_AV_AND_ADDRNAMES(r); /* for each remote endpoints */ for (int nic = 0; nic < num_nics; nic++) { for (int vci = 0; vci < NUM_VCIS_FOR_RANK(r); vci++) { /* for each local endpoints */ - fi_addr_t expect_addr = FI_ADDR_NOTAVAIL; for (int nic_local = 0; nic_local < num_nics; nic_local++) { for (int vci_local = 0; vci_local < my_num_vcis; vci_local++) { /* skip root */ @@ -245,18 +246,28 @@ static int addr_exchange_all_ctx(MPIR_Comm * comm, int *all_num_vcis) } int ctx_idx = MPIDI_OFI_get_ctx_index(vci_local, nic_local); DO_AV_INSERT(ctx_idx, nic, vci); - /* we expect all resulting addr to be the same */ + /* we expect all resulting addr to be the same except for local root endpoint, which + * will have an offset */ if (expect_addr == FI_ADDR_NOTAVAIL) { expect_addr = addr; + } else if (nic_local == 0 && vci_local == 0) { + if (root_offset == 0) { + root_offset = addr - expect_addr; + } else { + MPIR_Assert(addr == expect_addr + root_offset); + } } else { - MPIR_Assert(expect_addr == addr); + MPIR_Assert(addr == expect_addr); } } } MPIR_Assert(expect_addr != FI_ADDR_NOTAVAIL); - MPIDI_OFI_AV_ADDR_NONROOT(av, vci, nic) = expect_addr; + MPIDI_OFI_AV_ADDR_NO_OFFSET(av, vci, nic) = expect_addr; + /* next */ + expect_addr++; } } + MPIDI_OFI_AV(av).root_offset = root_offset; } mpi_errno = MPIR_Barrier_fallback(comm, MPIR_ERR_NONE);