Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

group: refactor MPIR_Group #7235

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions src/include/mpir_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,7 @@
* only because they are required for the group operations (e.g.,
* MPI_Group_intersection) and for the scalable RMA synchronization
*---------------------------------------------------------------------------*/
/* This structure is used to implement the group operations such as
MPI_Group_translate_ranks */
/* note: next_lpid (with idx_of_first_lpid in MPIR_Group) gives a linked list
* in a sorted lpid ascending order */
typedef struct MPII_Group_pmap_t {
uint64_t lpid; /* local process id, from VCONN */
int next_lpid; /* Index of next lpid (in lpid order) */
} MPII_Group_pmap_t;

/* Any changes in the MPIR_Group structure must be made to the
predefined value in MPIR_Group_builtin for MPI_GROUP_EMPTY in
src/mpi/group/grouputil.c */

/*S
MPIR_Group - Description of the Group data structure

Expand Down Expand Up @@ -53,22 +42,32 @@ typedef struct MPII_Group_pmap_t {
Group-DS

S*/

/* Abstract the integer type for lpid (process id). It is possible to use 32-bit
* in principle, but 64-bit is simpler since we can trivially combine
* (world_idx, world_rank).
*/
typedef uint64_t MPIR_Lpid;

struct MPIR_Pmap {
int size; /* same as group->size, duplicate here so Pmap is logically complete */
bool use_map;
union {
MPIR_Lpid *map;
struct {
MPIR_Lpid offset;
MPIR_Lpid stride;
MPIR_Lpid blocksize;
} stride;
} u;
};

struct MPIR_Group {
MPIR_OBJECT_HEADER; /* adds handle and ref_count fields */
int size; /* Size of a group */
int rank; /* rank of this process relative to this
* group */
int idx_of_first_lpid;
MPII_Group_pmap_t *lrank_to_lpid; /* Array mapping a local rank to local
* process number */
int is_local_dense_monotonic; /* see NOTE-G1 */

/* We may want some additional data for the RMA syncrhonization calls */
/* Other, device-specific information */
#ifdef MPID_DEV_GROUP_DECL
MPID_DEV_GROUP_DECL
#endif
MPIR_Session * session_ptr; /* Pointer to session to which this group belongs */
int rank; /* rank of this process relative to this group */
struct MPIR_Pmap pmap;
MPIR_Session *session_ptr; /* Pointer to session to which this group belongs */
};

/* NOTE-G1: is_local_dense_monotonic will be true iff the group meets the
Expand Down Expand Up @@ -97,18 +96,21 @@ extern MPIR_Group *const MPIR_Group_empty;
#define MPIR_Group_release_ref(_group, _inuse) \
do { MPIR_Object_release_ref(_group, _inuse); } while (0)

void MPII_Group_setup_lpid_list(MPIR_Group *);
int MPIR_Group_check_valid_ranks(MPIR_Group *, const int[], int);
int MPIR_Group_check_valid_ranges(MPIR_Group *, int[][3], int);
void MPIR_Group_setup_lpid_pairs(MPIR_Group *, MPIR_Group *);
int MPIR_Group_create(int, MPIR_Group **);
int MPIR_Group_release(MPIR_Group * group_ptr);

int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_Lpid * map,
MPIR_Group ** new_group_ptr);
int MPIR_Group_create_stride(int size, int rank, MPIR_Session * session_ptr,
MPIR_Lpid offset, MPIR_Lpid stride, MPIR_Lpid blocksize,
MPIR_Group ** new_group_ptr);
MPIR_Lpid MPIR_Group_rank_to_lpid(MPIR_Group * group, int rank);
int MPIR_Group_lpid_to_rank(MPIR_Group * group, MPIR_Lpid lpid);

int MPIR_Group_check_subset(MPIR_Group * group_ptr, MPIR_Comm * comm_ptr);
void MPIR_Group_set_session_ptr(MPIR_Group * group_ptr, MPIR_Session * session_out);
int MPIR_Group_init(void);

/* internal functions */
void MPII_Group_setup_lpid_list(MPIR_Group *);

