This commit is contained in:
Alix ANNERAUD 2025-02-28 16:10:15 +01:00
parent 025d9f6e98
commit 6904b0cc64
2 changed files with 235 additions and 146 deletions

View File

@ -0,0 +1,86 @@
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::Waker;
use core::cell::UnsafeCell;
pub trait RawRwLock {
fn lock_read(&self);
fn try_lock_read(&self) -> bool;
fn unlock_read(&self);
fn lock_write(&self);
fn try_lock_write(&self) -> bool;
fn unlock_write(&self);
}
pub struct RawRwLockImpl {
state: AtomicUsize,
waker: UnsafeCell<Option<Waker>>,
}
impl RawRwLockImpl {
pub const fn new() -> Self {
Self {
state: AtomicUsize::new(0),
waker: UnsafeCell::new(None),
}
}
}
unsafe impl Send for RawRwLockImpl {}
unsafe impl Sync for RawRwLockImpl {}
impl RawRwLock for RawRwLockImpl {
fn lock_read(&self) {
loop {
let state = self.state.load(Ordering::Acquire);
if state & 1 == 0 {
if self.state.compare_and_swap(state, state + 2, Ordering::AcqRel) == state {
break;
}
}
}
}
fn try_lock_read(&self) -> bool {
let state = self.state.load(Ordering::Acquire);
if state & 1 == 0 {
if self.state.compare_and_swap(state, state + 2, Ordering::AcqRel) == state {
return true;
}
}
false
}
fn unlock_read(&self) {
self.state.fetch_sub(2, Ordering::Release);
if self.state.load(Ordering::Acquire) == 0 {
if let Some(waker) = unsafe { &*self.waker.get() } {
waker.wake_by_ref();
}
}
}
fn lock_write(&self) {
loop {
let state = self.state.load(Ordering::Acquire);
if state == 0 {
if self.state.compare_and_swap(0, 1, Ordering::AcqRel) == 0 {
break;
}
}
}
}
fn try_lock_write(&self) -> bool {
if self.state.compare_and_swap(0, 1, Ordering::AcqRel) == 0 {
return true;
}
false
}
fn unlock_write(&self) {
self.state.store(0, Ordering::Release);
if let Some(waker) = unsafe { &*self.waker.get() } {
waker.wake_by_ref();
}
}
}

View File

