Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpmsg: add release cb and refcnt in endpoint to fix ept used-after-free #508

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lib/include/openamp/rpmsg.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct rpmsg_device;
/* Returns positive value on success or negative error value on failure */
typedef int (*rpmsg_ept_cb)(struct rpmsg_endpoint *ept, void *data,
size_t len, uint32_t src, void *priv);
typedef void (*rpmsg_ept_release_cb)(struct rpmsg_endpoint *ept);
typedef void (*rpmsg_ns_unbind_cb)(struct rpmsg_endpoint *ept);
typedef void (*rpmsg_ns_bind_cb)(struct rpmsg_device *rdev,
const char *name, uint32_t dest);
Expand All @@ -73,6 +74,12 @@ struct rpmsg_endpoint {
/** Address of the default remote endpoint binded */
uint32_t dest_addr;

/** Reference count for determining whether the endpoint can be deallocated */
uint32_t refcnt;

/** Callback to inform the user that the endpoint allocation can be safely removed */
rpmsg_ept_release_cb release_cb;

/**
* User rx callback, return value of this callback is reserved for future
* use, for now, only allow RPMSG_SUCCESS as return value
Expand Down
22 changes: 21 additions & 1 deletion lib/rpmsg/rpmsg.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,25 @@ static int rpmsg_set_address(unsigned long *bitmap, int size, int addr)
}
}

void rpmsg_ept_incref(struct rpmsg_endpoint *ept)
{
if (ept)
ept->refcnt++;
}

void rpmsg_ept_decref(struct rpmsg_endpoint *ept)
{
if (ept) {
yintao707 marked this conversation as resolved.
Show resolved Hide resolved
ept->refcnt--;
if (!ept->refcnt) {
if (ept->release_cb)
ept->release_cb(ept);
else
ept->rdev = NULL;
}
}
yintao707 marked this conversation as resolved.
Show resolved Hide resolved
}

int rpmsg_send_offchannel_raw(struct rpmsg_endpoint *ept, uint32_t src,
uint32_t dst, const void *data, int len,
int wait)
Expand Down Expand Up @@ -245,7 +264,7 @@ static void rpmsg_unregister_endpoint(struct rpmsg_endpoint *ept)
rpmsg_release_address(rdev->bitmap, RPMSG_ADDR_BMP_SIZE,
ept->addr);
metal_list_del(&ept->node);
ept->rdev = NULL;
rpmsg_ept_decref(ept);
arnopo marked this conversation as resolved.
Show resolved Hide resolved
yintao707 marked this conversation as resolved.
Show resolved Hide resolved
metal_mutex_release(&rdev->lock);
}

Expand All @@ -257,6 +276,7 @@ void rpmsg_register_endpoint(struct rpmsg_device *rdev,
rpmsg_ns_unbind_cb ns_unbind_cb)
{
strncpy(ept->name, name ? name : "", sizeof(ept->name));
ept->refcnt = 1;
ept->addr = src;
ept->dest_addr = dest;
ept->cb = cb;
Expand Down
25 changes: 25 additions & 0 deletions lib/rpmsg/rpmsg_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,31 @@ rpmsg_get_ept_from_addr(struct rpmsg_device *rdev, uint32_t addr)
return rpmsg_get_endpoint(rdev, NULL, addr, RPMSG_ADDR_ANY);
}

/**
* @internal
*
* @brief Increase the endpoint reference count
*
* This function is used to avoid calling ept_cb after release lock causes race condition
* it should be called under lock protection.
*
* @param ept pointer to rpmsg endpoint
*
*/
void rpmsg_ept_incref(struct rpmsg_endpoint *ept);

/**
* @internal
*
* @brief Decrease the end point reference count
*
* This function is used to avoid calling ept_cb after release lock causes race condition
* it should be called under lock protection.
*
* @param ept pointer to rpmsg endpoint
*/
void rpmsg_ept_decref(struct rpmsg_endpoint *ept);

#if defined __cplusplus
}
#endif
Expand Down
16 changes: 16 additions & 0 deletions lib/rpmsg/rpmsg_virtio.c
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq)
/* Get the channel node from the remote device channels list. */
metal_mutex_acquire(&rdev->lock);
ept = rpmsg_get_ept_from_addr(rdev, rp_hdr->dst);
rpmsg_ept_incref(ept);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yintao707, @arnopo does that make sense to move rpmsg_ept_incref() API within rpmsg_get_endpoint API ? If endpoint is retrieved successfully then we increase refcount. It is possible that get_endpoint API is called in future, so we increase refcount notifying endpoint is being used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arnopo Looks like refcnt is tracking more if callback is in progress or not rather than how many times endpoint is being used my multiple threads using rpmsg_get_endpoint correct ?

