Rely on atomic load-store on all targets

This commit is contained in:
Dániel Buga 2024-12-16 17:24:17 +01:00
parent b44ef5ccb4
commit b47a631abf
No known key found for this signature in database
2 changed files with 25 additions and 68 deletions

View File

@ -14,58 +14,7 @@ mod run_queue;
#[cfg_attr(all(cortex_m, target_has_atomic = "8"), path = "state_atomics_arm.rs")] #[cfg_attr(all(cortex_m, target_has_atomic = "8"), path = "state_atomics_arm.rs")]
#[cfg_attr(all(not(cortex_m), target_has_atomic = "8"), path = "state_atomics.rs")] #[cfg_attr(all(not(cortex_m), target_has_atomic = "8"), path = "state_atomics.rs")]
#[cfg_attr(not(target_has_atomic = "8"), path = "state_critical_section.rs")] #[cfg_attr(not(target_has_atomic = "8"), path = "state_critical_section.rs")]
pub(crate) mod state; mod state;
#[cfg(target_has_atomic = "ptr")]
mod owner {
use core::sync::atomic::{AtomicPtr, Ordering};
use super::{state::Token, SyncExecutor};
pub(crate) struct ExecutorRef(AtomicPtr<SyncExecutor>);
impl ExecutorRef {
pub const fn new() -> Self {
Self(AtomicPtr::new(core::ptr::null_mut()))
}
pub fn set(&self, executor: Option<&'static SyncExecutor>, _: Token) {
let ptr = executor.map(|e| e as *const SyncExecutor).unwrap_or(core::ptr::null());
self.0.store(ptr.cast_mut(), Ordering::Release);
}
pub fn get(&self, _: Token) -> *const SyncExecutor {
self.0.load(Ordering::Acquire).cast_const()
}
}
}
#[cfg(not(target_has_atomic = "ptr"))]
mod owner {
use super::{state::Token, SyncExecutor};
use core::cell::Cell;
use critical_section::Mutex;
pub(crate) struct ExecutorRef(Mutex<Cell<*const SyncExecutor>>);
unsafe impl Send for ExecutorRef {}
unsafe impl Sync for ExecutorRef {}
impl ExecutorRef {
pub const fn new() -> Self {
Self(Mutex::new(Cell::new(core::ptr::null())))
}
pub fn set(&self, executor: Option<&'static SyncExecutor>, cs: Token) {
let ptr = executor.map(|e| e as *const SyncExecutor).unwrap_or(core::ptr::null());
self.0.borrow(cs).set(ptr);
}
pub fn get(&self, cs: Token) -> *const SyncExecutor {
self.0.borrow(cs).get()
}
}
}
pub mod timer_queue; pub mod timer_queue;
#[cfg(feature = "trace")] #[cfg(feature = "trace")]
@ -79,10 +28,9 @@ use core::marker::PhantomData;
use core::mem; use core::mem;
use core::pin::Pin; use core::pin::Pin;
use core::ptr::NonNull; use core::ptr::NonNull;
use core::sync::atomic::{AtomicPtr, Ordering};
use core::task::{Context, Poll}; use core::task::{Context, Poll};
use crate::raw::owner::ExecutorRef;
use self::run_queue::{RunQueue, RunQueueItem}; use self::run_queue::{RunQueue, RunQueueItem};
use self::state::State; use self::state::State;
use self::util::{SyncUnsafeCell, UninitCell}; use self::util::{SyncUnsafeCell, UninitCell};
@ -93,7 +41,7 @@ use super::SpawnToken;
pub(crate) struct TaskHeader { pub(crate) struct TaskHeader {
pub(crate) state: State, pub(crate) state: State,
pub(crate) run_queue_item: RunQueueItem, pub(crate) run_queue_item: RunQueueItem,
pub(crate) executor: ExecutorRef, pub(crate) executor: AtomicPtr<SyncExecutor>,
poll_fn: SyncUnsafeCell<Option<unsafe fn(TaskRef)>>, poll_fn: SyncUnsafeCell<Option<unsafe fn(TaskRef)>>,
/// Integrated timer queue storage. This field should not be accessed outside of the timer queue. /// Integrated timer queue storage. This field should not be accessed outside of the timer queue.
@ -139,7 +87,7 @@ impl TaskRef {
/// Returns a reference to the executor that the task is currently running on. /// Returns a reference to the executor that the task is currently running on.
pub unsafe fn executor(self) -> Option<&'static Executor> { pub unsafe fn executor(self) -> Option<&'static Executor> {
let executor = state::locked(|token| self.header().executor.get(token)); let executor = self.header().executor.load(Ordering::Relaxed);
executor.as_ref().map(|e| Executor::wrap(e)) executor.as_ref().map(|e| Executor::wrap(e))
} }
@ -207,7 +155,7 @@ impl<F: Future + 'static> TaskStorage<F> {
raw: TaskHeader { raw: TaskHeader {
state: State::new(), state: State::new(),
run_queue_item: RunQueueItem::new(), run_queue_item: RunQueueItem::new(),
executor: ExecutorRef::new(), executor: AtomicPtr::new(core::ptr::null_mut()),
// Note: this is lazily initialized so that a static `TaskStorage` will go in `.bss` // Note: this is lazily initialized so that a static `TaskStorage` will go in `.bss`
poll_fn: SyncUnsafeCell::new(None), poll_fn: SyncUnsafeCell::new(None),
@ -450,9 +398,9 @@ impl SyncExecutor {
} }
pub(super) unsafe fn spawn(&'static self, task: TaskRef) { pub(super) unsafe fn spawn(&'static self, task: TaskRef) {
state::locked(|l| { task.header()
task.header().executor.set(Some(self), l); .executor
}); .store((self as *const Self).cast_mut(), Ordering::Relaxed);
#[cfg(feature = "trace")] #[cfg(feature = "trace")]
trace::task_new(self, &task); trace::task_new(self, &task);
@ -605,7 +553,7 @@ pub fn wake_task(task: TaskRef) {
header.state.run_enqueue(|l| { header.state.run_enqueue(|l| {
// We have just marked the task as scheduled, so enqueue it. // We have just marked the task as scheduled, so enqueue it.
unsafe { unsafe {
let executor = header.executor.get(l).as_ref().unwrap_unchecked(); let executor = header.executor.load(Ordering::Relaxed).as_ref().unwrap_unchecked();
executor.enqueue(task, l); executor.enqueue(task, l);
} }
}); });
@ -619,7 +567,7 @@ pub fn wake_task_no_pend(task: TaskRef) {
header.state.run_enqueue(|l| { header.state.run_enqueue(|l| {
// We have just marked the task as scheduled, so enqueue it. // We have just marked the task as scheduled, so enqueue it.
unsafe { unsafe {
let executor = header.executor.get(l).as_ref().unwrap_unchecked(); let executor = header.executor.load(Ordering::Relaxed).as_ref().unwrap_unchecked();
executor.run_queue.enqueue(task, l); executor.run_queue.enqueue(task, l);
} }
}); });

View File

@ -1,6 +1,7 @@
use core::future::poll_fn; use core::future::poll_fn;
use core::marker::PhantomData; use core::marker::PhantomData;
use core::mem; use core::mem;
use core::sync::atomic::Ordering;
use core::task::Poll; use core::task::Poll;
use super::raw; use super::raw;
@ -92,9 +93,13 @@ impl Spawner {
pub async fn for_current_executor() -> Self { pub async fn for_current_executor() -> Self {
poll_fn(|cx| { poll_fn(|cx| {
let task = raw::task_from_waker(cx.waker()); let task = raw::task_from_waker(cx.waker());
let executor = raw::state::locked(|l| { let executor = unsafe {
unsafe { task.header().executor.get(l).as_ref().unwrap_unchecked() } task.header()
}); .executor
.load(Ordering::Relaxed)
.as_ref()
.unwrap_unchecked()
};
let executor = unsafe { raw::Executor::wrap(executor) }; let executor = unsafe { raw::Executor::wrap(executor) };
Poll::Ready(Self::new(executor)) Poll::Ready(Self::new(executor))
}) })
@ -166,9 +171,13 @@ impl SendSpawner {
pub async fn for_current_executor() -> Self { pub async fn for_current_executor() -> Self {
poll_fn(|cx| { poll_fn(|cx| {
let task = raw::task_from_waker(cx.waker()); let task = raw::task_from_waker(cx.waker());
let executor = raw::state::locked(|l| { let executor = unsafe {
unsafe { task.header().executor.get(l).as_ref().unwrap_unchecked() } task.header()
}); .executor
.load(Ordering::Relaxed)
.as_ref()
.unwrap_unchecked()
};
Poll::Ready(Self::new(executor)) Poll::Ready(Self::new(executor))
}) })
.await .await