Skip to content

Commit

Permalink
init_shm: shortcut the trivial case of local_size is 1
Browse files Browse the repository at this point in the history
It's simpler to shortcut the trivial case.

Set global MPIDU_Init_shm_local_size and MPIDU_Init_shm_local_rank. This
prepares the flexibility that later we can extend Init_shm to dynamic
processes.
  • Loading branch information
hzhou committed Jan 1, 2025
1 parent 12c89ba commit d774dd3
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 81 deletions.
100 changes: 47 additions & 53 deletions src/mpid/common/shm/mpidu_init_shm.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

static int init_shm_initialized;

int MPIDU_Init_shm_local_size;
int MPIDU_Init_shm_local_rank;

#ifdef ENABLE_NO_LOCAL
/* shared memory disabled, just stubs */

Expand Down Expand Up @@ -55,8 +58,6 @@ typedef struct Init_shm_barrier {
MPL_atomic_int_t wait;
} Init_shm_barrier_t;

static int local_size;
static int my_local_rank;
static MPIDU_shm_seg_t memory;
static Init_shm_barrier_t *barrier;
static void *baseaddr;
Expand Down Expand Up @@ -88,12 +89,12 @@ static int Init_shm_barrier(void)

MPIR_FUNC_ENTER;

if (local_size == 1)
if (MPIDU_Init_shm_local_size == 1)
goto fn_exit;

MPIR_ERR_CHKINTERNAL(!barrier_init, mpi_errno, "barrier not initialized");

if (MPL_atomic_fetch_add_int(&barrier->val, 1) == local_size - 1) {
if (MPL_atomic_fetch_add_int(&barrier->val, 1) == MPIDU_Init_shm_local_size - 1) {
MPL_atomic_store_int(&barrier->val, 0);
MPL_atomic_store_int(&barrier->wait, 1 - sense);
} else {
Expand All @@ -117,35 +118,24 @@ int MPIDU_Init_shm_init(void)

MPIR_FUNC_ENTER;

local_size = MPIR_Process.local_size;
my_local_rank = MPIR_Process.local_rank;

size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN + sizeof(MPIDU_Init_shm_block_t) * local_size;

char *serialized_hnd = NULL;
int serialized_hnd_size = 0;
MPIDU_Init_shm_local_size = MPIR_Process.local_size;
MPIDU_Init_shm_local_rank = MPIR_Process.local_rank;

mpl_err = MPL_shm_hnd_init(&(memory.hnd));
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");

memory.segment_len = segment_len;
if (MPIDU_Init_shm_local_size == 1) {
/* We'll special case this trivial case */
} else {
size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN +
sizeof(MPIDU_Init_shm_block_t) * MPIDU_Init_shm_local_size;

if (local_size == 1) {
char *addr;
char *serialized_hnd = NULL;
int serialized_hnd_size = 0;

MPIR_CHKPMEM_MALLOC(addr, char *, segment_len + MPIDU_SHM_CACHE_LINE_LEN, mpi_errno,
"segment", MPL_MEM_SHM);
mpl_err = MPL_shm_hnd_init(&(memory.hnd));
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");

memory.base_addr = addr;
baseaddr =
(char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) &
(~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1)));
memory.symmetrical = 0;
memory.segment_len = segment_len;

mpi_errno = Init_shm_barrier_init(TRUE);
MPIR_ERR_CHECK(mpi_errno);
} else {
if (my_local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
/* root prepare shm segment */
mpl_err = MPL_shm_seg_create_and_attach(memory.hnd, memory.segment_len,
(void **) &(memory.base_addr), 0);
Expand All @@ -166,15 +156,13 @@ int MPIDU_Init_shm_init(void)
MPIR_CHKLMEM_MALLOC(serialized_hnd, char *, serialized_hnd_size, mpi_errno, "val",
MPL_MEM_OTHER);
}
}
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
* inside depend on PMI versions, and all processes need participate.
*/
mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL);
MPIR_ERR_CHECK(mpi_errno);
if (local_size != 1) {
MPIR_Assert(local_size > 1);
if (my_local_rank > 0) {
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
* inside depend on PMI versions, and all processes need participate.
*/
mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL);
MPIR_ERR_CHECK(mpi_errno);

if (MPIDU_Init_shm_local_rank > 0) {
/* non-root attach shm segment */
mpl_err = MPL_shm_hnd_deserialize(memory.hnd, serialized_hnd, strlen(serialized_hnd));
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
Expand All @@ -190,18 +178,17 @@ int MPIDU_Init_shm_init(void)
mpi_errno = Init_shm_barrier();
MPIR_ERR_CHECK(mpi_errno);

if (my_local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
/* memory->hnd no longer needed */
mpl_err = MPL_shm_seg_remove(memory.hnd);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem");
}

baseaddr = memory.base_addr + MPIDU_SHM_CACHE_LINE_LEN;
memory.symmetrical = 0;
}

mpi_errno = Init_shm_barrier();
MPIR_CHKPMEM_COMMIT();
mpi_errno = Init_shm_barrier();
}

