From f4a9fb72145fa658a842437abdd1cb5559296db6 Mon Sep 17 00:00:00 2001 From: Shi Jin Date: Thu, 2 Jan 2025 18:32:07 +0000 Subject: [PATCH] prov/efa: Add missing locks in efa_msg and efa_rma efa_post_send, efa_post_write, efa_post_read accesses base_ep->is_wr_started bool which needs to be protected by a lock, otherwise there will be a race condition when multiple threads to call them. Same issue with efa_post_recv which accesses the recv_wr_index This patch adds the required locking to protect these resources. This lock is a no-op unless FI_THREAD_SAFE. Signed-off-by: Shi Jin --- prov/efa/src/efa_msg.c | 21 +++++++++++++++------ prov/efa/src/efa_rma.c | 17 ++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/prov/efa/src/efa_msg.c b/prov/efa/src/efa_msg.c index 7920afbf531..fbd4adb2bd9 100644 --- a/prov/efa/src/efa_msg.c +++ b/prov/efa/src/efa_msg.c @@ -67,10 +67,12 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi struct ibv_recv_wr *wr; uintptr_t addr; ssize_t err, post_recv_err; - size_t i, wr_index = base_ep->recv_wr_index; + size_t i, wr_index; efa_tracepoint(recv_begin_msg_context, (size_t) msg->context, (size_t) msg->addr); + ofi_genlock_lock(&base_ep->util_ep.lock); + wr_index = base_ep->recv_wr_index; if (wr_index >= base_ep->info->rx_attr->size) { EFA_INFO(FI_LOG_EP_DATA, "recv_wr_index exceeds the rx limit, " @@ -118,8 +120,10 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi base_ep->recv_wr_index++; - if (flags & FI_MORE) - return 0; + if (flags & FI_MORE) { + err = 0; + goto out; + } efa_tracepoint(post_recv, wr->wr_id, (uintptr_t)msg->context); @@ -134,6 +138,9 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi base_ep->recv_wr_index = 0; +out: + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; out_err: @@ -148,6 +155,8 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi base_ep->recv_wr_index = 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; } @@ -209,6 +218,7 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi assert(len <= base_ep->info->ep_attr->max_msg_size); + ofi_genlock_lock(&base_ep->util_ep.lock); if (!base_ep->is_wr_started) { ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; @@ -260,10 +270,9 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi ret = ibv_wr_complete(qp->ibv_qp_ex); base_ep->is_wr_started = false; } - if (OFI_UNLIKELY(ret)) - return ret; - return 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return ret; } static ssize_t efa_ep_sendmsg(struct fid_ep *ep_fid, const struct fi_msg *msg, uint64_t flags) diff --git a/prov/efa/src/efa_rma.c b/prov/efa/src/efa_rma.c index a7bad7d3877..052e2aa89d7 100644 --- a/prov/efa/src/efa_rma.c +++ b/prov/efa/src/efa_rma.c @@ -83,6 +83,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep, base_ep->domain->device->max_rdma_size); qp = base_ep->qp; + + ofi_genlock_lock(&base_ep->util_ep.lock); + if (!base_ep->is_wr_started) { ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; @@ -113,10 +116,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep, err = ibv_wr_complete(qp->ibv_qp_ex); base_ep->is_wr_started = false; } - if (OFI_UNLIKELY(err)) - return err; - return 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; } static @@ -212,6 +214,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep, efa_tracepoint(write_begin_msg_context, (size_t) msg->context, (size_t) msg->addr); qp = base_ep->qp; + + ofi_genlock_lock(&base_ep->util_ep.lock); + if (!base_ep->is_wr_started) { ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; @@ -256,10 +261,8 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep, base_ep->is_wr_started = false; } - if (OFI_UNLIKELY(err)) - return err; - - return 0; + ofi_genlock_unlock(&base_ep->util_ep.lock); + return err; } ssize_t efa_rma_writemsg(struct fid_ep *ep_fid, const struct fi_msg_rma *msg,