Skip to content

Commit

Permalink
sync: relax type constraints on all Mutex types
Browse files Browse the repository at this point in the history
  • Loading branch information
Qix- committed Jan 12, 2025
1 parent d66f309 commit 3239594
Showing 1 changed file with 76 additions and 52 deletions.
128 changes: 76 additions & 52 deletions oro-sync/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Synchronization primitives for the Oro Kernel.
#![cfg_attr(not(test), no_std)]
#![cfg_attr(doc, feature(doc_cfg, doc_auto_cfg))]
// SAFETY(qix-): This is accepted but is taking ages to stabilize. In theory
// SAFETY(qix-): marker fields could be used but for now I want to keep things
// SAFETY(qix-): cleaner and more readable.
#![feature(negative_impls)]

use core::{
cell::UnsafeCell,
Expand All @@ -17,7 +21,7 @@ const TICKET_MUTEX_TIMEOUT: usize = 1000;
/// Standardized lock interface implemented for all lock types.
pub trait Lock {
/// The target type of value being guarded.
type Target: Send + 'static;
type Target: ?Sized;

/// The lock guard type used by the lock implementation.
type Guard<'a>: Drop + Deref<Target = Self::Target> + DerefMut
Expand All @@ -30,17 +34,17 @@ pub trait Lock {

/// A simple unfair, greedy spinlock. The most efficient spinlock
/// available in this library.
pub struct Mutex<T: Send + 'static> {
/// The guarded value.
value: UnsafeCell<T>,
pub struct Mutex<T: ?Sized> {
/// Whether or not the lock is taken.
locked: AtomicBool,
/// The guarded value.
value: UnsafeCell<T>,
}

// SAFETY: We are implementing a safe interface around a mutex so we can assert `Sync`.
unsafe impl<T: Send + 'static> Sync for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}

impl<T: Send + 'static> Mutex<T> {
impl<T> Mutex<T> {
/// Creates a new spinlock mutex for the given value.
pub const fn new(value: T) -> Self {
Self {
Expand All @@ -50,15 +54,18 @@ impl<T: Send + 'static> Mutex<T> {
}
}

impl<T: Send + 'static> Lock for Mutex<T> {
type Guard<'a> = MutexGuard<'a, T>;
impl<T: ?Sized> Lock for Mutex<T> {
type Guard<'a>
= MutexGuard<'a, T>
where
T: 'a;
type Target = T;

fn lock(&self) -> Self::Guard<'_> {
loop {
if !self.locked.swap(true, Acquire) {
#[cfg(debug_assertions)]
::oro_dbgutil::__oro_dbgutil_lock_acquire(self.value.get() as usize);
::oro_dbgutil::__oro_dbgutil_lock_acquire(self.value.get() as *const () as usize);
return MutexGuard { lock: self };
}

Expand All @@ -67,30 +74,33 @@ impl<T: Send + 'static> Lock for Mutex<T> {
}
}

impl<T: Default + Send + 'static> Default for Mutex<T> {
impl<T: Default> Default for Mutex<T> {
fn default() -> Self {
Self::new(T::default())
}
}

/// A mutex guard for the simple [`Mutex`] type.
pub struct MutexGuard<'a, T: Send + 'static>
pub struct MutexGuard<'a, T: ?Sized + 'a>
where
Self: 'a,
{
/// A reference to the lock for which we have a guard.
lock: &'a Mutex<T>,
}

impl<T: Send + 'static> Drop for MutexGuard<'_, T> {
impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}

impl<T: ?Sized> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
#[cfg(debug_assertions)]
::oro_dbgutil::__oro_dbgutil_lock_release(self.lock.value.get() as usize);
::oro_dbgutil::__oro_dbgutil_lock_release(self.lock.value.get() as *const () as usize);
self.lock.locked.store(false, Release);
}
}

impl<T: Send + 'static> Deref for MutexGuard<'_, T> {
impl<T: ?Sized> Deref for MutexGuard<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
Expand All @@ -99,42 +109,45 @@ impl<T: Send + 'static> Deref for MutexGuard<'_, T> {
}
}

impl<T: Send + 'static> DerefMut for MutexGuard<'_, T> {
impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: We have guaranteed singular access as we're locked.
unsafe { &mut *self.lock.value.get() }
}
}

/// A ticketed, fair mutex implementation.
pub struct TicketMutex<T: Send + 'static> {
/// The guarded value.
value: UnsafeCell<T>,
pub struct TicketMutex<T: ?Sized> {
/// The currently served ticket.
now_serving: AtomicUsize,
/// The next ticket.
next_ticket: AtomicUsize,
/// Whether or not we've locked the lock.
locked: AtomicBool,
/// The guarded value.
value: UnsafeCell<T>,
}

// SAFETY: We are implementing a safe interface around a mutex so we can assert `Sync`.
unsafe impl<T: Send + 'static> Sync for TicketMutex<T> {}
unsafe impl<T: ?Sized + Send> Send for TicketMutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for TicketMutex<T> {}

impl<T: Send + 'static> TicketMutex<T> {
impl<T> TicketMutex<T> {
/// Creates a new ticket mutex.
pub const fn new(value: T) -> Self {
Self {
value: UnsafeCell::new(value),
now_serving: AtomicUsize::new(0),
next_ticket: AtomicUsize::new(0),
locked: AtomicBool::new(false),
value: UnsafeCell::new(value),
}
}
}

impl<T: Send + 'static> Lock for TicketMutex<T> {
type Guard<'a> = TicketMutexGuard<'a, T>;
impl<T: ?Sized> Lock for TicketMutex<T> {
type Guard<'a>
= TicketMutexGuard<'a, T>
where
T: 'a;
type Target = T;

fn lock(&self) -> Self::Guard<'_> {
Expand All @@ -159,7 +172,9 @@ impl<T: Send + 'static> Lock for TicketMutex<T> {

if position == 0 && !self.locked.swap(true, AcqRel) {
#[cfg(debug_assertions)]
::oro_dbgutil::__oro_dbgutil_lock_acquire(self.value.get() as usize);
::oro_dbgutil::__oro_dbgutil_lock_acquire(
self.value.get() as *const () as usize
);
return TicketMutexGuard { lock: self, ticket };
}

Expand Down Expand Up @@ -201,27 +216,27 @@ impl<T: Send + 'static> Lock for TicketMutex<T> {
}
}

