diff --git a/prov/shm/src/smr_av.c b/prov/shm/src/smr_av.c index 355d3bcad64..e827afcea8c 100644 --- a/prov/shm/src/smr_av.c +++ b/prov/shm/src/smr_av.c @@ -67,10 +67,13 @@ static int smr_map_init(const struct fi_provider *prov, struct smr_map *map, static void smr_map_cleanup(struct smr_map *map) { - int64_t i; + int ret; - for (i = 0; i < SMR_MAX_PEERS; i++) - smr_map_del(map, i); + ret = ofi_rbmap_foreach(&map->rbmap, map->rbmap.root, smr_map_unmap, + NULL); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove all entries from the map\n"); ofi_rbmap_cleanup(&map->rbmap); } @@ -115,11 +118,13 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count, struct smr_ep *smr_ep; struct fid_peer_srx *srx; struct dlist_entry *av_entry; + struct ofi_rbnode *node; fi_addr_t util_addr; int64_t shm_id = -1; int i, ret; int succ_count = 0; + printf("start av insert\n"); util_av = container_of(av_fid, struct util_av, av_fid); smr_av = container_of(util_av, struct smr_av, util_av); @@ -148,8 +153,16 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count, if (ret) { if (fi_addr) fi_addr[i] = util_addr; - if (shm_id >= 0) - smr_map_del(&smr_av->smr_map, shm_id); + if (shm_id >= 0) { + node = ofi_rbmap_find(&smr_av->smr_map.rbmap, + &smr_av->smr_map.peers[shm_id].peer.name); + assert(node); + ret = smr_map_del(&smr_av->smr_map, node); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", + shm_id); + } continue; } @@ -190,9 +203,11 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count struct smr_av *smr_av; struct smr_ep *smr_ep; struct dlist_entry *av_entry; + struct ofi_rbnode *node; int i, ret = 0; int64_t id; + printf("av remove\n"); util_av = container_of(av_fid, struct util_av, av_fid); smr_av = container_of(util_av, struct smr_av, util_av); @@ -207,11 +222,18 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count break; } - smr_map_del(&smr_av->smr_map, id); + node = ofi_rbmap_find(&smr_av->smr_map.rbmap, + &smr_av->smr_map.peers[id].peer.name); + assert(node); + ret = smr_map_del(&smr_av->smr_map, node); + if (ret) { + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", id); + break; + } dlist_foreach(&util_av->ep_list, av_entry) { util_ep = container_of(av_entry, struct util_ep, av_entry); smr_ep = container_of(util_ep, struct smr_ep, util_ep); - smr_unmap_from_endpoint(smr_ep->region, id); if (smr_av->smr_map.num_peers > 0) smr_ep->region->max_sar_buf_per_peer = SMR_MAX_PEERS / diff --git a/prov/shm/src/smr_ep.c b/prov/shm/src/smr_ep.c index 8803495e382..39ca6678e7a 100644 --- a/prov/shm/src/smr_ep.c +++ b/prov/shm/src/smr_ep.c @@ -223,7 +223,9 @@ int64_t smr_verify_peer(struct smr_ep *ep, fi_addr_t fi_addr) return id; if (!ep->region->map->peers[id].region) { + ofi_spin_lock(&ep->region->map->lock); ret = smr_map_to_region(&smr_prov, ep->region->map, id); + ofi_spin_unlock(&ep->region->map->lock); if (ret) return -1; } @@ -791,6 +793,7 @@ static int smr_ep_close(struct fid *fid) { struct smr_ep *ep; + printf("start ep close\n"); ep = container_of(fid, struct smr_ep, util_ep.ep_fid.fid); if (smr_env.use_dsa_sar) @@ -831,6 +834,7 @@ static int smr_ep_close(struct fid *fid) free((void *)ep->name); free(ep); + printf("done\n"); return 0; } diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c index 141826b9bba..14e9a42136d 100644 --- a/prov/shm/src/smr_progress.c +++ b/prov/shm/src/smr_progress.c @@ -878,7 +878,9 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) peer_smr = smr_peer_region(ep->region, idx); if (!peer_smr) { + ofi_spin_lock(&ep->region->map->lock); ret = smr_map_to_region(&smr_prov, ep->region->map, idx); + ofi_spin_unlock(&ep->region->map->lock); if (ret) { FI_WARN(&smr_prov, FI_LOG_EP_CTRL, "Could not map peer region\n"); @@ -891,14 +893,11 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) if (peer_smr->pid != (int) cmd->msg.hdr.data) { /* TODO track and update/complete in error any transfers * to or from old mapping - * - * TODO create smr_unmap_region - * this needs to close peer_smr->map->peers[idx].pid_fd - * This case will also return an unmapped region because the idx - * is valid but the region was unmapped */ - munmap(peer_smr, peer_smr->total_size); + ofi_spin_lock(&ep->region->map->lock); + smr_unmap_region(&smr_prov, ep->region->map, idx); smr_map_to_region(&smr_prov, ep->region->map, idx); + ofi_spin_unlock(&ep->region->map->lock); peer_smr = smr_peer_region(ep->region, idx); } diff --git a/prov/shm/src/smr_util.c b/prov/shm/src/smr_util.c index 2924ddaa6f2..bbb5ac30ab6 100644 --- a/prov/shm/src/smr_util.c +++ b/prov/shm/src/smr_util.c @@ -325,6 +325,7 @@ int smr_create(const struct fi_provider *prov, struct smr_map *map, close: close(fd); shm_unlink(attr->name); + printf("done create\n"); return ret; } @@ -357,6 +358,7 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, const char *name = smr_no_prefix(peer_buf->peer.name); char tmp[SMR_PATH_MAX]; + printf("map to region\n"); pthread_mutex_lock(&ep_list_lock); entry = dlist_find_first_match(&ep_name_list, smr_match_name, name); if (entry) { @@ -367,16 +369,14 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, } pthread_mutex_unlock(&ep_list_lock); - ofi_spin_lock(&map->lock); if (peer_buf->region) - goto unlock; + return FI_SUCCESS; fd = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); if (fd < 0) { - ret = -errno; FI_WARN_ONCE(prov, FI_LOG_AV, "shm_open error: name %s errno %d\n", name, errno); - goto unlock; + return -errno; } memset(tmp, 0, sizeof(tmp)); @@ -437,8 +437,7 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, out: close(fd); -unlock: - ofi_spin_unlock(&map->lock); + printf("done map to region\n"); return ret; } @@ -479,24 +478,62 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) return; } +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t peer_id) +{ + struct smr_region *peer_region; + struct smr_peer *peer; + struct util_ep *util_ep; + struct smr_ep *smr_ep; + struct smr_av *av; + int ret = 0; + + peer_region = map->peers[peer_id].region; + if (!peer_region) + return; + + peer = &map->peers[peer_id]; + av = container_of(map, struct smr_av, smr_map); + dlist_foreach_container(&av->util_av.ep_list, struct util_ep, util_ep, + av_entry) { + smr_ep = container_of(util_ep, struct smr_ep, util_ep); + smr_unmap_from_endpoint(smr_ep->region, peer_id); + } + + if (map->flags & SMR_FLAG_HMEM_ENABLED) { + ret = ofi_hmem_host_unregister(peer_region); + if (ret) + FI_WARN(prov, FI_LOG_EP_CTRL, + "unable to unregister shm with iface\n"); + + if (peer->pid_fd != -1) { + close(peer->pid_fd); + peer->pid_fd = -1; + } + } + + munmap(peer_region, peer_region->total_size); + peer->region = NULL; +} + void smr_unmap_from_endpoint(struct smr_region *region, int64_t id) { struct smr_region *peer_smr; struct smr_peer_data *local_peers, *peer_peers; int64_t peer_id; - local_peers = smr_peer_data(region); if (region->map->peers[id].peer.id < 0) return; peer_smr = smr_peer_region(region, id); - peer_id = smr_peer_data(region)[id].addr.id; - + assert(peer_smr); peer_peers = smr_peer_data(peer_smr); + peer_id = smr_peer_data(region)[id].addr.id; peer_peers[peer_id].addr.id = -1; peer_peers[peer_id].name_sent = 0; + local_peers = smr_peer_data(region); ofi_xpmem_release(&local_peers[peer_id].xpmem); } @@ -544,40 +581,29 @@ int smr_map_add(const struct fi_provider *prov, struct smr_map *map, return FI_SUCCESS; } -void smr_map_del(struct smr_map *map, int64_t id) +int smr_map_unmap(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context) { - struct dlist_entry *entry; + struct smr_map *map = container_of(rbmap, struct smr_map, rbmap); + int64_t id = (uintptr_t) node->data; assert(id >= 0 && id < SMR_MAX_PEERS); - - pthread_mutex_lock(&ep_list_lock); - entry = dlist_find_first_match(&ep_name_list, smr_match_name, - smr_no_prefix(map->peers[id].peer.name)); - pthread_mutex_unlock(&ep_list_lock); - - ofi_spin_lock(&map->lock); - (void) ofi_rbmap_find_delete(&map->rbmap, - (void *) map->peers[id].peer.name); - + smr_unmap_region(&smr_prov, map, id); map->peers[id].fiaddr = FI_ADDR_NOTAVAIL; map->peers[id].peer.id = -1; map->num_peers--; - if (!map->peers[id].region) - goto unlock; - - if (!entry) { - if (map->flags & SMR_FLAG_HMEM_ENABLED) { - if (map->peers[id].pid_fd != -1) - close(map->peers[id].pid_fd); + return FI_SUCCESS; +} - (void) ofi_hmem_host_unregister(map->peers[id].region); - } - munmap(map->peers[id].region, map->peers[id].region->total_size); - map->peers[id].region = NULL; - } -unlock: +int smr_map_del(struct smr_map *map, struct ofi_rbnode *node) +{ + ofi_spin_lock(&map->lock); + smr_map_unmap(&map->rbmap, node, NULL); + ofi_rbmap_delete(&map->rbmap, node); ofi_spin_unlock(&map->lock); + + return FI_SUCCESS; } struct smr_region *smr_map_get(struct smr_map *map, int64_t id) diff --git a/prov/shm/src/smr_util.h b/prov/shm/src/smr_util.h index c5bf8124873..fa902f1fec6 100644 --- a/prov/shm/src/smr_util.h +++ b/prov/shm/src/smr_util.h @@ -356,11 +356,15 @@ void smr_cleanup(void); int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, int64_t id); void smr_map_to_endpoint(struct smr_region *region, int64_t id); +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t id); void smr_unmap_from_endpoint(struct smr_region *region, int64_t id); void smr_exchange_all_peers(struct smr_region *region); int smr_map_add(const struct fi_provider *prov, struct smr_map *map, const char *name, int64_t *id); -void smr_map_del(struct smr_map *map, int64_t id); +int smr_map_unmap(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context); +int smr_map_del(struct smr_map *map, struct ofi_rbnode *node); struct smr_region *smr_map_get(struct smr_map *map, int64_t id);