diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index d24eb00b..d8a010d7 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -39,7 +39,6 @@ use core::cell::UnsafeCell; use core::fmt; -use core::marker::PhantomData; use core::mem::MaybeUninit; use core::pin::Pin; use core::task::Context; @@ -55,32 +54,24 @@ use super::ThreadModeMutex; /// Send values to the associated `Receiver`. /// /// Instances are created by the [`split`](split) function. -pub struct Sender<'ch, M, T, const N: usize> -where - M: Mutex, -{ - channel: *mut Channel, - phantom_data: &'ch PhantomData, +pub struct Sender<'ch, T> { + channel: &'ch UnsafeCell>, } // Safe to pass the sender around -unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex {} -unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, T> Send for Sender<'ch, T> {} +unsafe impl<'ch, T> Sync for Sender<'ch, T> {} /// Receive values from the associated `Sender`. /// /// Instances are created by the [`split`](split) function. -pub struct Receiver<'ch, M, T, const N: usize> -where - M: Mutex, -{ - channel: *mut Channel, - _phantom_data: &'ch PhantomData, +pub struct Receiver<'ch, T> { + channel: &'ch UnsafeCell>, } // Safe to pass the receiver around -unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where M: Mutex {} -unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, T> Send for Receiver<'ch, T> {} +unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} /// Splits a bounded mpsc channel into a `Sender` and `Receiver`. /// @@ -98,37 +89,29 @@ unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: /// their channel. The following will therefore fail compilation: //// /// ```compile_fail +/// use core::cell::UnsafeCell; /// use embassy::util::mpsc; /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; /// /// let (sender, receiver) = { -/// let mut channel = Channel::::with_thread_mode_only(); -/// mpsc::split(&mut channel) +/// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); +/// mpsc::split(&channel) /// }; /// ``` -pub fn split<'ch, M, T, const N: usize>( - channel: &'ch mut Channel, -) -> (Sender<'ch, M, T, N>, Receiver<'ch, M, T, N>) -where - M: Mutex, -{ - let sender = Sender { - channel, - phantom_data: &PhantomData, - }; - let receiver = Receiver { - channel, - _phantom_data: &PhantomData, - }; - channel.register_receiver(); - channel.register_sender(); +pub fn split<'ch, T>( + channel: &'ch UnsafeCell>, +) -> (Sender<'ch, T>, Receiver<'ch, T>) { + let sender = Sender { channel: &channel }; + let receiver = Receiver { channel: &channel }; + { + let c = unsafe { &mut *channel.get() }; + c.register_receiver(); + c.register_sender(); + } (sender, receiver) } -impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Receiver<'ch, T> { /// Receives the next value for this receiver. /// /// This method returns `None` if the channel has been closed and there are @@ -154,7 +137,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - unsafe { self.channel.as_mut().unwrap().try_recv() } + unsafe { &mut *self.channel.get() }.try_recv() } /// Closes the receiving half of a channel without dropping it. @@ -168,14 +151,11 @@ where /// until those are released. /// pub fn close(&mut self) { - unsafe { self.channel.as_mut().unwrap().close() } + unsafe { &mut *self.channel.get() }.close() } } -impl<'ch, M, T, const N: usize> Future for Receiver<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Future for Receiver<'ch, T> { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -183,31 +163,20 @@ where Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => { - unsafe { - self.channel - .as_mut() - .unwrap() - .set_receiver_waker(cx.waker().clone()); - }; + unsafe { &mut *self.channel.get() }.set_receiver_waker(cx.waker().clone()); Poll::Pending } } } } -impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Drop for Receiver<'ch, T> { fn drop(&mut self) { - unsafe { self.channel.as_mut().unwrap().deregister_receiver() } + unsafe { &mut *self.channel.get() }.deregister_receiver() } } -impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Sender<'ch, T> { /// Sends a value, waiting until there is capacity. /// /// A successful send occurs when it is determined that the other end of the @@ -255,7 +224,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - unsafe { self.channel.as_mut().unwrap().try_send(message) } + unsafe { &mut *self.channel.get() }.try_send(message) } /// Completes when the receiver has dropped. @@ -276,22 +245,16 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - unsafe { self.channel.as_mut().unwrap().is_closed() } + unsafe { &mut *self.channel.get() }.is_closed() } } -struct SendFuture<'ch, M, T, const N: usize> -where - M: Mutex, -{ - sender: Sender<'ch, M, T, N>, +struct SendFuture<'ch, T> { + sender: Sender<'ch, T>, message: UnsafeCell, } -impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Future for SendFuture<'ch, T> { type Output = Result<(), SendError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -299,13 +262,7 @@ where Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - unsafe { - self.sender - .channel - .as_mut() - .unwrap() - .set_senders_waker(cx.waker().clone()); - }; + unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -315,53 +272,34 @@ where } } -struct CloseFuture<'ch, M, T, const N: usize> -where - M: Mutex, -{ - sender: Sender<'ch, M, T, N>, +struct CloseFuture<'ch, T> { + sender: Sender<'ch, T>, } -impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Future for CloseFuture<'ch, T> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.sender.is_closed() { Poll::Ready(()) } else { - unsafe { - self.sender - .channel - .as_mut() - .unwrap() - .set_senders_waker(cx.waker().clone()); - }; + unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); Poll::Pending } } } -impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Drop for Sender<'ch, T> { fn drop(&mut self) { - unsafe { self.channel.as_mut().unwrap().deregister_sender() } + unsafe { &mut *self.channel.get() }.deregister_sender() } } -impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Clone for Sender<'ch, T> { fn clone(&self) -> Self { - unsafe { self.channel.as_mut().unwrap().register_sender() }; + unsafe { &mut *self.channel.get() }.register_sender(); Sender { - channel: self.channel, - phantom_data: self.phantom_data, + channel: self.channel.clone(), } } } @@ -414,6 +352,28 @@ impl fmt::Display for TrySendError { } } +pub trait ChannelLike { + fn try_recv(&mut self) -> Result; + + fn try_send(&mut self, message: T) -> Result<(), TrySendError>; + + fn close(&mut self); + + fn is_closed(&mut self) -> bool; + + fn register_receiver(&mut self); + + fn deregister_receiver(&mut self); + + fn register_sender(&mut self); + + fn deregister_sender(&mut self); + + fn set_receiver_waker(&mut self, receiver_waker: Waker); + + fn set_senders_waker(&mut self, senders_waker: Waker); +} + pub struct ChannelState { buf: [MaybeUninit>; N], read_pos: usize, @@ -480,13 +440,14 @@ impl Channel { /// from exception mode e.g. interrupt handlers. To create one: /// /// ``` + /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; /// use embassy::util::mpsc::{Channel, WithCriticalSections}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = mpsc::Channel::::with_critical_sections(); + /// let mut channel = UnsafeCell::new(mpsc::Channel::::with_critical_sections()); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&mut channel); + /// let (sender, receiver) = mpsc::split(&channel); /// ``` pub const fn with_critical_sections() -> Self { let mutex = CriticalSectionMutex::new(()); @@ -504,13 +465,14 @@ impl Channel { /// channel avoids all locks. To create one: /// /// ``` no_run + /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = Channel::::with_thread_mode_only(); + /// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&mut channel); + /// let (sender, receiver) = mpsc::split(&channel); /// ``` pub const fn with_thread_mode_only() -> Self { let mutex = ThreadModeMutex::new(()); @@ -519,7 +481,7 @@ impl Channel { } } -impl Channel +impl ChannelLike for Channel where M: Mutex, { @@ -771,16 +733,16 @@ mod tests { #[test] fn simple_send_and_receive() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); } #[test] fn should_close_without_sender() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); drop(s); match r.try_recv() { Err(TryRecvError::Closed) => assert!(true), @@ -790,8 +752,8 @@ mod tests { #[test] fn should_close_once_drained() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); assert!(s.try_send(1).is_ok()); drop(s); assert_eq!(r.try_recv().unwrap(), 1); @@ -803,8 +765,8 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); drop(r); match s.try_send(1) { Err(TrySendError::Closed(1)) => assert!(true), @@ -814,8 +776,8 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let mut c = Channel::::with_no_threads(); - let (s, mut r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, mut r) = split(&c); assert!(s.try_send(1).is_ok()); r.close(); assert_eq!(r.try_recv().unwrap(), 1); @@ -830,9 +792,9 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { drop(s); @@ -845,12 +807,12 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { - let _ = s.try_send(1); + assert!(s.try_send(1).is_ok()); }) .is_ok()); assert_eq!(r.recv().await, Some(1)); @@ -858,18 +820,18 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); } #[futures_test::test] async fn sender_send_completes_if_closed() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, r) = split(unsafe { &CHANNEL }); drop(r); match s.send(1).await { Err(SendError(1)) => assert!(true), @@ -881,9 +843,9 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s0, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s0, mut r) = split(unsafe { &CHANNEL }); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); @@ -901,18 +863,18 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); r.close(); s.closed().await; } #[futures_test::test] async fn sender_close_completes_if_closed() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, r) = split(unsafe { &CHANNEL }); drop(r); s.closed().await; } diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index 6a0f8f47..d692abee 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -8,6 +8,8 @@ #[path = "../example_common.rs"] mod example_common; +use core::cell::UnsafeCell; + use defmt::panic; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; @@ -23,10 +25,10 @@ enum LedState { Off, } -static CHANNEL: Forever> = Forever::new(); +static CHANNEL: Forever>> = Forever::new(); #[embassy::task(pool_size = 1)] -async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { +async fn my_task(sender: Sender<'static, LedState>) { loop { let _ = sender.send(LedState::On).await; Timer::after(Duration::from_secs(1)).await; @@ -39,7 +41,7 @@ async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { async fn main(spawner: Spawner, p: Peripherals) { let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(Channel::with_thread_mode_only()); + let channel = CHANNEL.put(UnsafeCell::new(Channel::with_thread_mode_only())); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap();