Fix FairSemaphore bugs
- `acquire` and `acquire_all` futures were `!Send`, even for `M: RawMutex + Send` due to the captured `Cell`. - If multiple `acquire` tasks were queued, waking the first would not wake the second, even if there were permits remaining after the first `acquire` completed.
This commit is contained in:
parent
1fd260e4b1
commit
c9acebf783
@ -1,8 +1,7 @@
|
|||||||
//! A synchronization primitive for controlling access to a pool of resources.
|
//! A synchronization primitive for controlling access to a pool of resources.
|
||||||
use core::cell::{Cell, RefCell};
|
use core::cell::{Cell, RefCell};
|
||||||
use core::convert::Infallible;
|
use core::convert::Infallible;
|
||||||
use core::future::poll_fn;
|
use core::future::{poll_fn, Future};
|
||||||
use core::mem::MaybeUninit;
|
|
||||||
use core::task::{Poll, Waker};
|
use core::task::{Poll, Waker};
|
||||||
|
|
||||||
use heapless::Deque;
|
use heapless::Deque;
|
||||||
@ -258,9 +257,9 @@ where
|
|||||||
&self,
|
&self,
|
||||||
permits: usize,
|
permits: usize,
|
||||||
acquire_all: bool,
|
acquire_all: bool,
|
||||||
cx: Option<(&Cell<Option<usize>>, &Waker)>,
|
cx: Option<(&mut Option<usize>, &Waker)>,
|
||||||
) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
|
) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
|
||||||
let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None);
|
let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None);
|
||||||
self.state.lock(|cell| {
|
self.state.lock(|cell| {
|
||||||
let mut state = cell.borrow_mut();
|
let mut state = cell.borrow_mut();
|
||||||
if let Some(permits) = state.take(ticket, permits, acquire_all) {
|
if let Some(permits) = state.take(ticket, permits, acquire_all) {
|
||||||
@ -268,10 +267,10 @@ where
|
|||||||
semaphore: self,
|
semaphore: self,
|
||||||
permits,
|
permits,
|
||||||
}))
|
}))
|
||||||
} else if let Some((cell, waker)) = cx {
|
} else if let Some((ticket_ref, waker)) = cx {
|
||||||
match state.register(ticket, waker) {
|
match state.register(ticket, waker) {
|
||||||
Ok(ticket) => {
|
Ok(ticket) => {
|
||||||
cell.set(Some(ticket));
|
*ticket_ref = Some(ticket);
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
Err(err) => Poll::Ready(Err(err)),
|
Err(err) => Poll::Ready(Err(err)),
|
||||||
@ -291,10 +290,12 @@ pub struct WaitQueueFull;
|
|||||||
impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
|
impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
|
||||||
type Error = WaitQueueFull;
|
type Error = WaitQueueFull;
|
||||||
|
|
||||||
async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
|
fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
|
||||||
let ticket = Cell::new(None);
|
FairAcquire {
|
||||||
let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
|
sema: self,
|
||||||
poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await
|
permits,
|
||||||
|
ticket: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
|
fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
|
||||||
@ -304,10 +305,12 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
|
fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
|
||||||
let ticket = Cell::new(None);
|
FairAcquireAll {
|
||||||
let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
|
sema: self,
|
||||||
poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await
|
min,
|
||||||
|
ticket: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
|
fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
|
||||||
@ -338,6 +341,52 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct FairAcquire<'a, M: RawMutex, const N: usize> {
|
||||||
|
sema: &'a FairSemaphore<M, N>,
|
||||||
|
permits: usize,
|
||||||
|
ticket: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.sema
|
||||||
|
.state
|
||||||
|
.lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> {
|
||||||
|
type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
|
||||||
|
|
||||||
|
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
|
||||||
|
self.sema
|
||||||
|
.poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FairAcquireAll<'a, M: RawMutex, const N: usize> {
|
||||||
|
sema: &'a FairSemaphore<M, N>,
|
||||||
|
min: usize,
|
||||||
|
ticket: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.sema
|
||||||
|
.state
|
||||||
|
.lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> {
|
||||||
|
type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
|
||||||
|
|
||||||
|
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
|
||||||
|
self.sema
|
||||||
|
.poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct FairSemaphoreState<const N: usize> {
|
struct FairSemaphoreState<const N: usize> {
|
||||||
permits: usize,
|
permits: usize,
|
||||||
next_ticket: usize,
|
next_ticket: usize,
|
||||||
@ -406,6 +455,9 @@ impl<const N: usize> FairSemaphoreState<N> {
|
|||||||
|
|
||||||
if ticket.is_some() {
|
if ticket.is_some() {
|
||||||
self.pop();
|
self.pop();
|
||||||
|
if self.permits > 0 {
|
||||||
|
self.wake();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(permits)
|
Some(permits)
|
||||||
@ -432,25 +484,6 @@ impl<const N: usize> FairSemaphoreState<N> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A type to delay the drop handler invocation.
|
|
||||||
#[must_use = "to delay the drop handler invocation to the end of the scope"]
|
|
||||||
struct OnDrop<F: FnOnce()> {
|
|
||||||
f: MaybeUninit<F>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: FnOnce()> OnDrop<F> {
|
|
||||||
/// Create a new instance.
|
|
||||||
pub fn new(f: F) -> Self {
|
|
||||||
Self { f: MaybeUninit::new(f) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: FnOnce()> Drop for OnDrop<F> {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
unsafe { self.f.as_ptr().read()() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
mod greedy {
|
mod greedy {
|
||||||
@ -574,11 +607,16 @@ mod tests {
|
|||||||
|
|
||||||
mod fair {
|
mod fair {
|
||||||
use core::pin::pin;
|
use core::pin::pin;
|
||||||
|
use core::time::Duration;
|
||||||
|
|
||||||
|
use futures_executor::ThreadPool;
|
||||||
|
use futures_timer::Delay;
|
||||||
use futures_util::poll;
|
use futures_util::poll;
|
||||||
|
use futures_util::task::SpawnExt;
|
||||||
|
use static_cell::StaticCell;
|
||||||
|
|
||||||
use super::super::*;
|
use super::super::*;
|
||||||
use crate::blocking_mutex::raw::NoopRawMutex;
|
use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn try_acquire() {
|
fn try_acquire() {
|
||||||
@ -700,5 +738,35 @@ mod tests {
|
|||||||
let c = poll!(c_fut.as_mut());
|
let c = poll!(c_fut.as_mut());
|
||||||
assert!(c.is_ready());
|
assert!(c.is_ready());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[futures_test::test]
|
||||||
|
async fn wakers() {
|
||||||
|
let executor = ThreadPool::new().unwrap();
|
||||||
|
|
||||||
|
static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new();
|
||||||
|
let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3));
|
||||||
|
|
||||||
|
let a = semaphore.try_acquire(2);
|
||||||
|
assert!(a.is_some());
|
||||||
|
|
||||||
|
let b_task = executor
|
||||||
|
.spawn_with_handle(async move { semaphore.acquire(2).await })
|
||||||
|
.unwrap();
|
||||||
|
while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) {
|
||||||
|
Delay::new(Duration::from_millis(50)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let c_task = executor
|
||||||
|
.spawn_with_handle(async move { semaphore.acquire(1).await })
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
core::mem::drop(a);
|
||||||
|
|
||||||
|
let b = b_task.await.unwrap();
|
||||||
|
assert_eq!(b.permits(), 2);
|
||||||
|
|
||||||
|
let c = c_task.await.unwrap();
|
||||||
|
assert_eq!(c.permits(), 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user