Skip to content

Commit

Permalink
rdma: defer connect completion after sending connect message
Browse files Browse the repository at this point in the history
In the current implementation of connect/accept, it is possible for
`accept` to complete (i.e., return a non-NULL communicator) after the
corresponding `connect` returned a NULL communicator (while waiting for
a completion for the connection message). This is a strange semantic,
and evidently causes NCCL to be unhappy, particularly in the multi-recv
case (which is being added in a future commit).

So, after sending the connect message, defer waiting for completion;
block when closing the send comm if necessary.

Signed-off-by: Eric Raut <[email protected]>
  • Loading branch information
rauteric committed Feb 26, 2024
1 parent 02dc354 commit ae106d2
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 47 deletions.
1 change: 1 addition & 0 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ AC_SEARCH_LIBS([dlopen], [dl], [], [AC_MSG_ERROR([NCCL OFI Plugin requires dlope
# Check for GCC builtin functions
CHECK_GCC_BUILTIN([__builtin_expect])
CHECK_GCC_BUILTIN([__builtin_ffsll])
CHECK_GCC_BUILTIN([__sync_synchronize])

# Checks for external packages
CHECK_PKG_LIBFABRIC([], [AC_MSG_ERROR([NCCL OFI Plugin could not find a working Libfabric install.])])
Expand Down
10 changes: 8 additions & 2 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,18 @@ typedef struct nccl_net_ofi_rdma_send_comm {
/* Comm ID provided by remote endpoint */
uint64_t remote_comm_id;

/* Request tracking connect message */
nccl_net_ofi_rdma_req_t *send_conn_req;

/* Request to receive connect response message to finalize
* connection establishment */
nccl_net_ofi_rdma_req_t *conn_resp_req;

/* Indicates if connection establishment is completed */
bool connected;
/* Indicates if connect message was delivered (and req freed) */
bool connect_msg_delivered;

/* Indicates if connection response received and connection finalized */
bool connect_finalized;

/* Message struct send connect message and receive connect
* response message */
Expand Down
3 changes: 2 additions & 1 deletion m4/check_gcc_builtin.m4
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ AC_DEFUN([CHECK_GCC_BUILTIN], [
[__builtin_ffs], [$1(0)],
[__builtin_ffsl], [$1(0)],
[__builtin_ffsll], [$1(0)],
[__builtin_expect], [$1(0, 0)]),
[__builtin_expect], [$1(0, 0)],
[__sync_synchronize], [$1()]),
[exit(1)]
])], [result=yes], [result=no])
Expand Down
84 changes: 40 additions & 44 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,18 @@ static inline int process_completions(struct fi_cq_tagged_entry *cq_entry,
return ncclInternalError;
}

if (IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag) && (comp_flags & FI_RECV)) {
if (req->type == NCCL_OFI_RDMA_SEND_CONN) {
assert(req->comm->type == NCCL_NET_OFI_SEND_COMM);
nccl_net_ofi_rdma_send_comm_t *s_comm =
(nccl_net_ofi_rdma_send_comm_t *)req->comm;
assert(req == s_comm->send_conn_req);
/* Release connect message request */
req->free(req, false);
req = NULL;
s_comm->send_conn_req = NULL;
__sync_synchronize();
s_comm->connect_msg_delivered = true;
} else if (IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag) && (comp_flags & FI_RECV)) {
assert(req->comm->type == NCCL_NET_OFI_SEND_COMM);
/* Complete send communicator */
nccl_net_ofi_rdma_send_comm_t *s_comm =
Expand Down Expand Up @@ -2227,7 +2238,7 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm)
* should be a barrier after the communicator initialization
* is finalized */
__sync_synchronize();
s_comm->connected = true;
s_comm->connect_finalized = true;

