From bbed42fad574c1db5047d57694f5b47e96e48df7 Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Thu, 9 Jan 2025 23:34:09 +0000 Subject: [PATCH] rdma: add separate request types for eager/ctrl rx buffers These two requst types current share an underlying data structure. Signed-off-by: Eric Raut --- include/nccl_ofi_rdma.h | 10 ++++- src/nccl_ofi_rdma.c | 87 +++++++++++++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 6e1719366..16ecb16c6 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -80,8 +80,10 @@ typedef enum nccl_net_ofi_rdma_req_type { NCCL_OFI_RDMA_RECV_SEGMS, /* Eager local copy request. Subrequest of NCCL_OFI_RDMA_RECV */ NCCL_OFI_RDMA_EAGER_COPY, - /* Rx buff post request */ - NCCL_OFI_RDMA_RX_BUFF, + /* Ctrl rx buff post request */ + NCCL_OFI_RDMA_CTRL_RX_BUFF, + /* Eager rx buff post request */ + NCCL_OFI_RDMA_EAGER_RX_BUFF, /* Flush request */ NCCL_OFI_RDMA_FLUSH, /* Connect message send request */ @@ -690,6 +692,10 @@ struct nccl_net_ofi_ep_rail { size_t max_rx_buff_posted; /* Mutex for rx buffer operations */ pthread_mutex_t rx_buff_mutex; + + /* Allocate a receive buffer request for this rail (eager or ctrl) */ + nccl_net_ofi_rdma_req_t* (*rx_buff_req_alloc)(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail); }; /* diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 2ec6a8690..592e5050b 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -633,7 +633,8 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev, * @brief Return rx data struct of rx request */ static inline rdma_req_rx_buff_data_t *get_rx_buff_data(nccl_net_ofi_rdma_req_t *req) { - assert(req->type == NCCL_OFI_RDMA_RX_BUFF); + assert(req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF || + NCCL_OFI_RDMA_EAGER_RX_BUFF); return &req->rx_buff_data; } @@ -1242,7 +1243,7 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm); static int handle_close_msg_recv(nccl_net_ofi_rdma_req_t *rx_buff_req) { - assert(rx_buff_req->type == NCCL_OFI_RDMA_RX_BUFF); + assert(rx_buff_req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF); rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(rx_buff_req); @@ -1287,7 +1288,8 @@ static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, int ra NCCL_OFI_WARN("RECV event had NULL ctx!"); return -EINVAL; } - if (OFI_UNLIKELY(rx_buff_req->type != NCCL_OFI_RDMA_RX_BUFF)) { + if (OFI_UNLIKELY((eager && (rx_buff_req->type != NCCL_OFI_RDMA_EAGER_RX_BUFF)) + || ((!eager) && (rx_buff_req->type != NCCL_OFI_RDMA_CTRL_RX_BUFF)))) { NCCL_OFI_WARN("Invalid non-rx_buff request as ctx!"); return -EINVAL; } @@ -1530,8 +1532,10 @@ static const char *req_type_str(nccl_net_ofi_rdma_req_type_t type) return "SEND_CLOSE"; case NCCL_OFI_RDMA_RECV_SEGMS: return "RECV_SEGMS"; - case NCCL_OFI_RDMA_RX_BUFF: - return "RX_BUFF"; + case NCCL_OFI_RDMA_EAGER_RX_BUFF: + return "EAGER_RX_BUFF"; + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + return "CTRL_RX_BUFF"; case NCCL_OFI_RDMA_FLUSH: return "FLUSH"; case NCCL_OFI_RDMA_EAGER_COPY: @@ -1656,7 +1660,8 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_ case NCCL_OFI_RDMA_SEND_CLOSE: case NCCL_OFI_RDMA_RECV_SEGMS: case NCCL_OFI_RDMA_EAGER_COPY: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: case NCCL_OFI_RDMA_FLUSH: case NCCL_OFI_RDMA_SEND_CONN: case NCCL_OFI_RDMA_RECV_CONN: @@ -1691,7 +1696,8 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_ case NCCL_OFI_RDMA_SEND_CTRL: case NCCL_OFI_RDMA_SEND_CLOSE: case NCCL_OFI_RDMA_RECV_SEGMS: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: case NCCL_OFI_RDMA_SEND_CONN: case NCCL_OFI_RDMA_RECV_CONN: case NCCL_OFI_RDMA_RECV_CONN_RESP: @@ -1774,7 +1780,7 @@ static inline int process_err_completion(nccl_net_ofi_rdma_device_t *device, err_entry.prov_errno, fi_cq_strerror(cq, err_entry.prov_errno, err_entry.err_data, NULL, 0), (long)err_entry.len, nccl_net_ofi_req_str(req)); - if (req->type == NCCL_OFI_RDMA_RX_BUFF) { + if (req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF || NCCL_OFI_RDMA_EAGER_RX_BUFF) { /* A rx buffer receive failed -- this is an internal error so bail out */ NCCL_OFI_WARN("Fatal: rx buffer recv completed with error"); } else { @@ -1850,7 +1856,8 @@ static int receive_progress(nccl_net_ofi_rdma_req_t *req, bool add_to_pending) case NCCL_OFI_RDMA_RECV: case NCCL_OFI_RDMA_SEND: case NCCL_OFI_RDMA_RECV_SEGMS: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: case NCCL_OFI_RDMA_SEND_CONN: case NCCL_OFI_RDMA_RECV_CONN: case NCCL_OFI_RDMA_RECV_CONN_RESP: @@ -1910,7 +1917,8 @@ static int process_pending_reqs(nccl_net_ofi_rdma_ep_t *ep) switch (req->type) { case NCCL_OFI_RDMA_WRITE: case NCCL_OFI_RDMA_SEND: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: rc = send_progress(req); break; case NCCL_OFI_RDMA_READ: @@ -2290,7 +2298,7 @@ static inline int free_invalid(nccl_net_ofi_rdma_req_t *req, return -EINVAL; } -static inline int free_rx_buff_req(nccl_net_ofi_rdma_req_t *req, +static inline int free_eager_rx_buff_req(nccl_net_ofi_rdma_req_t *req, bool dec_inflight_reqs) { assert(!dec_inflight_reqs); @@ -2303,16 +2311,58 @@ static inline int free_rx_buff_req(nccl_net_ofi_rdma_req_t *req, return free_base_req(NULL, ep->rx_buff_reqs_fl, req, false); } -static inline nccl_net_ofi_rdma_req_t *alloc_rx_buff_req(nccl_net_ofi_rdma_ep_t *ep, - nccl_net_ofi_ep_rail_t *rail) +static inline nccl_net_ofi_rdma_req_t *eager_rx_buff_req_alloc(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { nccl_net_ofi_rdma_req_t *req = allocate_req(ep->rx_buff_reqs_fl); if (!req) return NULL; req->comm = NULL; - req->type = NCCL_OFI_RDMA_RX_BUFF; + req->type = NCCL_OFI_RDMA_EAGER_RX_BUFF; req->dev_id = rdma_endpoint_get_device(ep)->base.dev_id; - req->free = free_rx_buff_req; + req->free = free_eager_rx_buff_req; + + rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); + + nccl_ofi_freelist_elem_t *rx_buff_fl_elem = + nccl_ofi_freelist_entry_alloc(ep->rx_buff_fl); + if (!rx_buff_fl_elem) { + NCCL_OFI_WARN("Failed to allocate rx_buff_fl_elem"); + req->free(req, false); + return NULL; + } + assert(NCCL_OFI_IS_PTR_ALIGNED(rx_buff_fl_elem->ptr, EAGER_RX_BUFFER_ALIGNMENT)); + + rx_buff_data->rx_buff_fl_elem = rx_buff_fl_elem; + rx_buff_data->buff_len = ep->rx_buff_size; + rx_buff_data->rail = rail; + rx_buff_data->ep = ep; + return req; +} + +static inline int ctrl_rx_buff_req_free(nccl_net_ofi_rdma_req_t *req, + bool dec_inflight_reqs) +{ + assert(!dec_inflight_reqs); + rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); + nccl_net_ofi_rdma_ep_t *ep = rx_buff_data->ep; + /* Free buffer */ + if (rx_buff_data->rx_buff_fl_elem) { + nccl_ofi_freelist_entry_free(ep->rx_buff_fl, rx_buff_data->rx_buff_fl_elem); + } + return free_base_req(NULL, ep->rx_buff_reqs_fl, req, false); +} + +static inline nccl_net_ofi_rdma_req_t *ctrl_rx_buff_req_alloc(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) +{ + nccl_net_ofi_rdma_req_t *req = allocate_req(ep->rx_buff_reqs_fl); + if (!req) return NULL; + + req->comm = NULL; + req->type = NCCL_OFI_RDMA_CTRL_RX_BUFF; + req->dev_id = rdma_endpoint_get_device(ep)->base.dev_id; + req->free = ctrl_rx_buff_req_free; rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); @@ -2371,7 +2421,7 @@ static inline int post_rx_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, for (size_t i = 0; i < buffers_needed; ++i) { bool is_last_req = (i == (buffers_needed - 1)); nccl_net_ofi_rdma_req_t *req = - alloc_rx_buff_req(ep, rail); + rail->rx_buff_req_alloc(ep, rail); if (!req) { NCCL_OFI_WARN("Failed to allocate rx_buff req"); return -ENOMEM; @@ -5550,7 +5600,8 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req) // Successfully sent the xfer with this rail rma_op_data->xferred_rail_id++; } - } else if (req->type == NCCL_OFI_RDMA_RX_BUFF) { // Post rx Buffer + } else if (req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF || + req->type == NCCL_OFI_RDMA_EAGER_RX_BUFF) { // Post rx Buffer rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); /* Get ep rail information to xfer the req */ assert(rx_buff_data->rail != NULL); @@ -6171,6 +6222,7 @@ static inline int init_rx_buffers(nccl_net_ofi_rdma_ep_t *ep) ); rail->num_rx_buff_posted = 0; nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL); + rail->rx_buff_req_alloc = ctrl_rx_buff_req_alloc; } for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { @@ -6183,6 +6235,7 @@ static inline int init_rx_buffers(nccl_net_ofi_rdma_ep_t *ep) ); rail->num_rx_buff_posted = 0; nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL); + rail->rx_buff_req_alloc = eager_rx_buff_req_alloc; } return ret;