diff --git a/embassy/src/channel/mpsc.rs b/embassy/src/channel/mpsc.rs index 3585b803..9a57c0b1 100644 --- a/embassy/src/channel/mpsc.rs +++ b/embassy/src/channel/mpsc.rs @@ -37,7 +37,7 @@ //! //! This channel and its associated types were derived from https://docs.rs/tokio/0.1.22/tokio/sync/mpsc/fn.channel.html -use core::cell::UnsafeCell; +use core::cell::RefCell; use core::fmt; use core::pin::Pin; use core::task::Context; @@ -47,7 +47,8 @@ use core::task::Waker; use futures::Future; use heapless::Deque; -use crate::blocking_mutex::{CriticalSectionMutex, Mutex, NoopMutex, ThreadModeMutex}; +use crate::blocking_mutex::kind::MutexKind; +use crate::blocking_mutex::Mutex; use crate::waitqueue::WakerRegistration; /// Send values to the associated `Receiver`. @@ -55,35 +56,19 @@ use crate::waitqueue::WakerRegistration; /// Instances are created by the [`split`](split) function. pub struct Sender<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - channel_cell: &'ch UnsafeCell>, + channel: &'ch Channel, } -// Safe to pass the sender around -unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex + Sync -{} -unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex + Sync -{} - /// Receive values from the associated `Sender`. /// /// Instances are created by the [`split`](split) function. pub struct Receiver<'ch, M, T, const N: usize> where - M: Mutex, -{ - channel_cell: &'ch UnsafeCell>, -} - -// Safe to pass the receiver around -unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where - M: Mutex + Sync -{ -} -unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where - M: Mutex + Sync + M: MutexKind, { + channel: &'ch Channel, } /// Splits a bounded mpsc channel into a `Sender` and `Receiver`. @@ -114,15 +99,11 @@ pub fn split( channel: &mut Channel, ) -> (Sender, Receiver) where - M: Mutex, + M: MutexKind, { - let sender = Sender { - channel_cell: &channel.channel_cell, - }; - let receiver = Receiver { - channel_cell: &channel.channel_cell, - }; - Channel::lock(&channel.channel_cell, |c| { + let sender = Sender { channel }; + let receiver = Receiver { channel }; + channel.lock(|c| { c.register_receiver(); c.register_sender(); }); @@ -131,7 +112,7 @@ where impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { /// Receives the next value for this receiver. /// @@ -151,7 +132,7 @@ where /// [`close`]: Self::close pub fn recv<'m>(&'m mut self) -> RecvFuture<'m, M, T, N> { RecvFuture { - channel_cell: self.channel_cell, + channel: self.channel, } } @@ -160,7 +141,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - Channel::lock(self.channel_cell, |c| c.try_recv()) + self.channel.lock(|c| c.try_recv()) } /// Closes the receiving half of a channel without dropping it. @@ -174,56 +155,45 @@ where /// until those are released. /// pub fn close(&mut self) { - Channel::lock(self.channel_cell, |c| c.close()) + self.channel.lock(|c| c.close()) } } impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { fn drop(&mut self) { - Channel::lock(self.channel_cell, |c| c.deregister_receiver()) + self.channel.lock(|c| c.deregister_receiver()) } } pub struct RecvFuture<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - channel_cell: &'ch UnsafeCell>, + channel: &'ch Channel, } impl<'ch, M, T, const N: usize> Future for RecvFuture<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Channel::lock(self.channel_cell, |c| { - match c.try_recv_with_context(Some(cx)) { + self.channel + .lock(|c| match c.try_recv_with_context(Some(cx)) { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => Poll::Pending, - } - }) + }) } } -// Safe to pass the receive future around since it locks channel whenever polled -unsafe impl<'ch, M, T, const N: usize> Send for RecvFuture<'ch, M, T, N> where - M: Mutex + Sync -{ -} -unsafe impl<'ch, M, T, const N: usize> Sync for RecvFuture<'ch, M, T, N> where - M: Mutex + Sync -{ -} - impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { /// Sends a value, waiting until there is capacity. /// @@ -245,7 +215,7 @@ where /// [`Receiver`]: Receiver pub fn send(&self, message: T) -> SendFuture<'ch, M, T, N> { SendFuture { - sender: self.clone(), + channel: self.channel, message: Some(message), } } @@ -271,7 +241,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - Channel::lock(self.channel_cell, |c| c.try_send(message)) + self.channel.lock(|c| c.try_send(message)) } /// Completes when the receiver has dropped. @@ -280,7 +250,7 @@ where /// values is canceled and immediately stop doing work. pub async fn closed(&self) { CloseFuture { - sender: self.clone(), + channel: self.channel, } .await } @@ -292,29 +262,27 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - Channel::lock(self.channel_cell, |c| c.is_closed()) + self.channel.lock(|c| c.is_closed()) } } pub struct SendFuture<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - sender: Sender<'ch, M, T, N>, + channel: &'ch Channel, message: Option, } impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { type Output = Result<(), SendError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.message.take() { - Some(m) => match Channel::lock(self.sender.channel_cell, |c| { - c.try_send_with_context(m, Some(cx)) - }) { + Some(m) => match self.channel.lock(|c| c.try_send_with_context(m, Some(cx))) { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(m)) => { @@ -327,25 +295,23 @@ where } } -impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: Mutex {} +impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: MutexKind {} struct CloseFuture<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - sender: Sender<'ch, M, T, N>, + channel: &'ch Channel, } impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if Channel::lock(self.sender.channel_cell, |c| { - c.is_closed_with_context(Some(cx)) - }) { + if self.channel.lock(|c| c.is_closed_with_context(Some(cx))) { Poll::Ready(()) } else { Poll::Pending @@ -355,22 +321,21 @@ where impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { fn drop(&mut self) { - Channel::lock(self.channel_cell, |c| c.deregister_sender()) + self.channel.lock(|c| c.deregister_sender()) } } impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { - #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - Channel::lock(self.channel_cell, |c| c.register_sender()); + self.channel.lock(|c| c.register_sender()); Sender { - channel_cell: self.channel_cell.clone(), + channel: self.channel, } } } @@ -581,59 +546,35 @@ impl ChannelState { /// All data sent will become available in the same order as it was sent. pub struct Channel where - M: Mutex, + M: MutexKind, { - channel_cell: UnsafeCell>, + inner: M::Mutex>>, } -struct ChannelCell -where - M: Mutex, -{ - mutex: M, - state: ChannelState, -} - -pub type WithCriticalSections = CriticalSectionMutex<()>; - -pub type WithThreadModeOnly = ThreadModeMutex<()>; - -pub type WithNoThreads = NoopMutex<()>; - impl Channel where - M: Mutex, + M: MutexKind, { /// Establish a new bounded channel. For example, to create one with a NoopMutex: /// /// ``` /// use embassy::channel::mpsc; - /// use embassy::channel::mpsc::{Channel, WithNoThreads}; + /// use embassy::blocking_mutex::kind::Noop; + /// use embassy::channel::mpsc::Channel; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = Channel::::new(); + /// let mut channel = Channel::::new(); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub fn new() -> Self { - let mutex = M::new(()); - let state = ChannelState::new(); - let channel_cell = ChannelCell { mutex, state }; - Channel { - channel_cell: UnsafeCell::new(channel_cell), + Self { + inner: M::Mutex::new(RefCell::new(ChannelState::new())), } } - fn lock( - channel_cell: &UnsafeCell>, - f: impl FnOnce(&mut ChannelState) -> R, - ) -> R { - unsafe { - let channel_cell = &mut *(channel_cell.get()); - let mutex = &mut channel_cell.mutex; - let mut state = &mut channel_cell.state; - mutex.lock(|_| f(&mut state)) - } + fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { + self.inner.lock(|rc| f(&mut *rc.borrow_mut())) } } @@ -645,6 +586,7 @@ mod tests { use futures_executor::ThreadPool; use futures_timer::Delay; + use crate::blocking_mutex::kind::{CriticalSection, Noop}; use crate::util::Forever; use super::*; @@ -713,7 +655,7 @@ mod tests { #[test] fn simple_send_and_receive() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); @@ -721,7 +663,7 @@ mod tests { #[test] fn should_close_without_sender() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); drop(s); match r.try_recv() { @@ -732,7 +674,7 @@ mod tests { #[test] fn should_close_once_drained() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); assert!(s.try_send(1).is_ok()); drop(s); @@ -745,7 +687,7 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); drop(r); match s.try_send(1) { @@ -756,7 +698,7 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, mut r) = split(&mut c); assert!(s.try_send(1).is_ok()); r.close(); @@ -772,7 +714,7 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, mut r) = split(c); assert!(executor @@ -787,7 +729,7 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, mut r) = split(c); assert!(executor @@ -800,7 +742,7 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, mut r) = split(&mut c); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); @@ -808,7 +750,7 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_closed() { - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, r) = split(c); drop(r); @@ -822,7 +764,7 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s0, mut r) = split(c); assert!(s0.try_send(1).is_ok()); @@ -842,7 +784,7 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, mut r) = split(c); r.close(); @@ -851,7 +793,7 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closed() { - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, r) = split(c); drop(r); diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index 79fa3dfb..c85b7c28 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -6,7 +6,8 @@ mod example_common; use defmt::unwrap; -use embassy::channel::mpsc::{self, Channel, Sender, TryRecvError, WithNoThreads}; +use embassy::blocking_mutex::kind::Noop; +use embassy::channel::mpsc::{self, Channel, Sender, TryRecvError}; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; use embassy::util::Forever; @@ -19,10 +20,10 @@ enum LedState { Off, } -static CHANNEL: Forever> = Forever::new(); +static CHANNEL: Forever> = Forever::new(); #[embassy::task(pool_size = 1)] -async fn my_task(sender: Sender<'static, WithNoThreads, LedState, 1>) { +async fn my_task(sender: Sender<'static, Noop, LedState, 1>) { loop { let _ = sender.send(LedState::On).await; Timer::after(Duration::from_secs(1)).await;