Prevent task from respawning while in the timer queue
This commit is contained in:
		
							parent
							
								
									d45ea43892
								
							
						
					
					
						commit
						ec96395d08
					
				| @ -50,7 +50,7 @@ pub(crate) struct TaskHeader { | ||||
| } | ||||
| 
 | ||||
| /// This is essentially a `&'static TaskStorage<F>` where the type of the future has been erased.
 | ||||
| #[derive(Clone, Copy)] | ||||
| #[derive(Clone, Copy, PartialEq)] | ||||
| pub struct TaskRef { | ||||
|     ptr: NonNull<TaskHeader>, | ||||
| } | ||||
| @ -72,6 +72,16 @@ impl TaskRef { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// # Safety
 | ||||
|     ///
 | ||||
|     /// The result of this function must only be compared
 | ||||
|     /// for equality, or stored, but not used.
 | ||||
|     pub const unsafe fn dangling() -> Self { | ||||
|         Self { | ||||
|             ptr: NonNull::dangling(), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub(crate) fn header(self) -> &'static TaskHeader { | ||||
|         unsafe { self.ptr.as_ref() } | ||||
|     } | ||||
| @ -88,6 +98,30 @@ impl TaskRef { | ||||
|         &self.header().timer_queue_item | ||||
|     } | ||||
| 
 | ||||
|     /// Mark the task as timer-queued. Return whether it was newly queued (i.e. not queued before)
 | ||||
|     ///
 | ||||
|     /// Entering this state prevents the task from being respawned while in a timer queue.
 | ||||
|     ///
 | ||||
|     /// Safety:
 | ||||
|     ///
 | ||||
|     /// This functions should only be called by the timer queue implementation, before
 | ||||
|     /// enqueueing the timer item.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     pub unsafe fn timer_enqueue(&self) -> timer_queue::TimerEnqueueOperation { | ||||
|         self.header().state.timer_enqueue() | ||||
|     } | ||||
| 
 | ||||
|     /// Unmark the task as timer-queued.
 | ||||
|     ///
 | ||||
|     /// Safety:
 | ||||
|     ///
 | ||||
|     /// This functions should only be called by the timer queue implementation, after the task has
 | ||||
|     /// been removed from the timer queue.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     pub unsafe fn timer_dequeue(&self) { | ||||
|         self.header().state.timer_dequeue() | ||||
|     } | ||||
| 
 | ||||
|     /// The returned pointer is valid for the entire TaskStorage.
 | ||||
