diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index f350c6e5..246bd27e 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -39,6 +39,7 @@ use core::cell::UnsafeCell; use core::fmt; +use core::marker::PhantomData; use core::mem::MaybeUninit; use core::pin::Pin; use core::ptr; @@ -61,7 +62,7 @@ pub struct Sender<'ch, M, T, const N: usize> where M: Mutex, { - channel: &'ch Channel, + channel_cell: &'ch UnsafeCell>, } // Safe to pass the sender around @@ -77,7 +78,8 @@ pub struct Receiver<'ch, M, T, const N: usize> where M: Mutex, { - channel: &'ch Channel, + channel_cell: &'ch UnsafeCell>, + _receiver_consumed: &'ch mut PhantomData<()>, } // Safe to pass the receiver around @@ -111,18 +113,23 @@ unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where /// /// let (sender, receiver) = { /// let mut channel = Channel::::with_thread_mode_only(); -/// mpsc::split(&channel) +/// mpsc::split(&mut channel) /// }; /// ``` pub fn split( - channel: &Channel, + channel: &mut Channel, ) -> (Sender, Receiver) where M: Mutex, { - let sender = Sender { channel: &channel }; - let receiver = Receiver { channel: &channel }; - channel.lock(|c| { + let sender = Sender { + channel_cell: &channel.channel_cell, + }; + let receiver = Receiver { + channel_cell: &channel.channel_cell, + _receiver_consumed: &mut channel.receiver_consumed, + }; + Channel::lock(&channel.channel_cell, |c| { c.register_receiver(); c.register_sender(); }); @@ -154,12 +161,13 @@ where } fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll> { - self.channel - .lock(|c| match c.try_recv_with_context(Some(cx)) { + Channel::lock(self.channel_cell, |c| { + match c.try_recv_with_context(Some(cx)) { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => Poll::Pending, - }) + } + }) } /// Attempts to immediately receive a message on this `Receiver` @@ -167,7 +175,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - self.channel.lock(|c| c.try_recv()) + Channel::lock(self.channel_cell, |c| c.try_recv()) } /// Closes the receiving half of a channel without dropping it. @@ -181,7 +189,7 @@ where /// until those are released. /// pub fn close(&mut self) { - self.channel.lock(|c| c.close()) + Channel::lock(self.channel_cell, |c| c.close()) } } @@ -190,7 +198,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.lock(|c| c.deregister_receiver()) + Channel::lock(self.channel_cell, |c| c.deregister_receiver()) } } @@ -245,7 +253,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - self.channel.lock(|c| c.try_send(message)) + Channel::lock(self.channel_cell, |c| c.try_send(message)) } /// Completes when the receiver has dropped. @@ -266,7 +274,7 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - self.channel.lock(|c| c.is_closed()) + Channel::lock(self.channel_cell, |c| c.is_closed()) } } @@ -286,11 +294,9 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.message.take() { - Some(m) => match self - .sender - .channel - .lock(|c| c.try_send_with_context(m, Some(cx))) - { + Some(m) => match Channel::lock(self.sender.channel_cell, |c| { + c.try_send_with_context(m, Some(cx)) + }) { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(m)) => { @@ -319,11 +325,9 @@ where type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self - .sender - .channel - .lock(|c| c.is_closed_with_context(Some(cx))) - { + if Channel::lock(self.sender.channel_cell, |c| { + c.is_closed_with_context(Some(cx)) + }) { Poll::Ready(()) } else { Poll::Pending @@ -336,7 +340,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.lock(|c| c.deregister_sender()) + Channel::lock(self.channel_cell, |c| c.deregister_sender()) } } @@ -346,9 +350,9 @@ where { #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - self.channel.lock(|c| c.register_sender()); + Channel::lock(self.channel_cell, |c| c.register_sender()); Sender { - channel: self.channel.clone(), + channel_cell: self.channel_cell.clone(), } } } @@ -564,6 +568,7 @@ where M: Mutex, { channel_cell: UnsafeCell>, + receiver_consumed: PhantomData<()>, } struct ChannelCell @@ -588,7 +593,7 @@ impl Channel { /// // Declare a bounded channel of 3 u32s. /// let mut channel = Channel::::with_critical_sections(); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&channel); + /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub const fn with_critical_sections() -> Self { let mutex = CriticalSectionMutex::new(()); @@ -596,6 +601,7 @@ impl Channel { let channel_cell = ChannelCell { mutex, state }; Channel { channel_cell: UnsafeCell::new(channel_cell), + receiver_consumed: PhantomData, } } } @@ -615,7 +621,7 @@ impl Channel { /// // Declare a bounded channel of 3 u32s. /// let mut channel = Channel::::with_thread_mode_only(); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&channel); + /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub const fn with_thread_mode_only() -> Self { let mutex = ThreadModeMutex::new(()); @@ -623,6 +629,7 @@ impl Channel { let channel_cell = ChannelCell { mutex, state }; Channel { channel_cell: UnsafeCell::new(channel_cell), + receiver_consumed: PhantomData, } } } @@ -639,7 +646,7 @@ impl Channel { /// // Declare a bounded channel of 3 u32s. /// let mut channel = Channel::::with_no_threads(); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&channel); + /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub const fn with_no_threads() -> Self { let mutex = NoopMutex::new(()); @@ -647,6 +654,7 @@ impl Channel { let channel_cell = ChannelCell { mutex, state }; Channel { channel_cell: UnsafeCell::new(channel_cell), + receiver_consumed: PhantomData, } } } @@ -655,9 +663,12 @@ impl Channel where M: Mutex, { - fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { + fn lock( + channel_cell: &UnsafeCell>, + f: impl FnOnce(&mut ChannelState) -> R, + ) -> R { unsafe { - let channel_cell = &mut *(self.channel_cell.get()); + let channel_cell = &mut *(channel_cell.get()); let mutex = &mut channel_cell.mutex; let mut state = &mut channel_cell.state; mutex.lock(|_| f(&mut state)) @@ -747,16 +758,16 @@ mod tests { #[test] fn simple_send_and_receive() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); } #[test] fn should_close_without_sender() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); drop(s); match r.try_recv() { Err(TryRecvError::Closed) => assert!(true), @@ -766,8 +777,8 @@ mod tests { #[test] fn should_close_once_drained() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); assert!(s.try_send(1).is_ok()); drop(s); assert_eq!(r.try_recv().unwrap(), 1); @@ -779,8 +790,8 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); drop(r); match s.try_send(1) { Err(TrySendError::Closed(1)) => assert!(true), @@ -790,8 +801,8 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let c = Channel::::with_no_threads(); - let (s, mut r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, mut r) = split(&mut c); assert!(s.try_send(1).is_ok()); r.close(); assert_eq!(r.try_recv().unwrap(), 1); @@ -808,7 +819,7 @@ mod tests { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); assert!(executor .spawn(async move { drop(s); @@ -823,7 +834,7 @@ mod tests { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); assert!(executor .spawn(async move { assert!(s.try_send(1).is_ok()); @@ -836,7 +847,7 @@ mod tests { async fn sender_send_completes_if_capacity() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); } @@ -845,7 +856,7 @@ mod tests { async fn sender_send_completes_if_closed() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, r) = split(unsafe { &CHANNEL }); + let (s, r) = split(unsafe { &mut CHANNEL }); drop(r); match s.send(1).await { Err(SendError(1)) => assert!(true), @@ -859,7 +870,7 @@ mod tests { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s0, mut r) = split(unsafe { &CHANNEL }); + let (s0, mut r) = split(unsafe { &mut CHANNEL }); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); @@ -879,7 +890,7 @@ mod tests { async fn sender_close_completes_if_closing() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); r.close(); s.closed().await; } @@ -888,7 +899,7 @@ mod tests { async fn sender_close_completes_if_closed() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, r) = split(unsafe { &CHANNEL }); + let (s, r) = split(unsafe { &mut CHANNEL }); drop(r); s.closed().await; }