extras: Fix UB in Peripheral
				
					
				
			`Peripheral` assumed that interrupts can't be preempted, when they can be preempted by higher priority interrupts. So I put the interrupt handler inside a critical section, and also added checks for whether the state had been dropped before the critical section was entered. I also added a `'static` bound to `PeripheralState`, since `Pin` only guarantees that the memory it directly references will not be invalidated. It doesn't guarantee that memory its pointee references also won't be invalidated. There were already some implementations of `PeripheralState` that weren't `'static`, though, so I added an unsafe `PeripheralStateUnchecked` trait and forwarded the `unsafe` to the constructors of the implementors.
This commit is contained in:
		
							parent
							
								
									ed83b93b6d
								
							
						
					
					
						commit
						744e2cbb8a
					
				| @ -17,4 +17,5 @@ embassy = { version = "0.1.0", path = "../embassy" } | |||||||
| defmt = { version = "0.2.0", optional = true } | defmt = { version = "0.2.0", optional = true } | ||||||
| log = { version = "0.4.11", optional = true } | log = { version = "0.4.11", optional = true } | ||||||
| cortex-m = "0.7.1" | cortex-m = "0.7.1" | ||||||
|  | critical-section = "0.2.1" | ||||||
| usb-device = "0.2.7" | usb-device = "0.2.7" | ||||||
|  | |||||||
| @ -1,15 +1,38 @@ | |||||||
| use core::cell::UnsafeCell; | use core::cell::UnsafeCell; | ||||||
| use core::marker::{PhantomData, PhantomPinned}; | use core::marker::{PhantomData, PhantomPinned}; | ||||||
| use core::pin::Pin; | use core::pin::Pin; | ||||||
|  | use core::ptr; | ||||||
| 
 | 
 | ||||||
| use embassy::interrupt::{Interrupt, InterruptExt}; | use embassy::interrupt::{Interrupt, InterruptExt}; | ||||||
| 
 | 
 | ||||||
