diff --git a/embassy/Cargo.toml b/embassy/Cargo.toml index 0a8ab443..ae06bc19 100644 --- a/embassy/Cargo.toml +++ b/embassy/Cargo.toml @@ -42,6 +42,7 @@ embassy-traits = { version = "0.1.0", path = "../embassy-traits"} atomic-polyfill = "0.1.3" critical-section = "0.2.1" embedded-hal = "0.2.6" +heapless = "0.7.5" [dev-dependencies] embassy = { path = ".", features = ["executor-agnostic"] } diff --git a/embassy/src/blocking_mutex/kind.rs b/embassy/src/blocking_mutex/kind.rs new file mode 100644 index 00000000..30fc9049 --- /dev/null +++ b/embassy/src/blocking_mutex/kind.rs @@ -0,0 +1,19 @@ +use super::{CriticalSectionMutex, Mutex, NoopMutex, ThreadModeMutex}; + +pub trait MutexKind { + type Mutex: Mutex; +} + +pub enum CriticalSection {} +impl MutexKind for CriticalSection { + type Mutex = CriticalSectionMutex; +} + +pub enum ThreadMode {} +impl MutexKind for ThreadMode { + type Mutex = ThreadModeMutex; +} +pub enum Noop {} +impl MutexKind for Noop { + type Mutex = NoopMutex; +} diff --git a/embassy/src/blocking_mutex/mod.rs b/embassy/src/blocking_mutex/mod.rs index d112d2ed..641a1ed9 100644 --- a/embassy/src/blocking_mutex/mod.rs +++ b/embassy/src/blocking_mutex/mod.rs @@ -1,5 +1,7 @@ //! Blocking mutex (not async) +pub mod kind; + use core::cell::UnsafeCell; use critical_section::CriticalSection; @@ -13,7 +15,7 @@ pub trait Mutex { fn new(data: Self::Data) -> Self; /// Creates a critical section and grants temporary access to the protected data. - fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R; + fn lock(&self, f: impl FnOnce(&Self::Data) -> R) -> R; } /// A "mutex" based on critical sections @@ -55,7 +57,7 @@ impl Mutex for CriticalSectionMutex { Self::new(data) } - fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + fn lock(&self, f: impl FnOnce(&Self::Data) -> R) -> R { critical_section::with(|cs| f(self.borrow(cs))) } } @@ -102,7 +104,7 @@ impl Mutex for ThreadModeMutex { Self::new(data) } - fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + fn lock(&self, f: impl FnOnce(&Self::Data) -> R) -> R { f(self.borrow()) } } @@ -155,7 +157,7 @@ impl Mutex for NoopMutex { Self::new(data) } - fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + fn lock(&self, f: impl FnOnce(&Self::Data) -> R) -> R { f(self.borrow()) } } diff --git a/embassy/src/channel/mpsc.rs b/embassy/src/channel/mpsc.rs index b20d48a9..9a57c0b1 100644 --- a/embassy/src/channel/mpsc.rs +++ b/embassy/src/channel/mpsc.rs @@ -37,19 +37,18 @@ //! //! This channel and its associated types were derived from https://docs.rs/tokio/0.1.22/tokio/sync/mpsc/fn.channel.html -use core::cell::UnsafeCell; +use core::cell::RefCell; use core::fmt; -use core::marker::PhantomData; -use core::mem::MaybeUninit; use core::pin::Pin; -use core::ptr; use core::task::Context; use core::task::Poll; use core::task::Waker; use futures::Future; +use heapless::Deque; -use crate::blocking_mutex::{CriticalSectionMutex, Mutex, NoopMutex, ThreadModeMutex}; +use crate::blocking_mutex::kind::MutexKind; +use crate::blocking_mutex::Mutex; use crate::waitqueue::WakerRegistration; /// Send values to the associated `Receiver`. @@ -57,36 +56,19 @@ use crate::waitqueue::WakerRegistration; /// Instances are created by the [`split`](split) function. pub struct Sender<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - channel_cell: &'ch UnsafeCell>, + channel: &'ch Channel, } -// Safe to pass the sender around -unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex + Sync -{} -unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex + Sync -{} - /// Receive values from the associated `Sender`. /// /// Instances are created by the [`split`](split) function. pub struct Receiver<'ch, M, T, const N: usize> where - M: Mutex, -{ - channel_cell: &'ch UnsafeCell>, - _receiver_consumed: &'ch mut PhantomData<()>, -} - -// Safe to pass the receiver around -unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where - M: Mutex + Sync -{ -} -unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where - M: Mutex + Sync + M: MutexKind, { + channel: &'ch Channel, } /// Splits a bounded mpsc channel into a `Sender` and `Receiver`. @@ -117,16 +99,11 @@ pub fn split( channel: &mut Channel, ) -> (Sender, Receiver) where - M: Mutex, + M: MutexKind, { - let sender = Sender { - channel_cell: &channel.channel_cell, - }; - let receiver = Receiver { - channel_cell: &channel.channel_cell, - _receiver_consumed: &mut channel.receiver_consumed, - }; - Channel::lock(&channel.channel_cell, |c| { + let sender = Sender { channel }; + let receiver = Receiver { channel }; + channel.lock(|c| { c.register_receiver(); c.register_sender(); }); @@ -135,7 +112,7 @@ where impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { /// Receives the next value for this receiver. /// @@ -155,7 +132,7 @@ where /// [`close`]: Self::close pub fn recv<'m>(&'m mut self) -> RecvFuture<'m, M, T, N> { RecvFuture { - channel_cell: self.channel_cell, + channel: self.channel, } } @@ -164,7 +141,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - Channel::lock(self.channel_cell, |c| c.try_recv()) + self.channel.lock(|c| c.try_recv()) } /// Closes the receiving half of a channel without dropping it. @@ -178,56 +155,45 @@ where /// until those are released. /// pub fn close(&mut self) { - Channel::lock(self.channel_cell, |c| c.close()) + self.channel.lock(|c| c.close()) } } impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { fn drop(&mut self) { - Channel::lock(self.channel_cell, |c| c.deregister_receiver()) + self.channel.lock(|c| c.deregister_receiver()) } } pub struct RecvFuture<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - channel_cell: &'ch UnsafeCell>, + channel: &'ch Channel, } impl<'ch, M, T, const N: usize> Future for RecvFuture<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Channel::lock(self.channel_cell, |c| { - match c.try_recv_with_context(Some(cx)) { + self.channel + .lock(|c| match c.try_recv_with_context(Some(cx)) { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => Poll::Pending, - } - }) + }) } } -// Safe to pass the receive future around since it locks channel whenever polled -unsafe impl<'ch, M, T, const N: usize> Send for RecvFuture<'ch, M, T, N> where - M: Mutex + Sync -{ -} -unsafe impl<'ch, M, T, const N: usize> Sync for RecvFuture<'ch, M, T, N> where - M: Mutex + Sync -{ -} - impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { /// Sends a value, waiting until there is capacity. /// @@ -249,7 +215,7 @@ where /// [`Receiver`]: Receiver pub fn send(&self, message: T) -> SendFuture<'ch, M, T, N> { SendFuture { - sender: self.clone(), + channel: self.channel, message: Some(message), } } @@ -275,7 +241,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - Channel::lock(self.channel_cell, |c| c.try_send(message)) + self.channel.lock(|c| c.try_send(message)) } /// Completes when the receiver has dropped. @@ -284,7 +250,7 @@ where /// values is canceled and immediately stop doing work. pub async fn closed(&self) { CloseFuture { - sender: self.clone(), + channel: self.channel, } .await } @@ -296,29 +262,27 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - Channel::lock(self.channel_cell, |c| c.is_closed()) + self.channel.lock(|c| c.is_closed()) } } pub struct SendFuture<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - sender: Sender<'ch, M, T, N>, + channel: &'ch Channel, message: Option, } impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { type Output = Result<(), SendError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.message.take() { - Some(m) => match Channel::lock(self.sender.channel_cell, |c| { - c.try_send_with_context(m, Some(cx)) - }) { + Some(m) => match self.channel.lock(|c| c.try_send_with_context(m, Some(cx))) { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(m)) => { @@ -331,25 +295,23 @@ where } } -impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: Mutex {} +impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: MutexKind {} struct CloseFuture<'ch, M, T, const N: usize> where - M: Mutex, + M: MutexKind, { - sender: Sender<'ch, M, T, N>, + channel: &'ch Channel, } impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if Channel::lock(self.sender.channel_cell, |c| { - c.is_closed_with_context(Some(cx)) - }) { + if self.channel.lock(|c| c.is_closed_with_context(Some(cx))) { Poll::Ready(()) } else { Poll::Pending @@ -359,22 +321,21 @@ where impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { fn drop(&mut self) { - Channel::lock(self.channel_cell, |c| c.deregister_sender()) + self.channel.lock(|c| c.deregister_sender()) } } impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> where - M: Mutex, + M: MutexKind, { - #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - Channel::lock(self.channel_cell, |c| c.register_sender()); + self.channel.lock(|c| c.register_sender()); Sender { - channel_cell: self.channel_cell.clone(), + channel: self.channel, } } } @@ -446,10 +407,7 @@ impl defmt::Format for TrySendError { } struct ChannelState { - buf: [MaybeUninit>; N], - read_pos: usize, - write_pos: usize, - full: bool, + queue: Deque, closed: bool, receiver_registered: bool, senders_registered: u32, @@ -458,14 +416,9 @@ struct ChannelState { } impl ChannelState { - const INIT: MaybeUninit> = MaybeUninit::uninit(); - const fn new() -> Self { ChannelState { - buf: [Self::INIT; N], - read_pos: 0, - write_pos: 0, - full: false, + queue: Deque::new(), closed: false, receiver_registered: false, senders_registered: 0, @@ -479,17 +432,16 @@ impl ChannelState { } fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { - if self.read_pos != self.write_pos || self.full { - if self.full { - self.full = false; - self.senders_waker.wake(); - } - let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() }; - self.read_pos = (self.read_pos + 1) % self.buf.len(); + if self.queue.is_full() { + self.senders_waker.wake(); + } + + if let Some(message) = self.queue.pop_front() { Ok(message) } else if !self.closed { - cx.into_iter() - .for_each(|cx| self.set_receiver_waker(&cx.waker())); + if let Some(cx) = cx { + self.set_receiver_waker(cx.waker()); + } Err(TryRecvError::Empty) } else { Err(TryRecvError::Closed) @@ -505,22 +457,21 @@ impl ChannelState { message: T, cx: Option<&mut Context<'_>>, ) -> Result<(), TrySendError> { - if !self.closed { - if !self.full { - self.buf[self.write_pos] = MaybeUninit::new(message.into()); - self.write_pos = (self.write_pos + 1) % self.buf.len(); - if self.write_pos == self.read_pos { - self.full = true; - } + if self.closed { + return Err(TrySendError::Closed(message)); + } + + match self.queue.push_back(message) { + Ok(()) => { self.receiver_waker.wake(); + Ok(()) - } else { + } + Err(message) => { cx.into_iter() .for_each(|cx| self.set_senders_waker(&cx.waker())); Err(TrySendError::Full(message)) } - } else { - Err(TrySendError::Closed(message)) } } @@ -585,16 +536,6 @@ impl ChannelState { } } -impl Drop for ChannelState { - fn drop(&mut self) { - while self.read_pos != self.write_pos || self.full { - self.full = false; - unsafe { ptr::drop_in_place(self.buf[self.read_pos].as_mut_ptr()) }; - self.read_pos = (self.read_pos + 1) % N; - } - } -} - /// A a bounded mpsc channel for communicating between asynchronous tasks /// with backpressure. /// @@ -605,61 +546,35 @@ impl Drop for ChannelState { /// All data sent will become available in the same order as it was sent. pub struct Channel where - M: Mutex, + M: MutexKind, { - channel_cell: UnsafeCell>, - receiver_consumed: PhantomData<()>, + inner: M::Mutex>>, } -struct ChannelCell -where - M: Mutex, -{ - mutex: M, - state: ChannelState, -} - -pub type WithCriticalSections = CriticalSectionMutex<()>; - -pub type WithThreadModeOnly = ThreadModeMutex<()>; - -pub type WithNoThreads = NoopMutex<()>; - impl Channel where - M: Mutex, + M: MutexKind, { /// Establish a new bounded channel. For example, to create one with a NoopMutex: /// /// ``` /// use embassy::channel::mpsc; - /// use embassy::channel::mpsc::{Channel, WithNoThreads}; + /// use embassy::blocking_mutex::kind::Noop; + /// use embassy::channel::mpsc::Channel; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = Channel::::new(); + /// let mut channel = Channel::::new(); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub fn new() -> Self { - let mutex = M::new(()); - let state = ChannelState::new(); - let channel_cell = ChannelCell { mutex, state }; - Channel { - channel_cell: UnsafeCell::new(channel_cell), - receiver_consumed: PhantomData, + Self { + inner: M::Mutex::new(RefCell::new(ChannelState::new())), } } - fn lock( - channel_cell: &UnsafeCell>, - f: impl FnOnce(&mut ChannelState) -> R, - ) -> R { - unsafe { - let channel_cell = &mut *(channel_cell.get()); - let mutex = &mut channel_cell.mutex; - let mut state = &mut channel_cell.state; - mutex.lock(|_| f(&mut state)) - } + fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { + self.inner.lock(|rc| f(&mut *rc.borrow_mut())) } } @@ -671,20 +586,13 @@ mod tests { use futures_executor::ThreadPool; use futures_timer::Delay; + use crate::blocking_mutex::kind::{CriticalSection, Noop}; use crate::util::Forever; use super::*; fn capacity(c: &ChannelState) -> usize { - if !c.full { - if c.write_pos > c.read_pos { - (c.buf.len() - c.write_pos) + c.read_pos - } else { - (c.buf.len() - c.read_pos) + c.write_pos - } - } else { - 0 - } + c.queue.capacity() - c.queue.len() } #[test] @@ -747,7 +655,7 @@ mod tests { #[test] fn simple_send_and_receive() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); @@ -755,7 +663,7 @@ mod tests { #[test] fn should_close_without_sender() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); drop(s); match r.try_recv() { @@ -766,7 +674,7 @@ mod tests { #[test] fn should_close_once_drained() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); assert!(s.try_send(1).is_ok()); drop(s); @@ -779,7 +687,7 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); drop(r); match s.try_send(1) { @@ -790,7 +698,7 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, mut r) = split(&mut c); assert!(s.try_send(1).is_ok()); r.close(); @@ -806,7 +714,7 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, mut r) = split(c); assert!(executor @@ -821,7 +729,7 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, mut r) = split(c); assert!(executor @@ -834,7 +742,7 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - let mut c = Channel::::new(); + let mut c = Channel::::new(); let (s, mut r) = split(&mut c); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); @@ -842,7 +750,7 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_closed() { - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, r) = split(c); drop(r); @@ -856,7 +764,7 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s0, mut r) = split(c); assert!(s0.try_send(1).is_ok()); @@ -876,7 +784,7 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, mut r) = split(c); r.close(); @@ -885,7 +793,7 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closed() { - static CHANNEL: Forever> = Forever::new(); + static CHANNEL: Forever> = Forever::new(); let c = CHANNEL.put(Channel::new()); let (s, r) = split(c); drop(r); diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index 79fa3dfb..c85b7c28 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -6,7 +6,8 @@ mod example_common; use defmt::unwrap; -use embassy::channel::mpsc::{self, Channel, Sender, TryRecvError, WithNoThreads}; +use embassy::blocking_mutex::kind::Noop; +use embassy::channel::mpsc::{self, Channel, Sender, TryRecvError}; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; use embassy::util::Forever; @@ -19,10 +20,10 @@ enum LedState { Off, } -static CHANNEL: Forever> = Forever::new(); +static CHANNEL: Forever> = Forever::new(); #[embassy::task(pool_size = 1)] -async fn my_task(sender: Sender<'static, WithNoThreads, LedState, 1>) { +async fn my_task(sender: Sender<'static, Noop, LedState, 1>) { loop { let _ = sender.send(LedState::On).await; Timer::after(Duration::from_secs(1)).await;