Skip to content

Commit

Permalink
rdma: support NCCL multi-recv interface
Browse files Browse the repository at this point in the history
The multi-recv interface allows aggregating up to 8 receive requests in
a single request. The changes include msgbuff changes to be tag-aware.

* Temporarily disables eager; it will be re-enabled in a future commit.
* Makes `connect()` blocking; we need to better understand why this is
necessary.

Signed-off-by: Eric Raut <[email protected]>
  • Loading branch information
rauteric committed Feb 22, 2024
1 parent ed4d6e7 commit 349a396
Show file tree
Hide file tree
Showing 7 changed files with 762 additions and 301 deletions.
3 changes: 2 additions & 1 deletion include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ extern "C" {
#define MIN_TAG_BITS_FOR_RING_ID (32 + 1)

/* Maximum number of grouped receives */
#define NCCL_OFI_MAX_RECVS 1
#define NCCL_OFI_MAX_RECVS 8
#define NCCL_OFI_MAX_RECVS_SENDRECV 1

/*
* This defines a higher value than maximum inflight requests supported by NCCL
Expand Down
25 changes: 21 additions & 4 deletions include/nccl_ofi_msgbuff.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ typedef struct {
// Type of element
nccl_ofi_msgbuff_elemtype_t type;
void *elem;
// Multi-recv information
uint16_t multi_recv_size;
uint16_t multi_recv_start;
int multi_recv_tag;
} nccl_ofi_msgbuff_elem_t;

typedef struct {
Expand Down Expand Up @@ -110,9 +114,14 @@ bool nccl_ofi_msgbuff_destroy(nccl_ofi_msgbuff_t *msgbuff);
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type,
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size, int multi_recv_tag,
void *elem, nccl_ofi_msgbuff_elemtype_t type,
nccl_ofi_msgbuff_status_t *msg_idx_status);

nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert_ctrl_multirecv(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_base_index, uint16_t multi_recv_size, void *elem, nccl_ofi_msgbuff_elemtype_t type,
nccl_ofi_msgbuff_status_t *msg_idx_status);

/**
* Replace an existing message element
*
Expand All @@ -126,8 +135,9 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_insert(nccl_ofi_msgbuff_t *msgbuff,
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, void *elem, nccl_ofi_msgbuff_elemtype_t type,
nccl_ofi_msgbuff_status_t *msg_idx_status);
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size,
int multi_recv_tag, void *elem, nccl_ofi_msgbuff_elemtype_t type,
nccl_ofi_msgbuff_status_t *msg_idx_status, bool *multi_send_ready);

/**
* Retrieve message with given index
Expand All @@ -142,6 +152,12 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_replace(nccl_ofi_msgbuff_t *msgbuff,
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size,
int multi_recv_tag, void **elem, nccl_ofi_msgbuff_elemtype_t *type,
nccl_ofi_msgbuff_status_t *msg_idx_status);


nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve_notag(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, void **elem, nccl_ofi_msgbuff_elemtype_t *type,
nccl_ofi_msgbuff_status_t *msg_idx_status);

Expand All @@ -156,7 +172,8 @@ nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_retrieve(nccl_ofi_msgbuff_t *msgbuff,
* NCCL_OFI_MSGBUFF_ERROR, other error
*/
nccl_ofi_msgbuff_result_t nccl_ofi_msgbuff_complete(nccl_ofi_msgbuff_t *msgbuff,
uint16_t msg_index, nccl_ofi_msgbuff_status_t *msg_idx_status);
uint16_t msg_index, uint16_t multi_recv_start, uint16_t multi_recv_size,
int multi_recv_tag, nccl_ofi_msgbuff_status_t *msg_idx_status);

#ifdef _cplusplus
} // End extern "C"
Expand Down
58 changes: 55 additions & 3 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,32 @@ typedef struct nccl_net_ofi_rdma_mr_handle {
struct fid_mr *mr[];
} nccl_net_ofi_rdma_mr_handle_t;