init_shm_initialized = 1;

Expand All @@ -220,16 +207,12 @@ int MPIDU_Init_shm_finalize(void)

MPIR_FUNC_ENTER;

if (!init_shm_initialized) {
if (!init_shm_initialized || MPIDU_Init_shm_local_size == 1) {
goto fn_exit;
}

if (local_size == 1)
MPL_free(memory.base_addr);
else {
mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
}
mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");

MPL_shm_hnd_finalize(&(memory.hnd));

Expand All @@ -248,7 +231,9 @@ int MPIDU_Init_shm_barrier(void)

MPIR_FUNC_ENTER;

mpi_errno = Init_shm_barrier();
if (MPIDU_Init_shm_local_size > 1) {
mpi_errno = Init_shm_barrier();
}

MPIR_FUNC_EXIT;

Expand All @@ -261,8 +246,11 @@ int MPIDU_Init_shm_put(void *orig, size_t len)

MPIR_FUNC_ENTER;

MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t));
MPIR_Memcpy((char *) baseaddr + my_local_rank * sizeof(MPIDU_Init_shm_block_t), orig, len);
if (MPIDU_Init_shm_local_size > 1) {
MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t));
MPIR_Memcpy((char *) baseaddr + MPIDU_Init_shm_local_rank * sizeof(MPIDU_Init_shm_block_t),
orig, len);
}

MPIR_FUNC_EXIT;

Expand All @@ -275,7 +263,10 @@ int MPIDU_Init_shm_get(int local_rank, size_t len, void *target)

MPIR_FUNC_ENTER;

MPIR_Assert(local_rank < local_size && len <= sizeof(MPIDU_Init_shm_block_t));
/* a single process should not get its own put */
MPIR_Assert(MPIDU_Init_shm_local_size > 1);

MPIR_Assert(local_rank < MPIDU_Init_shm_local_size && len <= sizeof(MPIDU_Init_shm_block_t));
MPIR_Memcpy(target, (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t), len);

MPIR_FUNC_EXIT;
Expand All @@ -289,7 +280,10 @@ int MPIDU_Init_shm_query(int local_rank, void **target_addr)

MPIR_FUNC_ENTER;

MPIR_Assert(local_rank < local_size);
/* a single process should not get its own put */
MPIR_Assert(MPIDU_Init_shm_local_size > 1);

MPIR_Assert(local_rank < MPIDU_Init_shm_local_size);
*target_addr = (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t);

MPIR_FUNC_EXIT;
Expand Down
60 changes: 32 additions & 28 deletions src/mpid/common/shm/mpidu_init_shm_alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <sys/shm.h>
#endif

extern int MPIDU_Init_shm_local_size;
extern int MPIDU_Init_shm_local_rank;

typedef struct memory_list {
void *ptr;
MPIDU_shm_seg_t *memory;
Expand All @@ -39,8 +42,6 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
int mpi_errno = MPI_SUCCESS, mpl_err = 0;
void *current_addr;
size_t segment_len = len;
int local_rank = MPIR_Process.local_rank;
int num_local = MPIR_Process.local_size;
MPIDU_shm_seg_t *memory = NULL;
memory_list_t *memory_node = NULL;
MPIR_CHKPMEM_DECL(3);
Expand All @@ -49,6 +50,12 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)

