diff --git a/embassy-sync/src/raw_rwlock.rs b/embassy-sync/src/raw_rwlock.rs new file mode 100644 index 000000000..de4bd1dc5 --- /dev/null +++ b/embassy-sync/src/raw_rwlock.rs @@ -0,0 +1,86 @@ +use core::sync::atomic::{AtomicUsize, Ordering}; +use core::task::Waker; +use core::cell::UnsafeCell; + +pub trait RawRwLock { + fn lock_read(&self); + fn try_lock_read(&self) -> bool; + fn unlock_read(&self); + fn lock_write(&self); + fn try_lock_write(&self) -> bool; + fn unlock_write(&self); +} + +pub struct RawRwLockImpl { + state: AtomicUsize, + waker: UnsafeCell>, +} + +impl RawRwLockImpl { + pub const fn new() -> Self { + Self { + state: AtomicUsize::new(0), + waker: UnsafeCell::new(None), + } + } +} + +unsafe impl Send for RawRwLockImpl {} +unsafe impl Sync for RawRwLockImpl {} + +impl RawRwLock for RawRwLockImpl { + fn lock_read(&self) { + loop { + let state = self.state.load(Ordering::Acquire); + if state & 1 == 0 { + if self.state.compare_and_swap(state, state + 2, Ordering::AcqRel) == state { + break; + } + } + } + } + + fn try_lock_read(&self) -> bool { + let state = self.state.load(Ordering::Acquire); + if state & 1 == 0 { + if self.state.compare_and_swap(state, state + 2, Ordering::AcqRel) == state { + return true; + } + } + false + } + + fn unlock_read(&self) { + self.state.fetch_sub(2, Ordering::Release); + if self.state.load(Ordering::Acquire) == 0 { + if let Some(waker) = unsafe { &*self.waker.get() } { + waker.wake_by_ref(); + } + } + } + + fn lock_write(&self) { + loop { + let state = self.state.load(Ordering::Acquire); + if state == 0 { + if self.state.compare_and_swap(0, 1, Ordering::AcqRel) == 0 { + break; + } + } + } + } + + fn try_lock_write(&self) -> bool { + if self.state.compare_and_swap(0, 1, Ordering::AcqRel) == 0 { + return true; + } + false + } + + fn unlock_write(&self) { + self.state.store(0, Ordering::Release); + if let Some(waker) = unsafe { &*self.waker.get() } { + waker.wake_by_ref(); + } + } +} diff --git a/embassy-sync/src/rwlock.rs b/embassy-sync/src/rwlock.rs index 15ea8468e..30e1e74ad 100644 --- a/embassy-sync/src/rwlock.rs +++ b/embassy-sync/src/rwlock.rs @@ -1,136 +1,134 @@ -use core::cell::RefCell; -use core::future::{poll_fn, Future}; +use core::cell::UnsafeCell; +use core::future::poll_fn; use core::ops::{Deref, DerefMut}; use core::task::Poll; -use crate::blocking_mutex::raw::RawMutex; use crate::blocking_mutex::Mutex as BlockingMutex; -use crate::waitqueue::MultiWakerRegistration; +use crate::waitqueue::WakerRegistration; +use crate::raw_rwlock::RawRwLock; -/// Error returned by [`RwLock::try_read`] and [`RwLock::try_write`] -#[derive(PartialEq, Eq, Clone, Copy, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct TryLockError; - -/// Async read-write lock. -/// -/// The lock is generic over a blocking [`RawMutex`](crate::blocking_mutex::raw::RawMutex). -/// The raw mutex is used to guard access to the internal state. It -/// is held for very short periods only, while locking and unlocking. It is *not* held -/// for the entire time the async RwLock is locked. -/// -/// Which implementation you select depends on the context in which you're using the lock. -/// -/// Use [`CriticalSectionRawMutex`](crate::blocking_mutex::raw::CriticalSectionRawMutex) when data can be shared between threads and interrupts. -/// -/// Use [`NoopRawMutex`](crate::blocking_mutex::raw::NoopRawMutex) when data is only shared between tasks running on the same executor. -/// -/// Use [`ThreadModeRawMutex`](crate::blocking_mutex::raw::ThreadModeRawMutex) when data is shared between tasks running on the same executor but you want a singleton. -/// pub struct RwLock where - M: RawMutex, + M: RawRwLock, T: ?Sized, { - state: BlockingMutex>, - inner: RefCell, + state: BlockingMutex, + inner: UnsafeCell, } -struct State { - readers: usize, - writer: bool, - writer_waker: MultiWakerRegistration<1>, - reader_wakers: MultiWakerRegistration<8>, -} - -impl State { - fn new() -> Self { - Self { - readers: 0, - writer: false, - writer_waker: MultiWakerRegistration::new(), - reader_wakers: MultiWakerRegistration::new(), - } - } -} +unsafe impl Send for RwLock {} +unsafe impl Sync for RwLock {} impl RwLock where - M: RawMutex, + M: RawRwLock, { - /// Create a new read-write lock with the given value. pub const fn new(value: T) -> Self { Self { - inner: RefCell::new(value), - state: BlockingMutex::new(RefCell::new(State::new())), + inner: UnsafeCell::new(value), + state: BlockingMutex::new(RwLockState { + locked: LockedState::Unlocked, + writer_pending: 0, + readers_pending: 0, + waker: WakerRegistration::new(), + }), } } } impl RwLock where - M: RawMutex, + M: RawRwLock, T: ?Sized, { - /// Acquire a read lock. - /// - /// This will wait for the lock to be available if it's already locked for writing. pub fn read(&self) -> impl Future> { poll_fn(|cx| { - let mut state = self.state.lock(|s| s.borrow_mut()); - if state.writer { - state.reader_wakers.register(cx.waker()); - Poll::Pending - } else { - state.readers += 1; + let ready = self.state.lock(|s| { + let mut s = s.borrow_mut(); + match s.locked { + LockedState::Unlocked => { + s.locked = LockedState::ReadLocked(1); + true + } + LockedState::ReadLocked(ref mut count) => { + *count += 1; + true + } + LockedState::WriteLocked => { + s.readers_pending += 1; + s.waker.register(cx.waker()); + false + } + } + }); + + if ready { Poll::Ready(RwLockReadGuard { lock: self }) + } else { + Poll::Pending } }) } - /// Acquire a write lock. - /// - /// This will wait for the lock to be available if it's already locked for reading or writing. pub fn write(&self) -> impl Future> { poll_fn(|cx| { - let mut state = self.state.lock(|s| s.borrow_mut()); - if state.writer || state.readers > 0 { - state.writer_waker.register(cx.waker()); - Poll::Pending - } else { - state.writer = true; + let ready = self.state.lock(|s| { + let mut s = s.borrow_mut(); + match s.locked { + LockedState::Unlocked => { + s.locked = LockedState::WriteLocked; + true + } + _ => { + s.writer_pending += 1; + s.waker.register(cx.waker()); + false + } + } + }); + + if ready { Poll::Ready(RwLockWriteGuard { lock: self }) + } else { + Poll::Pending } }) } - /// Attempt to immediately acquire a read lock. - /// - /// If the lock is already locked for writing, this will return an error instead of waiting. pub fn try_read(&self) -> Result, TryLockError> { - let mut state = self.state.lock(|s| s.borrow_mut()); - if state.writer { - Err(TryLockError) - } else { - state.readers += 1; - Ok(RwLockReadGuard { lock: self }) - } + self.state.lock(|s| { + let mut s = s.borrow_mut(); + match s.locked { + LockedState::Unlocked => { + s.locked = LockedState::ReadLocked(1); + Ok(()) + } + LockedState::ReadLocked(ref mut count) => { + *count += 1; + Ok(()) + } + LockedState::WriteLocked => Err(TryLockError), + } + })?; + + Ok(RwLockReadGuard { lock: self }) } - /// Attempt to immediately acquire a write lock. - /// - /// If the lock is already locked for reading or writing, this will return an error instead of waiting. pub fn try_write(&self) -> Result, TryLockError> { - let mut state = self.state.lock(|s| s.borrow_mut()); - if state.writer || state.readers > 0 { - Err(TryLockError) - } else { - state.writer = true; - Ok(RwLockWriteGuard { lock: self }) - } + self.state.lock(|s| { + let mut s = s.borrow_mut(); + match s.locked { + LockedState::Unlocked => { + s.locked = LockedState::WriteLocked; + Ok(()) + } + _ => Err(TryLockError), + } + })?; + + Ok(RwLockWriteGuard { lock: self }) } - /// Consumes this lock, returning the underlying data. pub fn into_inner(self) -> T where T: Sized, @@ -138,19 +136,12 @@ where self.inner.into_inner() } - /// Returns a mutable reference to the underlying data. - /// - /// Since this call borrows the RwLock mutably, no actual locking needs to - /// take place -- the mutable borrow statically guarantees no locks exist. pub fn get_mut(&mut self) -> &mut T { self.inner.get_mut() } } -impl From for RwLock -where - M: RawMutex, -{ +impl From for RwLock { fn from(from: T) -> Self { Self::new(from) } @@ -158,7 +149,7 @@ where impl Default for RwLock where - M: RawMutex, + M: RawRwLock, T: Default, { fn default() -> Self { @@ -166,91 +157,103 @@ where } } -/// Async read lock guard. -/// -/// Owning an instance of this type indicates having -/// successfully locked the RwLock for reading, and grants access to the contents. -/// -/// Dropping it unlocks the RwLock. -#[must_use = "if unused the RwLock will immediately unlock"] pub struct RwLockReadGuard<'a, M, T> where - M: RawMutex, + M: RawRwLock, T: ?Sized, { lock: &'a RwLock, } -impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T> -where - M: RawMutex, - T: ?Sized, -{ - fn drop(&mut self) { - let mut state = self.lock.state.lock(|s| s.borrow_mut()); - state.readers -= 1; - if state.readers == 0 { - state.writer_waker.wake(); - } - } -} - impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T> where - M: RawMutex, + M: RawRwLock, T: ?Sized, { type Target = T; + fn deref(&self) -> &Self::Target { - self.lock.inner.borrow() + unsafe { &*self.lock.inner.get() } + } +} + +impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T> +where + M: RawRwLock, + T: ?Sized, +{ + fn drop(&mut self) { + self.lock.state.lock(|s| { + let mut s = s.borrow_mut(); + match s.locked { + LockedState::ReadLocked(ref mut count) => { + *count -= 1; + if *count == 0 { + s.locked = LockedState::Unlocked; + s.waker.wake(); + } + } + _ => unreachable!(), + } + }); } } -/// Async write lock guard. -/// -/// Owning an instance of this type indicates having -/// successfully locked the RwLock for writing, and grants access to the contents. -/// -/// Dropping it unlocks the RwLock. -#[must_use = "if unused the RwLock will immediately unlock"] pub struct RwLockWriteGuard<'a, M, T> where - M: RawMutex, + M: RawRwLock, T: ?Sized, { lock: &'a RwLock, } -impl<'a, M, T> Drop for RwLockWriteGuard<'a, M, T> -where - M: RawMutex, - T: ?Sized, -{ - fn drop(&mut self) { - let mut state = self.lock.state.lock(|s| s.borrow_mut()); - state.writer = false; - state.reader_wakers.wake(); - state.writer_waker.wake(); - } -} - impl<'a, M, T> Deref for RwLockWriteGuard<'a, M, T> where - M: RawMutex, + M: RawRwLock, T: ?Sized, { type Target = T; + fn deref(&self) -> &Self::Target { - self.lock.inner.borrow() + unsafe { &*self.lock.inner.get() } } } impl<'a, M, T> DerefMut for RwLockWriteGuard<'a, M, T> where - M: RawMutex, + M: RawRwLock, T: ?Sized, { fn deref_mut(&mut self) -> &mut Self::Target { - self.lock.inner.borrow_mut() + unsafe { &mut *self.lock.inner.get() } } } + +impl<'a, M, T> Drop for RwLockWriteGuard<'a, M, T> +where + M: RawRwLock, + T: ?Sized, +{ + fn drop(&mut self) { + self.lock.state.lock(|s| { + let mut s = s.borrow_mut(); + s.locked = LockedState::Unlocked; + s.waker.wake(); + }); + } +} + +struct RwLockState { + locked: LockedState, + writer_pending: usize, + readers_pending: usize, + waker: WakerRegistration, +} + +enum LockedState { + Unlocked, + ReadLocked(usize), + WriteLocked, +} + +pub struct TryLockError;