Refactor blocking read-write lock module structure and improve assertions in ThreadModeRawRwLock

This commit is contained in:
Alix ANNERAUD 2025-02-28 16:32:12 +01:00
parent a7ecf14259
commit 33cf27adf6
3 changed files with 274 additions and 138 deletions

View File

@ -1,11 +1,11 @@
//! Blocking read-write lock. //! Blocking read-write lock.
//! //!
//! This module provides a blocking read-write lock that can be used to synchronize data. //! This module provides a blocking read-write lock that can be used to synchronize data.
pub mod raw_rwlock; pub mod raw;
use core::cell::UnsafeCell; use core::cell::UnsafeCell;
use self::raw_rwlock::RawRwLock; use self::raw::RawRwLock;
/// Blocking read-write lock (not async) /// Blocking read-write lock (not async)
/// ///

View File

@ -126,13 +126,19 @@ mod thread_mode {
unsafe impl RawRwLock for ThreadModeRawRwLock { unsafe impl RawRwLock for ThreadModeRawRwLock {
const INIT: Self = Self::new(); const INIT: Self = Self::new();
fn read_lock<R>(&self, f: impl FnOnce() -> R) -> R { fn read_lock<R>(&self, f: impl FnOnce() -> R) -> R {
assert!(in_thread_mode(), "ThreadModeRwLock can only be locked from thread mode."); assert!(
in_thread_mode(),
"ThreadModeRwLock can only be locked from thread mode."
);
f() f()
} }
fn write_lock<R>(&self, f: impl FnOnce() -> R) -> R { fn write_lock<R>(&self, f: impl FnOnce() -> R) -> R {
assert!(in_thread_mode(), "ThreadModeRwLock can only be locked from thread mode."); assert!(
in_thread_mode(),
"ThreadModeRwLock can only be locked from thread mode."
);
f() f()
} }

View File

@ -1,134 +1,160 @@
use core::cell::UnsafeCell; //! Async read-write lock.
use core::future::poll_fn; //!
//! This module provides a read-write lock that can be used to synchronize data between asynchronous tasks.
use core::cell::{RefCell, UnsafeCell};
use core::future::{poll_fn, Future};
use core::ops::{Deref, DerefMut}; use core::ops::{Deref, DerefMut};
use core::task::Poll; use core::task::Poll;
use core::{fmt, mem};
use crate::blocking_mutex::Mutex as BlockingMutex; use crate::blocking_mutex::raw::RawRwLock;
use crate::blocking_mutex::RwLock as BlockingRwLock;
use crate::waitqueue::WakerRegistration; use crate::waitqueue::WakerRegistration;
use crate::raw_rwlock::RawRwLock;
pub struct RwLock<M, T> /// Error returned by [`RwLock::try_read_lock`] and [`RwLock::try_write_lock`]
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct TryLockError;
struct State {
readers: usize,
writer: bool,
waker: WakerRegistration,
}
/// Async read-write lock.
///
/// The read-write lock is generic over a blocking [`RawRwLock`](crate::blocking_mutex::raw_rwlock::RawRwLock).
/// The raw read-write lock is used to guard access to the internal state. It
/// is held for very short periods only, while locking and unlocking. It is *not* held
/// for the entire time the async RwLock is locked.
///
/// Which implementation you select depends on the context in which you're using the read-write lock.
///
/// Use [`CriticalSectionRawRwLock`](crate::blocking_mutex::raw_rwlock::CriticalSectionRawRwLock) when data can be shared between threads and interrupts.
///
/// Use [`NoopRawRwLock`](crate::blocking_mutex::raw_rwlock::NoopRawRwLock) when data is only shared between tasks running on the same executor.
///
/// Use [`ThreadModeRawRwLock`](crate::blocking_mutex::raw_rwlock::ThreadModeRawRwLock) when data is shared between tasks running on the same executor but you want a singleton.
///
pub struct RwLock<R, T>
where where
M: RawRwLock, R: RawRwLock,
T: ?Sized, T: ?Sized,
{ {
state: BlockingMutex<M, RwLockState>, state: BlockingRwLock<R, RefCell<State>>,
inner: UnsafeCell<T>, inner: UnsafeCell<T>,
} }
unsafe impl<M: RawRwLock + Send, T: ?Sized + Send> Send for RwLock<M, T> {} unsafe impl<R: RawRwLock + Send, T: ?Sized + Send> Send for RwLock<R, T> {}
unsafe impl<M: RawRwLock + Sync, T: ?Sized + Send> Sync for RwLock<M, T> {} unsafe impl<R: RawRwLock + Sync, T: ?Sized + Send> Sync for RwLock<R, T> {}
impl<M, T> RwLock<M, T> /// Async read-write lock.
impl<R, T> RwLock<R, T>
where where
M: RawRwLock, R: RawRwLock,
{ {
/// Create a new read-write lock with the given value.
pub const fn new(value: T) -> Self { pub const fn new(value: T) -> Self {
Self { Self {
inner: UnsafeCell::new(value), inner: UnsafeCell::new(value),
state: BlockingMutex::new(RwLockState { state: BlockingRwLock::new(RefCell::new(State {
locked: LockedState::Unlocked, readers: 0,
writer_pending: 0, writer: false,
readers_pending: 0,
waker: WakerRegistration::new(), waker: WakerRegistration::new(),
}), })),
} }
} }
} }
impl<M, T> RwLock<M, T> impl<R, T> RwLock<R, T>
where where
M: RawRwLock, R: RawRwLock,
T: ?Sized, T: ?Sized,
{ {
pub fn read(&self) -> impl Future<Output = RwLockReadGuard<'_, M, T>> { /// Lock the read-write lock for reading.
///
/// This will wait for the lock to be available if it's already locked for writing.
pub fn read_lock(&self) -> impl Future<Output = RwLockReadGuard<'_, R, T>> {
poll_fn(|cx| { poll_fn(|cx| {
let ready = self.state.lock(|s| { let ready = self.state.lock(|s| {
let mut s = s.borrow_mut(); let mut s = s.borrow_mut();
match s.locked { if s.writer {
LockedState::Unlocked => {
s.locked = LockedState::ReadLocked(1);
true
}
LockedState::ReadLocked(ref mut count) => {
*count += 1;
true
}
LockedState::WriteLocked => {
s.readers_pending += 1;
s.waker.register(cx.waker()); s.waker.register(cx.waker());
false false
} } else {
s.readers += 1;
true
} }
}); });
if ready { if ready {
Poll::Ready(RwLockReadGuard { lock: self }) Poll::Ready(RwLockReadGuard { rwlock: self })
} else { } else {
Poll::Pending Poll::Pending
} }
}) })
} }
pub fn write(&self) -> impl Future<Output = RwLockWriteGuard<'_, M, T>> { /// Lock the read-write lock for writing.
///
/// This will wait for the lock to be available if it's already locked for reading or writing.
pub fn write_lock(&self) -> impl Future<Output = RwLockWriteGuard<'_, R, T>> {
poll_fn(|cx| { poll_fn(|cx| {
let ready = self.state.lock(|s| { let ready = self.state.lock(|s| {
let mut s = s.borrow_mut(); let mut s = s.borrow_mut();
match s.locked { if s.readers > 0 || s.writer {
LockedState::Unlocked => {
s.locked = LockedState::WriteLocked;
true
}
_ => {
s.writer_pending += 1;
s.waker.register(cx.waker()); s.waker.register(cx.waker());
false false
} } else {
s.writer = true;
true
} }
}); });
if ready { if ready {
Poll::Ready(RwLockWriteGuard { lock: self }) Poll::Ready(RwLockWriteGuard { rwlock: self })
} else { } else {
Poll::Pending Poll::Pending
} }
}) })
} }
pub fn try_read(&self) -> Result<RwLockReadGuard<'_, M, T>, TryLockError> { /// Attempt to immediately lock the read-write lock for reading.
///
/// If the lock is already locked for writing, this will return an error instead of waiting.
pub fn try_read_lock(&self) -> Result<RwLockReadGuard<'_, R, T>, TryLockError> {
self.state.lock(|s| { self.state.lock(|s| {
let mut s = s.borrow_mut(); let mut s = s.borrow_mut();
match s.locked { if s.writer {
LockedState::Unlocked => { Err(TryLockError)
s.locked = LockedState::ReadLocked(1); } else {
s.readers += 1;
Ok(()) Ok(())
} }
LockedState::ReadLocked(ref mut count) => {
*count += 1;
Ok(())
}
LockedState::WriteLocked => Err(TryLockError),
}
})?; })?;
Ok(RwLockReadGuard { lock: self }) Ok(RwLockReadGuard { rwlock: self })
} }
pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, M, T>, TryLockError> { /// Attempt to immediately lock the read-write lock for writing.
///
/// If the lock is already locked for reading or writing, this will return an error instead of waiting.
pub fn try_write_lock(&self) -> Result<RwLockWriteGuard<'_, R, T>, TryLockError> {
self.state.lock(|s| { self.state.lock(|s| {
let mut s = s.borrow_mut(); let mut s = s.borrow_mut();
match s.locked { if s.readers > 0 || s.writer {
LockedState::Unlocked => { Err(TryLockError)
s.locked = LockedState::WriteLocked; } else {
s.writer = true;
Ok(()) Ok(())
} }
_ => Err(TryLockError),
}
})?; })?;
Ok(RwLockWriteGuard { lock: self }) Ok(RwLockWriteGuard { rwlock: self })
} }
/// Consumes this read-write lock, returning the underlying data.
pub fn into_inner(self) -> T pub fn into_inner(self) -> T
where where
T: Sized, T: Sized,
@ -136,20 +162,24 @@ where
self.inner.into_inner() self.inner.into_inner()
} }
/// Returns a mutable reference to the underlying data.
///
/// Since this call borrows the RwLock mutably, no actual locking needs to
/// take place -- the mutable borrow statically guarantees no locks exist.
pub fn get_mut(&mut self) -> &mut T { pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut() self.inner.get_mut()
} }
} }
impl<M: RawRwLock, T> From<T> for RwLock<M, T> { impl<R: RawRwLock, T> From<T> for RwLock<R, T> {
fn from(from: T) -> Self { fn from(from: T) -> Self {
Self::new(from) Self::new(from)
} }
} }
impl<M, T> Default for RwLock<M, T> impl<R, T> Default for RwLock<R, T>
where where
M: RawRwLock, R: RawRwLock,
T: Default, T: Default,
{ {
fn default() -> Self { fn default() -> Self {
@ -157,103 +187,203 @@ where
} }
} }
pub struct RwLockReadGuard<'a, M, T> impl<R, T> fmt::Debug for RwLock<R, T>
where where
M: RawRwLock, R: RawRwLock,
T: ?Sized + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_struct("RwLock");
match self.try_write_lock() {
Ok(value) => {
d.field("inner", &&*value);
}
Err(TryLockError) => {
d.field("inner", &format_args!("<locked>"));
}
}
d.finish_non_exhaustive()
}
}
/// Async read lock guard.
///
/// Owning an instance of this type indicates having
/// successfully locked the read-write lock for reading, and grants access to the contents.
///
/// Dropping it unlocks the read-write lock.
#[clippy::has_significant_drop]
#[must_use = "if unused the RwLock will immediately unlock"]
pub struct RwLockReadGuard<'a, R, T>
where
R: RawRwLock,
T: ?Sized, T: ?Sized,
{ {
lock: &'a RwLock<M, T>, rwlock: &'a RwLock<R, T>,
} }
impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T> impl<'a, R, T> Drop for RwLockReadGuard<'a, R, T>
where where
M: RawRwLock, R: RawRwLock,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.inner.get() }
}
}
impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T>
where
M: RawRwLock,
T: ?Sized, T: ?Sized,
{ {
fn drop(&mut self) { fn drop(&mut self) {
self.lock.state.lock(|s| { self.rwlock.state.lock(|s| {
let mut s = s.borrow_mut(); let mut s = unwrap!(s.try_borrow_mut());
match s.locked { s.readers -= 1;
LockedState::ReadLocked(ref mut count) => { if s.readers == 0 {
*count -= 1;
if *count == 0 {
s.locked = LockedState::Unlocked;
s.waker.wake(); s.waker.wake();
} }
} })
_ => unreachable!(),
}
});
} }
} }
pub struct RwLockWriteGuard<'a, M, T> impl<'a, R, T> Deref for RwLockReadGuard<'a, R, T>
where where
M: RawRwLock, R: RawRwLock,
T: ?Sized,
{
lock: &'a RwLock<M, T>,
}
impl<'a, M, T> Deref for RwLockWriteGuard<'a, M, T>
where
M: RawRwLock,
T: ?Sized, T: ?Sized,
{ {
type Target = T; type Target = T;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.inner.get() } // Safety: the RwLockReadGuard represents shared access to the contents
// of the read-write lock, so it's OK to get it.
unsafe { &*(self.rwlock.inner.get() as *const T) }
} }
} }
impl<'a, M, T> DerefMut for RwLockWriteGuard<'a, M, T> impl<'a, R, T> fmt::Debug for RwLockReadGuard<'a, R, T>
where where
M: RawRwLock, R: RawRwLock,
T: ?Sized + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<'a, R, T> fmt::Display for RwLockReadGuard<'a, R, T>
where
R: RawRwLock,
T: ?Sized + fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}
/// Async write lock guard.
///
/// Owning an instance of this type indicates having
/// successfully locked the read-write lock for writing, and grants access to the contents.
///
/// Dropping it unlocks the read-write lock.
#[clippy::has_significant_drop]
#[must_use = "if unused the RwLock will immediately unlock"]
pub struct RwLockWriteGuard<'a, R, T>
where
R: RawRwLock,
T: ?Sized,
{
rwlock: &'a RwLock<R, T>,
}
impl<'a, R, T> Drop for RwLockWriteGuard<'a, R, T>
where
R: RawRwLock,
T: ?Sized,
{
fn drop(&mut self) {
self.rwlock.state.lock(|s| {
let mut s = unwrap!(s.try_borrow_mut());
s.writer = false;
s.waker.wake();
})
}
}
impl<'a, R, T> Deref for RwLockWriteGuard<'a, R, T>
where
R: RawRwLock,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
// Safety: the RwLockWriteGuard represents exclusive access to the contents
// of the read-write lock, so it's OK to get it.
unsafe { &*(self.rwlock.inner.get() as *mut T) }
}
}
impl<'a, R, T> DerefMut for RwLockWriteGuard<'a, R, T>
where
R: RawRwLock,
T: ?Sized, T: ?Sized,
{ {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.inner.get() } // Safety: the RwLockWriteGuard represents exclusive access to the contents
// of the read-write lock, so it's OK to get it.
unsafe { &mut *(self.rwlock.inner.get()) }
} }
} }
impl<'a, M, T> Drop for RwLockWriteGuard<'a, M, T> impl<'a, R, T> fmt::Debug for RwLockWriteGuard<'a, R, T>
where where
M: RawRwLock, R: RawRwLock,
T: ?Sized, T: ?Sized + fmt::Debug,
{ {
fn drop(&mut self) { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.lock.state.lock(|s| { fmt::Debug::fmt(&**self, f)
let mut s = s.borrow_mut();
s.locked = LockedState::Unlocked;
s.waker.wake();
});
} }
} }
struct RwLockState { impl<'a, R, T> fmt::Display for RwLockWriteGuard<'a, R, T>
locked: LockedState, where
writer_pending: usize, R: RawRwLock,
readers_pending: usize, T: ?Sized + fmt::Display,
waker: WakerRegistration, {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
} }
enum LockedState { #[cfg(test)]
Unlocked, mod tests {
ReadLocked(usize), use crate::blocking_mutex::raw_rwlock::NoopRawRwLock;
WriteLocked, use crate::rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
#[futures_test::test]
async fn read_guard_releases_lock_when_dropped() {
let rwlock: RwLock<NoopRawRwLock, [i32; 2]> = RwLock::new([0, 1]);
{
let guard = rwlock.read_lock().await;
assert_eq!(*guard, [0, 1]);
} }
pub struct TryLockError; {
let guard = rwlock.read_lock().await;
assert_eq!(*guard, [0, 1]);
}
assert_eq!(*rwlock.read_lock().await, [0, 1]);
}
#[futures_test::test]
async fn write_guard_releases_lock_when_dropped() {
let rwlock: RwLock<NoopRawRwLock, [i32; 2]> = RwLock::new([0, 1]);
{
let mut guard = rwlock.write_lock().await;
assert_eq!(*guard, [0, 1]);
guard[1] = 2;
}
{
let guard = rwlock.read_lock().await;
assert_eq!(*guard, [0, 2]);
}
assert_eq!(*rwlock.read_lock().await, [0, 2]);
}
}