MPIR_Assert(segment_len > 0);

if (MPIDU_Init_shm_local_size == 1) {
*ptr = MPL_aligned_alloc(MPL_CACHELINE_SIZE, len, MPL_MEM_SHM);
MPIR_ERR_CHKANDJUMP(!*ptr, mpi_errno, MPI_ERR_OTHER, "**nomem");
goto fn_exit;
}

MPIR_CHKPMEM_MALLOC(memory, MPIDU_shm_seg_t *, sizeof(*memory), mpi_errno, "memory_handle",
MPL_MEM_OTHER);

Expand All @@ -59,20 +66,9 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)

char *serialized_hnd = NULL;
int serialized_hnd_size = 0;
/* if there is only one process on this processor, don't use shared memory */
if (num_local == 1) {
char *addr;

MPIR_CHKPMEM_MALLOC(addr, char *, segment_len + MPIDU_SHM_CACHE_LINE_LEN, mpi_errno,
"segment", MPL_MEM_SHM);

memory->base_addr = addr;
current_addr =
(char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) &
(~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1)));
memory->symmetrical = 1;
} else {
if (local_rank == 0) {
{
if (MPIDU_Init_shm_local_rank == 0) {
/* root prepare shm segment */
mpl_err = MPL_shm_seg_create_and_attach(memory->hnd, memory->segment_len,
(void **) &(memory->base_addr), 0);
Expand Down Expand Up @@ -100,7 +96,7 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)

MPIDU_Init_shm_barrier();

if (local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
/* memory->hnd no longer needed */
mpl_err = MPL_shm_seg_remove(memory->hnd);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem");
Expand Down Expand Up @@ -128,8 +124,10 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
return mpi_errno;
fn_fail:
/* --BEGIN ERROR HANDLING-- */
MPL_shm_seg_remove(memory->hnd);
MPL_shm_hnd_finalize(&(memory->hnd));
if (MPIDU_Init_shm_local_size > 1) {
MPL_shm_seg_remove(memory->hnd);
MPL_shm_hnd_finalize(&(memory->hnd));
}
MPIR_CHKPMEM_REAP();
goto fn_exit;
/* --END ERROR HANDLING-- */
Expand All @@ -144,6 +142,11 @@ int MPIDU_Init_shm_free(void *ptr)

MPIR_FUNC_ENTER;

if (MPIDU_Init_shm_local_size == 1) {
MPL_free(ptr);
goto fn_exit;
}

/* retrieve memory handle for baseaddr */
LL_FOREACH(memory_head, el) {
if (el->ptr == ptr) {
Expand All @@ -156,17 +159,14 @@ int MPIDU_Init_shm_free(void *ptr)

MPIR_Assert(memory != NULL);

if (MPIR_Process.local_size == 1)
MPL_free(memory->base_addr);
else {
mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr),
memory->segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
}
mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr), memory->segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");

fn_exit:
MPL_shm_hnd_finalize(&(memory->hnd));
MPL_free(memory);
if (MPIDU_Init_shm_local_size > 1) {
MPL_shm_hnd_finalize(&(memory->hnd));
MPL_free(memory);
}
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
Expand All @@ -178,6 +178,10 @@ int MPIDU_Init_shm_is_symm(void *ptr)
int ret = -1;
memory_list_t *el;

if (MPIDU_Init_shm_local_size == 1) {
return 1;
}

/* retrieve memory handle for baseaddr */
LL_FOREACH(memory_head, el) {
if (el->ptr == ptr) {
Expand All @@ -200,7 +204,7 @@ static int check_alloc(MPIDU_shm_seg_t * memory)

MPIR_FUNC_ENTER;

if (MPIR_Process.local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
MPIDU_Init_shm_put(memory->base_addr, sizeof(void *));
}

Expand Down

0 comments on commit d774dd3

Please sign in to comment.