|     pub(crate) fn as_ptr(self) -> *const TaskHeader { | ||||
|         self.ptr.as_ptr() | ||||
|  | ||||
| @ -1,9 +1,15 @@ | ||||
| use core::sync::atomic::{AtomicU32, Ordering}; | ||||
| 
 | ||||
| #[cfg(feature = "integrated-timers")] | ||||
| use super::timer_queue::TimerEnqueueOperation; | ||||
| 
 | ||||
| /// Task is spawned (has a future)
 | ||||
| pub(crate) const STATE_SPAWNED: u32 = 1 << 0; | ||||
| /// Task is in the executor run queue
 | ||||
| pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1; | ||||
| /// Task is in the executor timer queue
 | ||||
| #[cfg(feature = "integrated-timers")] | ||||
| pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2; | ||||
| 
 | ||||
| pub(crate) struct State { | ||||
|     state: AtomicU32, | ||||
| @ -52,4 +58,34 @@ impl State { | ||||
|         let state = self.state.fetch_and(!STATE_RUN_QUEUED, Ordering::AcqRel); | ||||
|         state & STATE_SPAWNED != 0 | ||||
|     } | ||||
| 
 | ||||
|     /// Mark the task as timer-queued. Return whether it can be enqueued.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     #[inline(always)] | ||||
|     pub fn timer_enqueue(&self) -> TimerEnqueueOperation { | ||||
|         if self | ||||
|             .state | ||||
|             .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| { | ||||
|                 // If not started, ignore it
 | ||||
|                 if state & STATE_SPAWNED == 0 { | ||||
|                     None | ||||
|                 } else { | ||||
|                     // Mark it as enqueued
 | ||||
|                     Some(state | STATE_TIMER_QUEUED) | ||||
|                 } | ||||
|             }) | ||||
|             .is_ok() | ||||
|         { | ||||
|             TimerEnqueueOperation::Enqueue | ||||
|         } else { | ||||
|             TimerEnqueueOperation::Ignore | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Unmark the task as timer-queued.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     #[inline(always)] | ||||
|     pub fn timer_dequeue(&self) { | ||||
|         self.state.fetch_and(!STATE_TIMER_QUEUED, Ordering::Relaxed); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -1,9 +1,14 @@ | ||||
| use core::arch::asm; | ||||
| use core::sync::atomic::{compiler_fence, AtomicBool, AtomicU32, Ordering}; | ||||
| 
 | ||||
| #[cfg(feature = "integrated-timers")] | ||||
| use super::timer_queue::TimerEnqueueOperation; | ||||
| 
 | ||||
| // Must be kept in sync with the layout of `State`!
 | ||||
| pub(crate) const STATE_SPAWNED: u32 = 1 << 0; | ||||
| pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 8; | ||||
| #[cfg(feature = "integrated-timers")] | ||||
| pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 16; | ||||
| 
 | ||||
| #[repr(C, align(4))] | ||||
| pub(crate) struct State { | ||||
| @ -11,8 +16,9 @@ pub(crate) struct State { | ||||
|     spawned: AtomicBool, | ||||
|     /// Task is in the executor run queue
 | ||||
|     run_queued: AtomicBool, | ||||
|     /// Task is in the executor timer queue
 | ||||
|     timer_queued: AtomicBool, | ||||
|     pad: AtomicBool, | ||||
|     pad2: AtomicBool, | ||||
| } | ||||
| 
 | ||||
| impl State { | ||||
| @ -20,8 +26,8 @@ impl State { | ||||
|         Self { | ||||
|             spawned: AtomicBool::new(false), | ||||
|             run_queued: AtomicBool::new(false), | ||||
|             timer_queued: AtomicBool::new(false), | ||||
|             pad: AtomicBool::new(false), | ||||
|             pad2: AtomicBool::new(false), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| @ -85,4 +91,34 @@ impl State { | ||||
|         self.run_queued.store(false, Ordering::Relaxed); | ||||
|         r | ||||
|     } | ||||
| 
 | ||||
|     /// Mark the task as timer-queued. Return whether it can be enqueued.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     #[inline(always)] | ||||
|     pub fn timer_enqueue(&self) -> TimerEnqueueOperation { | ||||
|         if self | ||||
|             .as_u32() | ||||
|             .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| { | ||||
|                 // If not started, ignore it
 | ||||
|                 if state & STATE_SPAWNED == 0 { | ||||
|                     None | ||||
|                 } else { | ||||
|                     // Mark it as enqueued
 | ||||
|                     Some(state | STATE_TIMER_QUEUED) | ||||
|                 } | ||||
|             }) | ||||
|             .is_ok() | ||||
|         { | ||||
|             TimerEnqueueOperation::Enqueue | ||||
|         } else { | ||||
|             TimerEnqueueOperation::Ignore | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Unmark the task as timer-queued.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     #[inline(always)] | ||||
|     pub fn timer_dequeue(&self) { | ||||
|         self.timer_queued.store(false, Ordering::Relaxed); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -2,10 +2,16 @@ use core::cell::Cell; | ||||
| 
 | ||||
| use critical_section::Mutex; | ||||
| 
 | ||||
| #[cfg(feature = "integrated-timers")] | ||||
| use super::timer_queue::TimerEnqueueOperation; | ||||
| 
 | ||||
| /// Task is spawned (has a future)
 | ||||
| pub(crate) const STATE_SPAWNED: u32 = 1 << 0; | ||||
| /// Task is in the executor run queue
 | ||||
| pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1; | ||||
| /// Task is in the executor timer queue
 | ||||
| #[cfg(feature = "integrated-timers")] | ||||
| pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2; | ||||
| 
 | ||||
| pub(crate) struct State { | ||||
|     state: Mutex<Cell<u32>>, | ||||
| @ -69,4 +75,27 @@ impl State { | ||||
|             ok | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     /// Mark the task as timer-queued. Return whether it can be enqueued.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     #[inline(always)] | ||||
|     pub fn timer_enqueue(&self) -> TimerEnqueueOperation { | ||||
|         self.update(|s| { | ||||
|             // FIXME: we need to split SPAWNED into two phases, to prevent enqueueing a task that is
 | ||||
|             // just being spawned, because its executor pointer may still be changing.
 | ||||
|             if *s & STATE_SPAWNED == STATE_SPAWNED { | ||||
|                 *s |= STATE_TIMER_QUEUED; | ||||
|                 TimerEnqueueOperation::Enqueue | ||||
|             } else { | ||||
|                 TimerEnqueueOperation::Ignore | ||||
|             } | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     /// Unmark the task as timer-queued.
 | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     #[inline(always)] | ||||
|     pub fn timer_dequeue(&self) { | ||||
|         self.update(|s| *s &= !STATE_TIMER_QUEUED); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -7,6 +7,9 @@ use super::TaskRef; | ||||
| /// An item in the timer queue.
 | ||||
| pub struct TimerQueueItem { | ||||
|     /// The next item in the queue.
 | ||||
|     ///
 | ||||
|     /// If this field contains `Some`, the item is in the queue. The last item in the queue has a
 | ||||
|     /// value of `Some(dangling_pointer)`
 | ||||
|     pub next: Cell<Option<TaskRef>>, | ||||
| 
 | ||||
|     /// The time at which this item expires.
 | ||||
| @ -19,7 +22,17 @@ impl TimerQueueItem { | ||||
|     pub(crate) const fn new() -> Self { | ||||
|         Self { | ||||
|             next: Cell::new(None), | ||||
|             expires_at: Cell::new(0), | ||||
|             expires_at: Cell::new(u64::MAX), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The operation to perform after `timer_enqueue` is called.
 | ||||
| #[derive(Debug, Copy, Clone, PartialEq)] | ||||
| #[cfg_attr(feature = "defmt", derive(defmt::Format))] | ||||
| pub enum TimerEnqueueOperation { | ||||
|     /// Enqueue the task.
 | ||||
|     Enqueue, | ||||
|     /// Update the task's expiration time.
 | ||||
|     Ignore, | ||||
| } | ||||
|  | ||||
| @ -73,6 +73,20 @@ extern "Rust" { | ||||
| 
 | ||||
| /// Schedule the given waker to be woken at `at`.
 | ||||
| pub fn schedule_wake(at: u64, waker: &Waker) { | ||||
|     #[cfg(feature = "integrated-timers")] | ||||
|     { | ||||
|         use embassy_executor::raw::task_from_waker; | ||||
|         use embassy_executor::raw::timer_queue::TimerEnqueueOperation; | ||||
|         // The very first thing we must do, before we even access the timer queue, is to
 | ||||
|         // mark the task a TIMER_QUEUED. This ensures that the task that is being scheduled
 | ||||
|         // can not be respawn while we are accessing the timer queue.
 | ||||
|         let task = task_from_waker(waker); | ||||
|         if unsafe { task.timer_enqueue() } == TimerEnqueueOperation::Ignore { | ||||
|             // We are not allowed to enqueue the task in the timer queue. This is because the
 | ||||
|             // task is not spawned, and so it makes no sense to schedule it.
 | ||||
|             return; | ||||
|         } | ||||
|     } | ||||
|     unsafe { _embassy_time_schedule_wake(at, waker) } | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -24,16 +24,21 @@ impl TimerQueue { | ||||
|         if item.next.get().is_none() { | ||||
|             // If not in the queue, add it and update.
 | ||||
|             let prev = self.head.replace(Some(p)); | ||||
|             item.next.set(prev); | ||||
|             item.next.set(if prev.is_none() { | ||||
|                 Some(unsafe { TaskRef::dangling() }) | ||||
|             } else { | ||||
|                 prev | ||||
|             }); | ||||
|             item.expires_at.set(at); | ||||
|             true | ||||
|         } else if at <= item.expires_at.get() { | ||||
|             // If expiration is sooner than previously set, update.
 | ||||
|             item.expires_at.set(at); | ||||
|             true | ||||
|         } else { | ||||
|             // Task does not need to be updated.
 | ||||
|             return false; | ||||
|             false | ||||
|         } | ||||
| 
 | ||||
|         item.expires_at.set(at); | ||||
|         true | ||||
|     } | ||||
| 
 | ||||
|     /// Dequeues expired timers and returns the next alarm time.
 | ||||
| @ -64,6 +69,10 @@ impl TimerQueue { | ||||
|     fn retain(&self, mut f: impl FnMut(TaskRef) -> bool) { | ||||
|         let mut prev = &self.head; | ||||
|         while let Some(p) = prev.get() { | ||||
|             if unsafe { p == TaskRef::dangling() } { | ||||
|                 // prev was the last item, stop
 | ||||
|                 break; | ||||
|             } | ||||
|             let item = p.timer_queue_item(); | ||||
|             if f(p) { | ||||
|                 // Skip to next
 | ||||
| @ -72,6 +81,7 @@ impl TimerQueue { | ||||
|                 // Remove it
 | ||||
|                 prev.set(item.next.get()); | ||||
|                 item.next.set(None); | ||||
|                 unsafe { p.timer_dequeue() }; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user