diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index fc8006c3..65e4bf7b 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -51,55 +51,33 @@ use super::CriticalSectionMutex; use super::Mutex; use super::ThreadModeMutex; -/// A ChannelCell permits a channel to be shared between senders and their receivers. -// Derived from UnsafeCell. -#[repr(transparent)] -pub struct ChannelCell { - _value: T, -} - -impl ChannelCell { - #[inline(always)] - pub const fn new(value: T) -> ChannelCell - where - T: ChannelLike, - { - ChannelCell { _value: value } - } -} - -impl ChannelCell { - #[inline(always)] - const fn get(&self) -> *mut T { - // As per UnsafeCell: - // We can just cast the pointer from `ChannelCell` to `T` because of - // #[repr(transparent)]. This exploits libstd's special status, there is - // no guarantee for user code that this will work in future versions of the compiler! - self as *const ChannelCell as *const T as *mut T - } -} - /// Send values to the associated `Receiver`. /// /// Instances are created by the [`split`](split) function. -pub struct Sender<'ch, T> { - channel: &'ch ChannelCell>, +pub struct Sender<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: &'ch Channel, } // Safe to pass the sender around -unsafe impl<'ch, T> Send for Sender<'ch, T> {} -unsafe impl<'ch, T> Sync for Sender<'ch, T> {} +unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex {} /// Receive values from the associated `Sender`. /// /// Instances are created by the [`split`](split) function. -pub struct Receiver<'ch, T> { - channel: &'ch ChannelCell>, +pub struct Receiver<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: &'ch Channel, } // Safe to pass the receiver around -unsafe impl<'ch, T> Send for Receiver<'ch, T> {} -unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} +unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: Mutex {} /// Splits a bounded mpsc channel into a `Sender` and `Receiver`. /// @@ -125,18 +103,26 @@ unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} /// mpsc::split(&channel) /// }; /// ``` -pub fn split(channel: &ChannelCell>) -> (Sender, Receiver) { +pub fn split( + channel: &Channel, +) -> (Sender, Receiver) +where + M: Mutex, +{ let sender = Sender { channel: &channel }; let receiver = Receiver { channel: &channel }; { - let c = unsafe { &mut *channel.get() }; + let c = channel.get(); c.register_receiver(); c.register_sender(); } (sender, receiver) } -impl<'ch, T> Receiver<'ch, T> { +impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> +where + M: Mutex, +{ /// Receives the next value for this receiver. /// /// This method returns `None` if the channel has been closed and there are @@ -162,7 +148,7 @@ impl<'ch, T> Receiver<'ch, T> { /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - unsafe { &mut *self.channel.get() }.try_recv() + self.channel.get().try_recv() } /// Closes the receiving half of a channel without dropping it. @@ -176,11 +162,14 @@ impl<'ch, T> Receiver<'ch, T> { /// until those are released. /// pub fn close(&mut self) { - unsafe { &mut *self.channel.get() }.close() + self.channel.get().close() } } -impl<'ch, T> Future for Receiver<'ch, T> { +impl<'ch, M, T, const N: usize> Future for Receiver<'ch, M, T, N> +where + M: Mutex, +{ type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -188,20 +177,26 @@ impl<'ch, T> Future for Receiver<'ch, T> { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => { - unsafe { &mut *self.channel.get() }.set_receiver_waker(cx.waker().clone()); + self.channel.get().set_receiver_waker(cx.waker().clone()); Poll::Pending } } } } -impl<'ch, T> Drop for Receiver<'ch, T> { +impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> +where + M: Mutex, +{ fn drop(&mut self) { - unsafe { &mut *self.channel.get() }.deregister_receiver() + self.channel.get().deregister_receiver() } } -impl<'ch, T> Sender<'ch, T> { +impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> +where + M: Mutex, +{ /// Sends a value, waiting until there is capacity. /// /// A successful send occurs when it is determined that the other end of the @@ -249,7 +244,7 @@ impl<'ch, T> Sender<'ch, T> { /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - unsafe { &mut *self.channel.get() }.try_send(message) + self.channel.get().try_send(message) } /// Completes when the receiver has dropped. @@ -270,16 +265,22 @@ impl<'ch, T> Sender<'ch, T> { /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - unsafe { &mut *self.channel.get() }.is_closed() + self.channel.get().is_closed() } } -struct SendFuture<'ch, T> { - sender: Sender<'ch, T>, +struct SendFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, message: UnsafeCell, } -impl<'ch, T> Future for SendFuture<'ch, T> { +impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> +where + M: Mutex, +{ type Output = Result<(), SendError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -287,7 +288,10 @@ impl<'ch, T> Future for SendFuture<'ch, T> { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); + self.sender + .channel + .get() + .set_senders_waker(cx.waker().clone()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -297,33 +301,48 @@ impl<'ch, T> Future for SendFuture<'ch, T> { } } -struct CloseFuture<'ch, T> { - sender: Sender<'ch, T>, +struct CloseFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, } -impl<'ch, T> Future for CloseFuture<'ch, T> { +impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> +where + M: Mutex, +{ type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.sender.is_closed() { Poll::Ready(()) } else { - unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); + self.sender + .channel + .get() + .set_senders_waker(cx.waker().clone()); Poll::Pending } } } -impl<'ch, T> Drop for Sender<'ch, T> { +impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> +where + M: Mutex, +{ fn drop(&mut self) { - unsafe { &mut *self.channel.get() }.deregister_sender() + self.channel.get().deregister_sender() } } -impl<'ch, T> Clone for Sender<'ch, T> { +impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> +where + M: Mutex, +{ #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - unsafe { &mut *self.channel.get() }.register_sender(); + self.channel.get().register_sender(); Sender { channel: self.channel.clone(), } @@ -378,28 +397,6 @@ impl fmt::Display for TrySendError { } } -pub trait ChannelLike { - fn try_recv(&mut self) -> Result; - - fn try_send(&mut self, message: T) -> Result<(), TrySendError>; - - fn close(&mut self); - - fn is_closed(&mut self) -> bool; - - fn register_receiver(&mut self); - - fn deregister_receiver(&mut self); - - fn register_sender(&mut self); - - fn deregister_sender(&mut self); - - fn set_receiver_waker(&mut self, receiver_waker: Waker); - - fn set_senders_waker(&mut self, senders_waker: Waker); -} - struct ChannelState { buf: [MaybeUninit>; N], read_pos: usize, @@ -505,10 +502,16 @@ impl Channel { } } -impl ChannelLike for Channel +impl Channel where M: Mutex, { + fn get(&self) -> &mut Self { + let const_ptr = self as *const Self; + let mut_ptr = const_ptr as *mut Self; + unsafe { &mut *mut_ptr } + } + fn try_recv(&mut self) -> Result { let state = &mut self.state; self.mutex.lock(|_| { diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index eafa29e6..6a0f8f47 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -11,7 +11,7 @@ mod example_common; use defmt::panic; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; -use embassy::util::mpsc::{ChannelCell, TryRecvError}; +use embassy::util::mpsc::TryRecvError; use embassy::util::{mpsc, Forever}; use embassy_nrf::gpio::{Level, Output, OutputDrive}; use embassy_nrf::Peripherals; @@ -23,10 +23,10 @@ enum LedState { Off, } -static CHANNEL: Forever>> = Forever::new(); +static CHANNEL: Forever> = Forever::new(); #[embassy::task(pool_size = 1)] -async fn my_task(sender: Sender<'static, LedState>) { +async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { loop { let _ = sender.send(LedState::On).await; Timer::after(Duration::from_secs(1)).await; @@ -39,7 +39,7 @@ async fn my_task(sender: Sender<'static, LedState>) { async fn main(spawner: Spawner, p: Peripherals) { let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(ChannelCell::new(Channel::with_thread_mode_only())); + let channel = CHANNEL.put(Channel::with_thread_mode_only()); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap();