Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpir_mem: simplify MPIR_CHKLMEM_ macros #7249

Merged
merged 4 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
===============================================================================
Changes in 5.0
===============================================================================
# MPIR_CHKLMEM_ and MPIR_CHKPMEM_ macros are simplified, removing non-essential
argument such as type case and custom error messages.

===============================================================================
Changes in 4.3
===============================================================================
Expand Down
141 changes: 63 additions & 78 deletions src/include/mpir_mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,98 +103,83 @@ extern "C" {
#define MPIR_CHKMEM_SETERR(rc_,nbytes_,name_) rc_=MPI_ERR_OTHER
#endif /* HAVE_ERROR_CHECKING */

/* CHKPMEM_REGISTER is used for memory allocated within another routine */

#define MPIR_CHKLMEM_DECL(n_) \
void *(mpiu_chklmem_stk_[n_]) = { NULL }; \
int mpiu_chklmem_stk_sp_=0; \
MPIR_AssertDeclValue(const int mpiu_chklmem_stk_sz_,n_)

#define MPIR_CHKLMEM_MALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,stmt_) \
{ \
pointer_ = (type_)MPL_malloc(nbytes_,class_); \
if (pointer_) { \
MPIR_Assert(mpiu_chklmem_stk_sp_<mpiu_chklmem_stk_sz_); \
mpiu_chklmem_stk_[mpiu_chklmem_stk_sp_++] = (void *) pointer_; \
} else if (nbytes_ > 0) { \
MPIR_CHKMEM_SETERR(rc_,nbytes_,name_); \
stmt_; \
} \
}

#define MPIR_CHKLMEM_MAX 10
#define MPIR_CHKLMEM_DECL() \
void *(mpiu_chklmem_stk_[MPIR_CHKLMEM_MAX]) = { NULL }; \
int mpiu_chklmem_stk_sp_=0;

#define MPIR_CHKLMEM_REGISTER(pointer_) \
do { \
MPIR_Assert(mpiu_chklmem_stk_sp_<MPIR_CHKLMEM_MAX); \
mpiu_chklmem_stk_[mpiu_chklmem_stk_sp_++] = pointer_; \
} while (0)

#define MPIR_CHKLMEM_MALLOC(pointer_,nbytes_) \
do { \
pointer_ = MPL_malloc(nbytes_,MPL_MEM_LOCAL); \
if (pointer_) { \
MPIR_Assert(mpiu_chklmem_stk_sp_<MPIR_CHKLMEM_MAX); \
mpiu_chklmem_stk_[mpiu_chklmem_stk_sp_++] = pointer_; \
} else if (nbytes_ > 0) { \
MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); \
} \
} while (0)

#define MPIR_CHKLMEM_FREEALL() \
do { \
while (mpiu_chklmem_stk_sp_ > 0) { \
MPL_free(mpiu_chklmem_stk_[--mpiu_chklmem_stk_sp_]); \
} \
} while (0)

#define MPIR_CHKLMEM_MALLOC(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKLMEM_MALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_)
#define MPIR_CHKLMEM_MALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKLMEM_MALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,goto fn_fail)

/* Persistent memory that we may want to recover if something goes wrong */
#define MPIR_CHKPMEM_DECL(n_) \
void *(mpiu_chkpmem_stk_[n_]) = { NULL }; \
int mpiu_chkpmem_stk_sp_=0; \
MPIR_AssertDeclValue(const int mpiu_chkpmem_stk_sz_,n_)
#define MPIR_CHKPMEM_MALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,stmt_) \
{ \
pointer_ = (type_)MPL_malloc(nbytes_,class_); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<mpiu_chkpmem_stk_sz_); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} else if (nbytes_ > 0) { \
MPIR_CHKMEM_SETERR(rc_,nbytes_,name_); \
stmt_; \
} \
}
#define MPIR_CHKPMEM_REGISTER(pointer_) \
{ \
MPIR_Assert(mpiu_chkpmem_stk_sp_<mpiu_chkpmem_stk_sz_); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
}
#define MPIR_CHKPMEM_REAP() \
{ \
while (mpiu_chkpmem_stk_sp_ > 0) { \
MPL_free(mpiu_chkpmem_stk_[--mpiu_chkpmem_stk_sp_]); \
} \
}
#define MPIR_CHKPMEM_COMMIT() \
#define MPIR_CHKPMEM_MAX 10
#define MPIR_CHKPMEM_DECL() \
void *(mpiu_chkpmem_stk_[MPIR_CHKPMEM_MAX]) = { NULL }; \
int mpiu_chkpmem_stk_sp_=0;

#define MPIR_CHKPMEM_MALLOC(pointer_,nbytes_,class_) \
do { \
pointer_ = MPL_malloc(nbytes_,class_); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<MPIR_CHKPMEM_MAX); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} else if (nbytes_ > 0) { \
MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); \
} \
} while (0)

#define MPIR_CHKPMEM_REGISTER(pointer_) \
do { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<MPIR_CHKPMEM_MAX); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} while (0)

#define MPIR_CHKPMEM_REAP() \
do { \
while (mpiu_chkpmem_stk_sp_ > 0) { \
MPL_free(mpiu_chkpmem_stk_[--mpiu_chkpmem_stk_sp_]); \
} \
} while (0)

/* NOTE: unnecessary to commit if all memory allocations need be freed at fail */
#define MPIR_CHKPMEM_COMMIT() \
mpiu_chkpmem_stk_sp_ = 0
#define MPIR_CHKPMEM_MALLOC(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_MALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_)
#define MPIR_CHKPMEM_MALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_MALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,goto fn_fail)

/* now the CALLOC version for zeroed memory */
#define MPIR_CHKPMEM_CALLOC(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_CALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_)
#define MPIR_CHKPMEM_CALLOC_ORJUMP(pointer_,type_,nbytes_,rc_,name_,class_) \
MPIR_CHKPMEM_CALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,goto fn_fail)
#define MPIR_CHKPMEM_CALLOC_ORSTMT(pointer_,type_,nbytes_,rc_,name_,class_,stmt_) \
do { \
pointer_ = (type_)MPL_calloc(1, (nbytes_), (class_)); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<mpiu_chkpmem_stk_sz_); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} \
else if (nbytes_ > 0) { \
MPIR_CHKMEM_SETERR(rc_,nbytes_,name_); \
stmt_; \
} \
#define MPIR_CHKPMEM_CALLOC(pointer_,nbytes_,class_) \
do { \
pointer_ = MPL_calloc(1, nbytes_, class_); \
if (pointer_) { \
MPIR_Assert(mpiu_chkpmem_stk_sp_<MPIR_CHKPMEM_MAX); \
mpiu_chkpmem_stk_[mpiu_chkpmem_stk_sp_++] = pointer_; \
} else if (nbytes_ > 0) { \
MPIR_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**nomem"); \
} \
} while (0)

/* A special version for routines that only allocate one item */
#define MPIR_CHKPMEM_MALLOC1(pointer_,type_,nbytes_,rc_,name_,class_,stmt_) \
{ \
pointer_ = (type_)MPL_malloc(nbytes_,class_); \
if (!(pointer_) && (nbytes_ > 0)) { \
MPIR_CHKMEM_SETERR(rc_,nbytes_,name_); \
stmt_; \
} \
}

/* Provides a easy way to use realloc safely and avoid the temptation to use
* realloc unsafely (direct ptr assignment). Zero-size reallocs returning NULL
* are handled and are not considered an error. */
Expand Down
8 changes: 3 additions & 5 deletions src/mpi/coll/algorithms/recexchalgo/recexchalgo.c
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ int MPII_Recexchalgo_reverse_digits_step2(int rank, int comm_size, int k)
int pofk = 1, log_pofk = 0;
int *digit, *digit_reverse;
int mpi_errno ATTRIBUTE((unused)) = MPI_SUCCESS;
MPIR_CHKLMEM_DECL(2);
MPIR_CHKLMEM_DECL();

MPIR_FUNC_ENTER;

Expand All @@ -350,10 +350,8 @@ int MPII_Recexchalgo_reverse_digits_step2(int rank, int comm_size, int k)
step2rank = MPII_Recexchalgo_origrank_to_step2rank(rank, rem, T, k);

/* calculate the digits in base k representation of step2rank */
MPIR_CHKLMEM_MALLOC(digit, int *, sizeof(int) * log_pofk,
mpi_errno, "digit buffer", MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(digit_reverse, int *, sizeof(int) * log_pofk,
mpi_errno, "digit_reverse buffer", MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(digit, sizeof(int) * log_pofk);
MPIR_CHKLMEM_MALLOC(digit_reverse, sizeof(int) * log_pofk);
for (i = 0; i < log_pofk; i++)
digit[i] = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint
void *tmp_buf = NULL;
MPIR_Comm *newcomm_ptr = NULL;

MPIR_CHKLMEM_DECL(1);
MPIR_CHKLMEM_DECL();

local_size = comm_ptr->local_size;
remote_size = comm_ptr->remote_size;
Expand All @@ -32,8 +32,7 @@ int MPIR_Allgather_inter_local_gather_remote_bcast(const void *sendbuf, MPI_Aint
/* In each group, rank 0 allocates temp. buffer for local
* gather */
MPIR_Datatype_get_size_macro(sendtype, sendtype_sz);
MPIR_CHKLMEM_MALLOC(tmp_buf, void *, sendcount * sendtype_sz * local_size, mpi_errno,
"tmp_buf", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(tmp_buf, sendcount * sendtype_sz * local_size);
} else {
/* silence -Wmaybe-uninitialized due to MPIR_{Gather,Bcast} calls by non-zero ranks */
sendtype_sz = 0;
Expand Down
5 changes: 2 additions & 3 deletions src/mpi/coll/allgather/allgather_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
void *tmp_buf = NULL;
int dst;

MPIR_CHKLMEM_DECL(1);
MPIR_CHKLMEM_DECL();

if (((sendcount == 0) && (sendbuf != MPI_IN_PLACE)) || (recvcount == 0))
goto fn_exit;
Expand All @@ -39,8 +39,7 @@ int MPIR_Allgather_intra_brucks(const void *sendbuf,
MPIR_Datatype_get_size_macro(recvtype, recvtype_sz);

/* allocate a temporary buffer of the same size as recvbuf. */
MPIR_CHKLMEM_MALLOC(tmp_buf, void *, recvcount * comm_size * recvtype_sz, mpi_errno,
"tmp_buf", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(tmp_buf, recvcount * comm_size * recvtype_sz);

/* copy local data to the top of tmp_buf */
if (sendbuf != MPI_IN_PLACE) {
Expand Down
5 changes: 2 additions & 3 deletions src/mpi/coll/allgather/allgather_intra_k_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ MPIR_Allgather_intra_k_brucks(const void *sendbuf, MPI_Aint sendcount,

int delta = 1;
void *tmp_recvbuf = NULL;
MPIR_CHKLMEM_DECL(2);
MPIR_CHKLMEM_MALLOC(reqs, MPIR_Request **, (2 * (k - 1) * sizeof(MPIR_Request *)), mpi_errno,
"reqs", MPL_MEM_BUFFER);
MPIR_CHKLMEM_DECL();
MPIR_CHKLMEM_MALLOC(reqs, (2 * (k - 1) * sizeof(MPIR_Request *)));

MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST,
"Allgather_brucks_radix_k algorithm: num_ranks: %d, k: %d",
Expand Down
5 changes: 2 additions & 3 deletions src/mpi/coll/allgatherv/allgatherv_intra_brucks.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf,
int pof2, src, dst, rem;
MPI_Aint curr_cnt, send_cnt, recv_cnt, total_count;
void *tmp_buf;
MPIR_CHKLMEM_DECL(1);
MPIR_CHKLMEM_DECL();

MPIR_THREADCOMM_RANK_SIZE(comm_ptr, rank, comm_size);

Expand All @@ -48,8 +48,7 @@ int MPIR_Allgatherv_intra_brucks(const void *sendbuf,
/* allocate a temporary buffer that can hold all the data */
MPIR_Datatype_get_size_macro(recvtype, recvtype_sz);

MPIR_CHKLMEM_MALLOC(tmp_buf, void *, total_count * recvtype_sz, mpi_errno, "tmp_buf",
MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(tmp_buf, total_count * recvtype_sz);

/* copy local data to the top of tmp_buf */
if (sendbuf != MPI_IN_PLACE) {
Expand Down
5 changes: 2 additions & 3 deletions src/mpi/coll/allgatherv/allgatherv_intra_recursive_doubling.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf,
void *tmp_buf;
int mask, dst_tree_root, my_tree_root, nprocs_completed, k, tmp_mask, tree_root;
MPI_Aint position, send_offset, recv_offset, offset;
MPIR_CHKLMEM_DECL(1);
MPIR_CHKLMEM_DECL();

comm_size = comm_ptr->local_size;
rank = comm_ptr->rank;
Expand All @@ -58,8 +58,7 @@ int MPIR_Allgatherv_intra_recursive_doubling(const void *sendbuf,
MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent);
MPIR_Datatype_get_size_macro(recvtype, recvtype_sz);

MPIR_CHKLMEM_MALLOC(tmp_buf, void *,
total_count * recvtype_sz, mpi_errno, "tmp_buf", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(tmp_buf, total_count * recvtype_sz);

/* copy local data into right location in tmp_buf */
position = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@ int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbu
MPI_Datatype datatype, MPI_Op op,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
{
int mpi_errno;
int mpi_errno = MPI_SUCCESS;
MPI_Aint true_extent, true_lb, extent;
void *tmp_buf = NULL;
MPIR_Comm *newcomm_ptr = NULL;
MPIR_CHKLMEM_DECL(1);
MPIR_CHKLMEM_DECL();

if (comm_ptr->rank == 0) {
MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
MPIR_Datatype_get_extent_macro(datatype, extent);
MPIR_CHKLMEM_MALLOC(tmp_buf, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno,
"temporary buffer", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(tmp_buf, count * (MPL_MAX(extent, true_extent)));
/* adjust for potential negative lower bound in datatype */
tmp_buf = (void *) ((char *) tmp_buf - true_lb);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf,
MPIR_Request **send_reqs = NULL, **recv_reqs = NULL;
int num_sreq = 0, num_rreq = 0, total_phases = 0;
void *tmp_recvbuf = NULL;
MPIR_CHKLMEM_DECL(2);
MPIR_CHKLMEM_DECL();

MPIR_Assert(k > 1);

Expand Down Expand Up @@ -123,10 +123,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf,
/* Main recursive exchange step */
if (in_step2) {
MPI_Aint *cnts = NULL, *displs = NULL;
MPIR_CHKLMEM_MALLOC(cnts, MPI_Aint *, sizeof(MPI_Aint) * nranks, mpi_errno, "cnts",
MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(displs, MPI_Aint *, sizeof(MPI_Aint) * nranks, mpi_errno, "displs",
MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(cnts, sizeof(MPI_Aint) * nranks);
MPIR_CHKLMEM_MALLOC(displs, sizeof(MPI_Aint) * nranks);
idx = 0;
rem = nranks - p_of_k;

Expand Down
5 changes: 2 additions & 3 deletions src/mpi/coll/allreduce/allreduce_intra_recursive_doubling.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf,
MPI_Datatype datatype,
MPI_Op op, MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag)
{
MPIR_CHKLMEM_DECL(1);
MPIR_CHKLMEM_DECL();
int comm_size, rank;
int mpi_errno = MPI_SUCCESS;
int mask, dst, is_commutative, pof2, newrank, rem, newdst;
Expand All @@ -39,8 +39,7 @@ int MPIR_Allreduce_intra_recursive_doubling(const void *sendbuf,
MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
MPIR_Datatype_get_extent_macro(datatype, extent);

MPIR_CHKLMEM_MALLOC(tmp_buf, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno,
"temporary buffer", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(tmp_buf, count * (MPL_MAX(extent, true_extent)));

/* adjust for potential negative lower bound in datatype */
tmp_buf = (void *) ((char *) tmp_buf - true_lb);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf,
pofk *= k;
}

MPIR_CHKLMEM_DECL(2);
MPIR_CHKLMEM_DECL();
void *tmp_buf;

/*Allocate for nb requests */
MPIR_Request **reqs;
int num_reqs = 0;
MPIR_CHKLMEM_MALLOC(reqs, MPIR_Request **, (2 * (k - 1) * sizeof(MPIR_Request *)), mpi_errno,
"reqs", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(reqs, (2 * (k - 1) * sizeof(MPIR_Request *)));

/* need to allocate temporary buffer to store incoming data */
MPI_Aint true_extent, true_lb, extent;
Expand All @@ -65,8 +64,7 @@ int MPIR_Allreduce_intra_recursive_multiplying(const void *sendbuf,
}
}

MPIR_CHKLMEM_MALLOC(tmp_buf, void *, (k - 1) * single_size, mpi_errno,
"temporary buffer", MPL_MEM_BUFFER);
MPIR_CHKLMEM_MALLOC(tmp_buf, (k - 1) * single_size);

/* adjust for potential negative lower bound in datatype */
tmp_buf = (void *) ((char *) tmp_buf - true_lb);
Expand Down
Loading