Skip to content

Commit

Permalink
fix: unload: Cleanup active domains and endpoints
Browse files Browse the repository at this point in the history
When aborting with no connection established, there were QPs leaked if
any domain and endpoint left open. This patch adds domain and endpoint
cleanup logic at the beginning of rdma and sendrecv device release to
prevent the QP leak.

Note the logic can be triggered without #772.

Signed-off-by: Mozar Huang <[email protected]>
  • Loading branch information
mozarhua committed Jan 29, 2025
1 parent 0a96be6 commit 65aa4d7
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 30 deletions.
19 changes: 16 additions & 3 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ struct nccl_net_ofi_device {
*/
nccl_net_ofi_domain_t *(*create_domain)(nccl_net_ofi_device_t *dev);

/*
* release all domains and endpoints. This function is a private
* function, which is called only during release() to free allocated
* domains and endpoints.
*/
int (*release_all_domain_and_ep)(nccl_net_ofi_device_t *dev);

/*
* hash table indexed by thread id of active domains.
*/
Expand Down Expand Up @@ -323,7 +330,7 @@ struct nccl_net_ofi_domain {
/*
* Destructor - release resources associated with the domain
*/
int (*release)(nccl_net_ofi_domain_t *domain);
int (*release)(nccl_net_ofi_domain_t *domain, bool skip_device_lock, bool force_cleanup);

/*
* Protocol-agnostic MR cache for this device.
Expand Down Expand Up @@ -424,7 +431,7 @@ struct nccl_net_ofi_ep {
* endpoint if reference counter becomes zero. Must be
* protected by lock stored in base_dev.
*/
int (*release_ep)(nccl_net_ofi_ep_t *ep);
int (*release_ep)(nccl_net_ofi_ep_t *ep, bool skip_lock, bool force_cleanup);

/* private */
/* pure virtual function called when resources associated with
Expand Down Expand Up @@ -616,7 +623,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p);
* override that function pointer and later call this function
* directly.
*/
int nccl_net_ofi_endpoint_release(nccl_net_ofi_ep_t *ep);
int nccl_net_ofi_endpoint_release(nccl_net_ofi_ep_t *ep, bool skip_lock, bool force_cleanup);

/* initialize resources associated with the endpoint base class.
* Expectation is that this will be called by a transport's endpoint
Expand Down Expand Up @@ -649,6 +656,12 @@ int nccl_net_ofi_device_init(nccl_net_ofi_device_t *device, nccl_net_ofi_plugin_
*/
int nccl_net_ofi_device_fini(nccl_net_ofi_device_t *device);

/* release all domains and their enpoints of a device. This is called
* only by device->release() during plugin release to free all fabric
* domain and QPs.
*/
int nccl_net_ofi_device_release_all_domain_and_ep(nccl_net_ofi_device_t *device);

/*
* Constructor for the nccl_net_ofi_plugin class
*
Expand Down
6 changes: 3 additions & 3 deletions src/nccl_ofi_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ ncclResult_t nccl_net_ofi_listen(int dev_id, void *handle, void **lComm)
listen_comm);

if (ret != 0) {
base_ep->release_ep(base_ep);
base_ep->release_ep(base_ep, false, false);
}
return nccl_net_ofi_retval_translate(ret);
}
Expand Down Expand Up @@ -318,7 +318,7 @@ ncclResult_t nccl_net_ofi_connect(int dev_id, void *handle, void **sComm)
int ret = base_ep->connect(base_ep, (nccl_net_ofi_conn_handle_t *)handle, send_comm);

if (ret != 0) {
base_ep->release_ep(base_ep);
base_ep->release_ep(base_ep, false, false);
}

return nccl_net_ofi_retval_translate(ret);
Expand Down Expand Up @@ -520,7 +520,7 @@ ncclResult_t nccl_net_ofi_accept(void *lComm, void **rComm)
ret = -EINVAL;
goto error;
}
ep->release_ep(ep);
ep->release_ep(ep, false, false);
}

error:
Expand Down
92 changes: 81 additions & 11 deletions src/nccl_ofi_net.c
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
(properties.regIsGlobal == 0) ? "false" : "true");
NCCL_OFI_INFO(NCCL_NET | NCCL_INIT, "Support for DMA-BUF registrations: %s",
(properties.dmabuf_support == 0) ? "false" : "true");
ret = base_ep->release_ep(base_ep);
ret = base_ep->release_ep(base_ep, false, false);
if (ret != 0) {
goto exit;
}
Expand Down Expand Up @@ -816,6 +816,7 @@ int nccl_net_ofi_device_init(nccl_net_ofi_device_t *device, nccl_net_ofi_plugin_
device->get_ep = nccl_net_ofi_device_get_ep;
device->get_mr_key = NULL;
device->release = nccl_net_ofi_device_fini;
device->release_all_domain_and_ep = nccl_net_ofi_device_release_all_domain_and_ep;

/* Intiaialize mutex for endpoint access */
ret = nccl_net_ofi_mutex_init(&device->device_lock, NULL);
Expand Down Expand Up @@ -858,6 +859,58 @@ int nccl_net_ofi_device_fini(nccl_net_ofi_device_t *device)
}


int nccl_net_ofi_device_release_all_domain_and_ep(nccl_net_ofi_device_t *device)
{
int ret, first_error = 0, domain_num;

assert(device != NULL);
nccl_net_ofi_domain_t *domain, *domain_tmp;
nccl_net_ofi_ep_t *ep;

nccl_net_ofi_mutex_lock(&device->device_lock);

domain_num = HASH_COUNT(device->domain_table);
assert(domain_num > 0);
HASH_ITER(hh, device->domain_table, domain, domain_tmp) {
/* For each domain, clean up its endpoints. */
nccl_net_ofi_mutex_lock(&domain->domain_lock);
if (domain->endpoint) {
ep = domain->endpoint;
domain->endpoint = NULL;

ret = ep->release_ep(ep, true, true);
if (ret != 0) {
NCCL_OFI_WARN("Freeing endpoint failed: %d", ret);
if (first_error != 0) {
first_error = ret;
}
}
ep = NULL;
}
nccl_net_ofi_mutex_unlock(&domain->domain_lock);

/* domain->release takes the domain lock, and removes itself
* from domain_table. Skipping device lock here.*/
ret = domain->release(domain, true, true);
if (ret != 0 && first_error != 0) {
first_error = ret;
}

}
nccl_net_ofi_mutex_unlock(&device->device_lock);

domain_num = HASH_COUNT(device->domain_table);
if (OFI_UNLIKELY(domain_num > 0)) {
NCCL_OFI_WARN("%u domains still active after cleanup", domain_num);
if (first_error != 0) {
first_error = -FI_EBUSY; // Anything else than above
}
}

return first_error;
}


static int nccl_net_ofi_domain_get_ep(nccl_net_ofi_domain_t *domain,
nccl_net_ofi_ep_t **ep_p)
{
Expand Down Expand Up @@ -892,7 +945,7 @@ static int nccl_net_ofi_domain_get_ep(nccl_net_ofi_domain_t *domain,
}


static int nccl_net_ofi_domain_release(nccl_net_ofi_domain_t *domain)
static int nccl_net_ofi_domain_release(nccl_net_ofi_domain_t *domain, bool skip_device_lock, bool force_cleanup)
{
int ret = 0;
nccl_net_ofi_device_t *device;
Expand All @@ -902,8 +955,11 @@ static int nccl_net_ofi_domain_release(nccl_net_ofi_domain_t *domain)

nccl_net_ofi_mutex_lock(&domain->domain_lock);

if (domain->endpoint == NULL) {
nccl_net_ofi_mutex_lock(&device->device_lock);
if (domain->endpoint == NULL || force_cleanup) {
// The caller takes device_lock when force_cleanup.
if (!skip_device_lock) {
nccl_net_ofi_mutex_lock(&device->device_lock);
}
HASH_DEL(device->domain_table, domain);

// domain->free below is going to free the domain lock
Expand All @@ -913,7 +969,9 @@ static int nccl_net_ofi_domain_release(nccl_net_ofi_domain_t *domain)
nccl_net_ofi_mutex_unlock(&domain->domain_lock);

ret = domain->free(domain);
nccl_net_ofi_mutex_unlock(&device->device_lock);
if (!skip_device_lock) {
nccl_net_ofi_mutex_unlock(&device->device_lock);
}
if (ret != 0) {
NCCL_OFI_WARN("Freeing domain failed: %d", ret);
return ret;
Expand Down Expand Up @@ -996,21 +1054,28 @@ int nccl_net_ofi_domain_fini(nccl_net_ofi_domain_t *domain)
}


int nccl_net_ofi_endpoint_release(nccl_net_ofi_ep_t *ep)
int nccl_net_ofi_endpoint_release(nccl_net_ofi_ep_t *ep, bool skip_lock, bool force_cleanup)
{
int ret = 0;
nccl_net_ofi_domain_t *domain;

assert(ep != NULL);
domain = ep->domain;

nccl_net_ofi_mutex_lock(&domain->domain_lock);
if (!skip_lock) {
nccl_net_ofi_mutex_lock(&domain->domain_lock);
}

ep->ref_cnt--;

if (ep->ref_cnt == 0) {
if (ep->ref_cnt == 0 || force_cleanup) {
domain->endpoint = NULL;

if (force_cleanup && ep->ref_cnt != 0) {
NCCL_OFI_INFO(NCCL_NET, "Endpoint %p still have ref count %d when released",
ep, ep->ref_cnt);
}

ret = ep->free_ep(ep);
if (ret != 0) {
NCCL_OFI_WARN("Freeing endpoint failed: %d", ret);
Expand All @@ -1019,10 +1084,15 @@ int nccl_net_ofi_endpoint_release(nccl_net_ofi_ep_t *ep)
}

cleanup:
nccl_net_ofi_mutex_unlock(&domain->domain_lock);

if (ret == 0) {
ret = domain->release(domain);
if (!skip_lock) {
nccl_net_ofi_mutex_unlock(&domain->domain_lock);
}

/* Skip domain->release when handled by device->release_all_domain_and_ep()
* to avoid domain lock issue after the domain freed */
if (!force_cleanup && ret == 0) {
ret = domain->release(domain, skip_lock, false);
}

return ret;
Expand Down
36 changes: 26 additions & 10 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -3914,7 +3914,7 @@ static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm)

free_rdma_recv_comm(r_comm);

ret = ep->base.release_ep(&ep->base);
ret = ep->base.release_ep(&ep->base, false, false);

return ret;
}
Expand Down Expand Up @@ -4094,7 +4094,7 @@ static int send_comm_destroy(nccl_net_ofi_rdma_send_comm_t *s_comm)

free_rdma_send_comm(s_comm);

ret = ep->base.release_ep(&ep->base);
ret = ep->base.release_ep(&ep->base, false, false);

return ret;
}
Expand Down Expand Up @@ -5178,7 +5178,7 @@ static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm)
}

free(l_comm);
ret = base_ep->release_ep(base_ep);
ret = base_ep->release_ep(base_ep, false, false);

return ret;
}
Expand Down Expand Up @@ -7057,7 +7057,7 @@ static int init_rail_ofi_resources(nccl_net_ofi_rdma_device_t *device,
}


static int nccl_net_ofi_rdma_endpoint_release(nccl_net_ofi_ep_t *base_ep)
static int nccl_net_ofi_rdma_endpoint_release(nccl_net_ofi_ep_t *base_ep, bool skip_lock, bool force_cleanup)
{
int ret = 0;
nccl_net_ofi_rdma_ep_t *ep = NULL;
Expand All @@ -7083,9 +7083,15 @@ static int nccl_net_ofi_rdma_endpoint_release(nccl_net_ofi_ep_t *base_ep)
return -EINVAL;
}

nccl_net_ofi_mutex_lock(&domain->base.domain_lock);
if (!skip_lock) {
nccl_net_ofi_mutex_lock(&domain->base.domain_lock);
}

if ((--ep->base.ref_cnt) == 0) {
if ((--ep->base.ref_cnt) == 0 || force_cleanup) {
if (force_cleanup && ep->base.ref_cnt != 0 ) {
NCCL_OFI_INFO(NCCL_NET, "Endpoint %p still have ref count %d when released",
ep, ep->base.ref_cnt);
}
ret = nccl_ofi_ep_addr_list_delete(domain->ep_addr_list, &ep->base);
if (ret != 0) {
NCCL_OFI_WARN("delete ep for addr failed: %d", ret);
Expand All @@ -7100,9 +7106,11 @@ static int nccl_net_ofi_rdma_endpoint_release(nccl_net_ofi_ep_t *base_ep)
}

unlock:
nccl_net_ofi_mutex_unlock(&domain->base.domain_lock);
if (!skip_lock) {
nccl_net_ofi_mutex_unlock(&domain->base.domain_lock);
}
} else {
ret = nccl_net_ofi_endpoint_release(&ep->base);
ret = nccl_net_ofi_endpoint_release(&ep->base, skip_lock, force_cleanup);
}

return ret;
Expand Down Expand Up @@ -7282,7 +7290,7 @@ static int nccl_net_ofi_rdma_domain_create_endpoint(nccl_net_ofi_domain_t *base_

error:
if (ret != 0) {
ep->base.release_ep(&(ep->base));
ep->base.release_ep(&(ep->base), false, false);
}

return ret;
Expand Down Expand Up @@ -7402,7 +7410,7 @@ static nccl_net_ofi_domain_t *nccl_net_ofi_rdma_device_create_domain(nccl_net_of

error:
if (ret != 0) {
domain->base.release(&(domain->base));
domain->base.release(&(domain->base), false, false);
domain = NULL;
}

Expand Down Expand Up @@ -7546,6 +7554,14 @@ nccl_net_ofi_rdma_device_release(nccl_net_ofi_device_t *base_device)
unsigned num_domains = HASH_COUNT(device->base.domain_table);
if (num_domains > 0) {
NCCL_OFI_INFO(NCCL_NET, "%u domains still active at close", num_domains);
ret = base_device->release_all_domain_and_ep(base_device);
if (ret != 0) {
NCCL_OFI_WARN("Cleanup of domain failed. RC: %d, ERROR: %s",
ret, fi_strerror(-ret));
if (first_error == 0) {
first_error = ret;
}
}
}

if (device->device_rails != NULL) {
Expand Down
14 changes: 11 additions & 3 deletions src/nccl_ofi_sendrecv.c
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ static int sendrecv_recv_comm_close(nccl_net_ofi_recv_comm_t *recv_comm)
nccl_ofi_freelist_fini(r_comm->nccl_ofi_reqs_fl);
free(recv_comm);

ret = base_ep->release_ep(base_ep);
ret = base_ep->release_ep(base_ep, false, false);
exit:
return ret;
}
Expand Down Expand Up @@ -1538,7 +1538,7 @@ static int sendrecv_listen_comm_close(nccl_net_ofi_listen_comm_t *listen_comm)
goto exit;
}

ret = base_ep->release_ep(base_ep);
ret = base_ep->release_ep(base_ep, false, false);
free(listen_comm);
exit:
return ret;
Expand Down Expand Up @@ -1816,7 +1816,7 @@ static int sendrecv_send_comm_close(nccl_net_ofi_send_comm_t *send_comm)
free(s_comm->conn_info);
free(send_comm);

ret = base_ep->release_ep(base_ep);
ret = base_ep->release_ep(base_ep, false, false);
exit:
return ret;
}
Expand Down Expand Up @@ -2328,6 +2328,14 @@ nccl_net_ofi_sendrecv_device_release(nccl_net_ofi_device_t *base_device)
unsigned num_domains = HASH_COUNT(device->base.domain_table);
if (num_domains > 0) {
NCCL_OFI_INFO(NCCL_NET, "%u domains still active at close", num_domains);
ret = base_device->release_all_domain_and_ep(base_device);
if (ret != 0) {
NCCL_OFI_WARN("Cleanup of domain failed. RC: %d, ERROR: %s",
ret, fi_strerror(-ret));
if (first_error == 0) {
first_error = ret;
}
}
}

if (device->fabric) {
Expand Down

0 comments on commit 65aa4d7

Please sign in to comment.