Replace UnsafeCell
Using a new ChannelCell so that there's no leaking of the abstraction
This commit is contained in:
		
							parent
							
								
									1b49acc2f7
								
							
						
					
					
						commit
						ae62948d6c
					
				| @ -51,11 +51,36 @@ use super::CriticalSectionMutex; | |||||||
| use super::Mutex; | use super::Mutex; | ||||||
| use super::ThreadModeMutex; | use super::ThreadModeMutex; | ||||||
| 
 | 
 | ||||||
|  | /// A ChannelCell permits a channel to be shared between senders and their receivers.
 | ||||||
|  | // Derived from UnsafeCell.
 | ||||||
|  | #[repr(transparent)] | ||||||
|  | pub struct ChannelCell<T: ?Sized> { | ||||||
|  |     _value: T, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<T> ChannelCell<T> { | ||||||
|  |     #[inline(always)] | ||||||
|  |     pub const fn new(value: T) -> ChannelCell<T> { | ||||||
|  |         ChannelCell { _value: value } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<T: ?Sized> ChannelCell<T> { | ||||||
|  |     #[inline(always)] | ||||||
|  |     const fn get(&self) -> *mut T { | ||||||
|  |         // As per UnsafeCell:
 | ||||||
|  |         // We can just cast the pointer from `ChannelCell<T>` to `T` because of
 | ||||||
|  |         // #[repr(transparent)]. This exploits libstd's special status, there is
 | ||||||
|  |         // no guarantee for user code that this will work in future versions of the compiler!
 | ||||||
|  |         self as *const ChannelCell<T> as *const T as *mut T | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /// Send values to the associated `Receiver`.
 | /// Send values to the associated `Receiver`.
 | ||||||
| ///
 | ///
 | ||||||
| /// Instances are created by the [`split`](split) function.
 | /// Instances are created by the [`split`](split) function.
 | ||||||
| pub struct Sender<'ch, T> { | pub struct Sender<'ch, T> { | ||||||
|     channel: &'ch UnsafeCell<dyn ChannelLike<T>>, |     channel: &'ch ChannelCell<dyn ChannelLike<T>>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Safe to pass the sender around
 | // Safe to pass the sender around
 | ||||||
| @ -66,7 +91,7 @@ unsafe impl<'ch, T> Sync for Sender<'ch, T> {} | |||||||
| ///
 | ///
 | ||||||
| /// Instances are created by the [`split`](split) function.
 | /// Instances are created by the [`split`](split) function.
 | ||||||
| pub struct Receiver<'ch, T> { | pub struct Receiver<'ch, T> { | ||||||
|     channel: &'ch UnsafeCell<dyn ChannelLike<T>>, |     channel: &'ch ChannelCell<dyn ChannelLike<T>>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Safe to pass the receiver around
 | // Safe to pass the receiver around
 | ||||||
| @ -89,16 +114,15 @@ unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} | |||||||
| /// their channel. The following will therefore fail compilation:
 | /// their channel. The following will therefore fail compilation:
 | ||||||
| ////
 | ////
 | ||||||
| /// ```compile_fail
 | /// ```compile_fail
 | ||||||
| /// use core::cell::UnsafeCell;
 |  | ||||||
| /// use embassy::util::mpsc;
 | /// use embassy::util::mpsc;
 | ||||||
| /// use embassy::util::mpsc::{Channel, WithThreadModeOnly};
 | /// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly};
 | ||||||
| ///
 | ///
 | ||||||
| /// let (sender, receiver) = {
 | /// let (sender, receiver) = {
 | ||||||
| ///    let mut channel = UnsafeCell::new(Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only());
 | ///    let mut channel = ChannelCell::new(Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only());
 | ||||||
| ///     mpsc::split(&channel)
 | ///     mpsc::split(&channel)
 | ||||||
| /// };
 | /// };
 | ||||||
| /// ```
 | /// ```
 | ||||||
| pub fn split<T>(channel: &UnsafeCell<dyn ChannelLike<T>>) -> (Sender<T>, Receiver<T>) { | pub fn split<T>(channel: &ChannelCell<dyn ChannelLike<T>>) -> (Sender<T>, Receiver<T>) { | ||||||
|     let sender = Sender { channel: &channel }; |     let sender = Sender { channel: &channel }; | ||||||
|     let receiver = Receiver { channel: &channel }; |     let receiver = Receiver { channel: &channel }; | ||||||
|     { |     { | ||||||
| @ -439,12 +463,11 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> { | |||||||
|     /// from exception mode e.g. interrupt handlers. To create one:
 |     /// from exception mode e.g. interrupt handlers. To create one:
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// ```
 |     /// ```
 | ||||||
|     /// use core::cell::UnsafeCell;
 |  | ||||||
|     /// use embassy::util::mpsc;
 |     /// use embassy::util::mpsc;
 | ||||||
|     /// use embassy::util::mpsc::{Channel, WithCriticalSections};
 |     /// use embassy::util::mpsc::{Channel, ChannelCell, WithCriticalSections};
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// // Declare a bounded channel of 3 u32s.
 |     /// // Declare a bounded channel of 3 u32s.
 | ||||||
|     /// let mut channel = UnsafeCell::new(mpsc::Channel::<WithCriticalSections, u32, 3>::with_critical_sections());
 |     /// let mut channel = ChannelCell::new(mpsc::Channel::<WithCriticalSections, u32, 3>::with_critical_sections());
 | ||||||
|     /// // once we have a channel, obtain its sender and receiver
 |     /// // once we have a channel, obtain its sender and receiver
 | ||||||
|     /// let (sender, receiver) = mpsc::split(&channel);
 |     /// let (sender, receiver) = mpsc::split(&channel);
 | ||||||
|     /// ```
 |     /// ```
 | ||||||
| @ -464,12 +487,11 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> { | |||||||
|     /// channel avoids all locks. To create one:
 |     /// channel avoids all locks. To create one:
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// ``` no_run
 |     /// ``` no_run
 | ||||||
|     /// use core::cell::UnsafeCell;
 |  | ||||||
|     /// use embassy::util::mpsc;
 |     /// use embassy::util::mpsc;
 | ||||||
|     /// use embassy::util::mpsc::{Channel, WithThreadModeOnly};
 |     /// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly};
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// // Declare a bounded channel of 3 u32s.
 |     /// // Declare a bounded channel of 3 u32s.
 | ||||||
|     /// let mut channel = UnsafeCell::new(Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only());
 |     /// let mut channel = ChannelCell::new(Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only());
 | ||||||
|     /// // once we have a channel, obtain its sender and receiver
 |     /// // once we have a channel, obtain its sender and receiver
 | ||||||
|     /// let (sender, receiver) = mpsc::split(&channel);
 |     /// let (sender, receiver) = mpsc::split(&channel);
 | ||||||
|     /// ```
 |     /// ```
 | ||||||
| @ -744,7 +766,7 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[test] |     #[test] | ||||||
|     fn simple_send_and_receive() { |     fn simple_send_and_receive() { | ||||||
|         let c = UnsafeCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); |         let c = ChannelCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); | ||||||
|         let (s, r) = split(&c); |         let (s, r) = split(&c); | ||||||
|         assert!(s.clone().try_send(1).is_ok()); |         assert!(s.clone().try_send(1).is_ok()); | ||||||
|         assert_eq!(r.try_recv().unwrap(), 1); |         assert_eq!(r.try_recv().unwrap(), 1); | ||||||
| @ -752,7 +774,7 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[test] |     #[test] | ||||||
|     fn should_close_without_sender() { |     fn should_close_without_sender() { | ||||||
|         let c = UnsafeCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); |         let c = ChannelCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); | ||||||
|         let (s, r) = split(&c); |         let (s, r) = split(&c); | ||||||
|         drop(s); |         drop(s); | ||||||
|         match r.try_recv() { |         match r.try_recv() { | ||||||
| @ -763,7 +785,7 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[test] |     #[test] | ||||||
|     fn should_close_once_drained() { |     fn should_close_once_drained() { | ||||||
|         let c = UnsafeCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); |         let c = ChannelCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); | ||||||
|         let (s, r) = split(&c); |         let (s, r) = split(&c); | ||||||
|         assert!(s.try_send(1).is_ok()); |         assert!(s.try_send(1).is_ok()); | ||||||
|         drop(s); |         drop(s); | ||||||
| @ -776,7 +798,7 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[test] |     #[test] | ||||||
|     fn should_reject_send_when_receiver_dropped() { |     fn should_reject_send_when_receiver_dropped() { | ||||||
|         let c = UnsafeCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); |         let c = ChannelCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); | ||||||
|         let (s, r) = split(&c); |         let (s, r) = split(&c); | ||||||
|         drop(r); |         drop(r); | ||||||
|         match s.try_send(1) { |         match s.try_send(1) { | ||||||
| @ -787,7 +809,7 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[test] |     #[test] | ||||||
|     fn should_reject_send_when_channel_closed() { |     fn should_reject_send_when_channel_closed() { | ||||||
|         let c = UnsafeCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); |         let c = ChannelCell::new(Channel::<WithNoThreads, u32, 3>::with_no_threads()); | ||||||
|         let (s, mut r) = split(&c); |         let (s, mut r) = split(&c); | ||||||
|         assert!(s.try_send(1).is_ok()); |         assert!(s.try_send(1).is_ok()); | ||||||
|         r.close(); |         r.close(); | ||||||
| @ -803,8 +825,8 @@ mod tests { | |||||||
|     async fn receiver_closes_when_sender_dropped_async() { |     async fn receiver_closes_when_sender_dropped_async() { | ||||||
|         let executor = ThreadPool::new().unwrap(); |         let executor = ThreadPool::new().unwrap(); | ||||||
| 
 | 
 | ||||||
|         static mut CHANNEL: UnsafeCell<Channel<WithCriticalSections, u32, 3>> = |         static mut CHANNEL: ChannelCell<Channel<WithCriticalSections, u32, 3>> = | ||||||
|             UnsafeCell::new(Channel::with_critical_sections()); |             ChannelCell::new(Channel::with_critical_sections()); | ||||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); |         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||||
|         assert!(executor |         assert!(executor | ||||||
|             .spawn(async move { |             .spawn(async move { | ||||||
| @ -818,8 +840,8 @@ mod tests { | |||||||
|     async fn receiver_receives_given_try_send_async() { |     async fn receiver_receives_given_try_send_async() { | ||||||
|         let executor = ThreadPool::new().unwrap(); |         let executor = ThreadPool::new().unwrap(); | ||||||
| 
 | 
 | ||||||
|         static mut CHANNEL: UnsafeCell<Channel<WithCriticalSections, u32, 3>> = |         static mut CHANNEL: ChannelCell<Channel<WithCriticalSections, u32, 3>> = | ||||||
|             UnsafeCell::new(Channel::with_critical_sections()); |             ChannelCell::new(Channel::with_critical_sections()); | ||||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); |         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||||
|         assert!(executor |         assert!(executor | ||||||
|             .spawn(async move { |             .spawn(async move { | ||||||
| @ -831,8 +853,8 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[futures_test::test] |     #[futures_test::test] | ||||||
|     async fn sender_send_completes_if_capacity() { |     async fn sender_send_completes_if_capacity() { | ||||||
|         static mut CHANNEL: UnsafeCell<Channel<WithCriticalSections, u32, 1>> = |         static mut CHANNEL: ChannelCell<Channel<WithCriticalSections, u32, 1>> = | ||||||
|             UnsafeCell::new(Channel::with_critical_sections()); |             ChannelCell::new(Channel::with_critical_sections()); | ||||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); |         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||||
|         assert!(s.send(1).await.is_ok()); |         assert!(s.send(1).await.is_ok()); | ||||||
|         assert_eq!(r.recv().await, Some(1)); |         assert_eq!(r.recv().await, Some(1)); | ||||||
| @ -840,8 +862,8 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[futures_test::test] |     #[futures_test::test] | ||||||
|     async fn sender_send_completes_if_closed() { |     async fn sender_send_completes_if_closed() { | ||||||
|         static mut CHANNEL: UnsafeCell<Channel<WithCriticalSections, u32, 1>> = |         static mut CHANNEL: ChannelCell<Channel<WithCriticalSections, u32, 1>> = | ||||||
|             UnsafeCell::new(Channel::with_critical_sections()); |             ChannelCell::new(Channel::with_critical_sections()); | ||||||
|         let (s, r) = split(unsafe { &CHANNEL }); |         let (s, r) = split(unsafe { &CHANNEL }); | ||||||
|         drop(r); |         drop(r); | ||||||
|         match s.send(1).await { |         match s.send(1).await { | ||||||
| @ -854,8 +876,8 @@ mod tests { | |||||||
|     async fn senders_sends_wait_until_capacity() { |     async fn senders_sends_wait_until_capacity() { | ||||||
|         let executor = ThreadPool::new().unwrap(); |         let executor = ThreadPool::new().unwrap(); | ||||||
| 
 | 
 | ||||||
|         static mut CHANNEL: UnsafeCell<Channel<WithCriticalSections, u32, 1>> = |         static mut CHANNEL: ChannelCell<Channel<WithCriticalSections, u32, 1>> = | ||||||
|             UnsafeCell::new(Channel::with_critical_sections()); |             ChannelCell::new(Channel::with_critical_sections()); | ||||||
|         let (s0, mut r) = split(unsafe { &CHANNEL }); |         let (s0, mut r) = split(unsafe { &CHANNEL }); | ||||||
|         assert!(s0.try_send(1).is_ok()); |         assert!(s0.try_send(1).is_ok()); | ||||||
|         let s1 = s0.clone(); |         let s1 = s0.clone(); | ||||||
| @ -874,8 +896,8 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[futures_test::test] |     #[futures_test::test] | ||||||
|     async fn sender_close_completes_if_closing() { |     async fn sender_close_completes_if_closing() { | ||||||
|         static mut CHANNEL: UnsafeCell<Channel<WithCriticalSections, u32, 1>> = |         static mut CHANNEL: ChannelCell<Channel<WithCriticalSections, u32, 1>> = | ||||||
|             UnsafeCell::new(Channel::with_critical_sections()); |             ChannelCell::new(Channel::with_critical_sections()); | ||||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); |         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||||
|         r.close(); |         r.close(); | ||||||
|         s.closed().await; |         s.closed().await; | ||||||
| @ -883,8 +905,8 @@ mod tests { | |||||||
| 
 | 
 | ||||||
|     #[futures_test::test] |     #[futures_test::test] | ||||||
|     async fn sender_close_completes_if_closed() { |     async fn sender_close_completes_if_closed() { | ||||||
|         static mut CHANNEL: UnsafeCell<Channel<WithCriticalSections, u32, 1>> = |         static mut CHANNEL: ChannelCell<Channel<WithCriticalSections, u32, 1>> = | ||||||
|             UnsafeCell::new(Channel::with_critical_sections()); |             ChannelCell::new(Channel::with_critical_sections()); | ||||||
|         let (s, r) = split(unsafe { &CHANNEL }); |         let (s, r) = split(unsafe { &CHANNEL }); | ||||||
|         drop(r); |         drop(r); | ||||||
|         s.closed().await; |         s.closed().await; | ||||||
|  | |||||||
| @ -8,12 +8,10 @@ | |||||||
| #[path = "../example_common.rs"] | #[path = "../example_common.rs"] | ||||||
| mod example_common; | mod example_common; | ||||||
| 
 | 
 | ||||||
| use core::cell::UnsafeCell; |  | ||||||
| 
 |  | ||||||
| use defmt::panic; | use defmt::panic; | ||||||
| use embassy::executor::Spawner; | use embassy::executor::Spawner; | ||||||
| use embassy::time::{Duration, Timer}; | use embassy::time::{Duration, Timer}; | ||||||
| use embassy::util::mpsc::TryRecvError; | use embassy::util::mpsc::{ChannelCell, TryRecvError}; | ||||||
| use embassy::util::{mpsc, Forever}; | use embassy::util::{mpsc, Forever}; | ||||||
| use embassy_nrf::gpio::{Level, Output, OutputDrive}; | use embassy_nrf::gpio::{Level, Output, OutputDrive}; | ||||||
| use embassy_nrf::Peripherals; | use embassy_nrf::Peripherals; | ||||||
| @ -25,7 +23,7 @@ enum LedState { | |||||||
|     Off, |     Off, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static CHANNEL: Forever<UnsafeCell<Channel<WithThreadModeOnly, LedState, 1>>> = Forever::new(); | static CHANNEL: Forever<ChannelCell<Channel<WithThreadModeOnly, LedState, 1>>> = Forever::new(); | ||||||
| 
 | 
 | ||||||
| #[embassy::task(pool_size = 1)] | #[embassy::task(pool_size = 1)] | ||||||
| async fn my_task(sender: Sender<'static, LedState>) { | async fn my_task(sender: Sender<'static, LedState>) { | ||||||
| @ -41,7 +39,7 @@ async fn my_task(sender: Sender<'static, LedState>) { | |||||||
| async fn main(spawner: Spawner, p: Peripherals) { | async fn main(spawner: Spawner, p: Peripherals) { | ||||||
|     let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); |     let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); | ||||||
| 
 | 
 | ||||||
|     let channel = CHANNEL.put(UnsafeCell::new(Channel::with_thread_mode_only())); |     let channel = CHANNEL.put(ChannelCell::new(Channel::with_thread_mode_only())); | ||||||
|     let (sender, mut receiver) = mpsc::split(channel); |     let (sender, mut receiver) = mpsc::split(channel); | ||||||
| 
 | 
 | ||||||
|     spawner.spawn(my_task(sender)).unwrap(); |     spawner.spawn(my_task(sender)).unwrap(); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user