If this is the case, can we update documentation accordingly ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yintao707, @arnopo does that make sense to move rpmsg_ept_incref() API within rpmsg_get_endpoint API ? If endpoint is retrieved successfully then we increase refcount. It is possible that get_endpoint API is called in future, so we increase refcount notifying endpoint is being used.

I would prefer not to hide it in rpmsg_get_endpoint and address this only if we need to export the rpmsg_get_endpoint API in the future.

@arnopo Looks like refcnt is tracking more if callback is in progress or not rather than how many times endpoint is being used my multiple threads using rpmsg_get_endpoint correct ?

If this is the case, can we update documentation accordingly ?

Is the documentation header for rpmsg_ept_incref not explicit enough for you?
Could you provide more details on which part of the documentation you would like to see updated ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation for ept_incref looks good. But same documentation for refcnt variable should be updated:

/** Reference count of the endpoint */
uint32_t refcnt;

Above comment gives impression that refcnt variable is used for endpoint object reference counts. But, refcnt variable is used to track if callback execution is in progress or not. So, above variable documentation should be updated accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not to hide it in rpmsg_get_endpoint and address this only if we need to export the rpmsg_get_endpoint API in the future.

Ok sounds good.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation for ept_incref looks good. But same documentation for refcnt variable should be updated:

/** Reference count of the endpoint */
uint32_t refcnt;

Above comment gives impression that refcnt variable is used for endpoint object reference counts. But, refcnt variable is used to track if callback execution is in progress or not. So, above variable documentation should be updated accordingly.

@yintao707 , please could you address @tnmysh comment that we can merge it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arnopo @tnmysh , Thank you for your suggestion, I modified the comments about refcnt. Can you help me review whether this modification is appropriate

metal_mutex_release(&rdev->lock);

if (ept) {
Expand All @@ -532,6 +533,7 @@ static void rpmsg_virtio_rx_callback(struct virtqueue *vq)
}

metal_mutex_acquire(&rdev->lock);
rpmsg_ept_decref(ept);

/* Check whether callback wants to hold buffer */
if (!(rp_hdr->reserved & RPMSG_BUF_HELD)) {
Expand Down Expand Up @@ -571,6 +573,7 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data,
struct rpmsg_endpoint *_ept;
struct rpmsg_ns_msg *ns_msg;
uint32_t dest;
bool ept_to_release;
char name[RPMSG_NAME_SIZE];

(void)priv;
Expand All @@ -589,14 +592,27 @@ static int rpmsg_virtio_ns_callback(struct rpmsg_endpoint *ept, void *data,
metal_mutex_acquire(&rdev->lock);
_ept = rpmsg_get_endpoint(rdev, name, RPMSG_ADDR_ANY, dest);

/*
* If ept-release callback is not implemented, ns_unbind_cb() can free the ept.
* Test _ept->release_cb before calling ns_unbind_cb() callbacks.
*/
ept_to_release = _ept && _ept->release_cb;

if (ns_msg->flags & RPMSG_NS_DESTROY) {
if (_ept)
_ept->dest_addr = RPMSG_ADDR_ANY;
if (ept_to_release)
rpmsg_ept_incref(_ept);
metal_mutex_release(&rdev->lock);
if (_ept && _ept->ns_unbind_cb)
_ept->ns_unbind_cb(_ept);
if (rdev->ns_unbind_cb)
rdev->ns_unbind_cb(rdev, name, dest);
if (ept_to_release) {
metal_mutex_acquire(&rdev->lock);
rpmsg_ept_decref(_ept);
metal_mutex_release(&rdev->lock);
}
} else {
if (!_ept) {
/*
Expand Down