From 3b27f651172c85bebc3afff2cea45dd531526968 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 13 Sep 2023 11:07:54 +0800 Subject: [PATCH] rpmsg: add release cb and refcnt in end pointto fix ept used-after-free if rpmsg service free the ept when has got the ept from the ept list in rpmsg_virtio_rx_callback, there is a used after free about the ept, so add refcnt to end point and call the rpmsg service release callback when ept callback fininshed. Signed-off-by: yintao Signed-off-by: Bowen Wang --- lib/include/openamp/rpmsg.h | 40 +++++++++++++++++++++++++++++++++++++ lib/rpmsg/rpmsg.c | 20 +++++++++++++++++++ lib/rpmsg/rpmsg_virtio.c | 7 ++++++- 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/lib/include/openamp/rpmsg.h b/lib/include/openamp/rpmsg.h index 9cf1e7444..61d10a0b2 100644 --- a/lib/include/openamp/rpmsg.h +++ b/lib/include/openamp/rpmsg.h @@ -19,6 +19,7 @@ #include #include #include +#include #if defined __cplusplus extern "C" { @@ -50,6 +51,9 @@ struct rpmsg_device; /* Returns positive value on success or negative error value on failure */ typedef int (*rpmsg_ept_cb)(struct rpmsg_endpoint *ept, void *data, size_t len, uint32_t src, void *priv); +typedef void (*rpmsg_ept_inc_ref_cb)(void *priv); +typedef void (*rpmsg_ept_dec_ref_cb)(void *priv); +typedef void (*rpmsg_ept_release_cb)(struct rpmsg_endpoint *ept); typedef void (*rpmsg_ns_unbind_cb)(struct rpmsg_endpoint *ept); typedef void (*rpmsg_ns_bind_cb)(struct rpmsg_device *rdev, const char *name, uint32_t dest); @@ -73,6 +77,18 @@ struct rpmsg_endpoint { /** Address of the default remote endpoint binded */ uint32_t dest_addr; + /** Reference count of the endpoint */ + atomic_uint refcnt; + + /** Callback to increase the endpoint reference count for service */ + rpmsg_ept_inc_ref_cb inc_ref_cb; + + /** Callback to decrease the endpoint reference count for service */ + rpmsg_ept_dec_ref_cb dec_ref_cb; + + /** Callback to free the endpoint */ + rpmsg_ept_release_cb rel_cb; + /** * User rx callback, return value of this callback is reserved for future * use, for now, only allow RPMSG_SUCCESS as return value @@ -142,6 +158,30 @@ struct rpmsg_device { bool support_ns; }; +/** + * @brief Increase the endpoint reference count + * + * This function increases reference count of the endpoint, if rpmsg service + * registered inc_ref_cb of @ept, the reference count is processed at the + * service; otherwise, it will directly increase refcnt of @ept. + * + * @ept: pointer to rpmsg endpoint + * + */ +void rpmsg_ept_inc_ref(struct rpmsg_endpoint *ept); + +/** + * @brief Decrease the end point reference count + * + * This function decreases reference count of the endpoint, if rpmsg service + * registered dec_ref_cb of @ept, the reference count is processed at the + * service; otherwise, it will directly increase refcnt of @ept. Release ept + * when the reference count decreases to zero. + * + * @ept: pointer to rpmsg endpoint + */ +void rpmsg_ept_dec_ref(struct rpmsg_endpoint *ept); + /** * @brief Send a message across to the remote processor, * specifying source and destination address. diff --git a/lib/rpmsg/rpmsg.c b/lib/rpmsg/rpmsg.c index 5a9237f47..8eeaa87cf 100644 --- a/lib/rpmsg/rpmsg.c +++ b/lib/rpmsg/rpmsg.c @@ -97,6 +97,24 @@ static int rpmsg_set_address(unsigned long *bitmap, int size, int addr) } } +void rpmsg_ept_inc_ref(struct rpmsg_endpoint *ept) +{ + if (ept->inc_ref_cb) { + ept->inc_ref_cb(ept->priv); + } else { + atomic_fetch_add(&ept->refcnt, 1); + } +} + +void rpmsg_ept_dec_ref(struct rpmsg_endpoint *ept) +{ + if (ept->dec_ref_cb) { + ept->dec_ref_cb(ept->priv); + } else if (atomic_fetch_sub(&ept->refcnt, 1) == 1 && ept->rel_cb) { + ept->rel_cb(ept); + } +} + int rpmsg_send_offchannel_raw(struct rpmsg_endpoint *ept, uint32_t src, uint32_t dst, const void *data, int len, int wait) @@ -247,6 +265,7 @@ static void rpmsg_unregister_endpoint(struct rpmsg_endpoint *ept) metal_list_del(&ept->node); ept->rdev = NULL; metal_mutex_release(&rdev->lock); + rpmsg_ept_dec_ref(ept); } void rpmsg_register_endpoint(struct rpmsg_device *rdev, @@ -256,6 +275,7 @@ void rpmsg_register_endpoint(struct rpmsg_device *rdev, rpmsg_ept_cb cb, rpmsg_ns_unbind_cb ns_unbind_cb) { + rpmsg_ept_inc_ref(ept); strncpy(ept->name, name ? name : "", sizeof(ept->name)); ept->addr = src; ept->dest_addr = dest; diff --git a/lib/rpmsg/rpmsg_virtio.c b/lib/rpmsg/rpmsg_virtio.c index ea4cc0d9e..57cdf3ddf 100644 --- a/lib/rpmsg/rpmsg_virtio.c +++ b/lib/rpmsg/rpmsg_virtio.c @@ -568,8 +568,10 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq) */ ept->dest_addr = rp_hdr->src; } + rpmsg_ept_inc_ref(ept); status = ept->cb(ept, RPMSG_LOCATE_DATA(rp_hdr), rp_hdr->len, rp_hdr->src, ept->priv); + rpmsg_ept_dec_ref(ept); RPMSG_ASSERT(status >= 0, "unexpected callback status\r\n"); @@ -637,8 +639,11 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data, if (_ept) _ept->dest_addr = RPMSG_ADDR_ANY; metal_mutex_release(&rdev->lock); - if (_ept && _ept->ns_unbind_cb) + if (_ept && _ept->ns_unbind_cb) { + rpmsg_ept_inc_ref(_ept); _ept->ns_unbind_cb(_ept); + rpmsg_ept_dec_ref(_ept); + } if (rdev->ns_unbind_cb) rdev->ns_unbind_cb(rdev, name, dest); } else {