Skip to content

Commit

Permalink
mpir_mem: simplify MPIR_CHKPMEM macros
Browse files Browse the repository at this point in the history
Reduce the number of trivial parameters to make the MPIR_CHKPMEM macros
easier to use and read.

* Similar to what we did with MPIR_CHKLMEM_MALLOC, but we retain the
class param for MPIR_CHKPMEM_MALLOC.
* Rename MPIR_CHKPMEM_REGISTER to MPIR_CHKPMEM_ADD.
* Always use MPIR_CHKPMEM_MAX (default to 10) slots.
* Remove the type-cast since it adds no value.
* Always use mpi_errno and goto fn_fail
* Just use the "**nomem" message.
* Most usages of MPIR_CHKPMEM_COMMIT is unnecessary before exiting. It
only has effect if there are additional error checking.
  • Loading branch information
hzhou committed Dec 31, 2024
1 parent 6c75ce6 commit e859442
Show file tree
Hide file tree
Showing 41 changed files with 234 additions and 340 deletions.
99 changes: 41 additions & 58 deletions src/include/mpir_mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ extern "C" {
#define MPIR_CHKMEM_SETERR(rc_,nbytes_,name_) rc_=MPI_ERR_OTHER
#endif /* HAVE_ERROR_CHECKING */

/* CHKPMEM_REGISTER is used for memory allocated within another routine */

#define MPIR_CHKLMEM_MAX 10
#define MPIR_CHKLMEM_DECL() \
void *(mpiu_chklmem_stk_[MPIR_CHKLMEM_MAX]) = { NULL }; \
int mpiu_chklmem_stk_sp_=0; \
int mpiu_chklmem_stk_sp_=0;

#define MPIR_CHKLMEM_ADD(pointer_) \
do { \
Expand Down Expand Up @@ -136,67 +135,51 @@ extern "C" {


/* Persistent memory that we may want to recover if something goes wrong */
#define MPIR_CHKPMEM_DECL(n_) \
void *(mpiu_chkpmem_stk_[n_]) = { NULL }; \
int mpiu_chkpmem_stk_sp_=0; \
MPIR_AssertDeclValue(const int mpiu_chkpmem_stk_sz_,n_)
#define MPIR_CHKPMEM_MALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,stmt_) \
{ \
pointer_ = (type_)MPL_malloc(nbytes_,class_); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<mpiu_chkpmem_stk_sz_); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} else if (nbytes_ > 0) { \
MPIR_CHKMEM_SETERR(rc_,nbytes_,name_); \
stmt_; \
} \
}
#define MPIR_CHKPMEM_REGISTER(pointer_) \
{ \
MPIR_Assert(mpiu_chkpmem_stk_sp_<mpiu_chkpmem_stk_sz_); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
}
#define MPIR_CHKPMEM_REAP() \
{ \
while (mpiu_chkpmem_stk_sp_ > 0) { \
MPL_free(mpiu_chkpmem_stk_[--mpiu_chkpmem_stk_sp_]); \
} \
}
#define MPIR_CHKPMEM_COMMIT() \
#define MPIR_CHKPMEM_MAX 10
#define MPIR_CHKPMEM_DECL() \
void *(mpiu_chkpmem_stk_[MPIR_CHKPMEM_MAX]) = { NULL }; \
int mpiu_chkpmem_stk_sp_=0;

#define MPIR_CHKPMEM_MALLOC(pointer_,nbytes_,class_) \
do { \
pointer_ = MPL_malloc(nbytes_,class_); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<MPIR_CHKPMEM_MAX); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} else if (nbytes_ > 0) { \
MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); \
} \
} while (0)

#define MPIR_CHKPMEM_ADD(pointer_) \
do { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<MPIR_CHKPMEM_MAX); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} while (0)

#define MPIR_CHKPMEM_REAP() \
do { \
while (mpiu_chkpmem_stk_sp_ > 0) { \
MPL_free(mpiu_chkpmem_stk_[--mpiu_chkpmem_stk_sp_]); \
} \
} while (0)

/* NOTE: unnecessary to commit if all memory allocations need be freed at fail */
#define MPIR_CHKPMEM_COMMIT() \
mpiu_chkpmem_stk_sp_ = 0
#define MPIR_CHKPMEM_MALLOC(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_MALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_)
#define MPIR_CHKPMEM_MALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_MALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,goto fn_fail)

