diff --git a/src/mpid/common/shm/mpidu_init_shm.c b/src/mpid/common/shm/mpidu_init_shm.c index 899865e2488..51d9596b8f4 100644 --- a/src/mpid/common/shm/mpidu_init_shm.c +++ b/src/mpid/common/shm/mpidu_init_shm.c @@ -12,6 +12,9 @@ static int init_shm_initialized; +int MPIDU_Init_shm_local_size; +int MPIDU_Init_shm_local_rank; + #ifdef ENABLE_NO_LOCAL /* shared memory disabled, just stubs */ @@ -55,8 +58,6 @@ typedef struct Init_shm_barrier { MPL_atomic_int_t wait; } Init_shm_barrier_t; -static int local_size; -static int my_local_rank; static MPIDU_shm_seg_t memory; static Init_shm_barrier_t *barrier; static void *baseaddr; @@ -88,12 +89,12 @@ static int Init_shm_barrier(void) MPIR_FUNC_ENTER; - if (local_size == 1) + if (MPIDU_Init_shm_local_size == 1) goto fn_exit; MPIR_ERR_CHKINTERNAL(!barrier_init, mpi_errno, "barrier not initialized"); - if (MPL_atomic_fetch_add_int(&barrier->val, 1) == local_size - 1) { + if (MPL_atomic_fetch_add_int(&barrier->val, 1) == MPIDU_Init_shm_local_size - 1) { MPL_atomic_store_int(&barrier->val, 0); MPL_atomic_store_int(&barrier->wait, 1 - sense); } else { @@ -117,35 +118,24 @@ int MPIDU_Init_shm_init(void) MPIR_FUNC_ENTER; - local_size = MPIR_Process.local_size; - my_local_rank = MPIR_Process.local_rank; - - size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN + sizeof(MPIDU_Init_shm_block_t) * local_size; - - char *serialized_hnd = NULL; - int serialized_hnd_size = 0; + MPIDU_Init_shm_local_size = MPIR_Process.local_size; + MPIDU_Init_shm_local_rank = MPIR_Process.local_rank; - mpl_err = MPL_shm_hnd_init(&(memory.hnd)); - MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); - - memory.segment_len = segment_len; + if (MPIDU_Init_shm_local_size == 1) { + /* We'll special case this trivial case */ + } else { + size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN + + sizeof(MPIDU_Init_shm_block_t) * MPIDU_Init_shm_local_size; - if (local_size == 1) { - char *addr; + char *serialized_hnd = NULL; + int serialized_hnd_size = 0; - MPIR_CHKPMEM_MALLOC(addr, char *, segment_len + MPIDU_SHM_CACHE_LINE_LEN, mpi_errno, - "segment", MPL_MEM_SHM); + mpl_err = MPL_shm_hnd_init(&(memory.hnd)); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); - memory.base_addr = addr; - baseaddr = - (char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) & - (~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1))); - memory.symmetrical = 0; + memory.segment_len = segment_len; - mpi_errno = Init_shm_barrier_init(TRUE); - MPIR_ERR_CHECK(mpi_errno); - } else { - if (my_local_rank == 0) { + if (MPIDU_Init_shm_local_rank == 0) { /* root prepare shm segment */ mpl_err = MPL_shm_seg_create_and_attach(memory.hnd, memory.segment_len, (void **) &(memory.base_addr), 0); @@ -166,15 +156,13 @@ int MPIDU_Init_shm_init(void) MPIR_CHKLMEM_MALLOC(serialized_hnd, char *, serialized_hnd_size, mpi_errno, "val", MPL_MEM_OTHER); } - } - /* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier - * inside depend on PMI versions, and all processes need participate. - */ - mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL); - MPIR_ERR_CHECK(mpi_errno); - if (local_size != 1) { - MPIR_Assert(local_size > 1); - if (my_local_rank > 0) { + /* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier + * inside depend on PMI versions, and all processes need participate. + */ + mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL); + MPIR_ERR_CHECK(mpi_errno); + + if (MPIDU_Init_shm_local_rank > 0) { /* non-root attach shm segment */ mpl_err = MPL_shm_hnd_deserialize(memory.hnd, serialized_hnd, strlen(serialized_hnd)); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem"); @@ -190,7 +178,7 @@ int MPIDU_Init_shm_init(void) mpi_errno = Init_shm_barrier(); MPIR_ERR_CHECK(mpi_errno); - if (my_local_rank == 0) { + if (MPIDU_Init_shm_local_rank == 0) { /* memory->hnd no longer needed */ mpl_err = MPL_shm_seg_remove(memory.hnd); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem"); @@ -198,10 +186,9 @@ int MPIDU_Init_shm_init(void) baseaddr = memory.base_addr + MPIDU_SHM_CACHE_LINE_LEN; memory.symmetrical = 0; - } - mpi_errno = Init_shm_barrier(); - MPIR_CHKPMEM_COMMIT(); + mpi_errno = Init_shm_barrier(); + } init_shm_initialized = 1; @@ -220,16 +207,12 @@ int MPIDU_Init_shm_finalize(void) MPIR_FUNC_ENTER; - if (!init_shm_initialized) { + if (!init_shm_initialized || MPIDU_Init_shm_local_size == 1) { goto fn_exit; } - if (local_size == 1) - MPL_free(memory.base_addr); - else { - mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len); - MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem"); - } + mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem"); MPL_shm_hnd_finalize(&(memory.hnd)); @@ -248,7 +231,9 @@ int MPIDU_Init_shm_barrier(void) MPIR_FUNC_ENTER; - mpi_errno = Init_shm_barrier(); + if (MPIDU_Init_shm_local_size > 1) { + mpi_errno = Init_shm_barrier(); + } MPIR_FUNC_EXIT; @@ -261,8 +246,11 @@ int MPIDU_Init_shm_put(void *orig, size_t len) MPIR_FUNC_ENTER; - MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t)); - MPIR_Memcpy((char *) baseaddr + my_local_rank * sizeof(MPIDU_Init_shm_block_t), orig, len); + if (MPIDU_Init_shm_local_size > 1) { + MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t)); + MPIR_Memcpy((char *) baseaddr + MPIDU_Init_shm_local_rank * sizeof(MPIDU_Init_shm_block_t), + orig, len); + } MPIR_FUNC_EXIT; @@ -275,7 +263,10 @@ int MPIDU_Init_shm_get(int local_rank, size_t len, void *target) MPIR_FUNC_ENTER; - MPIR_Assert(local_rank < local_size && len <= sizeof(MPIDU_Init_shm_block_t)); + /* a single process should not get its own put */ + MPIR_Assert(MPIDU_Init_shm_local_size > 1); + + MPIR_Assert(local_rank < MPIDU_Init_shm_local_size && len <= sizeof(MPIDU_Init_shm_block_t)); MPIR_Memcpy(target, (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t), len); MPIR_FUNC_EXIT; @@ -289,7 +280,10 @@ int MPIDU_Init_shm_query(int local_rank, void **target_addr) MPIR_FUNC_ENTER; - MPIR_Assert(local_rank < local_size); + /* a single process should not get its own put */ + MPIR_Assert(MPIDU_Init_shm_local_size > 1); + + MPIR_Assert(local_rank < MPIDU_Init_shm_local_size); *target_addr = (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t); MPIR_FUNC_EXIT; diff --git a/src/mpid/common/shm/mpidu_init_shm_alloc.c b/src/mpid/common/shm/mpidu_init_shm_alloc.c index 61b3c7d8943..76735199f39 100644 --- a/src/mpid/common/shm/mpidu_init_shm_alloc.c +++ b/src/mpid/common/shm/mpidu_init_shm_alloc.c @@ -19,6 +19,9 @@ #include #endif +extern int MPIDU_Init_shm_local_size; +extern int MPIDU_Init_shm_local_rank; + typedef struct memory_list { void *ptr; MPIDU_shm_seg_t *memory; @@ -39,8 +42,6 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) int mpi_errno = MPI_SUCCESS, mpl_err = 0; void *current_addr; size_t segment_len = len; - int local_rank = MPIR_Process.local_rank; - int num_local = MPIR_Process.local_size; MPIDU_shm_seg_t *memory = NULL; memory_list_t *memory_node = NULL; MPIR_CHKPMEM_DECL(3); @@ -49,6 +50,12 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) MPIR_Assert(segment_len > 0); + if (MPIDU_Init_shm_local_size == 1) { + *ptr = MPL_aligned_alloc(MPL_CACHELINE_SIZE, len, MPL_MEM_SHM); + MPIR_ERR_CHKANDJUMP(!*ptr, mpi_errno, MPI_ERR_OTHER, "**nomem"); + goto fn_exit; + } + MPIR_CHKPMEM_MALLOC(memory, MPIDU_shm_seg_t *, sizeof(*memory), mpi_errno, "memory_handle", MPL_MEM_OTHER); @@ -59,20 +66,9 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) char *serialized_hnd = NULL; int serialized_hnd_size = 0; - /* if there is only one process on this processor, don't use shared memory */ - if (num_local == 1) { - char *addr; - - MPIR_CHKPMEM_MALLOC(addr, char *, segment_len + MPIDU_SHM_CACHE_LINE_LEN, mpi_errno, - "segment", MPL_MEM_SHM); - memory->base_addr = addr; - current_addr = - (char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) & - (~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1))); - memory->symmetrical = 1; - } else { - if (local_rank == 0) { + { + if (MPIDU_Init_shm_local_rank == 0) { /* root prepare shm segment */ mpl_err = MPL_shm_seg_create_and_attach(memory->hnd, memory->segment_len, (void **) &(memory->base_addr), 0); @@ -100,7 +96,7 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) MPIDU_Init_shm_barrier(); - if (local_rank == 0) { + if (MPIDU_Init_shm_local_rank == 0) { /* memory->hnd no longer needed */ mpl_err = MPL_shm_seg_remove(memory->hnd); MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem"); @@ -128,8 +124,10 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr) return mpi_errno; fn_fail: /* --BEGIN ERROR HANDLING-- */ - MPL_shm_seg_remove(memory->hnd); - MPL_shm_hnd_finalize(&(memory->hnd)); + if (MPIDU_Init_shm_local_size > 1) { + MPL_shm_seg_remove(memory->hnd); + MPL_shm_hnd_finalize(&(memory->hnd)); + } MPIR_CHKPMEM_REAP(); goto fn_exit; /* --END ERROR HANDLING-- */ @@ -144,6 +142,11 @@ int MPIDU_Init_shm_free(void *ptr) MPIR_FUNC_ENTER; + if (MPIDU_Init_shm_local_size == 1) { + MPL_free(ptr); + goto fn_exit; + } + /* retrieve memory handle for baseaddr */ LL_FOREACH(memory_head, el) { if (el->ptr == ptr) { @@ -156,17 +159,14 @@ int MPIDU_Init_shm_free(void *ptr) MPIR_Assert(memory != NULL); - if (MPIR_Process.local_size == 1) - MPL_free(memory->base_addr); - else { - mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr), - memory->segment_len); - MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem"); - } + mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr), memory->segment_len); + MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem"); fn_exit: - MPL_shm_hnd_finalize(&(memory->hnd)); - MPL_free(memory); + if (MPIDU_Init_shm_local_size > 1) { + MPL_shm_hnd_finalize(&(memory->hnd)); + MPL_free(memory); + } MPIR_FUNC_EXIT; return mpi_errno; fn_fail: @@ -178,6 +178,10 @@ int MPIDU_Init_shm_is_symm(void *ptr) int ret = -1; memory_list_t *el; + if (MPIDU_Init_shm_local_size == 1) { + return 1; + } + /* retrieve memory handle for baseaddr */ LL_FOREACH(memory_head, el) { if (el->ptr == ptr) { @@ -200,7 +204,7 @@ static int check_alloc(MPIDU_shm_seg_t * memory) MPIR_FUNC_ENTER; - if (MPIR_Process.local_rank == 0) { + if (MPIDU_Init_shm_local_rank == 0) { MPIDU_Init_shm_put(memory->base_addr, sizeof(void *)); }