impl<T: Default + Send + 'static> Default for TicketMutex<T> {
impl<T: Default> Default for TicketMutex<T> {
fn default() -> Self {
Self::new(T::default())
}
}

/// A lock guard for a [`TicketMutex`].
pub struct TicketMutexGuard<'a, T: Send + 'static>
pub struct TicketMutexGuard<'a, T: ?Sized + 'a>
where
Self: 'a,
{
/// The lock we are guarding.
lock: &'a TicketMutex<T>,
/// Our ticket
ticket: usize,
/// The lock we are guarding.
lock: &'a TicketMutex<T>,
}

impl<T: Send + 'static> Drop for TicketMutexGuard<'_, T> {
impl<T: ?Sized> Drop for TicketMutexGuard<'_, T> {
fn drop(&mut self) {
#[cfg(debug_assertions)]
::oro_dbgutil::__oro_dbgutil_lock_release(self.lock.value.get() as usize);
::oro_dbgutil::__oro_dbgutil_lock_release(self.lock.value.get() as *const () as usize);
let _ = self.lock.now_serving.compare_exchange(
self.ticket,
self.ticket.wrapping_add(1),
Expand All @@ -232,7 +247,7 @@ impl<T: Send + 'static> Drop for TicketMutexGuard<'_, T> {
}
}

impl<T: Send + 'static> Deref for TicketMutexGuard<'_, T> {
impl<T: ?Sized> Deref for TicketMutexGuard<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
Expand All @@ -241,13 +256,16 @@ impl<T: Send + 'static> Deref for TicketMutexGuard<'_, T> {
}
}

impl<T: Send + 'static> DerefMut for TicketMutexGuard<'_, T> {
impl<T: ?Sized> DerefMut for TicketMutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: We have guaranteed singular access as we're locked.
unsafe { &mut *self.lock.value.get() }
}
}

impl<T: ?Sized> !Send for TicketMutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for TicketMutexGuard<'_, T> {}

#[doc(hidden)]
#[cfg(feature = "reentrant_mutex")]
mod reentrant {
Expand Down Expand Up @@ -280,29 +298,32 @@ mod reentrant {
///
/// **NOTE:** This implementation spins (and does not lock) if the refcount
/// reaches `u32::MAX`. This is usually not a problem.
pub struct ReentrantMutex<T: Send + 'static> {
/// The inner value.
inner: UnsafeCell<T>,
pub struct ReentrantMutex<T: ?Sized> {
/// The lock state.
///
/// The upper 32 bits are the core ID of the lock holder, and the lower 32 bits
/// are the lock count.
lock: AtomicU64,
/// The inner value.
inner: UnsafeCell<T>,
}

impl<T: Send + 'static> ReentrantMutex<T> {
impl<T> ReentrantMutex<T> {
/// Constructs a new reentrant mutex.
pub const fn new(inner: T) -> Self {
Self {
inner: UnsafeCell::new(inner),
lock: AtomicU64::new(0),
inner: UnsafeCell::new(inner),
}
}
}

impl<T: Send + 'static> Lock for ReentrantMutex<T> {
impl<T: ?Sized> Lock for ReentrantMutex<T> {
/// The lock guard type used by the lock implementation.
type Guard<'a> = ReentrantMutexGuard<'a, Self::Target>;
type Guard<'a>
= ReentrantMutexGuard<'a, T>
where
T: 'a;
/// The target type of value being guarded.
type Target = T;

Expand Down Expand Up @@ -333,12 +354,18 @@ mod reentrant {
}
}

impl<T: Default> Default for ReentrantMutex<T> {
fn default() -> Self {
Self::new(T::default())
}
}

/// A guard for a reentrant mutex.
pub struct ReentrantMutexGuard<'a, T: Send + 'static> {
pub struct ReentrantMutexGuard<'a, T: ?Sized + 'a> {
inner: &'a ReentrantMutex<T>,
}

impl<T: Send + 'static> core::ops::Deref for ReentrantMutexGuard<'_, T> {
impl<T: ?Sized> core::ops::Deref for ReentrantMutexGuard<'_, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
Expand All @@ -347,14 +374,14 @@ mod reentrant {
}
}

impl<T: Send + 'static> core::ops::DerefMut for ReentrantMutexGuard<'_, T> {
impl<T: ?Sized> core::ops::DerefMut for ReentrantMutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: The guard is only created if the lock is held.
unsafe { &mut *self.inner.inner.get() }
}
}

impl<T: Send + 'static> Drop for ReentrantMutexGuard<'_, T> {
impl<T: ?Sized> Drop for ReentrantMutexGuard<'_, T> {
fn drop(&mut self) {
loop {
let current = self.inner.lock.load(Relaxed);
Expand Down Expand Up @@ -383,13 +410,10 @@ mod reentrant {
}
}

unsafe impl<T: Send + 'static> Sync for ReentrantMutex<T> {}

impl<T: Default + Send + 'static> Default for ReentrantMutex<T> {
fn default() -> Self {
Self::new(T::default())
}
}
unsafe impl<T: ?Sized + Send> Send for ReentrantMutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for ReentrantMutex<T> {}
impl<T: ?Sized> !Send for ReentrantMutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for ReentrantMutexGuard<'_, T> {}
}

#[cfg(feature = "reentrant_mutex")]
Expand Down

0 comments on commit 3239594

Please sign in to comment.