/* now the CALLOC version for zeroed memory */
#define MPIR_CHKPMEM_CALLOC(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_CALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_)
#define MPIR_CHKPMEM_CALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_CALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,goto fn_fail)
#define MPIR_CHKPMEM_CALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,stmt_) \
do { \
pointer_ = (type_)MPL_calloc(1, (nbytes_), (class_)); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<mpiu_chkpmem_stk_sz_); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} \
else if (nbytes_ > 0) { \
MPIR_CHKMEM_SETERR(rc_,nbytes_,name_); \
stmt_; \
} \
#define MPIR_CHKPMEM_CALLOC(pointer_,nbytes_,class_) \
do { \
pointer_ = MPL_calloc(1, nbytes_, class_); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<MPIR_CHKPMEM_MAX); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} else if (nbytes_ > 0) { \
MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); \
} \
} while (0)

/* A special version for routines that only allocate one item */
#define MPIR_CHKPMEM_MALLOC1(pointer_,type_,nbytes_,rc_,name_,class_,stmt_) \
{ \
pointer_ = (type_)MPL_malloc(nbytes_,class_); \
if (!(pointer_) && (nbytes_ > 0)) { \
MPIR_CHKMEM_SETERR(rc_,nbytes_,name_); \
stmt_; \
} \
}

