Partial borrow for receiver to enforce compile-time mpssc
This commit is contained in:
		| @@ -39,6 +39,7 @@ | ||||
|  | ||||
| use core::cell::UnsafeCell; | ||||
| use core::fmt; | ||||
| use core::marker::PhantomData; | ||||
| use core::mem::MaybeUninit; | ||||
| use core::pin::Pin; | ||||
| use core::ptr; | ||||
| @@ -61,7 +62,7 @@ pub struct Sender<'ch, M, T, const N: usize> | ||||
| where | ||||
|     M: Mutex<Data = ()>, | ||||
| { | ||||
|     channel: &'ch Channel<M, T, N>, | ||||
|     channel_cell: &'ch UnsafeCell<ChannelCell<M, T, N>>, | ||||
| } | ||||
|  | ||||
| // Safe to pass the sender around | ||||
| @@ -77,7 +78,8 @@ pub struct Receiver<'ch, M, T, const N: usize> | ||||
| where | ||||
|     M: Mutex<Data = ()>, | ||||
| { | ||||
|     channel: &'ch Channel<M, T, N>, | ||||
|     channel_cell: &'ch UnsafeCell<ChannelCell<M, T, N>>, | ||||
|     _receiver_consumed: &'ch mut PhantomData<()>, | ||||
| } | ||||
|  | ||||
| // Safe to pass the receiver around | ||||
| @@ -111,18 +113,23 @@ unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where | ||||
| /// | ||||
| /// let (sender, receiver) = { | ||||
| ///    let mut channel = Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only(); | ||||
| ///     mpsc::split(&channel) | ||||
| ///     mpsc::split(&mut channel) | ||||
| /// }; | ||||
| /// ``` | ||||
| pub fn split<M, T, const N: usize>( | ||||
|     channel: &Channel<M, T, N>, | ||||
|     channel: &mut Channel<M, T, N>, | ||||
| ) -> (Sender<M, T, N>, Receiver<M, T, N>) | ||||
| where | ||||
|     M: Mutex<Data = ()>, | ||||
| { | ||||
|     let sender = Sender { channel: &channel }; | ||||
|     let receiver = Receiver { channel: &channel }; | ||||
|     channel.lock(|c| { | ||||
|     let sender = Sender { | ||||
|         channel_cell: &channel.channel_cell, | ||||
|     }; | ||||
|     let receiver = Receiver { | ||||
|         channel_cell: &channel.channel_cell, | ||||
|         _receiver_consumed: &mut channel.receiver_consumed, | ||||
|     }; | ||||
|     Channel::lock(&channel.channel_cell, |c| { | ||||
|         c.register_receiver(); | ||||
|         c.register_sender(); | ||||
|     }); | ||||
| @@ -154,12 +161,13 @@ where | ||||
|     } | ||||
|  | ||||
|     fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { | ||||
|         self.channel | ||||
|             .lock(|c| match c.try_recv_with_context(Some(cx)) { | ||||
|         Channel::lock(self.channel_cell, |c| { | ||||
|             match c.try_recv_with_context(Some(cx)) { | ||||
|                 Ok(v) => Poll::Ready(Some(v)), | ||||
|                 Err(TryRecvError::Closed) => Poll::Ready(None), | ||||
|                 Err(TryRecvError::Empty) => Poll::Pending, | ||||
|             }) | ||||
|             } | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     /// Attempts to immediately receive a message on this `Receiver` | ||||
| @@ -167,7 +175,7 @@ where | ||||
|     /// This method will either receive a message from the channel immediately or return an error | ||||
|     /// if the channel is empty. | ||||
|     pub fn try_recv(&self) -> Result<T, TryRecvError> { | ||||
|         self.channel.lock(|c| c.try_recv()) | ||||
|         Channel::lock(self.channel_cell, |c| c.try_recv()) | ||||
|     } | ||||
|  | ||||
|     /// Closes the receiving half of a channel without dropping it. | ||||
| @@ -181,7 +189,7 @@ where | ||||
|     /// until those are released. | ||||
|     /// | ||||
|     pub fn close(&mut self) { | ||||
|         self.channel.lock(|c| c.close()) | ||||
|         Channel::lock(self.channel_cell, |c| c.close()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -190,7 +198,7 @@ where | ||||
|     M: Mutex<Data = ()>, | ||||
| { | ||||
|     fn drop(&mut self) { | ||||
|         self.channel.lock(|c| c.deregister_receiver()) | ||||
|         Channel::lock(self.channel_cell, |c| c.deregister_receiver()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -245,7 +253,7 @@ where | ||||
|     /// [`channel`]: channel | ||||
|     /// [`close`]: Receiver::close | ||||
|     pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { | ||||
|         self.channel.lock(|c| c.try_send(message)) | ||||
|         Channel::lock(self.channel_cell, |c| c.try_send(message)) | ||||
|     } | ||||
|  | ||||
|     /// Completes when the receiver has dropped. | ||||
| @@ -266,7 +274,7 @@ where | ||||
|     /// [`Receiver`]: crate::sync::mpsc::Receiver | ||||
|     /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close | ||||
|     pub fn is_closed(&self) -> bool { | ||||
|         self.channel.lock(|c| c.is_closed()) | ||||
|         Channel::lock(self.channel_cell, |c| c.is_closed()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -286,11 +294,9 @@ where | ||||
|  | ||||
|     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||
|         match self.message.take() { | ||||
|             Some(m) => match self | ||||
|                 .sender | ||||
|                 .channel | ||||
|                 .lock(|c| c.try_send_with_context(m, Some(cx))) | ||||
|             { | ||||
|             Some(m) => match Channel::lock(self.sender.channel_cell, |c| { | ||||
|                 c.try_send_with_context(m, Some(cx)) | ||||
|             }) { | ||||
|                 Ok(..) => Poll::Ready(Ok(())), | ||||
|                 Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), | ||||
|                 Err(TrySendError::Full(m)) => { | ||||
| @@ -319,11 +325,9 @@ where | ||||
|     type Output = (); | ||||
|  | ||||
|     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||
|         if self | ||||
|             .sender | ||||
|             .channel | ||||
|             .lock(|c| c.is_closed_with_context(Some(cx))) | ||||
|         { | ||||
|         if Channel::lock(self.sender.channel_cell, |c| { | ||||
|             c.is_closed_with_context(Some(cx)) | ||||
|         }) { | ||||
|             Poll::Ready(()) | ||||
|         } else { | ||||
|             Poll::Pending | ||||
| @@ -336,7 +340,7 @@ where | ||||
|     M: Mutex<Data = ()>, | ||||
| { | ||||
|     fn drop(&mut self) { | ||||
|         self.channel.lock(|c| c.deregister_sender()) | ||||
|         Channel::lock(self.channel_cell, |c| c.deregister_sender()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -346,9 +350,9 @@ where | ||||
| { | ||||
|     #[allow(clippy::clone_double_ref)] | ||||
|     fn clone(&self) -> Self { | ||||
|         self.channel.lock(|c| c.register_sender()); | ||||
|         Channel::lock(self.channel_cell, |c| c.register_sender()); | ||||
|         Sender { | ||||
|             channel: self.channel.clone(), | ||||
|             channel_cell: self.channel_cell.clone(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -564,6 +568,7 @@ where | ||||
|     M: Mutex<Data = ()>, | ||||
| { | ||||
|     channel_cell: UnsafeCell<ChannelCell<M, T, N>>, | ||||
|     receiver_consumed: PhantomData<()>, | ||||
| } | ||||
|  | ||||
| struct ChannelCell<M, T, const N: usize> | ||||
| @@ -588,7 +593,7 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> { | ||||
|     /// // Declare a bounded channel of 3 u32s. | ||||
|     /// let mut channel = Channel::<WithCriticalSections, u32, 3>::with_critical_sections(); | ||||
|     /// // once we have a channel, obtain its sender and receiver | ||||
|     /// let (sender, receiver) = mpsc::split(&channel); | ||||
|     /// let (sender, receiver) = mpsc::split(&mut channel); | ||||
|     /// ``` | ||||
|     pub const fn with_critical_sections() -> Self { | ||||
|         let mutex = CriticalSectionMutex::new(()); | ||||
| @@ -596,6 +601,7 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> { | ||||
|         let channel_cell = ChannelCell { mutex, state }; | ||||
|         Channel { | ||||
|             channel_cell: UnsafeCell::new(channel_cell), | ||||
|             receiver_consumed: PhantomData, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -615,7 +621,7 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> { | ||||
|     /// // Declare a bounded channel of 3 u32s. | ||||
|     /// let mut channel = Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only(); | ||||
|     /// // once we have a channel, obtain its sender and receiver | ||||
|     /// let (sender, receiver) = mpsc::split(&channel); | ||||
|     /// let (sender, receiver) = mpsc::split(&mut channel); | ||||
|     /// ``` | ||||
|     pub const fn with_thread_mode_only() -> Self { | ||||
|         let mutex = ThreadModeMutex::new(()); | ||||
| @@ -623,6 +629,7 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> { | ||||
|         let channel_cell = ChannelCell { mutex, state }; | ||||
|         Channel { | ||||
|             channel_cell: UnsafeCell::new(channel_cell), | ||||
|             receiver_consumed: PhantomData, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -639,7 +646,7 @@ impl<T, const N: usize> Channel<WithNoThreads, T, N> { | ||||
|     /// // Declare a bounded channel of 3 u32s. | ||||
|     /// let mut channel = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|     /// // once we have a channel, obtain its sender and receiver | ||||
|     /// let (sender, receiver) = mpsc::split(&channel); | ||||
|     /// let (sender, receiver) = mpsc::split(&mut channel); | ||||
|     /// ``` | ||||
|     pub const fn with_no_threads() -> Self { | ||||
|         let mutex = NoopMutex::new(()); | ||||
| @@ -647,6 +654,7 @@ impl<T, const N: usize> Channel<WithNoThreads, T, N> { | ||||
|         let channel_cell = ChannelCell { mutex, state }; | ||||
|         Channel { | ||||
|             channel_cell: UnsafeCell::new(channel_cell), | ||||
|             receiver_consumed: PhantomData, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -655,9 +663,12 @@ impl<M, T, const N: usize> Channel<M, T, N> | ||||
| where | ||||
|     M: Mutex<Data = ()>, | ||||
| { | ||||
|     fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, N>) -> R) -> R { | ||||
|     fn lock<R>( | ||||
|         channel_cell: &UnsafeCell<ChannelCell<M, T, N>>, | ||||
|         f: impl FnOnce(&mut ChannelState<T, N>) -> R, | ||||
|     ) -> R { | ||||
|         unsafe { | ||||
|             let channel_cell = &mut *(self.channel_cell.get()); | ||||
|             let channel_cell = &mut *(channel_cell.get()); | ||||
|             let mutex = &mut channel_cell.mutex; | ||||
|             let mut state = &mut channel_cell.state; | ||||
|             mutex.lock(|_| f(&mut state)) | ||||
| @@ -747,16 +758,16 @@ mod tests { | ||||
|  | ||||
|     #[test] | ||||
|     fn simple_send_and_receive() { | ||||
|         let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&c); | ||||
|         let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&mut c); | ||||
|         assert!(s.clone().try_send(1).is_ok()); | ||||
|         assert_eq!(r.try_recv().unwrap(), 1); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn should_close_without_sender() { | ||||
|         let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&c); | ||||
|         let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&mut c); | ||||
|         drop(s); | ||||
|         match r.try_recv() { | ||||
|             Err(TryRecvError::Closed) => assert!(true), | ||||
| @@ -766,8 +777,8 @@ mod tests { | ||||
|  | ||||
|     #[test] | ||||
|     fn should_close_once_drained() { | ||||
|         let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&c); | ||||
|         let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&mut c); | ||||
|         assert!(s.try_send(1).is_ok()); | ||||
|         drop(s); | ||||
|         assert_eq!(r.try_recv().unwrap(), 1); | ||||
| @@ -779,8 +790,8 @@ mod tests { | ||||
|  | ||||
|     #[test] | ||||
|     fn should_reject_send_when_receiver_dropped() { | ||||
|         let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&c); | ||||
|         let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, r) = split(&mut c); | ||||
|         drop(r); | ||||
|         match s.try_send(1) { | ||||
|             Err(TrySendError::Closed(1)) => assert!(true), | ||||
| @@ -790,8 +801,8 @@ mod tests { | ||||
|  | ||||
|     #[test] | ||||
|     fn should_reject_send_when_channel_closed() { | ||||
|         let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, mut r) = split(&c); | ||||
|         let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); | ||||
|         let (s, mut r) = split(&mut c); | ||||
|         assert!(s.try_send(1).is_ok()); | ||||
|         r.close(); | ||||
|         assert_eq!(r.try_recv().unwrap(), 1); | ||||
| @@ -808,7 +819,7 @@ mod tests { | ||||
|  | ||||
|         static mut CHANNEL: Channel<WithCriticalSections, u32, 3> = | ||||
|             Channel::with_critical_sections(); | ||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||
|         let (s, mut r) = split(unsafe { &mut CHANNEL }); | ||||
|         assert!(executor | ||||
|             .spawn(async move { | ||||
|                 drop(s); | ||||
| @@ -823,7 +834,7 @@ mod tests { | ||||
|  | ||||
|         static mut CHANNEL: Channel<WithCriticalSections, u32, 3> = | ||||
|             Channel::with_critical_sections(); | ||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||
|         let (s, mut r) = split(unsafe { &mut CHANNEL }); | ||||
|         assert!(executor | ||||
|             .spawn(async move { | ||||
|                 assert!(s.try_send(1).is_ok()); | ||||
| @@ -836,7 +847,7 @@ mod tests { | ||||
|     async fn sender_send_completes_if_capacity() { | ||||
|         static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = | ||||
|             Channel::with_critical_sections(); | ||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||
|         let (s, mut r) = split(unsafe { &mut CHANNEL }); | ||||
|         assert!(s.send(1).await.is_ok()); | ||||
|         assert_eq!(r.recv().await, Some(1)); | ||||
|     } | ||||
| @@ -845,7 +856,7 @@ mod tests { | ||||
|     async fn sender_send_completes_if_closed() { | ||||
|         static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = | ||||
|             Channel::with_critical_sections(); | ||||
|         let (s, r) = split(unsafe { &CHANNEL }); | ||||
|         let (s, r) = split(unsafe { &mut CHANNEL }); | ||||
|         drop(r); | ||||
|         match s.send(1).await { | ||||
|             Err(SendError(1)) => assert!(true), | ||||
| @@ -859,7 +870,7 @@ mod tests { | ||||
|  | ||||
|         static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = | ||||
|             Channel::with_critical_sections(); | ||||
|         let (s0, mut r) = split(unsafe { &CHANNEL }); | ||||
|         let (s0, mut r) = split(unsafe { &mut CHANNEL }); | ||||
|         assert!(s0.try_send(1).is_ok()); | ||||
|         let s1 = s0.clone(); | ||||
|         let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); | ||||
| @@ -879,7 +890,7 @@ mod tests { | ||||
|     async fn sender_close_completes_if_closing() { | ||||
|         static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = | ||||
|             Channel::with_critical_sections(); | ||||
|         let (s, mut r) = split(unsafe { &CHANNEL }); | ||||
|         let (s, mut r) = split(unsafe { &mut CHANNEL }); | ||||
|         r.close(); | ||||
|         s.closed().await; | ||||
|     } | ||||
| @@ -888,7 +899,7 @@ mod tests { | ||||
|     async fn sender_close_completes_if_closed() { | ||||
|         static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = | ||||
|             Channel::with_critical_sections(); | ||||
|         let (s, r) = split(unsafe { &CHANNEL }); | ||||
|         let (s, r) = split(unsafe { &mut CHANNEL }); | ||||
|         drop(r); | ||||
|         s.closed().await; | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user