diff --git a/embassy/Cargo.toml b/embassy/Cargo.toml index b2ad8049..d2649024 100644 --- a/embassy/Cargo.toml +++ b/embassy/Cargo.toml @@ -3,6 +3,7 @@ name = "embassy" version = "0.1.0" authors = ["Dario Nieuwenhuis "] edition = "2018" +resolver = "2" [features] default = [] @@ -36,3 +37,9 @@ embedded-hal = "0.2.5" # Workaround https://github.com/japaric/cast.rs/pull/27 cast = { version = "=0.2.3", default-features = false } + +[dev-dependencies] +futures-executor = { version = "0.3", features = [ "thread-pool" ] } +futures-test = "0.3" +futures-timer = "0.3" +futures-util = { version = "0.3", features = [ "channel" ] } diff --git a/embassy/src/lib.rs b/embassy/src/lib.rs index 41102a18..845f82a3 100644 --- a/embassy/src/lib.rs +++ b/embassy/src/lib.rs @@ -7,6 +7,7 @@ #![feature(min_type_alias_impl_trait)] #![feature(impl_trait_in_bindings)] #![feature(type_alias_impl_trait)] +#![feature(maybe_uninit_ref)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; diff --git a/embassy/src/util/mod.rs b/embassy/src/util/mod.rs index 88ae5c28..87d313e2 100644 --- a/embassy/src/util/mod.rs +++ b/embassy/src/util/mod.rs @@ -11,6 +11,7 @@ mod waker; pub use drop_bomb::*; pub use forever::*; +pub mod mpsc; pub use mutex::*; pub use on_drop::*; pub use portal::*; diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs new file mode 100644 index 00000000..d24eb00b --- /dev/null +++ b/embassy/src/util/mpsc.rs @@ -0,0 +1,919 @@ +//! A multi-producer, single-consumer queue for sending values between +//! asynchronous tasks. This queue takes a Mutex type so that various +//! targets can be attained. For example, a ThreadModeMutex can be used +//! for single-core Cortex-M targets where messages are only passed +//! between tasks running in thread mode. Similarly, a CriticalSectionMutex +//! can also be used for single-core targets where messages are to be +//! passed from exception mode e.g. out of an interrupt handler. +//! +//! This module provides a bounded channel that has a limit on the number of +//! messages that it can store, and if this limit is reached, trying to send +//! another message will result in an error being returned. +//! +//! Similar to the `mpsc` channels provided by `std`, the channel constructor +//! functions provide separate send and receive handles, [`Sender`] and +//! [`Receiver`]. If there is no message to read, the current task will be +//! notified when a new value is sent. [`Sender`] allows sending values into +//! the channel. If the bounded channel is at capacity, the send is rejected. +//! +//! # Disconnection +//! +//! When all [`Sender`] handles have been dropped, it is no longer +//! possible to send values into the channel. This is considered the termination +//! event of the stream. +//! +//! If the [`Receiver`] handle is dropped, then messages can no longer +//! be read out of the channel. In this case, all further attempts to send will +//! result in an error. +//! +//! # Clean Shutdown +//! +//! When the [`Receiver`] is dropped, it is possible for unprocessed messages to +//! remain in the channel. Instead, it is usually desirable to perform a "clean" +//! shutdown. To do this, the receiver first calls `close`, which will prevent +//! any further messages to be sent into the channel. Then, the receiver +//! consumes the channel to completion, at which point the receiver can be +//! dropped. +//! +//! This channel and its associated types were derived from https://docs.rs/tokio/0.1.22/tokio/sync/mpsc/fn.channel.html + +use core::cell::UnsafeCell; +use core::fmt; +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::Context; +use core::task::Poll; +use core::task::Waker; + +use futures::Future; + +use super::CriticalSectionMutex; +use super::Mutex; +use super::ThreadModeMutex; + +/// Send values to the associated `Receiver`. +/// +/// Instances are created by the [`split`](split) function. +pub struct Sender<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: *mut Channel, + phantom_data: &'ch PhantomData, +} + +// Safe to pass the sender around +unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex {} + +/// Receive values from the associated `Sender`. +/// +/// Instances are created by the [`split`](split) function. +pub struct Receiver<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: *mut Channel, + _phantom_data: &'ch PhantomData, +} + +// Safe to pass the receiver around +unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: Mutex {} + +/// Splits a bounded mpsc channel into a `Sender` and `Receiver`. +/// +/// All data sent on `Sender` will become available on `Receiver` in the same +/// order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple code +/// locations. Only one `Receiver` is valid. +/// +/// If the `Receiver` is disconnected while trying to `send`, the `send` method +/// will return a `SendError`. Similarly, if `Sender` is disconnected while +/// trying to `recv`, the `recv` method will return a `RecvError`. +/// +/// Note that when splitting the channel, the sender and receiver cannot outlive +/// their channel. The following will therefore fail compilation: +//// +/// ```compile_fail +/// use embassy::util::mpsc; +/// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; +/// +/// let (sender, receiver) = { +/// let mut channel = Channel::::with_thread_mode_only(); +/// mpsc::split(&mut channel) +/// }; +/// ``` +pub fn split<'ch, M, T, const N: usize>( + channel: &'ch mut Channel, +) -> (Sender<'ch, M, T, N>, Receiver<'ch, M, T, N>) +where + M: Mutex, +{ + let sender = Sender { + channel, + phantom_data: &PhantomData, + }; + let receiver = Receiver { + channel, + _phantom_data: &PhantomData, + }; + channel.register_receiver(); + channel.register_sender(); + (sender, receiver) +} + +impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> +where + M: Mutex, +{ + /// Receives the next value for this receiver. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will sleep until a message is sent or + /// the channel is closed. + /// + /// Note that if [`close`] is called, but there are still outstanding + /// messages from before it was closed, the channel is not considered + /// closed by `recv` until they are all consumed. + /// + /// [`close`]: Self::close + pub async fn recv(&mut self) -> Option { + self.await + } + + /// Attempts to immediately receive a message on this `Receiver` + /// + /// This method will either receive a message from the channel immediately or return an error + /// if the channel is empty. + pub fn try_recv(&self) -> Result { + unsafe { self.channel.as_mut().unwrap().try_recv() } + } + + /// Closes the receiving half of a channel without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. + /// + /// To guarantee that no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. If there are + /// outstanding messages, the `recv` method will not return `None` + /// until those are released. + /// + pub fn close(&mut self) { + unsafe { self.channel.as_mut().unwrap().close() } + } +} + +impl<'ch, M, T, const N: usize> Future for Receiver<'ch, M, T, N> +where + M: Mutex, +{ + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.try_recv() { + Ok(v) => Poll::Ready(Some(v)), + Err(TryRecvError::Closed) => Poll::Ready(None), + Err(TryRecvError::Empty) => { + unsafe { + self.channel + .as_mut() + .unwrap() + .set_receiver_waker(cx.waker().clone()); + }; + Poll::Pending + } + } + } +} + +impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> +where + M: Mutex, +{ + fn drop(&mut self) { + unsafe { self.channel.as_mut().unwrap().deregister_receiver() } + } +} + +impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> +where + M: Mutex, +{ + /// Sends a value, waiting until there is capacity. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been closed. Note that a return + /// value of `Err` means that the data will never be received, but a return + /// value of `Ok` does not mean that the data will be received. It is + /// possible for the corresponding receiver to hang up immediately after + /// this function returns `Ok`. + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + pub async fn send(&self, message: T) -> Result<(), SendError> { + SendFuture { + sender: self.clone(), + message: UnsafeCell::new(message), + } + .await + } + + /// Attempts to immediately send a message on this `Sender` + /// + /// This method differs from [`send`] by returning immediately if the channel's + /// buffer is full or no receiver is waiting to acquire some data. Compared + /// with [`send`], this function has two failure cases instead of one (one for + /// disconnection, one for a full buffer). + /// + /// # Errors + /// + /// If the channel capacity has been reached, i.e., the channel has `n` + /// buffered values where `n` is the argument passed to [`channel`], then an + /// error is returned. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`send`]: Sender::send + /// [`channel`]: channel + /// [`close`]: Receiver::close + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + unsafe { self.channel.as_mut().unwrap().try_send(message) } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + pub async fn closed(&self) { + CloseFuture { + sender: self.clone(), + } + .await + } + + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. + /// + /// [`Receiver`]: crate::sync::mpsc::Receiver + /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close + pub fn is_closed(&self) -> bool { + unsafe { self.channel.as_mut().unwrap().is_closed() } + } +} + +struct SendFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, + message: UnsafeCell, +} + +impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> +where + M: Mutex, +{ + type Output = Result<(), SendError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.sender.try_send(unsafe { self.message.get().read() }) { + Ok(..) => Poll::Ready(Ok(())), + Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), + Err(TrySendError::Full(..)) => { + unsafe { + self.sender + .channel + .as_mut() + .unwrap() + .set_senders_waker(cx.waker().clone()); + }; + Poll::Pending + // Note we leave the existing UnsafeCell contents - they still + // contain the original message. We could create another UnsafeCell + // with the message of Full, but there's no real need. + } + } + } +} + +struct CloseFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, +} + +impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> +where + M: Mutex, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.sender.is_closed() { + Poll::Ready(()) + } else { + unsafe { + self.sender + .channel + .as_mut() + .unwrap() + .set_senders_waker(cx.waker().clone()); + }; + Poll::Pending + } + } +} + +impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> +where + M: Mutex, +{ + fn drop(&mut self) { + unsafe { self.channel.as_mut().unwrap().deregister_sender() } + } +} + +impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> +where + M: Mutex, +{ + fn clone(&self) -> Self { + unsafe { self.channel.as_mut().unwrap().register_sender() }; + Sender { + channel: self.channel, + phantom_data: self.phantom_data, + } + } +} + +/// An error returned from the [`try_recv`] method. +/// +/// [`try_recv`]: super::Receiver::try_recv +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum TryRecvError { + /// A message could not be received because the channel is empty. + Empty, + + /// The message could not be received because the channel is empty and closed. + Closed, +} + +/// Error returned by the `Sender`. +#[derive(Debug)] +pub struct SendError(pub T); + +impl fmt::Display for SendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +/// This enumeration is the list of the possible error outcomes for the +/// [try_send](super::Sender::try_send) method. +#[derive(Debug)] +pub enum TrySendError { + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + Full(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), +} + +impl fmt::Display for TrySendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + TrySendError::Full(..) => "no available capacity", + TrySendError::Closed(..) => "channel closed", + } + ) + } +} + +pub struct ChannelState { + buf: [MaybeUninit>; N], + read_pos: usize, + write_pos: usize, + full: bool, + closing: bool, + closed: bool, + receiver_registered: bool, + senders_registered: u32, + receiver_waker: Option, + senders_waker: Option, +} + +impl ChannelState { + const INIT: MaybeUninit> = MaybeUninit::uninit(); + + const fn new() -> Self { + let buf = [Self::INIT; N]; + let read_pos = 0; + let write_pos = 0; + let full = false; + let closing = false; + let closed = false; + let receiver_registered = false; + let senders_registered = 0; + let receiver_waker = None; + let senders_waker = None; + ChannelState { + buf, + read_pos, + write_pos, + full, + closing, + closed, + receiver_registered, + senders_registered, + receiver_waker, + senders_waker, + } + } +} + +/// A a bounded mpsc channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to `send` new messages will wait until a message is +/// received from the channel. +/// +/// All data sent will become available in the same order as it was sent. +pub struct Channel +where + M: Mutex, +{ + mutex: M, + state: ChannelState, +} + +pub type WithCriticalSections = CriticalSectionMutex<()>; + +impl Channel { + /// Establish a new bounded channel using critical sections. Critical sections + /// should be used only single core targets where communication is required + /// from exception mode e.g. interrupt handlers. To create one: + /// + /// ``` + /// use embassy::util::mpsc; + /// use embassy::util::mpsc::{Channel, WithCriticalSections}; + /// + /// // Declare a bounded channel of 3 u32s. + /// let mut channel = mpsc::Channel::::with_critical_sections(); + /// // once we have a channel, obtain its sender and receiver + /// let (sender, receiver) = mpsc::split(&mut channel); + /// ``` + pub const fn with_critical_sections() -> Self { + let mutex = CriticalSectionMutex::new(()); + let state = ChannelState::new(); + Channel { mutex, state } + } +} + +pub type WithThreadModeOnly = ThreadModeMutex<()>; + +impl Channel { + /// Establish a new bounded channel for use in Cortex-M thread mode. Thread + /// mode is intended for application threads on a single core, not interrupts. + /// As such, only one task at a time can acquire a resource and so this + /// channel avoids all locks. To create one: + /// + /// ``` no_run + /// use embassy::util::mpsc; + /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; + /// + /// // Declare a bounded channel of 3 u32s. + /// let mut channel = Channel::::with_thread_mode_only(); + /// // once we have a channel, obtain its sender and receiver + /// let (sender, receiver) = mpsc::split(&mut channel); + /// ``` + pub const fn with_thread_mode_only() -> Self { + let mutex = ThreadModeMutex::new(()); + let state = ChannelState::new(); + Channel { mutex, state } + } +} + +impl Channel +where + M: Mutex, +{ + fn try_recv(&mut self) -> Result { + let state = &mut self.state; + self.mutex.lock(|_| { + if !state.closed { + if state.read_pos != state.write_pos || state.full { + if state.full { + state.full = false; + state.senders_waker.take().map(|w| w.wake()); + } + let message = + unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; + state.read_pos = (state.read_pos + 1) % state.buf.len(); + Ok(message) + } else if !state.closing { + Err(TryRecvError::Empty) + } else { + state.closed = true; + state.senders_waker.take().map(|w| w.wake()); + Err(TryRecvError::Closed) + } + } else { + Err(TryRecvError::Closed) + } + }) + } + + fn try_send(&mut self, message: T) -> Result<(), TrySendError> { + let state = &mut self.state; + self.mutex.lock(|_| { + if !state.closed { + if !state.full { + state.buf[state.write_pos] = MaybeUninit::new(message.into()); + state.write_pos = (state.write_pos + 1) % state.buf.len(); + if state.write_pos == state.read_pos { + state.full = true; + } + state.receiver_waker.take().map(|w| w.wake()); + Ok(()) + } else { + Err(TrySendError::Full(message)) + } + } else { + Err(TrySendError::Closed(message)) + } + }) + } + + fn close(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + state.receiver_waker.take().map(|w| w.wake()); + state.closing = true; + }); + } + + fn is_closed(&mut self) -> bool { + let state = &self.state; + self.mutex.lock(|_| state.closing || state.closed) + } + + fn register_receiver(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + assert!(!state.receiver_registered); + state.receiver_registered = true; + }); + } + + fn deregister_receiver(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + if state.receiver_registered { + state.closed = true; + state.senders_waker.take().map(|w| w.wake()); + } + state.receiver_registered = false; + }) + } + + fn register_sender(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + state.senders_registered = state.senders_registered + 1; + }) + } + + fn deregister_sender(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + assert!(state.senders_registered > 0); + state.senders_registered = state.senders_registered - 1; + if state.senders_registered == 0 { + state.receiver_waker.take().map(|w| w.wake()); + state.closing = true; + } + }) + } + + fn set_receiver_waker(&mut self, receiver_waker: Waker) { + let state = &mut self.state; + self.mutex.lock(|_| { + state.receiver_waker = Some(receiver_waker); + }) + } + + fn set_senders_waker(&mut self, senders_waker: Waker) { + let state = &mut self.state; + self.mutex.lock(|_| { + + // Dispose of any existing sender causing them to be polled again. + // This could cause a spin given multiple concurrent senders, however given that + // most sends only block waiting for the receiver to become active, this should + // be a short-lived activity. The upside is a greatly simplified implementation + // that avoids the need for intrusive linked-lists and unsafe operations on pinned + // pointers. + if let Some(waker) = state.senders_waker.clone() { + if !senders_waker.will_wake(&waker) { + trace!("Waking an an active send waker due to being superseded with a new one. While benign, please report this."); + waker.wake(); + } + } + state.senders_waker = Some(senders_waker); + }) + } +} + +#[cfg(test)] +mod tests { + use core::time::Duration; + + use futures::task::SpawnExt; + use futures_executor::ThreadPool; + use futures_timer::Delay; + + use super::*; + + fn capacity(c: &Channel) -> usize + where + M: Mutex, + { + if !c.state.full { + if c.state.write_pos > c.state.read_pos { + (c.state.buf.len() - c.state.write_pos) + c.state.read_pos + } else { + (c.state.buf.len() - c.state.read_pos) + c.state.write_pos + } + } else { + 0 + } + } + + /// A mutex that does nothing - useful for our testing purposes + pub struct NoopMutex { + inner: UnsafeCell, + } + + impl NoopMutex { + pub const fn new(value: T) -> Self { + NoopMutex { + inner: UnsafeCell::new(value), + } + } + } + + impl NoopMutex { + pub fn borrow(&self) -> &T { + unsafe { &*self.inner.get() } + } + } + + impl Mutex for NoopMutex { + type Data = T; + + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + f(self.borrow()) + } + } + + pub type WithNoThreads = NoopMutex<()>; + + impl Channel { + pub const fn with_no_threads() -> Self { + let mutex = NoopMutex::new(()); + let state = ChannelState::new(); + Channel { mutex, state } + } + } + + #[test] + fn sending_once() { + let mut c = Channel::::with_no_threads(); + assert!(c.try_send(1).is_ok()); + assert_eq!(capacity(&c), 2); + } + + #[test] + fn sending_when_full() { + let mut c = Channel::::with_no_threads(); + let _ = c.try_send(1); + let _ = c.try_send(1); + let _ = c.try_send(1); + match c.try_send(2) { + Err(TrySendError::Full(2)) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 0); + } + + #[test] + fn sending_when_closed() { + let mut c = Channel::::with_no_threads(); + c.state.closed = true; + match c.try_send(2) { + Err(TrySendError::Closed(2)) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn receiving_once_with_one_send() { + let mut c = Channel::::with_no_threads(); + assert!(c.try_send(1).is_ok()); + assert_eq!(c.try_recv().unwrap(), 1); + assert_eq!(capacity(&c), 3); + } + + #[test] + fn receiving_when_empty() { + let mut c = Channel::::with_no_threads(); + match c.try_recv() { + Err(TryRecvError::Empty) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 3); + } + + #[test] + fn receiving_when_closed() { + let mut c = Channel::::with_no_threads(); + c.state.closed = true; + match c.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn simple_send_and_receive() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + assert!(s.clone().try_send(1).is_ok()); + assert_eq!(r.try_recv().unwrap(), 1); + } + + #[test] + fn should_close_without_sender() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + drop(s); + match r.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn should_close_once_drained() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + assert!(s.try_send(1).is_ok()); + drop(s); + assert_eq!(r.try_recv().unwrap(), 1); + match r.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn should_reject_send_when_receiver_dropped() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + drop(r); + match s.try_send(1) { + Err(TrySendError::Closed(1)) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn should_reject_send_when_channel_closed() { + let mut c = Channel::::with_no_threads(); + let (s, mut r) = split(&mut c); + assert!(s.try_send(1).is_ok()); + r.close(); + assert_eq!(r.try_recv().unwrap(), 1); + match r.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + assert!(s.is_closed()); + } + + #[futures_test::test] + async fn receiver_closes_when_sender_dropped_async() { + let executor = ThreadPool::new().unwrap(); + + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + assert!(executor + .spawn(async move { + drop(s); + }) + .is_ok()); + assert_eq!(r.recv().await, None); + } + + #[futures_test::test] + async fn receiver_receives_given_try_send_async() { + let executor = ThreadPool::new().unwrap(); + + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + assert!(executor + .spawn(async move { + let _ = s.try_send(1); + }) + .is_ok()); + assert_eq!(r.recv().await, Some(1)); + } + + #[futures_test::test] + async fn sender_send_completes_if_capacity() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + assert!(s.send(1).await.is_ok()); + assert_eq!(r.recv().await, Some(1)); + } + + #[futures_test::test] + async fn sender_send_completes_if_closed() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, r) = split(unsafe { &mut CHANNEL }); + drop(r); + match s.send(1).await { + Err(SendError(1)) => assert!(true), + _ => assert!(false), + } + } + + #[futures_test::test] + async fn senders_sends_wait_until_capacity() { + let executor = ThreadPool::new().unwrap(); + + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s0, mut r) = split(unsafe { &mut CHANNEL }); + assert!(s0.try_send(1).is_ok()); + let s1 = s0.clone(); + let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); + let send_task_2 = executor.spawn_with_handle(async move { s1.send(3).await }); + // Wish I could think of a means of determining that the async send is waiting instead. + // However, I've used the debugger to observe that the send does indeed wait. + assert!(Delay::new(Duration::from_millis(500)).await.is_ok()); + assert_eq!(r.recv().await, Some(1)); + assert!(executor + .spawn(async move { while let Some(_) = r.recv().await {} }) + .is_ok()); + assert!(send_task_1.unwrap().await.is_ok()); + assert!(send_task_2.unwrap().await.is_ok()); + } + + #[futures_test::test] + async fn sender_close_completes_if_closing() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + r.close(); + s.closed().await; + } + + #[futures_test::test] + async fn sender_close_completes_if_closed() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, r) = split(unsafe { &mut CHANNEL }); + drop(r); + s.closed().await; + } +} diff --git a/embassy/src/util/mutex.rs b/embassy/src/util/mutex.rs index e4b7764c..682fcb39 100644 --- a/embassy/src/util/mutex.rs +++ b/embassy/src/util/mutex.rs @@ -1,6 +1,17 @@ use core::cell::UnsafeCell; use critical_section::CriticalSection; +/// Any object implementing this trait guarantees exclusive access to the data contained +/// within the mutex for the duration of the lock. +/// Adapted from https://github.com/rust-embedded/mutex-trait. +pub trait Mutex { + /// Data protected by the mutex. + type Data; + + /// Creates a critical section and grants temporary access to the protected data. + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R; +} + /// A "mutex" based on critical sections /// /// # Safety @@ -33,6 +44,14 @@ impl CriticalSectionMutex { } } +impl Mutex for CriticalSectionMutex { + type Data = T; + + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + critical_section::with(|cs| f(self.borrow(cs))) + } +} + /// A "mutex" that only allows borrowing from thread mode. /// /// # Safety @@ -70,6 +89,14 @@ impl ThreadModeMutex { } } +impl Mutex for ThreadModeMutex { + type Data = T; + + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + f(self.borrow()) + } +} + pub fn in_thread_mode() -> bool { #[cfg(feature = "std")] return Some("main") == std::thread::current().name(); diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs new file mode 100644 index 00000000..6a0f8f47 --- /dev/null +++ b/examples/nrf/src/bin/mpsc.rs @@ -0,0 +1,64 @@ +#![no_std] +#![no_main] +#![feature(min_type_alias_impl_trait)] +#![feature(impl_trait_in_bindings)] +#![feature(type_alias_impl_trait)] +#![allow(incomplete_features)] + +#[path = "../example_common.rs"] +mod example_common; + +use defmt::panic; +use embassy::executor::Spawner; +use embassy::time::{Duration, Timer}; +use embassy::util::mpsc::TryRecvError; +use embassy::util::{mpsc, Forever}; +use embassy_nrf::gpio::{Level, Output, OutputDrive}; +use embassy_nrf::Peripherals; +use embedded_hal::digital::v2::OutputPin; +use mpsc::{Channel, Sender, WithThreadModeOnly}; + +enum LedState { + On, + Off, +} + +static CHANNEL: Forever> = Forever::new(); + +#[embassy::task(pool_size = 1)] +async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { + loop { + let _ = sender.send(LedState::On).await; + Timer::after(Duration::from_secs(1)).await; + let _ = sender.send(LedState::Off).await; + Timer::after(Duration::from_secs(1)).await; + } +} + +#[embassy::main] +async fn main(spawner: Spawner, p: Peripherals) { + let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); + + let channel = CHANNEL.put(Channel::with_thread_mode_only()); + let (sender, mut receiver) = mpsc::split(channel); + + spawner.spawn(my_task(sender)).unwrap(); + + // We could just loop on `receiver.recv()` for simplicity. The code below + // is optimized to drain the queue as fast as possible in the spirit of + // handling events as fast as possible. This optimization is benign when in + // thread mode, but can be useful when interrupts are sending messages + // with the channel having been created via with_critical_sections. + loop { + let maybe_message = match receiver.try_recv() { + m @ Ok(..) => m.ok(), + Err(TryRecvError::Empty) => receiver.recv().await, + Err(TryRecvError::Closed) => break, + }; + match maybe_message { + Some(LedState::On) => led.set_high().unwrap(), + Some(LedState::Off) => led.set_low().unwrap(), + _ => (), + } + } +}