/* Provides a easy way to use realloc safely and avoid the temptation to use
* realloc unsafely (direct ptr assignment). Zero-size reallocs returning NULL
* are handled and are not considered an error. */
Expand Down
5 changes: 2 additions & 3 deletions src/mpi/comm/comm_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,15 @@ int MPII_Comm_create_calculate_mapping(MPIR_Group * group_ptr,
int i, j;
int n;
int *mapping = 0;
MPIR_CHKPMEM_DECL(1);
MPIR_CHKPMEM_DECL();

MPIR_FUNC_ENTER;

*mapping_out = NULL;
*mapping_comm = comm_ptr;

n = group_ptr->size;
MPIR_CHKPMEM_MALLOC(mapping, int *, n * sizeof(int), mpi_errno, "mapping", MPL_MEM_ADDRESS);
MPIR_CHKPMEM_MALLOC(mapping, n * sizeof(int), MPL_MEM_COMM);

/* Make sure that the processes for this group are contained within
* the input communicator. Also identify the mapping from the ranks of
Expand Down Expand Up @@ -275,7 +275,6 @@ int MPII_Comm_create_calculate_mapping(MPIR_Group * group_ptr,
*mapping_out = mapping;
MPL_VG_CHECK_MEM_IS_DEFINED(*mapping_out, n * sizeof(**mapping_out));

MPIR_CHKPMEM_COMMIT();
fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
Expand Down
21 changes: 7 additions & 14 deletions src/mpi/comm/commutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,11 @@ int MPIR_Comm_map_irregular(MPIR_Comm * newcomm, MPIR_Comm * src_comm,
{
int mpi_errno = MPI_SUCCESS;
MPIR_Comm_map_t *mapper;
MPIR_CHKPMEM_DECL(3);
MPIR_CHKPMEM_DECL();

MPIR_FUNC_ENTER;

MPIR_CHKPMEM_MALLOC(mapper, MPIR_Comm_map_t *, sizeof(MPIR_Comm_map_t), mpi_errno, "mapper",
MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(mapper, sizeof(MPIR_Comm_map_t), MPL_MEM_COMM);

mapper->type = MPIR_COMM_MAP_TYPE__IRREGULAR;
mapper->src_comm = src_comm;
Expand All @@ -446,9 +445,7 @@ int MPIR_Comm_map_irregular(MPIR_Comm * newcomm, MPIR_Comm * src_comm,
mapper->src_mapping = src_mapping;
mapper->free_mapping = 0;
} else {
MPIR_CHKPMEM_MALLOC(mapper->src_mapping, int *,
src_mapping_size * sizeof(int), mpi_errno, "mapper mapping",
MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(mapper->src_mapping, src_mapping_size * sizeof(int), MPL_MEM_COMM);
mapper->free_mapping = 1;
}

Expand All @@ -460,7 +457,6 @@ int MPIR_Comm_map_irregular(MPIR_Comm * newcomm, MPIR_Comm * src_comm,
*map = mapper;

fn_exit:
MPIR_CHKPMEM_COMMIT();
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
Expand All @@ -472,12 +468,11 @@ int MPIR_Comm_map_dup(MPIR_Comm * newcomm, MPIR_Comm * src_comm, MPIR_Comm_map_d
{
int mpi_errno = MPI_SUCCESS;
MPIR_Comm_map_t *mapper;
MPIR_CHKPMEM_DECL(1);
MPIR_CHKPMEM_DECL();

MPIR_FUNC_ENTER;

MPIR_CHKPMEM_MALLOC(mapper, MPIR_Comm_map_t *, sizeof(MPIR_Comm_map_t), mpi_errno, "mapper",
MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(mapper, sizeof(MPIR_Comm_map_t), MPL_MEM_COMM);

mapper->type = MPIR_COMM_MAP_TYPE__DUP;
mapper->src_comm = src_comm;
Expand All @@ -488,7 +483,6 @@ int MPIR_Comm_map_dup(MPIR_Comm * newcomm, MPIR_Comm * src_comm, MPIR_Comm_map_d
LL_APPEND(newcomm->mapper_head, newcomm->mapper_tail, mapper);

fn_exit:
MPIR_CHKPMEM_COMMIT();
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
Expand Down Expand Up @@ -1326,7 +1320,7 @@ int MPII_Comm_is_node_balanced(MPIR_Comm * comm, int *num_nodes, bool * node_bal
int *ranks_per_node;
*num_nodes = 0;

MPIR_CHKPMEM_DECL(1);
MPIR_CHKPMEM_DECL();

if (!MPIR_Comm_is_parent_comm(comm)) {
*node_balanced = false;
Expand All @@ -1342,8 +1336,7 @@ int MPII_Comm_is_node_balanced(MPIR_Comm * comm, int *num_nodes, bool * node_bal
/* number of nodes is max_node_id + 1 */
(*num_nodes)++;

MPIR_CHKPMEM_CALLOC(ranks_per_node, int *,
*num_nodes * sizeof(int), mpi_errno, "ranks per node", MPL_MEM_OTHER);
MPIR_CHKPMEM_CALLOC(ranks_per_node, *num_nodes * sizeof(int), MPL_MEM_OTHER);

for (i = 0; i < comm->local_size; i++) {
ranks_per_node[comm->internode_table[i]]++;
Expand Down
6 changes: 2 additions & 4 deletions src/mpi/comm/contextid.c
Original file line number Diff line number Diff line change
Expand Up @@ -890,10 +890,9 @@ static int sched_get_cid_nonblock(MPIR_Comm * comm_ptr, MPIR_Comm * newcomm,
{
int mpi_errno = MPI_SUCCESS;
struct gcn_state *st = NULL;
MPIR_CHKPMEM_DECL(1);
MPIR_CHKPMEM_DECL();

MPIR_CHKPMEM_MALLOC(st, struct gcn_state *, sizeof(struct gcn_state), mpi_errno, "gcn_state",
MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(st, sizeof(struct gcn_state), MPL_MEM_COMM);
st->ctx0 = ctx0;
st->ctx1 = ctx1;
if (gcn_cid_kind == MPIR_COMM_KIND__INTRACOMM) {
Expand Down Expand Up @@ -921,7 +920,6 @@ static int sched_get_cid_nonblock(MPIR_Comm * comm_ptr, MPIR_Comm * newcomm,
MPIR_ERR_CHECK(mpi_errno);
MPIR_SCHED_BARRIER(s);

MPIR_CHKPMEM_COMMIT();
fn_exit:
return mpi_errno;
/* --BEGIN ERROR HANDLING-- */
Expand Down
7 changes: 3 additions & 4 deletions src/mpi/request/mpir_greq.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ int MPIR_Grequest_start_impl(MPI_Grequest_query_function * query_fn,
void *extra_state, MPIR_Request ** request_ptr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_CHKPMEM_DECL(1);
MPIR_CHKPMEM_DECL();

/* MT FIXME this routine is not thread-safe in the non-global case */

Expand All @@ -68,8 +68,8 @@ int MPIR_Grequest_start_impl(MPI_Grequest_query_function * query_fn,
(*request_ptr)->cc_ptr = &(*request_ptr)->cc;
MPIR_cc_set((*request_ptr)->cc_ptr, 1);
(*request_ptr)->comm = NULL;
MPIR_CHKPMEM_MALLOC((*request_ptr)->u.ureq.greq_fns, struct MPIR_Grequest_fns *,
sizeof(struct MPIR_Grequest_fns), mpi_errno, "greq_fns", MPL_MEM_GREQ);
MPIR_CHKPMEM_MALLOC((*request_ptr)->u.ureq.greq_fns, sizeof(struct MPIR_Grequest_fns),
MPL_MEM_GREQ);
(*request_ptr)->u.ureq.greq_fns->U.C.cancel_fn = cancel_fn;
(*request_ptr)->u.ureq.greq_fns->U.C.free_fn = free_fn;
(*request_ptr)->u.ureq.greq_fns->U.C.query_fn = query_fn;
Expand All @@ -83,7 +83,6 @@ int MPIR_Grequest_start_impl(MPI_Grequest_query_function * query_fn,
* we test or wait on it. */
MPIR_Request_add_ref((*request_ptr));

MPIR_CHKPMEM_COMMIT();
fn_exit:
return mpi_errno;
fn_fail:
Expand Down
22 changes: 8 additions & 14 deletions src/mpi/topo/dist_graph_create.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr,

int comm_size = comm_ptr->local_size;
MPIR_CHKLMEM_DECL();
MPIR_CHKPMEM_DECL(1);
MPIR_CHKPMEM_DECL();

/* following the spirit of the old topo interface, attributes do not
* propagate to the new communicator (see MPI-2.1 pp. 243 line 11) */
Expand Down Expand Up @@ -159,8 +159,7 @@ int MPIR_Dist_graph_create_impl(MPIR_Comm * comm_ptr,

/* Create the topology structure */
MPIR_Topology *topo_ptr = NULL;
MPIR_CHKPMEM_MALLOC(topo_ptr, MPIR_Topology *, sizeof(MPIR_Topology), mpi_errno, "topo_ptr",
MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(topo_ptr, sizeof(MPIR_Topology), MPL_MEM_COMM);
topo_ptr->kind = MPI_DIST_GRAPH;
dist_graph_ptr = &topo_ptr->topo.dist_graph;
dist_graph_ptr->indegree = 0;
Expand Down Expand Up @@ -300,7 +299,7 @@ int MPIR_Dist_graph_create_adjacent_impl(MPIR_Comm * comm_ptr,
MPIR_Comm ** comm_dist_graph_ptr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_CHKPMEM_DECL(5);
MPIR_CHKPMEM_DECL();

/* Implementation based on Torsten Hoefler's reference implementation
* attached to MPI-2.2 ticket #33. */
Expand All @@ -312,8 +311,7 @@ int MPIR_Dist_graph_create_adjacent_impl(MPIR_Comm * comm_ptr,

/* Create the topology structure */
MPIR_Topology *topo_ptr;
MPIR_CHKPMEM_MALLOC(topo_ptr, MPIR_Topology *, sizeof(MPIR_Topology), mpi_errno, "topo_ptr",
MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(topo_ptr, sizeof(MPIR_Topology), MPL_MEM_COMM);
topo_ptr->kind = MPI_DIST_GRAPH;
MPII_Dist_graph_topology *dist_graph_ptr = &topo_ptr->topo.dist_graph;
dist_graph_ptr->indegree = indegree;
Expand All @@ -325,23 +323,19 @@ int MPIR_Dist_graph_create_adjacent_impl(MPIR_Comm * comm_ptr,
dist_graph_ptr->is_weighted = (sourceweights != MPI_UNWEIGHTED);

if (indegree > 0) {
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->in, int *, indegree * sizeof(int), mpi_errno,
"dist_graph_ptr->in", MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->in, indegree * sizeof(int), MPL_MEM_COMM);
MPIR_Memcpy(dist_graph_ptr->in, sources, indegree * sizeof(int));
if (dist_graph_ptr->is_weighted) {
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->in_weights, int *, indegree * sizeof(int),
mpi_errno, "dist_graph_ptr->in_weights", MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->in_weights, indegree * sizeof(int), MPL_MEM_COMM);
MPIR_Memcpy(dist_graph_ptr->in_weights, sourceweights, indegree * sizeof(int));
}
}

if (outdegree > 0) {
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->out, int *, outdegree * sizeof(int), mpi_errno,
"dist_graph_ptr->out", MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->out, outdegree * sizeof(int), MPL_MEM_COMM);
MPIR_Memcpy(dist_graph_ptr->out, destinations, outdegree * sizeof(int));
if (dist_graph_ptr->is_weighted) {
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->out_weights, int *, outdegree * sizeof(int),
mpi_errno, "dist_graph_ptr->out_weights", MPL_MEM_COMM);
MPIR_CHKPMEM_MALLOC(dist_graph_ptr->out_weights, outdegree * sizeof(int), MPL_MEM_COMM);
MPIR_Memcpy(dist_graph_ptr->out_weights, destweights, outdegree * sizeof(int));
}
}
Expand Down
Loading

0 comments on commit e859442

Please sign in to comment.