From 6faa202fce2069534def6912a81771a20ab9222d Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Fri, 23 Feb 2024 00:09:08 +0000 Subject: [PATCH] rdma: defer connect completion after sending connect message In the current implementation of connect/accept, it is possible for `accept` to complete (i.e., return a non-NULL communicator) after the corresponding `connect` returned a NULL communicator (while waiting for a completion for the connection message). This is a strange semantic, and evidently causes NCCL to be unhappy, particularly in the multi-recv case (which is being added in a future commit). So, after sending the connect message, defer waiting for completion; block when closing the send comm if necessary. Signed-off-by: Eric Raut --- include/nccl_ofi_rdma.h | 6 ++++ src/nccl_ofi_rdma.c | 78 ++++++++++++++++++++--------------------- 2 files changed, 44 insertions(+), 40 deletions(-) diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 36b7d4318..2be54a40f 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -343,6 +343,12 @@ typedef struct nccl_net_ofi_rdma_send_comm { /* Comm ID provided by remote endpoint */ uint64_t remote_comm_id; + /* Request to send connect message */ + nccl_net_ofi_rdma_req_t *send_conn_req; + + /* Indicates if connect message was delivered (and req freed) */ + bool connect_msg_delivered; + /* Request to receive connect response message to finalize * connection establishment */ nccl_net_ofi_rdma_req_t *conn_resp_req; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 451170e79..e55b6ff18 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -1428,7 +1428,18 @@ static inline int process_completions(struct fi_cq_tagged_entry *cq_entry, return ncclInternalError; } - if (IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag) && (comp_flags & FI_RECV)) { + if (req->type == NCCL_OFI_RDMA_SEND_CONN) { + assert(req->comm->type == NCCL_NET_OFI_SEND_COMM); + nccl_net_ofi_rdma_send_comm_t *s_comm = + (nccl_net_ofi_rdma_send_comm_t *)req->comm; + assert(req == s_comm->send_conn_req); + /* Release connect message request */ + req->free(req, false); + req = NULL; + s_comm->send_conn_req = NULL; + __sync_synchronize(); + s_comm->connect_msg_delivered = true; + } else if (IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag) && (comp_flags & FI_RECV)) { assert(req->comm->type == NCCL_NET_OFI_SEND_COMM); /* Complete send communicator */ nccl_net_ofi_rdma_send_comm_t *s_comm = @@ -4821,8 +4832,9 @@ static int blocked_send_close(nccl_net_ofi_send_comm_t *send_comm) return ncclInternalError; } - // TODO: We might want to use READ_ONCE to read variable `connected' - while (!s_comm->connected) { + // TODO: We might want to use READ_ONCE to read variables + // `connect_msg_delivered` and `connected' + while (!s_comm->connect_msg_delivered || !s_comm->connected) { __compiler_barrier(); int ret = 0; /* Progress our engine to get completions. If the @@ -5229,14 +5241,12 @@ static int connect(nccl_net_ofi_ep_t *base_ep, nccl_net_ofi_send_comm_t **send_comm) { int ret = 0; - nccl_net_ofi_rdma_req_state_t conn_msg_state; *send_comm = NULL; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; /* Extract connection state of the communicator */ save_comm_state_t *comm_state = &(handle->state); - nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)comm_state->req; nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)comm_state->comm; @@ -5276,23 +5286,22 @@ static int connect(nccl_net_ofi_ep_t *base_ep, comm_state->comm = &s_comm->base.base; /* Prepare connect request to be sent to peer */ - req = prepare_send_conn_req(s_comm); - if (OFI_UNLIKELY(req == NULL)) { + s_comm->send_conn_req = prepare_send_conn_req(s_comm); + if (OFI_UNLIKELY(s_comm->send_conn_req == NULL)) { send_close(s_comm); return ncclSystemError; } - comm_state->req = &req->base; comm_state->stage = COMM_SEND_CONN; case COMM_SEND_CONN: /* COMM_SEND_CONN: Post a connect message to send peer connections */ - ret = post_send_conn(s_comm, device, ep, req); + ret = post_send_conn(s_comm, device, ep, s_comm->send_conn_req); if (ret == -FI_EAGAIN) { return 0; } else if (ret != 0) { - req->free(req, false); + s_comm->send_conn_req->free(s_comm->send_conn_req, false); send_close(s_comm); return ret; } @@ -5313,29 +5322,6 @@ static int connect(nccl_net_ofi_ep_t *base_ep, return ret; } - /* Check if the connect message is sent */ - ret = pthread_mutex_lock(&req->req_lock); - if (OFI_UNLIKELY(ret)) { - NCCL_OFI_WARN("Unable to acquire req_lock mutex"); - return ncclInternalError; - } - conn_msg_state = req->state; - ret = pthread_mutex_unlock(&req->req_lock); - if (OFI_UNLIKELY(ret)) { - NCCL_OFI_WARN("Failed to unlock req_lock mutex"); - return ncclInternalError; - } - - /* Wait until connect message is sent */ - if (conn_msg_state != NCCL_OFI_RDMA_REQ_COMPLETED) { - return 0; - } - - /* Release connect message request */ - req->free(req, false); - comm_state->req = NULL; - req = NULL; - /* Prepare request to receive connect response message */ s_comm->conn_resp_req = prepare_recv_conn_resp_req(s_comm); if (OFI_UNLIKELY(s_comm->conn_resp_req == NULL)) { @@ -5345,15 +5331,27 @@ static int connect(nccl_net_ofi_ep_t *base_ep, comm_state->stage = COMM_RECV_CONN; - case COMM_RECV_CONN: + case COMM_RECV_CONN:; /* COMM_RECV_CONN: Receive connect response message from remote */ - ret = post_recv_conn_resp(s_comm, device, ep); - if (ret == -FI_EAGAIN) { - return 0; - } else if (ret != 0) { - send_close(s_comm); - return ret; + bool recv_conn_resp_posted = false; + while (!recv_conn_resp_posted) { + ret = post_recv_conn_resp(s_comm, device, ep); + if (ret == -FI_EAGAIN) { + /* Block until we post the connection response request. + EAGAIN only involves waiting for local resources to free up, so it + should be safe to block. */ + ret = ofi_process_cq(ep); + if (OFI_UNLIKELY(ret != 0)) { + send_close(s_comm); + return ret; + } + } else if (ret != 0) { + send_close(s_comm); + return ret; + } else { + recv_conn_resp_posted = true; + } } /* Progress our engine to get completions. If the