#endif /* MPIR_GROUP_H_INCLUDED */
64 changes: 18 additions & 46 deletions src/mpi/comm/comm_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,36 +68,19 @@ int MPIR_Comm_test_threadcomm_impl(MPIR_Comm * comm_ptr, int *flag)
static int comm_create_local_group(MPIR_Comm * comm_ptr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_Group *group_ptr;
int n = comm_ptr->local_size;

mpi_errno = MPIR_Group_create(n, &group_ptr);
MPIR_ERR_CHECK(mpi_errno);

/* Group belongs to the same session as communicator */
MPIR_Group_set_session_ptr(group_ptr, comm_ptr->session_ptr);

group_ptr->is_local_dense_monotonic = TRUE;
int n = comm_ptr->local_size;
MPIR_Lpid *map = MPL_malloc(n * sizeof(MPIR_Lpid), MPL_MEM_GROUP);

int comm_world_size = MPIR_Process.size;
for (int i = 0; i < n; i++) {
uint64_t lpid;
(void) MPID_Comm_get_lpid(comm_ptr, i, &lpid, FALSE);
group_ptr->lrank_to_lpid[i].lpid = lpid;
if (lpid > comm_world_size || (i > 0 && group_ptr->lrank_to_lpid[i - 1].lpid != (lpid - 1))) {
group_ptr->is_local_dense_monotonic = FALSE;
}
map[i] = lpid;
}

group_ptr->size = n;
group_ptr->rank = comm_ptr->rank;
group_ptr->idx_of_first_lpid = -1;

comm_ptr->local_group = group_ptr;

/* FIXME : Add a sanity check that the size of the group is the same as
* the size of the communicator. This helps catch corrupted
* communicators */
mpi_errno = MPIR_Group_create_map(n, comm_ptr->rank, comm_ptr->session_ptr, map,
&comm_ptr->local_group);
MPIR_ERR_CHECK(mpi_errno);

fn_exit:
return mpi_errno;
Expand Down Expand Up @@ -215,16 +198,13 @@ int MPII_Comm_create_calculate_mapping(MPIR_Group * group_ptr,
* exactly the same as the ranks in comm world.
*/

/* we examine the group's lpids in both the intracomm and non-comm_world cases */
MPII_Group_setup_lpid_list(group_ptr);

/* Optimize for groups contained within MPI_COMM_WORLD. */
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
int wsize;
subsetOfWorld = 1;
wsize = MPIR_Process.size;
for (i = 0; i < n; i++) {
uint64_t g_lpid = group_ptr->lrank_to_lpid[i].lpid;
MPIR_Lpid g_lpid = MPIR_Group_rank_to_lpid(group_ptr, i);

/* This mapping is relative to comm world */
MPL_DBG_MSG_FMT(MPIR_DBG_COMM, VERBOSE,
Expand Down Expand Up @@ -261,7 +241,7 @@ int MPII_Comm_create_calculate_mapping(MPIR_Group * group_ptr,
for (j = 0; j < comm_ptr->local_size; j++) {
uint64_t comm_lpid;
MPID_Comm_get_lpid(comm_ptr, j, &comm_lpid, FALSE);
if (comm_lpid == group_ptr->lrank_to_lpid[i].lpid) {
if (comm_lpid == MPIR_Group_rank_to_lpid(group_ptr, i)) {
mapping[i] = j;
break;
}
Expand Down Expand Up @@ -800,7 +780,7 @@ int MPIR_Intercomm_create_from_groups_impl(MPIR_Group * local_group_ptr, int loc

int tag = get_tag_from_stringtag(stringtag);
/* FIXME: ensure lpid is from comm_world */
uint64_t remote_lpid = remote_group_ptr->lrank_to_lpid[remote_leader].lpid;
MPIR_Lpid remote_lpid = MPIR_Group_rank_to_lpid(remote_group_ptr, remote_leader);
MPIR_Assert(remote_lpid < MPIR_Process.size);
mpi_errno = MPIR_Intercomm_create_impl(local_comm, local_leader,
MPIR_Process.comm_world, (int) remote_lpid,
Expand Down Expand Up @@ -931,31 +911,23 @@ int MPIR_Comm_idup_with_info_impl(MPIR_Comm * comm_ptr, MPIR_Info * info,
int MPIR_Comm_remote_group_impl(MPIR_Comm * comm_ptr, MPIR_Group ** group_ptr)
{
int mpi_errno = MPI_SUCCESS;
int i, n;

MPIR_FUNC_ENTER;

/* Create a group and populate it with the local process ids */
if (!comm_ptr->remote_group) {
n = comm_ptr->remote_size;
mpi_errno = MPIR_Group_create(n, group_ptr);
MPIR_ERR_CHECK(mpi_errno);
int n = comm_ptr->remote_size;
MPIR_Lpid *map = MPL_malloc(n * sizeof(MPIR_Lpid), MPL_MEM_GROUP);

for (i = 0; i < n; i++) {
for (int i = 0; i < n; i++) {
uint64_t lpid;
(void) MPID_Comm_get_lpid(comm_ptr, i, &lpid, TRUE);
(*group_ptr)->lrank_to_lpid[i].lpid = lpid;
/* TODO calculate is_local_dense_monotonic */
map[i] = lpid;
}
(*group_ptr)->size = n;
(*group_ptr)->rank = MPI_UNDEFINED;
(*group_ptr)->idx_of_first_lpid = -1;

MPIR_Group_set_session_ptr(*group_ptr, comm_ptr->session_ptr);

comm_ptr->remote_group = *group_ptr;
} else {
*group_ptr = comm_ptr->remote_group;
mpi_errno = MPIR_Group_create_map(n, MPI_UNDEFINED, comm_ptr->session_ptr, map,
&comm_ptr->remote_group);
MPIR_ERR_CHECK(mpi_errno);
}
*group_ptr = comm_ptr->remote_group;
MPIR_Group_add_ref(comm_ptr->remote_group);

fn_exit:
Expand Down
15 changes: 8 additions & 7 deletions src/mpi/comm/ulfm_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,22 @@ int MPIR_Comm_get_failed_impl(MPIR_Comm * comm_ptr, MPIR_Group ** failed_group_p
/* create failed_group */
int n = utarray_len(failed_procs);

MPIR_Lpid *map = MPL_malloc(n * sizeof(MPIR_Lpid), MPL_MEM_GROUP);

MPIR_Group *new_group;
mpi_errno = MPIR_Group_create(n, &new_group);
MPIR_ERR_CHECK(mpi_errno);

new_group->rank = MPI_UNDEFINED;
int myrank = MPI_UNDEFINED;
for (int i = 0; i < utarray_len(failed_procs); i++) {
int *p = (int *) utarray_eltptr(failed_procs, i);
new_group->lrank_to_lpid[i].lpid = *p;
map[i] = *p;
/* if calling process is part of the group, set the rank */
if (*p == MPIR_Process.rank) {
new_group->rank = i;
myrank = i;
}
}
new_group->size = n;
new_group->idx_of_first_lpid = -1;

mpi_errno = MPIR_Group_create_map(n, myrank, comm_ptr->session_ptr, map, &new_group);
MPIR_ERR_CHECK(mpi_errno);

MPIR_Group *comm_group;
MPIR_Comm_group_impl(comm_ptr, &comm_group);
Expand Down
Loading