Merge #1182
1182: executor: Replace `NonNull<TaskHeader>` with `TaskRef` r=Dirbaio a=GrantM11235 Co-authored-by: Grant Miller <GrantM11235@gmail.com>
This commit is contained in:
		
						commit
						c21cc21c62
					
				| @ -43,14 +43,11 @@ pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1; | ||||
| pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2; | ||||
| 
 | ||||
| /// Raw task header for use in task pointers.
 | ||||
| ///
 | ||||
| /// This is an opaque struct, used for raw pointers to tasks, for use
 | ||||
| /// with funtions like [`wake_task`] and [`task_from_waker`].
 | ||||
| pub struct TaskHeader { | ||||
| pub(crate) struct TaskHeader { | ||||
|     pub(crate) state: AtomicU32, | ||||
|     pub(crate) run_queue_item: RunQueueItem, | ||||
|     pub(crate) executor: Cell<*const Executor>, // Valid if state != 0
 | ||||
|     pub(crate) poll_fn: UninitCell<unsafe fn(NonNull<TaskHeader>)>, // Valid if STATE_SPAWNED
 | ||||
|     pub(crate) executor: Cell<*const Executor>,         // Valid if state != 0
 | ||||
|     pub(crate) poll_fn: UninitCell<unsafe fn(TaskRef)>, // Valid if STATE_SPAWNED
 | ||||
| 
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     pub(crate) expires_at: Cell<Instant>, | ||||
| @ -59,7 +56,7 @@ pub struct TaskHeader { | ||||
| } | ||||
| 
 | ||||
| impl TaskHeader { | ||||
|     pub(crate) const fn new() -> Self { | ||||
|     const fn new() -> Self { | ||||
|         Self { | ||||
|             state: AtomicU32::new(0), | ||||
|             run_queue_item: RunQueueItem::new(), | ||||
| @ -74,6 +71,36 @@ impl TaskHeader { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// This is essentially a `&'static TaskStorage<F>` where the type of the future has been erased.
 | ||||
| #[derive(Clone, Copy)] | ||||
| pub struct TaskRef { | ||||
|     ptr: NonNull<TaskHeader>, | ||||
| } | ||||
| 
 | ||||
| impl TaskRef { | ||||
|     fn new<F: Future + 'static>(task: &'static TaskStorage<F>) -> Self { | ||||
|         Self { | ||||
|             ptr: NonNull::from(task).cast(), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Safety: The pointer must have been obtained with `Task::as_ptr`
 | ||||
|     pub(crate) unsafe fn from_ptr(ptr: *const TaskHeader) -> Self { | ||||
|         Self { | ||||
|             ptr: NonNull::new_unchecked(ptr as *mut TaskHeader), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub(crate) fn header(self) -> &'static TaskHeader { | ||||
|         unsafe { self.ptr.as_ref() } | ||||
|     } | ||||
| 
 | ||||
|     /// The returned pointer is valid for the entire TaskStorage.
 | ||||
|     pub(crate) fn as_ptr(self) -> *const TaskHeader { | ||||
|         self.ptr.as_ptr() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Raw storage in which a task can be spawned.
 | ||||
| ///
 | ||||
| /// This struct holds the necessary memory to spawn one task whose future is `F`.
 | ||||
| @ -135,14 +162,14 @@ impl<F: Future + 'static> TaskStorage<F> { | ||||
|             .is_ok() | ||||
|     } | ||||
| 
 | ||||
|     unsafe fn spawn_initialize(&'static self, future: impl FnOnce() -> F) -> NonNull<TaskHeader> { | ||||
|     unsafe fn spawn_initialize(&'static self, future: impl FnOnce() -> F) -> TaskRef { | ||||
|         // Initialize the task
 | ||||
|         self.raw.poll_fn.write(Self::poll); | ||||
|         self.future.write(future()); | ||||
|         NonNull::new_unchecked(self as *const TaskStorage<F> as *const TaskHeader as *mut TaskHeader) | ||||
|         TaskRef::new(self) | ||||
|     } | ||||
| 
 | ||||
|     unsafe fn poll(p: NonNull<TaskHeader>) { | ||||
|     unsafe fn poll(p: TaskRef) { | ||||
|         let this = &*(p.as_ptr() as *const TaskStorage<F>); | ||||
| 
 | ||||
|         let future = Pin::new_unchecked(this.future.as_mut()); | ||||
| @ -307,7 +334,7 @@ impl Executor { | ||||
|     /// - `task` must be set up to run in this executor.
 | ||||
|     /// - `task` must NOT be already enqueued (in this executor or another one).
 | ||||
|     #[inline(always)] | ||||
|     unsafe fn enqueue(&self, cs: CriticalSection, task: NonNull<TaskHeader>) { | ||||
|     unsafe fn enqueue(&self, cs: CriticalSection, task: TaskRef) { | ||||
|         #[cfg(feature = "rtos-trace")] | ||||
|         trace::task_ready_begin(task.as_ptr() as u32); | ||||
| 
 | ||||
| @ -325,8 +352,8 @@ impl Executor { | ||||
|     /// It is OK to use `unsafe` to call this from a thread that's not the executor thread.
 | ||||
|     /// In this case, the task's Future must be Send. This is because this is effectively
 | ||||
|     /// sending the task to the executor thread.
 | ||||
|     pub(super) unsafe fn spawn(&'static self, task: NonNull<TaskHeader>) { | ||||
|         task.as_ref().executor.set(self); | ||||
|     pub(super) unsafe fn spawn(&'static self, task: TaskRef) { | ||||
|         task.header().executor.set(self); | ||||
| 
 | ||||
|         #[cfg(feature = "rtos-trace")] | ||||
|         trace::task_new(task.as_ptr() as u32); | ||||
| @ -359,7 +386,7 @@ impl Executor { | ||||
|             self.timer_queue.dequeue_expired(Instant::now(), |task| wake_task(task)); | ||||
| 
 | ||||
|             self.run_queue.dequeue_all(|p| { | ||||
|                 let task = p.as_ref(); | ||||
|                 let task = p.header(); | ||||
| 
 | ||||
|                 #[cfg(feature = "integrated-timers")] | ||||
|                 task.expires_at.set(Instant::MAX); | ||||
| @ -378,7 +405,7 @@ impl Executor { | ||||
|                 trace::task_exec_begin(p.as_ptr() as u32); | ||||
| 
 | ||||
|                 // Run the task
 | ||||
|                 task.poll_fn.read()(p as _); | ||||
|                 task.poll_fn.read()(p); | ||||
| 
 | ||||
|                 #[cfg(feature = "rtos-trace")] | ||||
|                 trace::task_exec_end(); | ||||
| @ -417,16 +444,12 @@ impl Executor { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Wake a task by raw pointer.
 | ||||
| /// Wake a task by `TaskRef`.
 | ||||
| ///
 | ||||
| /// You can obtain task pointers from `Waker`s using [`task_from_waker`].
 | ||||
| ///
 | ||||
| /// # Safety
 | ||||
| ///
 | ||||
| /// `task` must be a valid task pointer obtained from [`task_from_waker`].
 | ||||
| pub unsafe fn wake_task(task: NonNull<TaskHeader>) { | ||||
| /// You can obtain a `TaskRef` from a `Waker` using [`task_from_waker`].
 | ||||
| pub fn wake_task(task: TaskRef) { | ||||
|     critical_section::with(|cs| { | ||||
|         let header = task.as_ref(); | ||||
|         let header = task.header(); | ||||
|         let state = header.state.load(Ordering::Relaxed); | ||||
| 
 | ||||
|         // If already scheduled, or if not started,
 | ||||
| @ -438,8 +461,10 @@ pub unsafe fn wake_task(task: NonNull<TaskHeader>) { | ||||
|         header.state.store(state | STATE_RUN_QUEUED, Ordering::Relaxed); | ||||
| 
 | ||||
|         // We have just marked the task as scheduled, so enqueue it.
 | ||||
|         let executor = &*header.executor.get(); | ||||
|         executor.enqueue(cs, task); | ||||
|         unsafe { | ||||
|             let executor = &*header.executor.get(); | ||||
|             executor.enqueue(cs, task); | ||||
|         } | ||||
|     }) | ||||
| } | ||||
| 
 | ||||
| @ -450,7 +475,7 @@ struct TimerQueue; | ||||
| impl embassy_time::queue::TimerQueue for TimerQueue { | ||||
|     fn schedule_wake(&'static self, at: Instant, waker: &core::task::Waker) { | ||||
|         let task = waker::task_from_waker(waker); | ||||
|         let task = unsafe { task.as_ref() }; | ||||
|         let task = task.header(); | ||||
|         let expires_at = task.expires_at.get(); | ||||
|         task.expires_at.set(expires_at.min(at)); | ||||
|     } | ||||
|  | ||||
| @ -4,7 +4,7 @@ use core::ptr::NonNull; | ||||
| use atomic_polyfill::{AtomicPtr, Ordering}; | ||||
| use critical_section::CriticalSection; | ||||
| 
 | ||||
| use super::TaskHeader; | ||||
| use super::{TaskHeader, TaskRef}; | ||||
| 
 | ||||
| pub(crate) struct RunQueueItem { | ||||
|     next: AtomicPtr<TaskHeader>, | ||||
| @ -46,25 +46,26 @@ impl RunQueue { | ||||
|     ///
 | ||||
|     /// `item` must NOT be already enqueued in any queue.
 | ||||
|     #[inline(always)] | ||||
|     pub(crate) unsafe fn enqueue(&self, _cs: CriticalSection, task: NonNull<TaskHeader>) -> bool { | ||||
|     pub(crate) unsafe fn enqueue(&self, _cs: CriticalSection, task: TaskRef) -> bool { | ||||
|         let prev = self.head.load(Ordering::Relaxed); | ||||
|         task.as_ref().run_queue_item.next.store(prev, Ordering::Relaxed); | ||||
|         self.head.store(task.as_ptr(), Ordering::Relaxed); | ||||
|         task.header().run_queue_item.next.store(prev, Ordering::Relaxed); | ||||
|         self.head.store(task.as_ptr() as _, Ordering::Relaxed); | ||||
|         prev.is_null() | ||||
|     } | ||||
| 
 | ||||
|     /// Empty the queue, then call `on_task` for each task that was in the queue.
 | ||||
|     /// NOTE: It is OK for `on_task` to enqueue more tasks. In this case they're left in the queue
 | ||||
|     /// and will be processed by the *next* call to `dequeue_all`, *not* the current one.
 | ||||
|     pub(crate) fn dequeue_all(&self, on_task: impl Fn(NonNull<TaskHeader>)) { | ||||
|     pub(crate) fn dequeue_all(&self, on_task: impl Fn(TaskRef)) { | ||||
|         // Atomically empty the queue.
 | ||||
|         let mut ptr = self.head.swap(ptr::null_mut(), Ordering::AcqRel); | ||||
| 
 | ||||
|         // Iterate the linked list of tasks that were previously in the queue.
 | ||||
|         while let Some(task) = NonNull::new(ptr) { | ||||
|             let task = unsafe { TaskRef::from_ptr(task.as_ptr()) }; | ||||
|             // If the task re-enqueues itself, the `next` pointer will get overwritten.
 | ||||
|             // Therefore, first read the next pointer, and only then process the task.
 | ||||
|             let next = unsafe { task.as_ref() }.run_queue_item.next.load(Ordering::Relaxed); | ||||
|             let next = task.header().run_queue_item.next.load(Ordering::Relaxed); | ||||
| 
 | ||||
|             on_task(task); | ||||
| 
 | ||||
|  | ||||
| @ -1,45 +1,39 @@ | ||||
| use core::cell::Cell; | ||||
| use core::cmp::min; | ||||
| use core::ptr; | ||||
| use core::ptr::NonNull; | ||||
| 
 | ||||
| use atomic_polyfill::Ordering; | ||||
| use embassy_time::Instant; | ||||
| 
 | ||||
| use super::{TaskHeader, STATE_TIMER_QUEUED}; | ||||
| use super::{TaskRef, STATE_TIMER_QUEUED}; | ||||
| 
 | ||||
| pub(crate) struct TimerQueueItem { | ||||
|     next: Cell<*mut TaskHeader>, | ||||
|     next: Cell<Option<TaskRef>>, | ||||
| } | ||||
| 
 | ||||
| impl TimerQueueItem { | ||||
|     pub const fn new() -> Self { | ||||
|         Self { | ||||
|             next: Cell::new(ptr::null_mut()), | ||||
|         } | ||||
|         Self { next: Cell::new(None) } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub(crate) struct TimerQueue { | ||||
|     head: Cell<*mut TaskHeader>, | ||||
|     head: Cell<Option<TaskRef>>, | ||||
| } | ||||
| 
 | ||||
| impl TimerQueue { | ||||
|     pub const fn new() -> Self { | ||||
|         Self { | ||||
|             head: Cell::new(ptr::null_mut()), | ||||
|         } | ||||
|         Self { head: Cell::new(None) } | ||||
|     } | ||||
| 
 | ||||
|     pub(crate) unsafe fn update(&self, p: NonNull<TaskHeader>) { | ||||
|         let task = p.as_ref(); | ||||
|     pub(crate) unsafe fn update(&self, p: TaskRef) { | ||||
|         let task = p.header(); | ||||
|         if task.expires_at.get() != Instant::MAX { | ||||
|             let old_state = task.state.fetch_or(STATE_TIMER_QUEUED, Ordering::AcqRel); | ||||
|             let is_new = old_state & STATE_TIMER_QUEUED == 0; | ||||
| 
 | ||||
|             if is_new { | ||||
|                 task.timer_queue_item.next.set(self.head.get()); | ||||
|                 self.head.set(p.as_ptr()); | ||||
|                 self.head.set(Some(p)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @ -47,7 +41,7 @@ impl TimerQueue { | ||||
|     pub(crate) unsafe fn next_expiration(&self) -> Instant { | ||||
|         let mut res = Instant::MAX; | ||||
|         self.retain(|p| { | ||||
|             let task = p.as_ref(); | ||||
|             let task = p.header(); | ||||
|             let expires = task.expires_at.get(); | ||||
|             res = min(res, expires); | ||||
|             expires != Instant::MAX | ||||
| @ -55,9 +49,9 @@ impl TimerQueue { | ||||
|         res | ||||
|     } | ||||
| 
 | ||||
|     pub(crate) unsafe fn dequeue_expired(&self, now: Instant, on_task: impl Fn(NonNull<TaskHeader>)) { | ||||
|     pub(crate) unsafe fn dequeue_expired(&self, now: Instant, on_task: impl Fn(TaskRef)) { | ||||
|         self.retain(|p| { | ||||
|             let task = p.as_ref(); | ||||
|             let task = p.header(); | ||||
|             if task.expires_at.get() <= now { | ||||
|                 on_task(p); | ||||
|                 false | ||||
| @ -67,11 +61,10 @@ impl TimerQueue { | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     pub(crate) unsafe fn retain(&self, mut f: impl FnMut(NonNull<TaskHeader>) -> bool) { | ||||
|     pub(crate) unsafe fn retain(&self, mut f: impl FnMut(TaskRef) -> bool) { | ||||
|         let mut prev = &self.head; | ||||
|         while !prev.get().is_null() { | ||||
|             let p = NonNull::new_unchecked(prev.get()); | ||||
|             let task = &*p.as_ptr(); | ||||
|         while let Some(p) = prev.get() { | ||||
|             let task = p.header(); | ||||
|             if f(p) { | ||||
|                 // Skip to next
 | ||||
|                 prev = &task.timer_queue_item.next; | ||||
|  | ||||
| @ -1,8 +1,7 @@ | ||||
| use core::mem; | ||||
| use core::ptr::NonNull; | ||||
| use core::task::{RawWaker, RawWakerVTable, Waker}; | ||||
| 
 | ||||
| use super::{wake_task, TaskHeader}; | ||||
| use super::{wake_task, TaskHeader, TaskRef}; | ||||
| 
 | ||||
| const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, drop); | ||||
| 
 | ||||
| @ -11,14 +10,14 @@ unsafe fn clone(p: *const ()) -> RawWaker { | ||||
| } | ||||
| 
 | ||||
| unsafe fn wake(p: *const ()) { | ||||
|     wake_task(NonNull::new_unchecked(p as *mut TaskHeader)) | ||||
|     wake_task(TaskRef::from_ptr(p as *const TaskHeader)) | ||||
| } | ||||
| 
 | ||||
| unsafe fn drop(_: *const ()) { | ||||
|     // nop
 | ||||
| } | ||||
| 
 | ||||
| pub(crate) unsafe fn from_task(p: NonNull<TaskHeader>) -> Waker { | ||||
| pub(crate) unsafe fn from_task(p: TaskRef) -> Waker { | ||||
|     Waker::from_raw(RawWaker::new(p.as_ptr() as _, &VTABLE)) | ||||
| } | ||||
| 
 | ||||
| @ -33,7 +32,7 @@ pub(crate) unsafe fn from_task(p: NonNull<TaskHeader>) -> Waker { | ||||
| /// # Panics
 | ||||
| ///
 | ||||
| /// Panics if the waker is not created by the Embassy executor.
 | ||||
| pub fn task_from_waker(waker: &Waker) -> NonNull<TaskHeader> { | ||||
| pub fn task_from_waker(waker: &Waker) -> TaskRef { | ||||
|     // safety: OK because WakerHack has the same layout as Waker.
 | ||||
|     // This is not really guaranteed because the structs are `repr(Rust)`, it is
 | ||||
|     // indeed the case in the current implementation.
 | ||||
| @ -43,8 +42,8 @@ pub fn task_from_waker(waker: &Waker) -> NonNull<TaskHeader> { | ||||
|         panic!("Found waker not created by the Embassy executor. `embassy_time::Timer` only works with the Embassy executor.") | ||||
|     } | ||||
| 
 | ||||
|     // safety: we never create a waker with a null data pointer.
 | ||||
|     unsafe { NonNull::new_unchecked(hack.data as *mut TaskHeader) } | ||||
|     // safety: our wakers are always created with `TaskRef::as_ptr`
 | ||||
|     unsafe { TaskRef::from_ptr(hack.data as *const TaskHeader) } | ||||
| } | ||||
| 
 | ||||
| struct WakerHack { | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| use core::future::poll_fn; | ||||
| use core::marker::PhantomData; | ||||
| use core::mem; | ||||
| use core::ptr::NonNull; | ||||
| use core::task::Poll; | ||||
| 
 | ||||
| use super::raw; | ||||
| @ -22,12 +21,12 @@ use super::raw; | ||||
| /// Once you've invoked a task function and obtained a SpawnToken, you *must* spawn it.
 | ||||
| #[must_use = "Calling a task function does nothing on its own. You must spawn the returned SpawnToken, typically with Spawner::spawn()"] | ||||
| pub struct SpawnToken<S> { | ||||
|     raw_task: Option<NonNull<raw::TaskHeader>>, | ||||
|     raw_task: Option<raw::TaskRef>, | ||||
|     phantom: PhantomData<*mut S>, | ||||
| } | ||||
| 
 | ||||
| impl<S> SpawnToken<S> { | ||||
|     pub(crate) unsafe fn new(raw_task: NonNull<raw::TaskHeader>) -> Self { | ||||
|     pub(crate) unsafe fn new(raw_task: raw::TaskRef) -> Self { | ||||
|         Self { | ||||
|             raw_task: Some(raw_task), | ||||
|             phantom: PhantomData, | ||||
| @ -92,7 +91,7 @@ impl Spawner { | ||||
|     pub async fn for_current_executor() -> Self { | ||||
|         poll_fn(|cx| unsafe { | ||||
|             let task = raw::task_from_waker(cx.waker()); | ||||
|             let executor = (*task.as_ptr()).executor.get(); | ||||
|             let executor = task.header().executor.get(); | ||||
|             Poll::Ready(Self::new(&*executor)) | ||||
|         }) | ||||
|         .await | ||||
| @ -168,7 +167,7 @@ impl SendSpawner { | ||||
|     pub async fn for_current_executor() -> Self { | ||||
|         poll_fn(|cx| unsafe { | ||||
|             let task = raw::task_from_waker(cx.waker()); | ||||
|             let executor = (*task.as_ptr()).executor.get(); | ||||
|             let executor = task.header().executor.get(); | ||||
|             Poll::Ready(Self::new(&*executor)) | ||||
|         }) | ||||
|         .await | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user