From 4ce0b9929986d36dd127ada547f49400d3e57cb4 Mon Sep 17 00:00:00 2001 From: Shi Jin Date: Tue, 12 Sep 2023 00:05:24 +0000 Subject: [PATCH] prov/efa: Add missing locks in efa_cntr_wait. Currently, efa_cntr_wait call cntr->progress, which finally call efa_ep_progress_internal(). However, efa_ep_progress_internal() must be called inside the srx->lock. This patch fixes this issue. It also adds a check to handle the case where srx_ctx could be NULL; Signed-off-by: Shi Jin --- prov/efa/src/efa_cntr.c | 48 ++++++++++++++++++++++++++++++----------- prov/efa/src/efa_cntr.h | 15 +++++++++++++ 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/prov/efa/src/efa_cntr.c b/prov/efa/src/efa_cntr.c index d70481eb464..3f76d689614 100644 --- a/prov/efa/src/efa_cntr.c +++ b/prov/efa/src/efa_cntr.c @@ -44,6 +44,12 @@ static int efa_cntr_wait(struct fid_cntr *cntr_fid, uint64_t threshold, int time int numtry = 5; int tryid = 0; int waitim = 1; + struct util_srx_ctx *srx_ctx; + + srx_ctx = efa_cntr_get_srx_ctx(cntr_fid); + + if (srx_ctx) + ofi_genlock_lock(srx_ctx->lock); cntr = container_of(cntr_fid, struct util_cntr, cntr_fid); assert(cntr->wait); @@ -52,16 +58,22 @@ static int efa_cntr_wait(struct fid_cntr *cntr_fid, uint64_t threshold, int time for (tryid = 0; tryid < numtry; ++tryid) { cntr->progress(cntr); - if (threshold <= ofi_atomic_get64(&cntr->cnt)) - return FI_SUCCESS; + if (threshold <= ofi_atomic_get64(&cntr->cnt)) { + ret = FI_SUCCESS; + goto unlock; + } - if (errcnt != ofi_atomic_get64(&cntr->err)) - return -FI_EAVAIL; + if (errcnt != ofi_atomic_get64(&cntr->err)) { + ret = -FI_EAVAIL; + goto unlock; + } if (timeout >= 0) { timeout -= (int)(ofi_gettime_ms() - start); - if (timeout <= 0) - return -FI_ETIMEDOUT; + if (timeout <= 0) { + ret = -FI_ETIMEDOUT; + goto unlock; + } } ret = fi_wait(&cntr->wait->wait_fid, waitim); @@ -71,6 +83,9 @@ static int efa_cntr_wait(struct fid_cntr *cntr_fid, uint64_t threshold, int time waitim *= 2; } +unlock: + if (srx_ctx) + ofi_genlock_unlock(srx_ctx->lock); return ret; } @@ -81,13 +96,18 @@ static uint64_t efa_cntr_read(struct fid_cntr *cntr_fid) uint64_t ret; efa_cntr = container_of(cntr_fid, struct efa_cntr, util_cntr.cntr_fid); - srx_ctx = efa_cntr->util_cntr.domain->srx->ep_fid.fid.context; - ofi_genlock_lock(srx_ctx->lock); + srx_ctx = efa_cntr_get_srx_ctx(cntr_fid); + + if (srx_ctx) + ofi_genlock_lock(srx_ctx->lock); + if (efa_cntr->shm_cntr) fi_cntr_read(efa_cntr->shm_cntr); ret = ofi_cntr_read(cntr_fid); - ofi_genlock_unlock(srx_ctx->lock); + + if (srx_ctx) + ofi_genlock_unlock(srx_ctx->lock); return ret; } @@ -99,13 +119,17 @@ static uint64_t efa_cntr_readerr(struct fid_cntr *cntr_fid) uint64_t ret; efa_cntr = container_of(cntr_fid, struct efa_cntr, util_cntr.cntr_fid); - srx_ctx = efa_cntr->util_cntr.domain->srx->ep_fid.fid.context; - ofi_genlock_lock(srx_ctx->lock); + srx_ctx = efa_cntr_get_srx_ctx(cntr_fid); + + if (srx_ctx) + ofi_genlock_lock(srx_ctx->lock); if (efa_cntr->shm_cntr) fi_cntr_read(efa_cntr->shm_cntr); ret = ofi_cntr_readerr(cntr_fid); - ofi_genlock_unlock(srx_ctx->lock); + + if (srx_ctx) + ofi_genlock_unlock(srx_ctx->lock); return ret; } diff --git a/prov/efa/src/efa_cntr.h b/prov/efa/src/efa_cntr.h index 3db456b81a1..4dd1eed800e 100644 --- a/prov/efa/src/efa_cntr.h +++ b/prov/efa/src/efa_cntr.h @@ -52,5 +52,20 @@ void efa_cntr_report_rx_completion(struct util_ep *ep, uint64_t flags); void efa_cntr_report_error(struct util_ep *ep, uint64_t flags); +static inline +void *efa_cntr_get_srx_ctx(struct fid_cntr *cntr_fid) +{ + struct efa_cntr *efa_cntr; + struct fid_peer_srx *srx = NULL; + + efa_cntr = container_of(cntr_fid, struct efa_cntr, util_cntr.cntr_fid); + + srx = efa_cntr->util_cntr.domain->srx; + if (!srx) + return NULL; + + return srx->ep_fid.fid.context; +} + #endif