return ret;
}
Expand Down Expand Up @@ -4583,7 +4594,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
* Try finalize connection if not established yet; Return NULL
* request if not able to finalize connection.
*/
if (OFI_UNLIKELY(!s_comm->connected)) {
if (OFI_UNLIKELY(!s_comm->connect_finalized)) {
__compiler_barrier();

/* Progress our engine to get completions. If the
Expand All @@ -4594,7 +4605,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
goto error;
}

if (!s_comm->connected) {
if (!s_comm->connect_finalized) {
/* Return NULL request */
*base_req = NULL;
goto exit;
Expand Down Expand Up @@ -4804,8 +4815,9 @@ static int blocked_send_close(nccl_net_ofi_send_comm_t *send_comm)
return ncclInternalError;
}

// TODO: We might want to use READ_ONCE to read variable `connected'
while (!s_comm->connected) {
// TODO: We might want to use READ_ONCE to read variables
// `connect_msg_delivered` and `connected'
while (!s_comm->connect_msg_delivered || !s_comm->connect_finalized) {
__compiler_barrier();
int ret = 0;
/* Progress our engine to get completions. If the
Expand Down Expand Up @@ -5212,14 +5224,12 @@ static int connect(nccl_net_ofi_ep_t *base_ep,
nccl_net_ofi_send_comm_t **send_comm)
{
int ret = 0;
nccl_net_ofi_rdma_req_state_t conn_msg_state;
*send_comm = NULL;
nccl_net_ofi_rdma_ep_t *ep =
(nccl_net_ofi_rdma_ep_t *)base_ep;

/* Extract connection state of the communicator */
save_comm_state_t *comm_state = &(handle->state);
nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)comm_state->req;
nccl_net_ofi_rdma_send_comm_t *s_comm =
(nccl_net_ofi_rdma_send_comm_t *)comm_state->comm;

Expand Down Expand Up @@ -5259,23 +5269,22 @@ static int connect(nccl_net_ofi_ep_t *base_ep,
comm_state->comm = &s_comm->base.base;

/* Prepare connect request to be sent to peer */
req = prepare_send_conn_req(s_comm);
if (OFI_UNLIKELY(req == NULL)) {
s_comm->send_conn_req = prepare_send_conn_req(s_comm);
if (OFI_UNLIKELY(s_comm->send_conn_req == NULL)) {
send_close(s_comm);
return ncclSystemError;
}
comm_state->req = &req->base;

comm_state->stage = COMM_SEND_CONN;

case COMM_SEND_CONN:
/* COMM_SEND_CONN: Post a connect message to send peer connections */
ret = post_send_conn(s_comm, device, ep, req);
ret = post_send_conn(s_comm, device, ep, s_comm->send_conn_req);
if (ret == -FI_EAGAIN) {
return 0;
}
else if (ret != 0) {
req->free(req, false);
s_comm->send_conn_req->free(s_comm->send_conn_req, false);
send_close(s_comm);
return ret;
}
Expand All @@ -5296,29 +5305,6 @@ static int connect(nccl_net_ofi_ep_t *base_ep,
return ret;
}

/* Check if the connect message is sent */
ret = pthread_mutex_lock(&req->req_lock);
if (OFI_UNLIKELY(ret)) {
NCCL_OFI_WARN("Unable to acquire req_lock mutex");
return ncclInternalError;
}
conn_msg_state = req->state;
ret = pthread_mutex_unlock(&req->req_lock);
if (OFI_UNLIKELY(ret)) {
NCCL_OFI_WARN("Failed to unlock req_lock mutex");
return ncclInternalError;
}

/* Wait until connect message is sent */
if (conn_msg_state != NCCL_OFI_RDMA_REQ_COMPLETED) {
return 0;
}

/* Release connect message request */
req->free(req, false);
comm_state->req = NULL;
req = NULL;

/* Prepare request to receive connect response message */
s_comm->conn_resp_req = prepare_recv_conn_resp_req(s_comm);
if (OFI_UNLIKELY(s_comm->conn_resp_req == NULL)) {
Expand All @@ -5328,15 +5314,25 @@ static int connect(nccl_net_ofi_ep_t *base_ep,

comm_state->stage = COMM_RECV_CONN;

case COMM_RECV_CONN:
case COMM_RECV_CONN: {
/* COMM_RECV_CONN: Receive connect response message from remote */

ret = post_recv_conn_resp(s_comm, device, ep);
if (ret == -FI_EAGAIN) {
return 0;
} else if (ret != 0) {
send_close(s_comm);
return ret;
bool recv_conn_resp_posted = false;
while (!recv_conn_resp_posted) {
ret = post_recv_conn_resp(s_comm, device, ep);
if (ret == -FI_EAGAIN) {
/* Block until we post the connection response request.
EAGAIN only involves waiting for local resources to free up, so it
should be safe to block. */
ret = ofi_process_cq(ep);
if (OFI_UNLIKELY(ret != 0)) {
return ret;
}
} else if (ret != 0) {
return ret;
} else {
recv_conn_resp_posted = true;
}
}

/* Progress our engine to get completions. If the
Expand All @@ -5350,7 +5346,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep,
comm_state->stage = COMM_CONN_RESP_REQ_PENDING;

break;

}
case COMM_CONN_RESP_REQ_PENDING:
case COMM_CONNECTED:
default:
Expand Down

0 comments on commit ae106d2

Please sign in to comment.