diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 7cef3e0d7..224e09ed3 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -59,7 +59,8 @@ extern "C" { #define MIN_TAG_BITS_FOR_RING_ID (32 + 1) /* Maximum number of grouped receives */ -#define NCCL_OFI_MAX_RECVS 1 +#define NCCL_OFI_MAX_RECVS 8 +#define NCCL_OFI_MAX_RECVS_SENDRECV 1 /* * This defines a higher value than maximum inflight requests supported by NCCL @@ -195,6 +196,7 @@ typedef struct nccl_ofi_connection_info { typedef struct nccl_net_ofi_conn_handle { char ep_name[MAX_EP_ADDR]; uint64_t comm_id; + uintptr_t l_comm_ptr; /* Save temporary communicator state when creating send communicator */ save_comm_state_t state; } nccl_net_ofi_conn_handle_t; diff --git a/include/nccl_ofi_msgbuff.h b/include/nccl_ofi_msgbuff.h index d53073058..1db62d097 100644 --- a/include/nccl_ofi_msgbuff.h +++ b/include/nccl_ofi_msgbuff.h @@ -68,6 +68,10 @@ typedef struct { // Type of element nccl_ofi_msgbuff_elemtype_t type; void *elem; + // Multi-recv information + uint16_t multi_recv_size; + uint16_t multi_recv_start; + int multi_recv_tag; } nccl_ofi_msgbuff_elem_t; typedef struct { @@ -110,9 +114,14 @@ bool nccl_ofi_msgbuff_destroy(nccl_ofi_msgbuff_t *msgbuff); * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag, + void *elem, nccl_ofi_msgbuff_elemtype_t type, nccl_ofi_msgbuff_status_t *msg_idx_status); +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_ctrl_multirecv(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_base_index, uint16_t multi_recv_size, void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status); + /** * Replace an existing message element * @@ -126,8 +135,9 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, - nccl_ofi_msgbuff_status_t *msg_idx_status); + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status, bool *multi_send_ready); /** * Retrieve message with given index @@ -142,6 +152,12 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void **elem, nccl_ofi_msgbuff_elemtype_t *type, + nccl_ofi_msgbuff_status_t *msg_idx_status); + + +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve_notag(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void **elem, nccl_ofi_msgbuff_elemtype_t *type, nccl_ofi_msgbuff_status_t *msg_idx_status); @@ -156,7 +172,8 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, * NCCL_OFI_MSGBUFF_ERROR, other error */ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, nccl_ofi_msgbuff_status_t *msg_idx_status); + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status); #ifdef _cplusplus } // End extern "C" diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 36b7d4318..4944459f7 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -74,20 +74,33 @@ typedef struct nccl_net_ofi_rdma_mr_handle { struct fid_mr *mr[]; } nccl_net_ofi_rdma_mr_handle_t; -/* Contents of ctrl message sent from receiver to sender to advertise - destination buffer */ -typedef struct nccl_net_ofi_rdma_ctrl_msg { +typedef struct nccl_net_ofi_rdma_ctrl_msg_entry { + int multi_recv_tag; uint64_t buff_addr; uint64_t buff_len; uint64_t buff_mr_key[MAX_NUM_RAILS]; +} nccl_net_ofi_rdma_ctrl_msg_entry_t; + +/* Contents of ctrl message sent from receiver to sender to advertise + destination buffer */ +typedef struct nccl_net_ofi_rdma_ctrl_msg { + uint16_t msg_seq_num; + uint16_t multi_recv_size; + nccl_net_ofi_rdma_ctrl_msg_entry_t entries[]; + /* uintptr_t r_comm_ptr; DBG DBG */ } nccl_net_ofi_rdma_ctrl_msg_t; +#define RDMA_CTRL_MSG_ENTRIES_MAX_SIZE (NCCL_OFI_MAX_RECVS * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t)) +#define RDMA_CTRL_MSG_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE) + /* Structure used to store control messages in a free list */ typedef struct nccl_net_ofi_rdma_ctrl_fl_item { nccl_ofi_freelist_reginfo_t fl_reginfo; nccl_net_ofi_rdma_ctrl_msg_t ctrl_msg; } nccl_net_ofi_rdma_ctrl_fl_item_t; +#define RDMA_CTRL_FL_ITEM_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE) + /* For LL/LL128 protocols, bounce buffers (source of RDMA read operations) need to be 128B aligned */ #define BOUNCE_BUFFER_ALIGNMENT 128 @@ -150,6 +163,13 @@ typedef struct { /* Total number of completions. Expect one completion for receiving the * control message and one completion for each send segment. */ int total_num_compls; + + /* Multi-recv information */ + uint16_t multi_recv_size; + uint16_t multi_recv_start; + int multi_recv_tag; + /* This may not match sender-side seq num with multi-recv */ + uint16_t recv_side_msg_seq_num; } rdma_req_send_data_t; /* @@ -164,6 +184,8 @@ typedef struct { nccl_net_ofi_schedule_t *ctrl_schedule; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req_t *recv_req; + /* Size of ctrl message */ + size_t ctrl_msg_size; } rdma_req_send_ctrl_data_t; typedef struct { @@ -204,6 +226,12 @@ typedef struct { * For eager messages, the second completion will be received * when the local read into the destination buffer is complete */ int total_num_compls; + /* Multi-recv information */ + uint16_t multi_recv_size; + uint16_t multi_recv_start; + int multi_recv_tag; + /* Next req in sequence */ + nccl_net_ofi_rdma_req_t *multi_recv_next; } rdma_req_recv_data_t; /* @@ -298,6 +326,8 @@ typedef struct nccl_ofi_rdma_connection_info { side. The receiver must use this ID when sending messages to sender */ uint64_t local_comm_id; + uintptr_t s_comm_ptr; + /* Number of rails */ int num_rails; @@ -625,6 +655,31 @@ typedef struct nccl_net_ofi_rdma_device { nccl_ofi_idpool_t key_pool; } nccl_net_ofi_rdma_device_t; +/* + * @brief Return send data struct of send request + */ +static inline rdma_req_send_data_t *get_send_data(nccl_net_ofi_rdma_req_t *req) { + assert(req->type == NCCL_OFI_RDMA_SEND); + return &req->send_data; +} + +/* + * @brief Return bounce data struct of bounce request + */ +static inline rdma_req_bounce_data_t *get_bounce_data(nccl_net_ofi_rdma_req_t *req) { + assert(req->type == NCCL_OFI_RDMA_BOUNCE); + return &req->bounce_data; +} + +/* + * Get ctrl message from bounce buffer + */ +static inline nccl_net_ofi_rdma_ctrl_msg_t *get_bounce_ctrl_msg + (nccl_net_ofi_rdma_bounce_fl_item_t *bounce_fl_item) +{ + return (nccl_net_ofi_rdma_ctrl_msg_t *)&bounce_fl_item->bounce_msg; +} + /* * @brief Initialize plugin with rdma protocol structures */ diff --git a/src/nccl_ofi_api.c b/src/nccl_ofi_api.c index 3b78f0718..a1ccfcf38 100644 --- a/src/nccl_ofi_api.c +++ b/src/nccl_ofi_api.c @@ -11,8 +11,8 @@ _Static_assert(sizeof(nccl_net_ofi_conn_handle_t) <= NCCL_NET_HANDLE_MAXSIZE, "Size of OFI Handle is too large"); -_Static_assert(offsetof(nccl_net_ofi_conn_handle_t, state) <= NCCL_NET_HANDLE_MAXSIZE_V4, - "Size of OFI Handle (without state) is too large"); +//_Static_assert(offsetof(nccl_net_ofi_conn_handle_t, state) <= NCCL_NET_HANDLE_MAXSIZE_V4, +// "Size of OFI Handle (without state) is too large"); _Static_assert(NCCL_NET_MAX_REQUESTS <= NCCL_OFI_MAX_REQUESTS, "Maximum outstanding requests for plugin is less than what NCCL requires"); diff --git a/src/nccl_ofi_msgbuff.c b/src/nccl_ofi_msgbuff.c index 2e6d6eea5..001b544ed 100644 --- a/src/nccl_ofi_msgbuff.c +++ b/src/nccl_ofi_msgbuff.c @@ -9,6 +9,7 @@ #include "nccl_ofi_msgbuff.h" #include "nccl_ofi.h" #include "nccl_ofi_log.h" +#include "nccl_ofi_rdma.h" nccl_ofi_msgbuff_t *nccl_ofi_msgbuff_init(uint16_t buffer_size) { @@ -25,7 +26,7 @@ nccl_ofi_msgbuff_t *nccl_ofi_msgbuff_init(uint16_t buffer_size) goto error; } msgbuff->buff_size = buffer_size; - if (!(msgbuff->buff = malloc(sizeof(nccl_ofi_msgbuff_elem_t)*buffer_size))) { + if (!(msgbuff->buff = calloc((4*buffer_size), sizeof(nccl_ofi_msgbuff_elem_t)))) { NCCL_OFI_WARN("Memory allocation (msgbuff->buff) failed"); goto error; } @@ -77,7 +78,7 @@ static uint16_t nccl_ofi_msgbuff_num_inflight(const nccl_ofi_msgbuff_t *msgbuff) static inline nccl_ofi_msgbuff_elem_t *buff_idx(const nccl_ofi_msgbuff_t *msgbuff, uint16_t idx) { - return &msgbuff->buff[idx % msgbuff->buff_size]; + return &msgbuff->buff[idx % (4*msgbuff->buff_size)]; } /** @@ -115,19 +116,11 @@ static nccl_ofi_msgbuff_status_t nccl_ofi_msgbuff_get_idx_status return NCCL_OFI_MSGBUFF_UNAVAILABLE; } -nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, +static inline nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_at_idx(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, + uint16_t multi_recv_size, uint16_t multi_recv_start, int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status) { - if (!msgbuff) { - NCCL_OFI_WARN("msgbuff is NULL"); - return NCCL_OFI_MSGBUFF_ERROR; - } - if (pthread_mutex_lock(&msgbuff->lock)) { - NCCL_OFI_WARN("Error locking mutex"); - return NCCL_OFI_MSGBUFF_ERROR; - } - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; @@ -135,6 +128,10 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, buff_idx(msgbuff, msg_index)->stat = NCCL_OFI_MSGBUFF_INPROGRESS; buff_idx(msgbuff, msg_index)->elem = elem; buff_idx(msgbuff, msg_index)->type = type; + buff_idx(msgbuff, msg_index)->multi_recv_size = multi_recv_size; + if (multi_recv_size > 1) + buff_idx(msgbuff, msg_index)->multi_recv_start = multi_recv_start; + buff_idx(msgbuff, msg_index)->multi_recv_tag = multi_recv_tag; /* Update msg_next ptr */ while ((uint16_t)(msg_index - msgbuff->msg_next) <= msgbuff->buff_size) { if (msgbuff->msg_next != msg_index) { @@ -148,16 +145,105 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, ret = NCCL_OFI_MSGBUFF_INVALID_IDX; } + return ret; +} + +static inline bool nccl_ofi_msgbuff_multirecv_search(nccl_ofi_msgbuff_t *msgbuff, + uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag, + uint16_t *match_index) +{ + for (uint16_t idx = multi_recv_start; idx != (uint16_t)(multi_recv_start+multi_recv_size); ++idx) { + nccl_ofi_msgbuff_status_t msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, idx); + if (msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { + int present_tag = buff_idx(msgbuff, idx)->multi_recv_tag; + if (present_tag == multi_recv_tag) { + *match_index = idx; + return true; + } + } + } + return false; +} + +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag, + void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status) +{ + nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + + if (pthread_mutex_lock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } + + /** Multi-recv specific behavior **/ + assert(type == NCCL_OFI_MSGBUFF_REQ); + + ret = nccl_ofi_msgbuff_insert_at_idx(msgbuff, msg_index, elem, type, + multi_recv_size, multi_recv_start, multi_recv_tag, msg_idx_status); + if (pthread_mutex_unlock(&msgbuff->lock)) { NCCL_OFI_WARN("Error unlocking mutex"); - ret = NCCL_OFI_MSGBUFF_ERROR; + return NCCL_OFI_MSGBUFF_ERROR; + } + return ret; +} + +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_ctrl_multirecv(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_base_index, uint16_t multi_recv_size, void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status) +{ + assert(type == NCCL_OFI_MSGBUFF_BUFF); + + nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + + if (pthread_mutex_lock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } + + nccl_net_ofi_rdma_req_t *bounce_req = elem; + nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = get_bounce_ctrl_msg(get_bounce_data(bounce_req)->bounce_fl_item); + + assert(msg_base_index == ctrl_msg->msg_seq_num); + assert(multi_recv_size == ctrl_msg->multi_recv_size); + + for (uint16_t i = 0; i < multi_recv_size; ++i) { + uint16_t msg_index = msg_base_index + i; + ret = nccl_ofi_msgbuff_insert_at_idx(msgbuff, msg_index, elem, type, + multi_recv_size, msg_base_index, ctrl_msg->entries[i].multi_recv_tag, + msg_idx_status); + assert(ret == NCCL_OFI_MSGBUFF_SUCCESS); + } + + if (pthread_mutex_unlock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error unlocking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; } return ret; } +static bool test_ms_ready(nccl_ofi_msgbuff_t *msgbuff, uint16_t multi_recv_start, + uint16_t multi_recv_size) +{ + for (uint16_t i = multi_recv_start; i != (uint16_t)(multi_recv_start + multi_recv_size); + ++i) { + nccl_ofi_msgbuff_status_t msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, i); + if (msg_idx_status != NCCL_OFI_MSGBUFF_INPROGRESS) { + return false; + } + if (buff_idx(msgbuff, i)->type != NCCL_OFI_MSGBUFF_REQ) { + return false; + } + } + return true; +} + nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type, - nccl_ofi_msgbuff_status_t *msg_idx_status) + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void *elem, nccl_ofi_msgbuff_elemtype_t type, + nccl_ofi_msgbuff_status_t *msg_idx_status, bool *multi_send_ready) { if (!msgbuff) { NCCL_OFI_WARN("msgbuff is NULL"); @@ -167,18 +253,33 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, NCCL_OFI_WARN("Error locking mutex"); return NCCL_OFI_MSGBUFF_ERROR; } + assert(type == NCCL_OFI_MSGBUFF_REQ); + assert(multi_send_ready != NULL); + *multi_send_ready = false; - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + bool match_found = nccl_ofi_msgbuff_multirecv_search(msgbuff, multi_recv_start, + multi_recv_size, multi_recv_tag, &msg_index); + if (!match_found) { + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + goto unlock; + } + + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { buff_idx(msgbuff, msg_index)->elem = elem; buff_idx(msgbuff, msg_index)->type = type; + *multi_send_ready = test_ms_ready(msgbuff, multi_recv_start, + multi_recv_size); ret = NCCL_OFI_MSGBUFF_SUCCESS; } else { ret = NCCL_OFI_MSGBUFF_INVALID_IDX; } +unlock: if (pthread_mutex_unlock(&msgbuff->lock)) { NCCL_OFI_WARN("Error unlocking mutex"); ret = NCCL_OFI_MSGBUFF_ERROR; @@ -186,7 +287,7 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff, return ret; } -nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve_notag(nccl_ofi_msgbuff_t *msgbuff, uint16_t msg_index, void **elem, nccl_ofi_msgbuff_elemtype_t *type, nccl_ofi_msgbuff_status_t *msg_idx_status) { @@ -199,16 +300,17 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, return NCCL_OFI_MSGBUFF_ERROR; } if (pthread_mutex_lock(&msgbuff->lock)) { - NCCL_OFI_WARN("Error locking mutex"); - return NCCL_OFI_MSGBUFF_ERROR; - } + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { *elem = buff_idx(msgbuff, msg_index)->elem; *type = buff_idx(msgbuff, msg_index)->type; + assert(*type == NCCL_OFI_MSGBUFF_REQ); ret = NCCL_OFI_MSGBUFF_SUCCESS; } else { if (*msg_idx_status == NCCL_OFI_MSGBUFF_UNAVAILABLE) { @@ -225,21 +327,102 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, return ret; } +nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff, + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void **elem, nccl_ofi_msgbuff_elemtype_t *type, + nccl_ofi_msgbuff_status_t *msg_idx_status) +{ + if (!msgbuff) { + NCCL_OFI_WARN("msgbuff is NULL"); + return NCCL_OFI_MSGBUFF_ERROR; + } + if (!elem) { + NCCL_OFI_WARN("elem is NULL"); + return NCCL_OFI_MSGBUFF_ERROR; + } + if (pthread_mutex_lock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } + + nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + + if (multi_recv_size <= 1) { + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status != NCCL_OFI_MSGBUFF_UNAVAILABLE) { + /* Check if this actually should be a multi-recv */ + if (buff_idx(msgbuff, msg_index)->multi_recv_size > 1) { + assert(multi_recv_size == 0); + multi_recv_start = buff_idx(msgbuff, msg_index)->multi_recv_start; + multi_recv_size = buff_idx(msgbuff, msg_index)->multi_recv_size; + } + } + } + + if (multi_recv_size <= 1) { + /* Ok so this actually isn't a multirecv (that we know of) */ + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { + *elem = buff_idx(msgbuff, msg_index)->elem; + *type = buff_idx(msgbuff, msg_index)->type; + ret = NCCL_OFI_MSGBUFF_SUCCESS; + } else { + if (*msg_idx_status == NCCL_OFI_MSGBUFF_UNAVAILABLE) { + // UNAVAILABLE really only applies to insert, so return NOTSTARTED here + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + } + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + } + } else { + /* Multi-recv -- search the index space */ + bool match_found = nccl_ofi_msgbuff_multirecv_search(msgbuff, multi_recv_start, + multi_recv_size, multi_recv_tag, &msg_index); + if (!match_found) { + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + } else { + *msg_idx_status = NCCL_OFI_MSGBUFF_INPROGRESS; + *elem = buff_idx(msgbuff, msg_index)->elem; + *type = buff_idx(msgbuff, msg_index)->type; + + ret = NCCL_OFI_MSGBUFF_SUCCESS; + } + } + + if (pthread_mutex_unlock(&msgbuff->lock)) { + NCCL_OFI_WARN("Error unlocking mutex"); + ret = NCCL_OFI_MSGBUFF_ERROR; + } + return ret; +} + nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, - uint16_t msg_index, nccl_ofi_msgbuff_status_t *msg_idx_status) + uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status) { if (!msgbuff) { NCCL_OFI_WARN("msgbuff is null"); return NCCL_OFI_MSGBUFF_ERROR; } if (pthread_mutex_lock(&msgbuff->lock)) { - NCCL_OFI_WARN("Error locking mutex"); - return NCCL_OFI_MSGBUFF_ERROR; - } + NCCL_OFI_WARN("Error locking mutex"); + return NCCL_OFI_MSGBUFF_ERROR; + } - *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); nccl_ofi_msgbuff_result_t ret = NCCL_OFI_MSGBUFF_ERROR; + if (multi_recv_size > 1) { + bool match_found = nccl_ofi_msgbuff_multirecv_search(msgbuff, multi_recv_start, + multi_recv_size, multi_recv_tag, &msg_index); + if (!match_found) { + *msg_idx_status = NCCL_OFI_MSGBUFF_NOTSTARTED; + ret = NCCL_OFI_MSGBUFF_INVALID_IDX; + goto unlock; + } + } + + *msg_idx_status = nccl_ofi_msgbuff_get_idx_status(msgbuff, msg_index); + if (*msg_idx_status == NCCL_OFI_MSGBUFF_INPROGRESS) { buff_idx(msgbuff, msg_index)->stat = NCCL_OFI_MSGBUFF_COMPLETED; buff_idx(msgbuff, msg_index)->elem = NULL; @@ -247,6 +430,12 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, while (msgbuff->msg_last_incomplete != msgbuff->msg_next && buff_idx(msgbuff, msgbuff->msg_last_incomplete)->stat == NCCL_OFI_MSGBUFF_COMPLETED) { + /* Clear out relevant info of the now-unavailable message */ + uint16_t unavail_index = msgbuff->msg_last_incomplete - msgbuff->buff_size; + buff_idx(msgbuff, unavail_index)->elem = NULL; + buff_idx(msgbuff, unavail_index)->multi_recv_size = 0; + buff_idx(msgbuff, unavail_index)->multi_recv_start = 0; + buff_idx(msgbuff, unavail_index)->multi_recv_tag = 0; ++(msgbuff->msg_last_incomplete); } ret = NCCL_OFI_MSGBUFF_SUCCESS; @@ -257,6 +446,8 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff, } ret = NCCL_OFI_MSGBUFF_INVALID_IDX; } + +unlock: if (pthread_mutex_unlock(&msgbuff->lock)) { NCCL_OFI_WARN("Error unlocking mutex"); ret = NCCL_OFI_MSGBUFF_ERROR; diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index 2654461d3..6df42449c 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -1140,7 +1140,7 @@ static int set_nic_props_default(int dev_id, struct fi_info *nic_prov, * impacted with this feature as NCCL doesn't aggregate receives from * same source. */ - props->max_group_receives = NCCL_OFI_MAX_RECVS; + props->max_group_receives = NCCL_OFI_MAX_RECVS_SENDRECV; if (support_gdr == GDR_SUPPORTED) { props->hmem_support = true; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 3acfb9c70..2eee57739 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -249,15 +249,6 @@ static inline nccl_net_ofi_rdma_recv_comm_t *get_recv_comm(nccl_net_ofi_rdma_ep_ return r_comm; } -/* - * Get ctrl message from bounce buffer - */ -static inline nccl_net_ofi_rdma_ctrl_msg_t *get_bounce_ctrl_msg - (nccl_net_ofi_rdma_bounce_fl_item_t *bounce_fl_item) -{ - return (nccl_net_ofi_rdma_ctrl_msg_t *)&bounce_fl_item->bounce_msg; -} - /* * @brief Return send communicator rail with index `rail_id` */ @@ -612,6 +603,9 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev, struct fi_info *info = device->device_rails[0].info; int ret = nccl_net_ofi_info_properties(info, dev_id, base_dev->plugin->num_devs, props); + /* Multi-recv adjustment */ + props->max_group_receives = NCCL_OFI_MAX_RECVS; + /* Scale speed by the total number of rails. Assume that all * reails have the same speed. */ if (ret == 0) { @@ -623,22 +617,6 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev, return ret; } -/* - * @brief Return bounce data struct of bounce request - */ -static inline rdma_req_bounce_data_t *get_bounce_data(nccl_net_ofi_rdma_req_t *req) { - assert(req->type == NCCL_OFI_RDMA_BOUNCE); - return &req->bounce_data; -} - -/* - * @brief Return send data struct of send request - */ -static inline rdma_req_send_data_t *get_send_data(nccl_net_ofi_rdma_req_t *req) { - assert(req->type == NCCL_OFI_RDMA_SEND); - return &req->send_data; -} - /* * @brief Return recv data struct of recv request */ @@ -929,18 +907,42 @@ static inline int inc_recv_seg_completion(nccl_net_ofi_rdma_req_t *req, return ret; } -static void copy_ctrl_data(nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdma_req_t *req) +static void copy_ctrl_data(nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdma_req_t *req, int tag) { rdma_req_send_data_t *send_data = get_send_data(req); rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = get_bounce_ctrl_msg(bounce_data->bounce_fl_item); + /** Ctrl message size consistency check **/ + assert(bounce_data->recv_len == sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + + ctrl_msg->multi_recv_size * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t)); + + + uint16_t multi_recv_size = ctrl_msg->multi_recv_size; + + /* TODO remove an extra search */ + int ctrl_idx; + for (ctrl_idx = 0; ctrl_idx < multi_recv_size; ++ctrl_idx) { + nccl_net_ofi_rdma_ctrl_msg_entry_t *entry = &ctrl_msg->entries[ctrl_idx]; + if (entry->multi_recv_tag == tag) { + break; + } + } + if (ctrl_idx >= multi_recv_size) { + assert(false); abort(); + } + for (int rail_id = 0; rail_id != MAX_NUM_RAILS; ++rail_id) { - send_data->remote_mr_key[rail_id] = ctrl_msg->buff_mr_key[rail_id]; + send_data->remote_mr_key[rail_id] = ctrl_msg->entries[ctrl_idx].buff_mr_key[rail_id]; } - send_data->remote_buff = ctrl_msg->buff_addr; - send_data->remote_len = ctrl_msg->buff_len; + send_data->remote_buff = ctrl_msg->entries[ctrl_idx].buff_addr; + send_data->remote_len = ctrl_msg->entries[ctrl_idx].buff_len; + + send_data->multi_recv_size = ctrl_msg->multi_recv_size; + send_data->multi_recv_start = ctrl_msg->msg_seq_num; + assert(send_data->multi_recv_tag == ctrl_msg->entries[ctrl_idx].multi_recv_tag); + send_data->recv_side_msg_seq_num = ctrl_msg->msg_seq_num + (uint16_t)ctrl_idx; } /* @@ -1028,19 +1030,26 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdma_ep_t *ep) { - int ret; + // int ret; int bounce_rail_id = get_bounce_data(bounce_req)->bounce_rail_id; + nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = get_bounce_ctrl_msg(get_bounce_data(bounce_req)->bounce_fl_item); + + /* Assert that imm data matches ctrl data for seq num */ + assert(msg_seq_num == ctrl_msg->msg_seq_num); + nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, msg_seq_num, - bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); + nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert_ctrl_multirecv(s_comm->msgbuff, msg_seq_num, + ctrl_msg->multi_recv_size, bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { /* Inserted! In this case sender has not yet called send() for this message, so return success and initiate RDMA write when sender calls send(). */ return decrease_bounce_buff_cnt(ep, bounce_rail_id); } + assert(false); abort(); /* TODO handle this case */ +#if 0 if (mb_res != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_INPROGRESS) { NCCL_OFI_WARN("Unexpected message insert result (%d) (ctrl recv)", (int)mb_res); return -EINVAL; @@ -1049,7 +1058,9 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, // Already a req entry here void *elem; nccl_ofi_msgbuff_elemtype_t type; - mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, &elem, &type, &stat); + mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, + ctrl_msg->multi_recv_start, ctrl_msg->multi_recv_size, + ctrl_msg->multi_recv_tag, &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS || type != NCCL_OFI_MSGBUFF_REQ) { NCCL_OFI_WARN("Invalid message retrieval result for msg %hu", msg_seq_num); return -EINVAL; @@ -1058,7 +1069,8 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, rdma_req_send_data_t *send_data = get_send_data(req); if (!send_data->eager) { - copy_ctrl_data(bounce_req, req); + abort(); + copy_ctrl_data(bounce_req, req, -1); /* We need to initiate RDMA write here. */ if (send_data->buff_len > send_data->remote_len) { @@ -1099,7 +1111,7 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, NCCL_OFI_WARN("Failed to repost bounce buff"); return ret; } - +#endif return 0; } @@ -1158,6 +1170,7 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_ofi_msgbuff_status_t stat; nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, msg_seq_num, + 0, 1, 0, bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { @@ -1178,7 +1191,8 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, // In this case, there is already a req entry here. Initiate eager copy. void *elem; nccl_ofi_msgbuff_elemtype_t type; - mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, &elem, &type, &stat); + mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, 0, 1, 0, + &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS || type != NCCL_OFI_MSGBUFF_REQ) { NCCL_OFI_WARN("Invalid message retrieval result for msg %hu", msg_seq_num); return -EINVAL; @@ -1186,6 +1200,9 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_req_t *recv_req = elem; rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); + /* Eager not allowed for multi-recv (for now) */ + assert(recv_data->multi_recv_size == 1); + rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); if (bounce_data->recv_len == 0) { /* Special case: for zero-sized messages, we can skip the local read */ @@ -1246,10 +1263,12 @@ static inline int handle_bounce_recv(struct fi_cq_tagged_entry *cq_entry, int ra NCCL_OFI_TRACE_SEND_CTRL_RECV(comm->dev_id, rail_id, comm, msg_seq_num); nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)comm; assert(s_comm->local_comm_id == local_comm_id); - assert(bounce_data->recv_len == sizeof(nccl_net_ofi_rdma_ctrl_msg_t)); + assert(bounce_data->recv_len <= RDMA_CTRL_MSG_MAX_SIZE); return handle_ctrl_recv(s_comm, msg_seq_num, bounce_req, ep); } else if (comm->type == NCCL_NET_OFI_RECV_COMM) { + NCCL_OFI_WARN("Eager receive is not yet supported!"); + assert(false); abort(); /* Eager message */ NCCL_OFI_TRACE_EAGER_RECV(comm->dev_id, rail_id, comm, msg_seq_num); nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)comm; @@ -1278,7 +1297,9 @@ static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data nccl_ofi_msgbuff_elemtype_t type; nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, + /* We don't have a multi-recv tag here, so we rely on msg_seq_num matching + our seq num */ + nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_retrieve_notag(r_comm->msgbuff, msg_seq_num, &elem, &type, &stat); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { /* Unexpected: we don't have a msgbuff entry corresponding to this message*/ @@ -1305,6 +1326,7 @@ static inline int handle_write_comp(struct fi_cq_tagged_entry *cq_entry, return ncclSystemError; } assert(req->type == NCCL_OFI_RDMA_RECV); + assert(req->msg_seq_num == GET_SEQ_NUM_FROM_IMM(cq_entry->data)); rdma_req_recv_data_t *recv_data = get_recv_data(req); nccl_net_ofi_rdma_req_t *recv_segms_req = recv_data->recv_segms_req; @@ -2221,34 +2243,13 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm) #define __compiler_barrier() do { asm volatile ("" : : : "memory"); } while(0) -static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) +static int test_req(nccl_net_ofi_rdma_req_t *req, int *done, int *size) { - int ret = 0; - nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)base_req; *done = 0; - assert(req->type == NCCL_OFI_RDMA_SEND || - req->type == NCCL_OFI_RDMA_RECV || - req->type == NCCL_OFI_RDMA_FLUSH); - - /* Retrieve and validate comm */ - nccl_net_ofi_comm_t *base_comm = req->comm; - assert(base_comm != NULL); - - /* Retrieve and validate endpoint */ - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_comm->ep; - assert(ep != NULL); - - /* Process more completions unless the current request is - * completed */ - if (req->state != NCCL_OFI_RDMA_REQ_COMPLETED - && OFI_LIKELY(req->state != NCCL_OFI_RDMA_REQ_ERROR)) { - ret = ofi_process_cq(ep); - if (OFI_UNLIKELY(ret != 0)) - goto exit; - } + int ret = 0; /* Determine whether the request has finished without error and free if done */ - if (OFI_LIKELY(req->state == NCCL_OFI_RDMA_REQ_COMPLETED)) { + if (req->state == NCCL_OFI_RDMA_REQ_COMPLETED) { size_t req_size; if (pthread_mutex_lock(&req->req_lock)) { NCCL_OFI_WARN("Unable to acquire req_lock mutex"); @@ -2268,35 +2269,146 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) *size = req_size; /* Mark as done */ *done = 1; + } else if (OFI_UNLIKELY(req->state == NCCL_OFI_RDMA_REQ_ERROR)) { + NCCL_OFI_WARN("Request completed with error"); + ret = ncclSystemError; + goto exit; + } +exit: + return ret; +} + +static int test_free_req(nccl_net_ofi_rdma_req_t *req) +{ + int ret = 0; + if (req->type != NCCL_OFI_RDMA_FLUSH) { + uint16_t multi_recv_start; + uint16_t multi_recv_size; + int multi_recv_tag; + + /* Retrieve and validate comm */ + nccl_net_ofi_comm_t *base_comm = req->comm; + assert(base_comm != NULL); + + /* Mark as complete in message buffer */ + nccl_ofi_msgbuff_t *msgbuff; + + if (req->type == NCCL_OFI_RDMA_SEND) { + msgbuff = ((nccl_net_ofi_rdma_send_comm_t *)base_comm)->msgbuff; + rdma_req_send_data_t *send_data = get_send_data(req); + multi_recv_start = send_data->multi_recv_start; + multi_recv_size = send_data->multi_recv_size; + multi_recv_tag = send_data->multi_recv_tag; + } else if (req->type == NCCL_OFI_RDMA_RECV) { + msgbuff = ((nccl_net_ofi_rdma_recv_comm_t *)base_comm)->msgbuff; + rdma_req_recv_data_t *recv_data = get_recv_data(req); + multi_recv_start = recv_data->multi_recv_start; + multi_recv_size = recv_data->multi_recv_size; + multi_recv_tag = recv_data->multi_recv_tag; + } else { + NCCL_OFI_WARN("Unexpected request type: %d", req->type); + ret = ncclSystemError; + goto exit; + } + + nccl_ofi_msgbuff_status_t stat; + nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_complete(msgbuff, req->msg_seq_num, multi_recv_start, + multi_recv_size, multi_recv_tag, &stat); + if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { + NCCL_OFI_WARN("Invalid result (%d,%d) of msgbuff_complete for msg %hu type %d", mb_res, stat, req->msg_seq_num, req->type); + ret = ncclSystemError; + goto exit; + } + } + assert(req->free); + req->free(req, true); + +exit: + return ret; +} + +static int free_multirecv_req(nccl_net_ofi_rdma_req_t *req) +{ + int ret = 0; + while (req) { + nccl_net_ofi_rdma_req_t *next_req = get_recv_data(req)->multi_recv_next; + ret = test_free_req(req); + if (OFI_UNLIKELY(ret != 0)) { + return ret; + } + req = next_req; + } + return ret; +} + +static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) +{ + int ret = 0; + nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)base_req; + *done = 0; + assert(req->type == NCCL_OFI_RDMA_SEND || + req->type == NCCL_OFI_RDMA_RECV || + req->type == NCCL_OFI_RDMA_FLUSH); + + /* Retrieve and validate comm */ + nccl_net_ofi_comm_t *base_comm = req->comm; + assert(base_comm != NULL); - if (req->type != NCCL_OFI_RDMA_FLUSH) { - /* Mark as complete in message buffer */ - nccl_ofi_msgbuff_t *msgbuff; - if (req->type == NCCL_OFI_RDMA_SEND) { - msgbuff = ((nccl_net_ofi_rdma_send_comm_t *)base_comm)->msgbuff; - } else if (req->type == NCCL_OFI_RDMA_RECV) { - msgbuff = ((nccl_net_ofi_rdma_recv_comm_t *)base_comm)->msgbuff; + /* Retrieve and validate endpoint */ + nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_comm->ep; + assert(ep != NULL); + + if (req->type == NCCL_OFI_RDMA_RECV && + get_recv_data(req)->multi_recv_size > 1) { +#ifndef NDEBUG + uint16_t multi_recv_size = get_recv_data(req)->multi_recv_size; +#endif + /* Multi-recv: test each request individually */ + bool processed_cq = false; + int i = 0; + while (req) { + ret = test_req(req, done, &size[i]); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + if (*done) { + req = get_recv_data(req)->multi_recv_next; + ++i; } else { - NCCL_OFI_WARN("Unexpected request type: %d", req->type); - ret = ncclSystemError; + if (!processed_cq) { + ret = ofi_process_cq(ep); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + processed_cq = true; + } else { + break; + } + } + } + if (*done) { + assert(i == multi_recv_size); + req = (nccl_net_ofi_rdma_req_t *)base_req; + ret = free_multirecv_req(req); + } + } else { + ret = test_req(req, done, size); + if (OFI_UNLIKELY(ret)) { + goto exit; + } + if (!(*done)) { + ret = ofi_process_cq(ep); + if (OFI_UNLIKELY(ret != 0)) { goto exit; } - - nccl_ofi_msgbuff_status_t stat; - nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_complete(msgbuff, req->msg_seq_num, &stat); - if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { - NCCL_OFI_WARN("Invalid result of msgbuff_complete for msg %hu", req->msg_seq_num); - ret = ncclSystemError; + ret = test_req(req, done, size); + if (OFI_UNLIKELY(ret)) { goto exit; } } - - assert(req->free); - req->free(req, true); - } else if (OFI_UNLIKELY(req->state == NCCL_OFI_RDMA_REQ_ERROR)) { - NCCL_OFI_WARN("Request completed with error"); - ret = ncclSystemError; - goto exit; + if (*done) { + ret = test_free_req(req); + } } exit: @@ -2793,6 +2905,10 @@ static inline int insert_send_ctrl_req( return ncclSystemError; } + rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); + + recv_data->total_num_compls = 2; + send_ctrl_req->comm = &r_comm->base.base; send_ctrl_req->dev_id = dev_id; send_ctrl_req->type = NCCL_OFI_RDMA_SEND_CTRL; @@ -2800,11 +2916,14 @@ static inline int insert_send_ctrl_req( send_ctrl_req->msg_seq_num = msg_seq_num; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req); + send_ctrl_data->ctrl_msg_size = sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + + recv_data->multi_recv_size * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t); send_ctrl_data->recv_req = recv_req; send_ctrl_data->ctrl_fl_item = NULL; send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, - sizeof(nccl_net_ofi_rdma_ctrl_msg_t), - device->num_rails); + sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + + (recv_data->multi_recv_size * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t)), + device->num_rails); if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) { return ncclInternalError; @@ -2835,13 +2954,16 @@ static inline int insert_send_ctrl_req( return ncclInternalError; } - ctrl_fl_item->ctrl_msg.buff_addr = (uint64_t)buff; - ctrl_fl_item->ctrl_msg.buff_len = size; + ctrl_fl_item->ctrl_msg.msg_seq_num = msg_seq_num; + ctrl_fl_item->ctrl_msg.multi_recv_size = recv_data->multi_recv_size; + ctrl_fl_item->ctrl_msg.entries[0].multi_recv_tag = recv_data->multi_recv_tag; + ctrl_fl_item->ctrl_msg.entries[0].buff_addr = (uint64_t)buff; + ctrl_fl_item->ctrl_msg.entries[0].buff_len = size; int rail_id = 0; for (; rail_id < r_comm->num_rails; rail_id++) { - ctrl_fl_item->ctrl_msg.buff_mr_key[rail_id] = fi_mr_key(buff_mr_handle->mr[rail_id]); + ctrl_fl_item->ctrl_msg.entries[0].buff_mr_key[rail_id] = fi_mr_key(buff_mr_handle->mr[rail_id]); - if (ctrl_fl_item->ctrl_msg.buff_mr_key[rail_id] == FI_KEY_NOTAVAIL) { + if (ctrl_fl_item->ctrl_msg.entries[0].buff_mr_key[rail_id] == FI_KEY_NOTAVAIL) { NCCL_OFI_WARN("RDMA write buffers should be pre-registered"); return ncclInternalError; } @@ -2849,7 +2971,6 @@ static inline int insert_send_ctrl_req( send_ctrl_data->ctrl_fl_item = ctrl_fl_item; - rdma_req_recv_data_t *recv_data = get_recv_data(recv_req); recv_data->send_ctrl_req = send_ctrl_req; return 0; @@ -2896,8 +3017,9 @@ static inline int insert_recv_segms_req( static inline int allocate_rdma_recv_req( nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_device_t *device, - int dev_id, uint16_t msg_seq_num, void *buff, - size_t size, + int dev_id, uint16_t msg_seq_num, + uint16_t multi_recv_start, uint16_t multi_recv_size, + int multi_recv_tag, void *buff, size_t size, nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, nccl_net_ofi_rdma_req_t **ret_req) { @@ -2920,18 +3042,19 @@ static inline int allocate_rdma_recv_req( req->msg_seq_num = msg_seq_num; recv_data = get_recv_data(req); - recv_data->total_num_compls = 2; + recv_data->total_num_compls = 1; recv_data->eager_copy_req = NULL; recv_data->dst_buff = buff; recv_data->dst_len = size; recv_data->dest_mr_handle = buff_mr_handle; - /* TODO consolidate arguments to insert_send_ctrl_req and insert_recv_segms_req */ - ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req); - if (ret) { - NCCL_OFI_WARN("Failed to insert send ctrl request into recv request"); - return ret; - } + /* Populate multi-recv data */ + recv_data->multi_recv_size = multi_recv_size; + recv_data->multi_recv_start = multi_recv_start; + recv_data->multi_recv_tag = multi_recv_tag; + recv_data->multi_recv_next = NULL; + + recv_data->send_ctrl_req = NULL; ret = insert_recv_segms_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req); if (ret) { @@ -2951,15 +3074,21 @@ static inline int insert_rdma_recv_req_into_msgbuff(nccl_net_ofi_rdma_recv_comm_ nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; + rdma_req_recv_data_t *recv_data = get_recv_data(req); + if (eager) { + assert(false); + assert(recv_data->multi_recv_size == 1); /* * There is already a buffer entry in the message buffer, so * replace it with a request. */ mb_res = nccl_ofi_msgbuff_replace(r_comm->msgbuff, - req->msg_seq_num, req, + req->msg_seq_num, recv_data->multi_recv_start, + recv_data->multi_recv_size, + recv_data->multi_recv_tag, req, NCCL_OFI_MSGBUFF_REQ, - &msg_stat); + &msg_stat, NULL); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("Unexpected result of nccl_ofi_msgbuff_replace for msg %hu", req->msg_seq_num); @@ -2967,13 +3096,17 @@ static inline int insert_rdma_recv_req_into_msgbuff(nccl_net_ofi_rdma_recv_comm_ } } else { /* Try inserting the new request */ - mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, req->msg_seq_num, req, - NCCL_OFI_MSGBUFF_REQ, &msg_stat); + mb_res = nccl_ofi_msgbuff_insert(r_comm->msgbuff, req->msg_seq_num, + recv_data->multi_recv_start, + recv_data->multi_recv_size, + recv_data->multi_recv_tag, req, + NCCL_OFI_MSGBUFF_REQ, &msg_stat); if (OFI_UNLIKELY((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && (msg_stat == NCCL_OFI_MSGBUFF_INPROGRESS))) { /* Unlikely: an eager message was received on another thread. Return NULL and let NCCL call recv again. */ + assert(false); /* Reduce testing surface for now. TODO remove. */ req->free(req, false); *ret_req = NULL; } else if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_SUCCESS)) { @@ -3020,7 +3153,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, assert(r_comm != NULL); - if (OFI_UNLIKELY(r_comm->num_inflight_reqs == NCCL_OFI_MAX_REQUESTS)) { + if (OFI_UNLIKELY(r_comm->num_inflight_reqs + n > NCCL_OFI_MAX_REQUESTS)) { ret = -ENOSPC; NCCL_OFI_WARN("Can not support more than %d inflight requests", NCCL_OFI_MAX_REQUESTS); @@ -3045,109 +3178,92 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - uint16_t msg_seq_num = r_comm->next_msg_seq_num; + uint16_t base_msg_seq_num = r_comm->next_msg_seq_num; - bool eager = false; - void *elem; - nccl_ofi_msgbuff_elemtype_t type; - nccl_ofi_msgbuff_status_t msg_stat; - nccl_ofi_msgbuff_result_t mb_res; + nccl_net_ofi_rdma_req_t *multirecv_base_req = NULL; + nccl_net_ofi_rdma_req_t *multirecv_prev_req = NULL; + rdma_req_recv_data_t *base_recv_data = NULL; - mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, &elem, - &type, &msg_stat); - if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { + assert(n <= NCCL_OFI_MAX_RECVS); - if (type == NCCL_OFI_MSGBUFF_REQ) { - /* Shouldn't happen: duplicate request */ - NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num); - ret = -EINVAL; - goto error; - } else if (type == NCCL_OFI_MSGBUFF_BUFF) { - /* This is an eager message */ - eager = true; - } else { - NCCL_OFI_WARN("Invalid type in msg buff"); - ret = -EINVAL; - goto error; - } - } else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && - (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) { - /* Allocate a new req */ - } else { - NCCL_OFI_WARN("Message %hu has invalid status.", msg_seq_num); - ret = -EINVAL; - goto error; - } + for (uint16_t i = 0; i < n; ++i) { + uint16_t msg_idx = base_msg_seq_num + i; + bool eager = false; - ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_seq_num, - buffers[0], sizes[0], - mr_handles[0], &req); - if (ret != 0) { - goto error; - } + /* Eager TODO: check for existing request */ - rdma_req_recv_data_t *recv_data = get_recv_data(req); + ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_idx, + base_msg_seq_num, n, tags[i], + buffers[i], sizes[i], + mr_handles[i], &req); + if (ret != 0) { + goto error; + } - if (eager) { - nccl_net_ofi_rdma_req_t *bounce_req = elem; - rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); - if (bounce_data->recv_len == 0) { - /* Special case for zero-sized messages */ - ret = check_post_bounce_req(bounce_req); - if (ret != 0) { - NCCL_OFI_WARN("Failed call to check_post_bounce_req"); + if (i == 0) { + ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_idx, buffers[i], + sizes[i], mr_handles[i], req); + if (ret) { + NCCL_OFI_WARN("Failed to insert send ctrl request into recv request"); return ret; } - recv_data->eager_copy_req = NULL; } else { - ret = alloc_eager_copy_req(req, r_comm, bounce_req); - if (ret != 0) { - goto error; + /* Fill in info for this req */ + assert(multirecv_base_req); + assert(n == base_recv_data->multi_recv_size); + nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = + get_send_ctrl_data(base_recv_data->send_ctrl_req)->ctrl_fl_item; + ctrl_fl_item->ctrl_msg.entries[i].multi_recv_tag = tags[i]; + ctrl_fl_item->ctrl_msg.entries[i].buff_addr = (uint64_t)buffers[i]; + ctrl_fl_item->ctrl_msg.entries[i].buff_len = sizes[i]; + for (int rail_id = 0; rail_id < r_comm->num_rails; ++rail_id) { + ctrl_fl_item->ctrl_msg.entries[i].buff_mr_key[rail_id] = + fi_mr_key(mr_handles[i]->mr[rail_id]); + + if (ctrl_fl_item->ctrl_msg.entries[i].buff_mr_key[rail_id] == FI_KEY_NOTAVAIL) { + NCCL_OFI_WARN("RDMA write buffers should be pre-registered"); + return ncclInternalError; + } } } - } - ret = insert_rdma_recv_req_into_msgbuff(r_comm, eager, &req); - if (ret != 0) { - goto free_req; - } else if (req == NULL) { - ret = -ENOMEM; - goto free_req; - } + if (multirecv_prev_req != NULL) { + get_recv_data(multirecv_prev_req)->multi_recv_next = req; + } + multirecv_prev_req = req; + if (i == 0) { + multirecv_base_req = req; + base_recv_data = get_recv_data(multirecv_base_req); + } - /* At this point, we've successfully inserted a new request, so update the num inflight. */ - (r_comm->num_inflight_reqs)++; + /* Eager TODO: allocate eager copy req */ + + ret = insert_rdma_recv_req_into_msgbuff(r_comm, eager, &req); + if (ret != 0) { + goto free_req; + } else if (req == NULL) { + ret = -ENOMEM; + goto free_req; + } + + /* At this point, we've successfully inserted a new request, so update the num inflight. */ + (r_comm->num_inflight_reqs)++; - NCCL_OFI_TRACE_RECV(dev_id, r_comm->local_comm_id, sizes[0], req, base_req); + NCCL_OFI_TRACE_RECV(dev_id, r_comm->local_tag, sizes[0], req, base_req); - ret = receive_progress(recv_data->send_ctrl_req, true); + /* Eager TODO: post eager copy req */ + } + + ret = receive_progress(base_recv_data->send_ctrl_req, true); if (OFI_UNLIKELY(ret != 0)) { /* TODO: Remove req from message buffer */ goto error; } - if (eager) { - if (recv_data->eager_copy_req == NULL) { - /* If we don't need to do eager copy, this recv is already complete */ - ret = inc_req_completion(req, 0, recv_data->total_num_compls); - if (ret != 0) { - goto error; - } - } else { - /* Post eager copy */ - ret = receive_progress(recv_data->eager_copy_req, true); - if (ret != 0) { - NCCL_OFI_WARN("Failed to issue eager read"); - /* TODO: Remove req from message buffer */ - goto error; - } - } - } - /* Return request to NCCL */ - *base_req = (nccl_net_ofi_req_t *)req; + *base_req = (nccl_net_ofi_req_t *)multirecv_base_req; /* Increment next_msg_seq_num for next call */ - ++(r_comm->next_msg_seq_num); + r_comm->next_msg_seq_num += n; goto exit; @@ -3329,6 +3445,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, int *sizes, nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **base_req) { + assert(n == 1); int ret = 0; nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; @@ -3621,7 +3738,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen return NULL; } - ret = nccl_ofi_freelist_init_mr(sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t), 8, 8, + ret = nccl_ofi_freelist_init_mr(RDMA_CTRL_FL_ITEM_MAX_SIZE, 8, 8, NCCL_OFI_MAX_REQUESTS, freelist_regmr_host_fn, freelist_deregmr_host_fn, ep, 0, 1, &r_comm->ctrl_buff_fl); @@ -3755,6 +3872,7 @@ static int close_listen_recv_comm(nccl_net_ofi_rdma_listen_comm_t *l_comm) } if (l_comm->req.state == NCCL_OFI_RDMA_REQ_PENDING) { + assert(false); NCCL_OFI_WARN("Unable to free request of listen communicator. Request is still pending. Leaking memory."); return ncclInternalError; } @@ -3882,7 +4000,7 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, ret = ncclInternalError; goto exit; } - + /* Set r_comm's (local) comm ID to be sent back to remote */ conn_msg->local_comm_id = r_comm->local_comm_id; @@ -3981,7 +4099,7 @@ static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm) if (l_comm->req.state == NCCL_OFI_RDMA_REQ_PENDING) { NCCL_OFI_WARN("Unable to free request of listen communicator. Request is still pending. Leaking memory."); - return ncclInternalError; + return 0; } if (l_comm->r_comm) { @@ -4056,6 +4174,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, } l_comm->comm_id = (uint64_t)comm_id; handle->comm_id = l_comm->comm_id; + handle->l_comm_ptr = (uintptr_t)l_comm; /* Prepare receive request to accept connections */ ret = prepare_recv_conn_req(l_comm); @@ -4109,7 +4228,7 @@ static int dereg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, } static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm_t *s_comm, - uint16_t msg_seq_num, + uint16_t msg_seq_num, int multi_recv_tag, void *buff, size_t size, nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, bool eager, bool have_ctrl, @@ -4152,6 +4271,11 @@ static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm_t *s_comm, req->msg_seq_num, send_data->schedule->num_xfer_infos); + /* Initialize for now. It will be populated later with correct info from receiver*/ + send_data->multi_recv_size = 0; + send_data->multi_recv_start = 0; + send_data->multi_recv_tag = multi_recv_tag; + *ret_req = req; return 0; @@ -4159,36 +4283,46 @@ static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm_t *s_comm, static int insert_rdma_send_req_into_msgbuff(nccl_net_ofi_rdma_send_comm_t *s_comm, int dev_id, bool have_ctrl, - nccl_net_ofi_rdma_req_t **ret_req) + nccl_net_ofi_rdma_req_t **ret_req, + bool *multi_send_ready) { nccl_net_ofi_rdma_req_t *req = *ret_req; nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; + rdma_req_send_data_t *send_data = get_send_data(req); + if (have_ctrl) { /* * There is already a buffer entry in the message buffer, * so replace it with a request. */ mb_res = nccl_ofi_msgbuff_replace(s_comm->msgbuff, - req->msg_seq_num, req, + req->msg_seq_num, send_data->multi_recv_start, + send_data->multi_recv_size, + send_data->multi_recv_tag, + req, NCCL_OFI_MSGBUFF_REQ, - &msg_stat); + &msg_stat, multi_send_ready); if (mb_res != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("Unexpected result of nccl_ofi_msgbuff_replace for msg %hu", req->msg_seq_num); return ncclSystemError; } } else { + assert(false); abort(); /* Try inserting the new request */ mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, - req->msg_seq_num, req, + req->msg_seq_num, send_data->multi_recv_start, + send_data->multi_recv_size, + send_data->multi_recv_tag, req, NCCL_OFI_MSGBUFF_REQ, &msg_stat); if (OFI_UNLIKELY((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && (msg_stat == NCCL_OFI_MSGBUFF_INPROGRESS))) { /* Unlikely: a ctrl message was received on another thread. Return NULL and let NCCL call send again. */ + assert(false); req->free(req, false); *ret_req = NULL; } else if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_SUCCESS)) { @@ -4210,6 +4344,13 @@ static int post_rdma_write(nccl_net_ofi_rdma_req_t *req, struct fid_mr *rail_mr_handle = send_data->buff_mr_handle->mr[rail_id]; void *desc = fi_mr_desc(rail_mr_handle); + /* For multi-recv, in wdata, we need to make sure we use the same msg_seq_num as + receiver has, so recompute the wdata */ + send_data->wdata = + GET_RDMA_WRITE_IMM_DATA(((nccl_net_ofi_rdma_send_comm_t*)(req->comm))->remote_comm_id, + send_data->recv_side_msg_seq_num, + send_data->schedule->num_xfer_infos); + ssize_t rc; /* Post RDMA write */ rc = fi_writedata(comm_rail->local_ep, send_data->buff + xfer_info->offset, @@ -4384,7 +4525,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) uint64_t data = GET_RDMA_WRITE_IMM_DATA(r_comm->remote_comm_id, req->msg_seq_num, 0); ssize_t rc = fi_tsenddata(comm_rail->local_ep, &ctrl_fl_item->ctrl_msg, - sizeof(nccl_net_ofi_rdma_ctrl_msg_t), desc, + send_ctrl_data->ctrl_msg_size, desc, data, comm_rail->remote_addr, RDMA_DATA_TAG, req); if ((rc != 0) && (rc != -FI_EAGAIN)) { @@ -4544,6 +4685,57 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req) return ret; } +static int rdma_post_multi_send(nccl_net_ofi_rdma_send_comm_t *s_comm, uint16_t multi_recv_start, + uint16_t multi_recv_size) +{ + int ret = 0; + + nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; + + for (uint16_t idx = multi_recv_start; idx != (uint16_t)(multi_recv_start+multi_recv_size); ++idx) { + void *elem; + nccl_ofi_msgbuff_elemtype_t type; + nccl_ofi_msgbuff_status_t stat; + nccl_ofi_msgbuff_result_t res = nccl_ofi_msgbuff_retrieve_notag(s_comm->msgbuff, + idx, &elem, &type, &stat); + if (res != NCCL_OFI_MSGBUFF_SUCCESS) { + assert(false); + abort(); + } + assert(elem); + assert(type == NCCL_OFI_MSGBUFF_REQ); + assert(stat == NCCL_OFI_MSGBUFF_INPROGRESS); + nccl_net_ofi_rdma_req_t *req = elem; + + ret = send_progress(req); + if (ret == -FI_EAGAIN) { + /* Add to pending reqs queue */ + assert(ep != NULL); + ret = nccl_ofi_deque_insert_back(ep->pending_reqs_queue, &req->pending_reqs_elem); + if (ret != 0) { + assert(false); + NCCL_OFI_WARN("Failed to nccl_ofi_deque_insert_back: %d", ret); + goto exit; + } + } else if (OFI_UNLIKELY(ret != 0)) { + /* TODO: Remove req from message buffer */ + ret = -ENOTSUP; + assert(false); + goto exit; + } + } + assert(!ret); + if (ret) goto exit; + ret = process_cq_if_pending(ep); + if (ret == -FI_EAGAIN) { + ret = 0; + } else if (ret != 0) { + assert(false); + } +exit: + return ret; +} + /** * @brief Send a message. This "interface function" is called, indirectly, from * the application @@ -4603,11 +4795,6 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } - /* - * TODO: Use NCCL provided tags when using grouped receives aka - * props->maxRecvs > 1. - */ - bool have_ctrl = false; uint16_t msg_seq_num = s_comm->next_msg_seq_num; @@ -4616,9 +4803,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; - /* Retrive entry from message buffer for msg_seq_num index */ - mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, &elem, - &type, &msg_stat); + /* Retrive entry from message buffer for msg_seq_num index. + At this point we don't have multi-recv info */ + mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, 0, 0, tag, + &elem, &type, &msg_stat); if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { if (type == NCCL_OFI_MSGBUFF_BUFF) { /* @@ -4628,8 +4816,13 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t have_ctrl = true; } else if (type == NCCL_OFI_MSGBUFF_REQ) { /* Shouldn't happen: we already have a req in the message buffer */ - NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num); - ret = ncclSystemError; + //NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num); + //ret = ncclSystemError; + ret = ofi_process_cq(ep); + if (ret != 0) { + goto error; + } + ret = ncclSuccess; goto error; } else { NCCL_OFI_WARN("Unexpected type of buffer retrieved from message buffer: %d", @@ -4638,13 +4831,21 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } } else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) && - (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) { + (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED || msg_stat == NCCL_OFI_MSGBUFF_UNAVAILABLE)) { /* * We haven't encountered this message sequence number. * Allocate a request so that we are able to send RDMA write * as soon as we receive the RDMA control message. */ have_ctrl = false; + /** Just return a NULL req here **/ + /** Eager TODO: this will be an eager message if small enough */ + ret = ofi_process_cq(ep); + if (ret != 0) { + goto error; + } + ret = ncclSuccess; + goto free_req; } else { NCCL_OFI_WARN("Message %hu has invalid status. res = %d and stat = %d", msg_seq_num, mb_res, msg_stat); @@ -4654,33 +4855,28 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t /* Determine if this should be sent eagerly. */ bool eager = false; - if ((!have_ctrl && size <= eager_max_size) || - (size == 0)) { - eager = true; - } + /* Eager TODO */ - ret = alloc_rdma_send_req(s_comm, msg_seq_num, data, + ret = alloc_rdma_send_req(s_comm, msg_seq_num, tag, data, size, mr_handle, eager, have_ctrl, &req); if (OFI_UNLIKELY(ret != 0)) { + assert(false); goto error; } + assert(have_ctrl); if (have_ctrl) { /* * For already received RDMA control message, populate * the RDMA write metadata from the bounce buffer */ nccl_net_ofi_rdma_req_t *bounce_req = elem; - copy_ctrl_data(bounce_req, req); - - /* Post if needed */ - ret = check_post_bounce_req(bounce_req); - if (OFI_UNLIKELY(ret != 0)) { - goto error; - } + copy_ctrl_data(bounce_req, req, tag); } - ret = insert_rdma_send_req_into_msgbuff(s_comm, dev_id, have_ctrl, &req); + bool multi_send_ready = false; + ret = insert_rdma_send_req_into_msgbuff(s_comm, dev_id, have_ctrl, &req, + &multi_send_ready); if (ret != 0 || req == NULL) { goto free_req; } @@ -4693,25 +4889,25 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t NCCL_OFI_TRACE_SEND(req->dev_id, size, s_comm, msg_seq_num, req, base_req); - /* Try posting RDMA write for received RDMA control messages */ - if (have_ctrl || eager) { + assert(!eager); - ret = send_progress(req); - if (ret == -FI_EAGAIN) { - /* Add to pending reqs queue */ - ret = nccl_ofi_deque_insert_back(ep->pending_reqs_queue, &req->pending_reqs_elem); - if (ret != 0) { - NCCL_OFI_WARN("Failed to nccl_ofi_deque_insert_back: %d", ret); - goto error; - } - NCCL_OFI_TRACE_PENDING_INSERT(req); - } else if (OFI_UNLIKELY(ret != 0)) { - /* TODO: Remove req from message buffer */ - ret = -ENOTSUP; + if (multi_send_ready) { + rdma_req_send_data_t *send_data = get_send_data(req); + + ret = rdma_post_multi_send(s_comm, send_data->multi_recv_start, + send_data->multi_recv_size); + + /* Re-post bounce buffer if needed */ + nccl_net_ofi_rdma_req_t *bounce_req = elem; + ret = check_post_bounce_req(bounce_req); + if (OFI_UNLIKELY(ret != 0)) { + assert(false); goto error; } } + /* Eager TODO: post eager message */ + /* Return request to NCCL */ *base_req = &req->base; /* Increment next_msg_seq_num for next call */ @@ -4724,6 +4920,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t if (req) req->free(req, false); *base_req = NULL; + exit: return ret; } @@ -5058,6 +5255,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, /* Allocate and initialize connect message */ prepare_send_connect_message(ep, dev_id, ret_s_comm->local_comm_id, handle, &ret_s_comm->conn_msg); + ret_s_comm->conn_msg.s_comm_ptr = (uintptr_t)ret_s_comm; /* Allocate message buffer */ ret_s_comm->msgbuff = nccl_ofi_msgbuff_init(NCCL_OFI_RDMA_MSGBUFF_SIZE); @@ -5246,7 +5444,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, /* Build send communicator with one comm rail */ ret = create_send_comm(handle, ep, &s_comm); if (OFI_UNLIKELY(ret != 0)) { - return ret; + assert(false); abort(); return ret; } comm_state->comm = &s_comm->base.base; @@ -5279,27 +5477,29 @@ static int connect(nccl_net_ofi_ep_t *base_ep, * has been sent. Afterwards, reset previously used * request. */ - /* Progress our engine to get completions */ - ret = ofi_process_cq(ep); - if (OFI_UNLIKELY(ret != 0)) { - /* Send communicator cannot be closed since - * send request of send connect message is - * still pending */ - return ret; - } + do { + /* Progress our engine to get completions */ + ret = ofi_process_cq(ep); + if (OFI_UNLIKELY(ret != 0)) { + /* Send communicator cannot be closed since + * send request of send connect message is + * still pending */ + return ret; + } - /* Check if the connect message is sent */ - ret = pthread_mutex_lock(&req->req_lock); - if (OFI_UNLIKELY(ret)) { - NCCL_OFI_WARN("Unable to acquire req_lock mutex"); - return ncclInternalError; - } - conn_msg_state = req->state; - ret = pthread_mutex_unlock(&req->req_lock); - if (OFI_UNLIKELY(ret)) { - NCCL_OFI_WARN("Failed to unlock req_lock mutex"); - return ncclInternalError; - } + /* Check if the connect message is sent */ + ret = pthread_mutex_lock(&req->req_lock); + if (OFI_UNLIKELY(ret)) { + NCCL_OFI_WARN("Unable to acquire req_lock mutex"); + return ncclInternalError; + } + conn_msg_state = req->state; + ret = pthread_mutex_unlock(&req->req_lock); + if (OFI_UNLIKELY(ret)) { + NCCL_OFI_WARN("Failed to unlock req_lock mutex"); + return ncclInternalError; + } + } while (conn_msg_state != NCCL_OFI_RDMA_REQ_COMPLETED); /* Wait until connect message is sent */ if (conn_msg_state != NCCL_OFI_RDMA_REQ_COMPLETED) { @@ -5573,8 +5773,7 @@ static int get_ep(nccl_net_ofi_device_t *base_dev, /* Initialize reference count */ ep->ref_cnt = 0; - ep->bounce_buff_size = NCCL_OFI_MAX(sizeof(nccl_net_ofi_rdma_ctrl_msg_t), - eager_max_size); + ep->bounce_buff_size = NCCL_OFI_MAX(RDMA_CTRL_MSG_MAX_SIZE, eager_max_size); /* Store endpoint in thread-local variable */ pthread_setspecific(device->ep_key, (void *)ep); diff --git a/tests/unit/msgbuff.c b/tests/unit/msgbuff.c index 074dcb217..63dc1fdd0 100644 --- a/tests/unit/msgbuff.c +++ b/tests/unit/msgbuff.c @@ -26,17 +26,17 @@ int main(int argc, char *argv[]) /** Test insert new **/ for (uint16_t i = 0; i < buff_sz; ++i) { - if (nccl_ofi_msgbuff_insert(msgbuff, i, &buff_store[i], type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { + if (nccl_ofi_msgbuff_insert(msgbuff, i, i, 1, 0, &buff_store[i], type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_insert failed when non-full"); return 1; } } - if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz, buff_sz, 1, 0, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_UNAVAILABLE) { NCCL_OFI_WARN("nccl_ofi_msgbuff_insert did not return unavailable when full"); return 1; } - if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz-1, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_insert(msgbuff, buff_sz-1, buff_sz-1, 1, 0, NULL, type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_INPROGRESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_insert did not return inprogress on duplicate insert"); return 1; @@ -45,7 +45,7 @@ int main(int argc, char *argv[]) /** Test retrieve **/ uint16_t *result; for (uint16_t i = 0; i < buff_sz; ++i) { - if (nccl_ofi_msgbuff_retrieve(msgbuff, i, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { + if (nccl_ofi_msgbuff_retrieve(msgbuff, i, i, 1, 0, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_retrieve failed on valid index"); return 1; } @@ -54,12 +54,13 @@ int main(int argc, char *argv[]) return 1; } } - if (nccl_ofi_msgbuff_retrieve(msgbuff, buff_sz, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_retrieve(msgbuff, buff_sz, buff_sz, 1, 0, (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_NOTSTARTED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_retrieve did not return notstarted"); return 1; } - if (nccl_ofi_msgbuff_retrieve(msgbuff, UINT16_C(0) - UINT16_C(1), (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_retrieve(msgbuff, UINT16_C(0) - UINT16_C(1), UINT16_C(0) - UINT16_C(1), 1, 0, + (void**)&result, &type, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_COMPLETED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_retrieve did not return completed"); return 1; @@ -67,17 +68,17 @@ int main(int argc, char *argv[]) /** Test complete **/ for (uint16_t i = 0; i < buff_sz; ++i) { - if (nccl_ofi_msgbuff_complete(msgbuff, i, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { + if (nccl_ofi_msgbuff_complete(msgbuff, i, i, 1, 0, &stat) != NCCL_OFI_MSGBUFF_SUCCESS) { NCCL_OFI_WARN("nccl_ofi_msgbuff_complete failed"); return 1; } } - if (nccl_ofi_msgbuff_complete(msgbuff, buff_sz, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_complete(msgbuff, buff_sz, buff_sz, 1, 0, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_NOTSTARTED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_complete did not return notstarted"); return 1; } - if (nccl_ofi_msgbuff_complete(msgbuff, 0, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || + if (nccl_ofi_msgbuff_complete(msgbuff, 0, 0, 1, 0, &stat) != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_COMPLETED) { NCCL_OFI_WARN("nccl_ofi_msgbuff_complete did not return completed"); return 1;