/* Contents of ctrl message sent from receiver to sender to advertise
destination buffer */
typedef struct nccl_net_ofi_rdma_ctrl_msg {
typedef struct nccl_net_ofi_rdma_ctrl_msg_entry {
int multi_recv_tag;
uint64_t buff_addr;
uint64_t buff_len;
uint64_t buff_mr_key[MAX_NUM_RAILS];
} nccl_net_ofi_rdma_ctrl_msg_entry_t;

/* Contents of ctrl message sent from receiver to sender to advertise
destination buffer */
typedef struct nccl_net_ofi_rdma_ctrl_msg {
uint16_t msg_seq_num;
uint16_t multi_recv_size;
nccl_net_ofi_rdma_ctrl_msg_entry_t entries[];
} nccl_net_ofi_rdma_ctrl_msg_t;

#define RDMA_CTRL_MSG_ENTRIES_MAX_SIZE (NCCL_OFI_MAX_RECVS * sizeof(nccl_net_ofi_rdma_ctrl_msg_entry_t))
#define RDMA_CTRL_MSG_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_msg_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE)

/* Structure used to store control messages in a free list */
typedef struct nccl_net_ofi_rdma_ctrl_fl_item {
nccl_ofi_freelist_reginfo_t fl_reginfo;
nccl_net_ofi_rdma_ctrl_msg_t ctrl_msg;
} nccl_net_ofi_rdma_ctrl_fl_item_t;

#define RDMA_CTRL_FL_ITEM_MAX_SIZE (sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t) + RDMA_CTRL_MSG_ENTRIES_MAX_SIZE)

/* For LL/LL128 protocols, bounce buffers (source of RDMA read operations) need to be 128B aligned */
#define BOUNCE_BUFFER_ALIGNMENT 128

Expand Down Expand Up @@ -150,6 +162,13 @@ typedef struct {
/* Total number of completions. Expect one completion for receiving the
* control message and one completion for each send segment. */
int total_num_compls;

/* Multi-recv information */
uint16_t multi_recv_size;
uint16_t multi_recv_start;
int multi_recv_tag;
/* This may not match sender-side seq num with multi-recv */
uint16_t recv_side_msg_seq_num;
} rdma_req_send_data_t;

/*
Expand All @@ -164,6 +183,8 @@ typedef struct {
nccl_net_ofi_schedule_t *ctrl_schedule;
/* Pointer to recv parent request */
nccl_net_ofi_rdma_req_t *recv_req;
/* Size of ctrl message */
size_t ctrl_msg_size;
} rdma_req_send_ctrl_data_t;

typedef struct {
Expand Down Expand Up @@ -204,6 +225,12 @@ typedef struct {
* For eager messages, the second completion will be received
* when the local read into the destination buffer is complete */
int total_num_compls;
/* Multi-recv information */
uint16_t multi_recv_size;
uint16_t multi_recv_start;
int multi_recv_tag;
/* Next req in sequence */
nccl_net_ofi_rdma_req_t *multi_recv_next;
} rdma_req_recv_data_t;

/*
Expand Down Expand Up @@ -625,6 +652,31 @@ typedef struct nccl_net_ofi_rdma_device {
nccl_ofi_idpool_t key_pool;
} nccl_net_ofi_rdma_device_t;

/*
* @brief Return send data struct of send request
*/
static inline rdma_req_send_data_t *get_send_data(nccl_net_ofi_rdma_req_t *req) {
assert(req->type == NCCL_OFI_RDMA_SEND);
return &req->send_data;
}

/*
* @brief Return bounce data struct of bounce request
*/
static inline rdma_req_bounce_data_t *get_bounce_data(nccl_net_ofi_rdma_req_t *req) {
assert(req->type == NCCL_OFI_RDMA_BOUNCE);
return &req->bounce_data;
}

/*
* Get ctrl message from bounce buffer
*/
static inline nccl_net_ofi_rdma_ctrl_msg_t *get_bounce_ctrl_msg
(nccl_net_ofi_rdma_bounce_fl_item_t *bounce_fl_item)
{
return (nccl_net_ofi_rdma_ctrl_msg_t *)&bounce_fl_item->bounce_msg;
}

/*
* @brief Initialize plugin with rdma protocol structures
*/
Expand Down
Loading

0 comments on commit 349a396

Please sign in to comment.