diff --git a/embassy-executor/src/raw/mod.rs b/embassy-executor/src/raw/mod.rs index bcbd214a9..808a78389 100644 --- a/embassy-executor/src/raw/mod.rs +++ b/embassy-executor/src/raw/mod.rs @@ -14,7 +14,58 @@ mod run_queue; #[cfg_attr(all(cortex_m, target_has_atomic = "8"), path = "state_atomics_arm.rs")] #[cfg_attr(all(not(cortex_m), target_has_atomic = "8"), path = "state_atomics.rs")] #[cfg_attr(not(target_has_atomic = "8"), path = "state_critical_section.rs")] -mod state; +pub(crate) mod state; + +#[cfg(target_has_atomic = "ptr")] +mod owner { + use core::sync::atomic::{AtomicPtr, Ordering}; + + use super::{state::Token, SyncExecutor}; + + pub(crate) struct ExecutorRef(AtomicPtr); + + impl ExecutorRef { + pub const fn new() -> Self { + Self(AtomicPtr::new(core::ptr::null_mut())) + } + + pub fn set(&self, executor: Option<&'static SyncExecutor>, _: Token) { + let ptr = executor.map(|e| e as *const SyncExecutor).unwrap_or(core::ptr::null()); + self.0.store(ptr.cast_mut(), Ordering::Release); + } + + pub fn get(&self, _: Token) -> *const SyncExecutor { + self.0.load(Ordering::Acquire).cast_const() + } + } +} +#[cfg(not(target_has_atomic = "ptr"))] +mod owner { + use super::{state::Token, SyncExecutor}; + use core::cell::Cell; + + use critical_section::Mutex; + + pub(crate) struct ExecutorRef(Mutex>); + + unsafe impl Send for ExecutorRef {} + unsafe impl Sync for ExecutorRef {} + + impl ExecutorRef { + pub const fn new() -> Self { + Self(Mutex::new(Cell::new(core::ptr::null()))) + } + + pub fn set(&self, executor: Option<&'static SyncExecutor>, cs: Token) { + let ptr = executor.map(|e| e as *const SyncExecutor).unwrap_or(core::ptr::null()); + self.0.borrow(cs).set(ptr); + } + + pub fn get(&self, cs: Token) -> *const SyncExecutor { + self.0.borrow(cs).get() + } + } +} pub mod timer_queue; #[cfg(feature = "trace")] @@ -30,6 +81,8 @@ use core::pin::Pin; use core::ptr::NonNull; use core::task::{Context, Poll}; +use crate::raw::owner::ExecutorRef; + use self::run_queue::{RunQueue, RunQueueItem}; use self::state::State; use self::util::{SyncUnsafeCell, UninitCell}; @@ -40,7 +93,7 @@ use super::SpawnToken; pub(crate) struct TaskHeader { pub(crate) state: State, pub(crate) run_queue_item: RunQueueItem, - pub(crate) executor: SyncUnsafeCell>, + pub(crate) executor: ExecutorRef, poll_fn: SyncUnsafeCell>, /// Integrated timer queue storage. This field should not be accessed outside of the timer queue. @@ -86,7 +139,8 @@ impl TaskRef { /// Returns a reference to the executor that the task is currently running on. pub unsafe fn executor(self) -> Option<&'static Executor> { - self.header().executor.get().map(|e| Executor::wrap(e)) + let executor = state::locked(|token| self.header().executor.get(token)); + executor.as_ref().map(|e| Executor::wrap(e)) } /// Returns a reference to the timer queue item. @@ -153,7 +207,7 @@ impl TaskStorage { raw: TaskHeader { state: State::new(), run_queue_item: RunQueueItem::new(), - executor: SyncUnsafeCell::new(None), + executor: ExecutorRef::new(), // Note: this is lazily initialized so that a static `TaskStorage` will go in `.bss` poll_fn: SyncUnsafeCell::new(None), @@ -396,7 +450,9 @@ impl SyncExecutor { } pub(super) unsafe fn spawn(&'static self, task: TaskRef) { - task.header().executor.set(Some(self)); + state::locked(|l| { + task.header().executor.set(Some(self), l); + }); #[cfg(feature = "trace")] trace::task_new(self, &task); @@ -549,7 +605,7 @@ pub fn wake_task(task: TaskRef) { header.state.run_enqueue(|l| { // We have just marked the task as scheduled, so enqueue it. unsafe { - let executor = header.executor.get().unwrap_unchecked(); + let executor = header.executor.get(l).as_ref().unwrap_unchecked(); executor.enqueue(task, l); } }); @@ -563,7 +619,7 @@ pub fn wake_task_no_pend(task: TaskRef) { header.state.run_enqueue(|l| { // We have just marked the task as scheduled, so enqueue it. unsafe { - let executor = header.executor.get().unwrap_unchecked(); + let executor = header.executor.get(l).as_ref().unwrap_unchecked(); executor.run_queue.enqueue(task, l); } }); diff --git a/embassy-executor/src/raw/state_atomics.rs b/embassy-executor/src/raw/state_atomics.rs index abfe94486..d7350464f 100644 --- a/embassy-executor/src/raw/state_atomics.rs +++ b/embassy-executor/src/raw/state_atomics.rs @@ -2,13 +2,14 @@ use core::sync::atomic::{AtomicU32, Ordering}; use super::timer_queue::TimerEnqueueOperation; +#[derive(Clone, Copy)] pub(crate) struct Token(()); /// Creates a token and passes it to the closure. /// /// This is a no-op replacement for `CriticalSection::with` because we don't need any locking. -pub(crate) fn locked(f: impl FnOnce(Token)) { - f(Token(())); +pub(crate) fn locked(f: impl FnOnce(Token) -> R) -> R { + f(Token(())) } /// Task is spawned (has a future) diff --git a/embassy-executor/src/raw/state_atomics_arm.rs b/embassy-executor/src/raw/state_atomics_arm.rs index f0f014652..c1e8f69ab 100644 --- a/embassy-executor/src/raw/state_atomics_arm.rs +++ b/embassy-executor/src/raw/state_atomics_arm.rs @@ -3,13 +3,14 @@ use core::sync::atomic::{compiler_fence, AtomicBool, AtomicU32, Ordering}; use super::timer_queue::TimerEnqueueOperation; +#[derive(Clone, Copy)] pub(crate) struct Token(()); /// Creates a token and passes it to the closure. /// /// This is a no-op replacement for `CriticalSection::with` because we don't need any locking. -pub(crate) fn locked(f: impl FnOnce(Token)) { - f(Token(())); +pub(crate) fn locked(f: impl FnOnce(Token) -> R) -> R { + f(Token(())) } // Must be kept in sync with the layout of `State`! diff --git a/embassy-executor/src/spawner.rs b/embassy-executor/src/spawner.rs index 271606244..bc243bee7 100644 --- a/embassy-executor/src/spawner.rs +++ b/embassy-executor/src/spawner.rs @@ -92,7 +92,9 @@ impl Spawner { pub async fn for_current_executor() -> Self { poll_fn(|cx| { let task = raw::task_from_waker(cx.waker()); - let executor = unsafe { task.header().executor.get().unwrap_unchecked() }; + let executor = raw::state::locked(|l| { + unsafe { task.header().executor.get(l).as_ref().unwrap_unchecked() } + }); let executor = unsafe { raw::Executor::wrap(executor) }; Poll::Ready(Self::new(executor)) }) @@ -164,7 +166,9 @@ impl SendSpawner { pub async fn for_current_executor() -> Self { poll_fn(|cx| { let task = raw::task_from_waker(cx.waker()); - let executor = unsafe { task.header().executor.get().unwrap_unchecked() }; + let executor = raw::state::locked(|l| { + unsafe { task.header().executor.get(l).as_ref().unwrap_unchecked() } + }); Poll::Ready(Self::new(executor)) }) .await