@ -1,136 +1,134 @@
use core::cell::RefCell;
use core::future::{poll_fn, Future};
use core::cell::UnsafeCell;
use core::future::poll_fn;
use core::ops::{Deref, DerefMut};
use core::task::Poll;
use crate::blocking_mutex::raw::RawMutex;
use crate::blocking_mutex::Mutex as BlockingMutex;
use crate::waitqueue::MultiWakerRegistration;
use crate::waitqueue::WakerRegistration;
use crate::raw_rwlock::RawRwLock;
/// Error returned by [`RwLock::try_read`] and [`RwLock::try_write`]
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct TryLockError;
/// Async read-write lock.
///
/// The lock is generic over a blocking [`RawMutex`](crate::blocking_mutex::raw::RawMutex).
/// The raw mutex is used to guard access to the internal state. It
/// is held for very short periods only, while locking and unlocking. It is *not* held
/// for the entire time the async RwLock is locked.
///
/// Which implementation you select depends on the context in which you're using the lock.
///
/// Use [`CriticalSectionRawMutex`](crate::blocking_mutex::raw::CriticalSectionRawMutex) when data can be shared between threads and interrupts.
///
/// Use [`NoopRawMutex`](crate::blocking_mutex::raw::NoopRawMutex) when data is only shared between tasks running on the same executor.
///
/// Use [`ThreadModeRawMutex`](crate::blocking_mutex::raw::ThreadModeRawMutex) when data is shared between tasks running on the same executor but you want a singleton.
///
pub struct RwLock<M, T>
where
M: RawMutex,
M: RawRwLock,
T: ?Sized,
{
state: BlockingMutex<M, RefCell<State>>,
inner: RefCell<T>,
state: BlockingMutex<M, RwLockState>,
inner: UnsafeCell<T>,
}
struct State {
readers: usize,
writer: bool,
writer_waker: MultiWakerRegistration<1>,
reader_wakers: MultiWakerRegistration<8>,
}
impl State {
fn new() -> Self {
Self {
readers: 0,
writer: false,
writer_waker: MultiWakerRegistration::new(),
reader_wakers: MultiWakerRegistration::new(),
}
}
}
unsafe impl<M: RawRwLock + Send, T: ?Sized + Send> Send for RwLock<M, T> {}
unsafe impl<M: RawRwLock + Sync, T: ?Sized + Send> Sync for RwLock<M, T> {}
impl<M, T> RwLock<M, T>
where
M: RawMutex,
M: RawRwLock,
{
/// Create a new read-write lock with the given value.
pub const fn new(value: T) -> Self {
Self {
inner: RefCell::new(value),
state: BlockingMutex::new(RefCell::new(State::new())),
inner: UnsafeCell::new(value),
state: BlockingMutex::new(RwLockState {
locked: LockedState::Unlocked,
writer_pending: 0,
readers_pending: 0,
waker: WakerRegistration::new(),
}),
}
}
}
impl<M, T> RwLock<M, T>
where
M: RawMutex,
M: RawRwLock,
T: ?Sized,
{
/// Acquire a read lock.
///
/// This will wait for the lock to be available if it's already locked for writing.
pub fn read(&self) -> impl Future<Output = RwLockReadGuard<'_, M, T>> {
poll_fn(|cx| {
let mut state = self.state.lock(|s| s.borrow_mut());
if state.writer {
state.reader_wakers.register(cx.waker());
Poll::Pending
} else {
state.readers += 1;
let ready = self.state.lock(|s| {
let mut s = s.borrow_mut();
match s.locked {
LockedState::Unlocked => {
s.locked = LockedState::ReadLocked(1);
true
}
LockedState::ReadLocked(ref mut count) => {
*count += 1;
true
}
LockedState::WriteLocked => {
s.readers_pending += 1;
s.waker.register(cx.waker());
false
}
}
});
if ready {
Poll::Ready(RwLockReadGuard { lock: self })
} else {
Poll::Pending
}
})
}
/// Acquire a write lock.
///
/// This will wait for the lock to be available if it's already locked for reading or writing.
pub fn write(&self) -> impl Future<Output = RwLockWriteGuard<'_, M, T>> {
poll_fn(|cx| {
let mut state = self.state.lock(|s| s.borrow_mut());
if state.writer || state.readers > 0 {
state.writer_waker.register(cx.waker());
Poll::Pending
} else {
state.writer = true;
let ready = self.state.lock(|s| {
let mut s = s.borrow_mut();
match s.locked {
LockedState::Unlocked => {
s.locked = LockedState::WriteLocked;
true
}
_ => {
s.writer_pending += 1;
s.waker.register(cx.waker());
false
}
}
});
if ready {
Poll::Ready(RwLockWriteGuard { lock: self })
} else {
Poll::Pending
}
})
}
/// Attempt to immediately acquire a read lock.
///
/// If the lock is already locked for writing, this will return an error instead of waiting.
pub fn try_read(&self) -> Result<RwLockReadGuard<'_, M, T>, TryLockError> {
let mut state = self.state.lock(|s| s.borrow_mut());
if state.writer {
Err(TryLockError)
} else {
state.readers += 1;
Ok(RwLockReadGuard { lock: self })
}
self.state.lock(|s| {
let mut s = s.borrow_mut();
match s.locked {
LockedState::Unlocked => {
s.locked = LockedState::ReadLocked(1);
Ok(())
}
LockedState::ReadLocked(ref mut count) => {
*count += 1;
Ok(())
}
LockedState::WriteLocked => Err(TryLockError),
}
})?;
Ok(RwLockReadGuard { lock: self })
}
/// Attempt to immediately acquire a write lock.
///
/// If the lock is already locked for reading or writing, this will return an error instead of waiting.
pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, M, T>, TryLockError> {
let mut state = self.state.lock(|s| s.borrow_mut());
if state.writer || state.readers > 0 {
Err(TryLockError)
} else {
state.writer = true;
Ok(RwLockWriteGuard { lock: self })
}
self.state.lock(|s| {
let mut s = s.borrow_mut();
match s.locked {
LockedState::Unlocked => {
s.locked = LockedState::WriteLocked;
Ok(())
}
_ => Err(TryLockError),
}
})?;
Ok(RwLockWriteGuard { lock: self })
}
/// Consumes this lock, returning the underlying data.
pub fn into_inner(self) -> T
where
T: Sized,
@ -138,19 +136,12 @@ where
self.inner.into_inner()
}
/// Returns a mutable reference to the underlying data.
///
/// Since this call borrows the RwLock mutably, no actual locking needs to
/// take place -- the mutable borrow statically guarantees no locks exist.
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
impl<M, T> From<T> for RwLock<M, T>
where
M: RawMutex,
{
impl<M: RawRwLock, T> From<T> for RwLock<M, T> {
fn from(from: T) -> Self {
Self::new(from)
}
@ -158,7 +149,7 @@ where
impl<M, T> Default for RwLock<M, T>
where
M: RawMutex,
M: RawRwLock,
T: Default,
{
fn default() -> Self {
@ -166,91 +157,103 @@ where
}
}
/// Async read lock guard.
///
/// Owning an instance of this type indicates having
/// successfully locked the RwLock for reading, and grants access to the contents.
///
/// Dropping it unlocks the RwLock.
#[must_use = "if unused the RwLock will immediately unlock"]
pub struct RwLockReadGuard<'a, M, T>
where
M: RawMutex,
M: RawRwLock,
T: ?Sized,
{
lock: &'a RwLock<M, T>,
}
impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
fn drop(&mut self) {
let mut state = self.lock.state.lock(|s| s.borrow_mut());
state.readers -= 1;
if state.readers == 0 {
state.writer_waker.wake();
}
}
}
impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T>
where
M: RawMutex,
M: RawRwLock,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
self.lock.inner.borrow()
unsafe { &*self.lock.inner.get() }
}
}
impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T>
where
M: RawRwLock,
T: ?Sized,
{
fn drop(&mut self) {
self.lock.state.lock(|s| {
let mut s = s.borrow_mut();
match s.locked {
LockedState::ReadLocked(ref mut count) => {
*count -= 1;
if *count == 0 {
s.locked = LockedState::Unlocked;
s.waker.wake();
}
}
_ => unreachable!(),
}
});
}
}
/// Async write lock guard.
///
/// Owning an instance of this type indicates having
/// successfully locked the RwLock for writing, and grants access to the contents.
///
/// Dropping it unlocks the RwLock.
#[must_use = "if unused the RwLock will immediately unlock"]
pub struct RwLockWriteGuard<'a, M, T>
where
M: RawMutex,
M: RawRwLock,
T: ?Sized,
{
lock: &'a RwLock<M, T>,
}
impl<'a, M, T> Drop for RwLockWriteGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
fn drop(&mut self) {
let mut state = self.lock.state.lock(|s| s.borrow_mut());
state.writer = false;
state.reader_wakers.wake();
state.writer_waker.wake();
}
}
impl<'a, M, T> Deref for RwLockWriteGuard<'a, M, T>
where
M: RawMutex,
M: RawRwLock,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
self.lock.inner.borrow()
unsafe { &*self.lock.inner.get() }
}
}
impl<'a, M, T> DerefMut for RwLockWriteGuard<'a, M, T>
where
M: RawMutex,
M: RawRwLock,
T: ?Sized,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.lock.inner.borrow_mut()
unsafe { &mut *self.lock.inner.get() }
}
}
impl<'a, M, T> Drop for RwLockWriteGuard<'a, M, T>
where
M: RawRwLock,
T: ?Sized,
{
fn drop(&mut self) {
self.lock.state.lock(|s| {
let mut s = s.borrow_mut();
s.locked = LockedState::Unlocked;
s.waker.wake();
});
}
}
struct RwLockState {
locked: LockedState,
writer_pending: usize,
readers_pending: usize,
waker: WakerRegistration,
}
enum LockedState {
Unlocked,
ReadLocked(usize),
WriteLocked,
}
pub struct TryLockError;