| pub trait PeripheralState { | /// # Safety
 | ||||||
|  | /// When types implementing this trait are used with `Peripheral` or `PeripheralMutex`,
 | ||||||
|  | /// their lifetime must not end without first calling `Drop` on the `Peripheral` or `PeripheralMutex`.
 | ||||||
|  | pub unsafe trait PeripheralStateUnchecked { | ||||||
|     type Interrupt: Interrupt; |     type Interrupt: Interrupt; | ||||||
|     fn on_interrupt(&mut self); |     fn on_interrupt(&mut self); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub struct PeripheralMutex<S: PeripheralState> { | // `PeripheralMutex` is safe because `Pin` guarantees that the memory it references will not be invalidated or reused
 | ||||||
|  | // without calling `Drop`. However, it provides no guarantees about references contained within the state still being valid,
 | ||||||
|  | // so this `'static` bound is necessary.
 | ||||||
|  | pub trait PeripheralState: 'static { | ||||||
|  |     type Interrupt: Interrupt; | ||||||
|  |     fn on_interrupt(&mut self); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SAFETY: `T` has to live for `'static` to implement `PeripheralState`, thus its lifetime cannot end.
 | ||||||
|  | unsafe impl<T> PeripheralStateUnchecked for T | ||||||
|  | where | ||||||
|  |     T: PeripheralState, | ||||||
|  | { | ||||||
|  |     type Interrupt = T::Interrupt; | ||||||
|  |     fn on_interrupt(&mut self) { | ||||||
|  |         self.on_interrupt() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct PeripheralMutex<S: PeripheralStateUnchecked> { | ||||||
|     state: UnsafeCell<S>, |     state: UnsafeCell<S>, | ||||||
| 
 | 
 | ||||||
|     irq_setup_done: bool, |     irq_setup_done: bool, | ||||||
| @ -19,7 +42,7 @@ pub struct PeripheralMutex<S: PeripheralState> { | |||||||
|     _pinned: PhantomPinned, |     _pinned: PhantomPinned, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<S: PeripheralState> PeripheralMutex<S> { | impl<S: PeripheralStateUnchecked> PeripheralMutex<S> { | ||||||
|     pub fn new(state: S, irq: S::Interrupt) -> Self { |     pub fn new(state: S, irq: S::Interrupt) -> Self { | ||||||
|         Self { |         Self { | ||||||
|             irq, |             irq, | ||||||
| @ -39,11 +62,17 @@ impl<S: PeripheralState> PeripheralMutex<S> { | |||||||
| 
 | 
 | ||||||
|         this.irq.disable(); |         this.irq.disable(); | ||||||
|         this.irq.set_handler(|p| { |         this.irq.set_handler(|p| { | ||||||
|  |             critical_section::with(|_| { | ||||||
|  |                 if p.is_null() { | ||||||
|  |                     // The state was dropped, so we can't operate on it.
 | ||||||
|  |                     return; | ||||||
|  |                 } | ||||||
|                 // Safety: it's OK to get a &mut to the state, since
 |                 // Safety: it's OK to get a &mut to the state, since
 | ||||||
|             // - We're in the IRQ, no one else can't preempt us
 |                 // - We're in a critical section, no one can preempt us (and call with())
 | ||||||
|                 // - We can't have preempted a with() call because the irq is disabled during it.
 |                 // - We can't have preempted a with() call because the irq is disabled during it.
 | ||||||
|                 let state = unsafe { &mut *(p as *mut S) }; |                 let state = unsafe { &mut *(p as *mut S) }; | ||||||
|                 state.on_interrupt(); |                 state.on_interrupt(); | ||||||
|  |             }) | ||||||
|         }); |         }); | ||||||
|         this.irq |         this.irq | ||||||
|             .set_handler_context((&mut this.state) as *mut _ as *mut ()); |             .set_handler_context((&mut this.state) as *mut _ as *mut ()); | ||||||
| @ -67,9 +96,12 @@ impl<S: PeripheralState> PeripheralMutex<S> { | |||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<S: PeripheralState> Drop for PeripheralMutex<S> { | impl<S: PeripheralStateUnchecked> Drop for PeripheralMutex<S> { | ||||||
|     fn drop(&mut self) { |     fn drop(&mut self) { | ||||||
|         self.irq.disable(); |         self.irq.disable(); | ||||||
|         self.irq.remove_handler(); |         self.irq.remove_handler(); | ||||||
|  |         // Set the context to null so that the interrupt will know we're dropped
 | ||||||
|  |         // if we pre-empted it before it entered a critical section.
 | ||||||
|  |         self.irq.set_handler_context(ptr::null_mut()); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,16 +1,27 @@ | |||||||
| use core::cell::UnsafeCell; |  | ||||||
| use core::marker::{PhantomData, PhantomPinned}; | use core::marker::{PhantomData, PhantomPinned}; | ||||||
| use core::pin::Pin; | use core::pin::Pin; | ||||||
|  | use core::ptr; | ||||||
| 
 | 
 | ||||||
| use embassy::interrupt::{Interrupt, InterruptExt}; | use embassy::interrupt::{Interrupt, InterruptExt}; | ||||||
| 
 | 
 | ||||||
| pub trait PeripheralState { | /// # Safety
 | ||||||
|  | /// When types implementing this trait are used with `Peripheral` or `PeripheralMutex`,
 | ||||||
|  | /// their lifetime must not end without first calling `Drop` on the `Peripheral` or `PeripheralMutex`.
 | ||||||
|  | pub unsafe trait PeripheralStateUnchecked { | ||||||
|     type Interrupt: Interrupt; |     type Interrupt: Interrupt; | ||||||
|     fn on_interrupt(&self); |     fn on_interrupt(&self); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub struct Peripheral<S: PeripheralState> { | // `Peripheral` is safe because `Pin` guarantees that the memory it references will not be invalidated or reused
 | ||||||
|     state: UnsafeCell<S>, | // without calling `Drop`. However, it provides no guarantees about references contained within the state still being valid,
 | ||||||
|  | // so this `'static` bound is necessary.
 | ||||||
|  | pub trait PeripheralState: 'static { | ||||||
|  |     type Interrupt: Interrupt; | ||||||
|  |     fn on_interrupt(&self); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct Peripheral<S: PeripheralStateUnchecked> { | ||||||
|  |     state: S, | ||||||
| 
 | 
 | ||||||
|     irq_setup_done: bool, |     irq_setup_done: bool, | ||||||
|     irq: S::Interrupt, |     irq: S::Interrupt, | ||||||
| @ -19,13 +30,13 @@ pub struct Peripheral<S: PeripheralState> { | |||||||
|     _pinned: PhantomPinned, |     _pinned: PhantomPinned, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<S: PeripheralState> Peripheral<S> { | impl<S: PeripheralStateUnchecked> Peripheral<S> { | ||||||
|     pub fn new(irq: S::Interrupt, state: S) -> Self { |     pub fn new(irq: S::Interrupt, state: S) -> Self { | ||||||
|         Self { |         Self { | ||||||
|             irq, |             irq, | ||||||
|             irq_setup_done: false, |             irq_setup_done: false, | ||||||
| 
 | 
 | ||||||
|             state: UnsafeCell::new(state), |             state, | ||||||
|             _not_send: PhantomData, |             _not_send: PhantomData, | ||||||
|             _pinned: PhantomPinned, |             _pinned: PhantomPinned, | ||||||
|         } |         } | ||||||
| @ -39,9 +50,17 @@ impl<S: PeripheralState> Peripheral<S> { | |||||||
| 
 | 
 | ||||||
|         this.irq.disable(); |         this.irq.disable(); | ||||||
|         this.irq.set_handler(|p| { |         this.irq.set_handler(|p| { | ||||||
|  |             // We need to be in a critical section so that no one can preempt us
 | ||||||
|  |             // and drop the state after we check whether `p.is_null()`.
 | ||||||
|  |             critical_section::with(|_| { | ||||||
|  |                 if p.is_null() { | ||||||
|  |                     // The state was dropped, so we can't operate on it.
 | ||||||
|  |                     return; | ||||||
|  |                 } | ||||||
|                 let state = unsafe { &*(p as *const S) }; |                 let state = unsafe { &*(p as *const S) }; | ||||||
|                 state.on_interrupt(); |                 state.on_interrupt(); | ||||||
|             }); |             }); | ||||||
|  |         }); | ||||||
|         this.irq |         this.irq | ||||||
|             .set_handler_context((&this.state) as *const _ as *mut ()); |             .set_handler_context((&this.state) as *const _ as *mut ()); | ||||||
|         this.irq.enable(); |         this.irq.enable(); | ||||||
| @ -49,15 +68,17 @@ impl<S: PeripheralState> Peripheral<S> { | |||||||
|         this.irq_setup_done = true; |         this.irq_setup_done = true; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn state(self: Pin<&mut Self>) -> &S { |     pub fn state<'a>(self: Pin<&'a mut Self>) -> &'a S { | ||||||
|         let this = unsafe { self.get_unchecked_mut() }; |         &self.into_ref().get_ref().state | ||||||
|         unsafe { &*this.state.get() } |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<S: PeripheralState> Drop for Peripheral<S> { | impl<S: PeripheralStateUnchecked> Drop for Peripheral<S> { | ||||||
|     fn drop(&mut self) { |     fn drop(&mut self) { | ||||||
|         self.irq.disable(); |         self.irq.disable(); | ||||||
|         self.irq.remove_handler(); |         self.irq.remove_handler(); | ||||||
|  |         // Set the context to null so that the interrupt will know we're dropped
 | ||||||
|  |         // if we pre-empted it before it entered a critical section.
 | ||||||
|  |         self.irq.set_handler_context(ptr::null_mut()); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ use usb_device::device::UsbDevice; | |||||||
| mod cdc_acm; | mod cdc_acm; | ||||||
| pub mod usb_serial; | pub mod usb_serial; | ||||||
| 
 | 
 | ||||||
| use crate::peripheral::{PeripheralMutex, PeripheralState}; | use crate::peripheral::{PeripheralMutex, PeripheralStateUnchecked}; | ||||||
| use embassy::interrupt::Interrupt; | use embassy::interrupt::Interrupt; | ||||||
| use usb_serial::{ReadInterface, UsbSerial, WriteInterface}; | use usb_serial::{ReadInterface, UsbSerial, WriteInterface}; | ||||||
| 
 | 
 | ||||||
| @ -55,10 +55,12 @@ where | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn start(self: Pin<&mut Self>) { |     /// # Safety
 | ||||||
|         let this = unsafe { self.get_unchecked_mut() }; |     /// The `UsbDevice` passed to `Self::new` must not be dropped without calling `Drop` on this `Usb` first.
 | ||||||
|  |     pub unsafe fn start(self: Pin<&mut Self>) { | ||||||
|  |         let this = self.get_unchecked_mut(); | ||||||
|         let mut mutex = this.inner.borrow_mut(); |         let mut mutex = this.inner.borrow_mut(); | ||||||
|         let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; |         let mutex = Pin::new_unchecked(&mut *mutex); | ||||||
| 
 | 
 | ||||||
|         // Use inner to register the irq
 |         // Use inner to register the irq
 | ||||||
|         mutex.register_interrupt(); |         mutex.register_interrupt(); | ||||||
| @ -125,7 +127,8 @@ where | |||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<'bus, B, T, I> PeripheralState for State<'bus, B, T, I> | // SAFETY: The safety contract of `PeripheralStateUnchecked` is forwarded to `Usb::start`.
 | ||||||
|  | unsafe impl<'bus, B, T, I> PeripheralStateUnchecked for State<'bus, B, T, I> | ||||||
| where | where | ||||||
|     B: UsbBus, |     B: UsbBus, | ||||||
|     T: ClassSet<B>, |     T: ClassSet<B>, | ||||||
|  | |||||||
| @ -7,7 +7,7 @@ use core::task::{Context, Poll}; | |||||||
| use embassy::interrupt::InterruptExt; | use embassy::interrupt::InterruptExt; | ||||||
| use embassy::io::{AsyncBufRead, AsyncWrite, Result}; | use embassy::io::{AsyncBufRead, AsyncWrite, Result}; | ||||||
| use embassy::util::{Unborrow, WakerRegistration}; | use embassy::util::{Unborrow, WakerRegistration}; | ||||||
| use embassy_extras::peripheral::{PeripheralMutex, PeripheralState}; | use embassy_extras::peripheral::{PeripheralMutex, PeripheralStateUnchecked}; | ||||||
| use embassy_extras::ring_buffer::RingBuffer; | use embassy_extras::ring_buffer::RingBuffer; | ||||||
| use embassy_extras::{low_power_wait_until, unborrow}; | use embassy_extras::{low_power_wait_until, unborrow}; | ||||||
| 
 | 
 | ||||||
| @ -283,7 +283,8 @@ impl<'a, U: UarteInstance, T: TimerInstance> Drop for State<'a, U, T> { | |||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<'a, U: UarteInstance, T: TimerInstance> PeripheralState for State<'a, U, T> { | // SAFETY: the safety contract of `PeripheralStateUnchecked` is forwarded to `BufferedUarte::new`.
 | ||||||
|  | unsafe impl<'a, U: UarteInstance, T: TimerInstance> PeripheralStateUnchecked for State<'a, U, T> { | ||||||
|     type Interrupt = U::Interrupt; |     type Interrupt = U::Interrupt; | ||||||
|     fn on_interrupt(&mut self) { |     fn on_interrupt(&mut self) { | ||||||
|         trace!("irq: start"); |         trace!("irq: start"); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user