embassy/channel: switch to use MutexKind

This commit is contained in:
Dario Nieuwenhuis 2021-09-12 23:36:52 +02:00
parent 5be5bdfd20
commit 70e5877d68
2 changed files with 69 additions and 126 deletions

View File

@ -37,7 +37,7 @@
//! //!
//! This channel and its associated types were derived from https://docs.rs/tokio/0.1.22/tokio/sync/mpsc/fn.channel.html //! This channel and its associated types were derived from https://docs.rs/tokio/0.1.22/tokio/sync/mpsc/fn.channel.html
use core::cell::UnsafeCell; use core::cell::RefCell;
use core::fmt; use core::fmt;
use core::pin::Pin; use core::pin::Pin;
use core::task::Context; use core::task::Context;
@ -47,7 +47,8 @@ use core::task::Waker;
use futures::Future; use futures::Future;
use heapless::Deque; use heapless::Deque;
use crate::blocking_mutex::{CriticalSectionMutex, Mutex, NoopMutex, ThreadModeMutex}; use crate::blocking_mutex::kind::MutexKind;
use crate::blocking_mutex::Mutex;
use crate::waitqueue::WakerRegistration; use crate::waitqueue::WakerRegistration;
/// Send values to the associated `Receiver`. /// Send values to the associated `Receiver`.
@ -55,35 +56,19 @@ use crate::waitqueue::WakerRegistration;
/// Instances are created by the [`split`](split) function. /// Instances are created by the [`split`](split) function.
pub struct Sender<'ch, M, T, const N: usize> pub struct Sender<'ch, M, T, const N: usize>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
channel_cell: &'ch UnsafeCell<ChannelCell<M, T, N>>, channel: &'ch Channel<M, T, N>,
} }
// Safe to pass the sender around
unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex<Data = ()> + Sync
{}
unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex<Data = ()> + Sync
{}
/// Receive values from the associated `Sender`. /// Receive values from the associated `Sender`.
/// ///
/// Instances are created by the [`split`](split) function. /// Instances are created by the [`split`](split) function.
pub struct Receiver<'ch, M, T, const N: usize> pub struct Receiver<'ch, M, T, const N: usize>
where where
M: Mutex<Data = ()>, M: MutexKind,
{
channel_cell: &'ch UnsafeCell<ChannelCell<M, T, N>>,
}
// Safe to pass the receiver around
unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where
M: Mutex<Data = ()> + Sync
{
}
unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where
M: Mutex<Data = ()> + Sync
{ {
channel: &'ch Channel<M, T, N>,
} }
/// Splits a bounded mpsc channel into a `Sender` and `Receiver`. /// Splits a bounded mpsc channel into a `Sender` and `Receiver`.
@ -114,15 +99,11 @@ pub fn split<M, T, const N: usize>(
channel: &mut Channel<M, T, N>, channel: &mut Channel<M, T, N>,
) -> (Sender<M, T, N>, Receiver<M, T, N>) ) -> (Sender<M, T, N>, Receiver<M, T, N>)
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
let sender = Sender { let sender = Sender { channel };
channel_cell: &channel.channel_cell, let receiver = Receiver { channel };
}; channel.lock(|c| {
let receiver = Receiver {
channel_cell: &channel.channel_cell,
};
Channel::lock(&channel.channel_cell, |c| {
c.register_receiver(); c.register_receiver();
c.register_sender(); c.register_sender();
}); });
@ -131,7 +112,7 @@ where
impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
/// Receives the next value for this receiver. /// Receives the next value for this receiver.
/// ///
@ -151,7 +132,7 @@ where
/// [`close`]: Self::close /// [`close`]: Self::close
pub fn recv<'m>(&'m mut self) -> RecvFuture<'m, M, T, N> { pub fn recv<'m>(&'m mut self) -> RecvFuture<'m, M, T, N> {
RecvFuture { RecvFuture {
channel_cell: self.channel_cell, channel: self.channel,
} }
} }
@ -160,7 +141,7 @@ where
/// This method will either receive a message from the channel immediately or return an error /// This method will either receive a message from the channel immediately or return an error
/// if the channel is empty. /// if the channel is empty.
pub fn try_recv(&self) -> Result<T, TryRecvError> { pub fn try_recv(&self) -> Result<T, TryRecvError> {
Channel::lock(self.channel_cell, |c| c.try_recv()) self.channel.lock(|c| c.try_recv())
} }
/// Closes the receiving half of a channel without dropping it. /// Closes the receiving half of a channel without dropping it.
@ -174,56 +155,45 @@ where
/// until those are released. /// until those are released.
/// ///
pub fn close(&mut self) { pub fn close(&mut self) {
Channel::lock(self.channel_cell, |c| c.close()) self.channel.lock(|c| c.close())
} }
} }
impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
fn drop(&mut self) { fn drop(&mut self) {
Channel::lock(self.channel_cell, |c| c.deregister_receiver()) self.channel.lock(|c| c.deregister_receiver())
} }
} }
pub struct RecvFuture<'ch, M, T, const N: usize> pub struct RecvFuture<'ch, M, T, const N: usize>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
channel_cell: &'ch UnsafeCell<ChannelCell<M, T, N>>, channel: &'ch Channel<M, T, N>,
} }
impl<'ch, M, T, const N: usize> Future for RecvFuture<'ch, M, T, N> impl<'ch, M, T, const N: usize> Future for RecvFuture<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
type Output = Option<T>; type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
Channel::lock(self.channel_cell, |c| { self.channel
match c.try_recv_with_context(Some(cx)) { .lock(|c| match c.try_recv_with_context(Some(cx)) {
Ok(v) => Poll::Ready(Some(v)), Ok(v) => Poll::Ready(Some(v)),
Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Closed) => Poll::Ready(None),
Err(TryRecvError::Empty) => Poll::Pending, Err(TryRecvError::Empty) => Poll::Pending,
} })
})
} }
} }
// Safe to pass the receive future around since it locks channel whenever polled
unsafe impl<'ch, M, T, const N: usize> Send for RecvFuture<'ch, M, T, N> where
M: Mutex<Data = ()> + Sync
{
}
unsafe impl<'ch, M, T, const N: usize> Sync for RecvFuture<'ch, M, T, N> where
M: Mutex<Data = ()> + Sync
{
}
impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
/// Sends a value, waiting until there is capacity. /// Sends a value, waiting until there is capacity.
/// ///
@ -245,7 +215,7 @@ where
/// [`Receiver`]: Receiver /// [`Receiver`]: Receiver
pub fn send(&self, message: T) -> SendFuture<'ch, M, T, N> { pub fn send(&self, message: T) -> SendFuture<'ch, M, T, N> {
SendFuture { SendFuture {
sender: self.clone(), channel: self.channel,
message: Some(message), message: Some(message),
} }
} }
@ -271,7 +241,7 @@ where
/// [`channel`]: channel /// [`channel`]: channel
/// [`close`]: Receiver::close /// [`close`]: Receiver::close
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
Channel::lock(self.channel_cell, |c| c.try_send(message)) self.channel.lock(|c| c.try_send(message))
} }
/// Completes when the receiver has dropped. /// Completes when the receiver has dropped.
@ -280,7 +250,7 @@ where
/// values is canceled and immediately stop doing work. /// values is canceled and immediately stop doing work.
pub async fn closed(&self) { pub async fn closed(&self) {
CloseFuture { CloseFuture {
sender: self.clone(), channel: self.channel,
} }
.await .await
} }
@ -292,29 +262,27 @@ where
/// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver`]: crate::sync::mpsc::Receiver
/// [`Receiver::close`]: crate::sync::mpsc::Receiver::close /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
pub fn is_closed(&self) -> bool { pub fn is_closed(&self) -> bool {
Channel::lock(self.channel_cell, |c| c.is_closed()) self.channel.lock(|c| c.is_closed())
} }
} }
pub struct SendFuture<'ch, M, T, const N: usize> pub struct SendFuture<'ch, M, T, const N: usize>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
sender: Sender<'ch, M, T, N>, channel: &'ch Channel<M, T, N>,
message: Option<T>, message: Option<T>,
} }
impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
type Output = Result<(), SendError<T>>; type Output = Result<(), SendError<T>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.message.take() { match self.message.take() {
Some(m) => match Channel::lock(self.sender.channel_cell, |c| { Some(m) => match self.channel.lock(|c| c.try_send_with_context(m, Some(cx))) {
c.try_send_with_context(m, Some(cx))
}) {
Ok(..) => Poll::Ready(Ok(())), Ok(..) => Poll::Ready(Ok(())),
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
Err(TrySendError::Full(m)) => { Err(TrySendError::Full(m)) => {
@ -327,25 +295,23 @@ where
} }
} }
impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: Mutex<Data = ()> {} impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: MutexKind {}
struct CloseFuture<'ch, M, T, const N: usize> struct CloseFuture<'ch, M, T, const N: usize>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
sender: Sender<'ch, M, T, N>, channel: &'ch Channel<M, T, N>,
} }
impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
type Output = (); type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if Channel::lock(self.sender.channel_cell, |c| { if self.channel.lock(|c| c.is_closed_with_context(Some(cx))) {
c.is_closed_with_context(Some(cx))
}) {
Poll::Ready(()) Poll::Ready(())
} else { } else {
Poll::Pending Poll::Pending
@ -355,22 +321,21 @@ where
impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
fn drop(&mut self) { fn drop(&mut self) {
Channel::lock(self.channel_cell, |c| c.deregister_sender()) self.channel.lock(|c| c.deregister_sender())
} }
} }
impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
#[allow(clippy::clone_double_ref)]
fn clone(&self) -> Self { fn clone(&self) -> Self {
Channel::lock(self.channel_cell, |c| c.register_sender()); self.channel.lock(|c| c.register_sender());
Sender { Sender {
channel_cell: self.channel_cell.clone(), channel: self.channel,
} }
} }
} }
@ -581,59 +546,35 @@ impl<T, const N: usize> ChannelState<T, N> {
/// All data sent will become available in the same order as it was sent. /// All data sent will become available in the same order as it was sent.
pub struct Channel<M, T, const N: usize> pub struct Channel<M, T, const N: usize>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
channel_cell: UnsafeCell<ChannelCell<M, T, N>>, inner: M::Mutex<RefCell<ChannelState<T, N>>>,
} }
struct ChannelCell<M, T, const N: usize>
where
M: Mutex<Data = ()>,
{
mutex: M,
state: ChannelState<T, N>,
}
pub type WithCriticalSections = CriticalSectionMutex<()>;
pub type WithThreadModeOnly = ThreadModeMutex<()>;
pub type WithNoThreads = NoopMutex<()>;
impl<M, T, const N: usize> Channel<M, T, N> impl<M, T, const N: usize> Channel<M, T, N>
where where
M: Mutex<Data = ()>, M: MutexKind,
{ {
/// Establish a new bounded channel. For example, to create one with a NoopMutex: /// Establish a new bounded channel. For example, to create one with a NoopMutex:
/// ///
/// ``` /// ```
/// use embassy::channel::mpsc; /// use embassy::channel::mpsc;
/// use embassy::channel::mpsc::{Channel, WithNoThreads}; /// use embassy::blocking_mutex::kind::Noop;
/// use embassy::channel::mpsc::Channel;
/// ///
/// // Declare a bounded channel of 3 u32s. /// // Declare a bounded channel of 3 u32s.
/// let mut channel = Channel::<WithNoThreads, u32, 3>::new(); /// let mut channel = Channel::<Noop, u32, 3>::new();
/// // once we have a channel, obtain its sender and receiver /// // once we have a channel, obtain its sender and receiver
/// let (sender, receiver) = mpsc::split(&mut channel); /// let (sender, receiver) = mpsc::split(&mut channel);
/// ``` /// ```
pub fn new() -> Self { pub fn new() -> Self {
let mutex = M::new(()); Self {
let state = ChannelState::new(); inner: M::Mutex::new(RefCell::new(ChannelState::new())),
let channel_cell = ChannelCell { mutex, state };
Channel {
channel_cell: UnsafeCell::new(channel_cell),
} }
} }
fn lock<R>( fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, N>) -> R) -> R {
channel_cell: &UnsafeCell<ChannelCell<M, T, N>>, self.inner.lock(|rc| f(&mut *rc.borrow_mut()))
f: impl FnOnce(&mut ChannelState<T, N>) -> R,
) -> R {
unsafe {
let channel_cell = &mut *(channel_cell.get());
let mutex = &mut channel_cell.mutex;
let mut state = &mut channel_cell.state;
mutex.lock(|_| f(&mut state))
}
} }
} }
@ -645,6 +586,7 @@ mod tests {
use futures_executor::ThreadPool; use futures_executor::ThreadPool;
use futures_timer::Delay; use futures_timer::Delay;
use crate::blocking_mutex::kind::{CriticalSection, Noop};
use crate::util::Forever; use crate::util::Forever;
use super::*; use super::*;
@ -713,7 +655,7 @@ mod tests {
#[test] #[test]
fn simple_send_and_receive() { fn simple_send_and_receive() {
let mut c = Channel::<WithNoThreads, u32, 3>::new(); let mut c = Channel::<Noop, u32, 3>::new();
let (s, r) = split(&mut c); let (s, r) = split(&mut c);
assert!(s.clone().try_send(1).is_ok()); assert!(s.clone().try_send(1).is_ok());
assert_eq!(r.try_recv().unwrap(), 1); assert_eq!(r.try_recv().unwrap(), 1);
@ -721,7 +663,7 @@ mod tests {
#[test] #[test]
fn should_close_without_sender() { fn should_close_without_sender() {
let mut c = Channel::<WithNoThreads, u32, 3>::new(); let mut c = Channel::<Noop, u32, 3>::new();
let (s, r) = split(&mut c); let (s, r) = split(&mut c);
drop(s); drop(s);
match r.try_recv() { match r.try_recv() {
@ -732,7 +674,7 @@ mod tests {
#[test] #[test]
fn should_close_once_drained() { fn should_close_once_drained() {
let mut c = Channel::<WithNoThreads, u32, 3>::new(); let mut c = Channel::<Noop, u32, 3>::new();
let (s, r) = split(&mut c); let (s, r) = split(&mut c);
assert!(s.try_send(1).is_ok()); assert!(s.try_send(1).is_ok());
drop(s); drop(s);
@ -745,7 +687,7 @@ mod tests {
#[test] #[test]
fn should_reject_send_when_receiver_dropped() { fn should_reject_send_when_receiver_dropped() {
let mut c = Channel::<WithNoThreads, u32, 3>::new(); let mut c = Channel::<Noop, u32, 3>::new();
let (s, r) = split(&mut c); let (s, r) = split(&mut c);
drop(r); drop(r);
match s.try_send(1) { match s.try_send(1) {
@ -756,7 +698,7 @@ mod tests {
#[test] #[test]
fn should_reject_send_when_channel_closed() { fn should_reject_send_when_channel_closed() {
let mut c = Channel::<WithNoThreads, u32, 3>::new(); let mut c = Channel::<Noop, u32, 3>::new();
let (s, mut r) = split(&mut c); let (s, mut r) = split(&mut c);
assert!(s.try_send(1).is_ok()); assert!(s.try_send(1).is_ok());
r.close(); r.close();
@ -772,7 +714,7 @@ mod tests {
async fn receiver_closes_when_sender_dropped_async() { async fn receiver_closes_when_sender_dropped_async() {
let executor = ThreadPool::new().unwrap(); let executor = ThreadPool::new().unwrap();
static CHANNEL: Forever<Channel<WithCriticalSections, u32, 3>> = Forever::new(); static CHANNEL: Forever<Channel<CriticalSection, u32, 3>> = Forever::new();
let c = CHANNEL.put(Channel::new()); let c = CHANNEL.put(Channel::new());
let (s, mut r) = split(c); let (s, mut r) = split(c);
assert!(executor assert!(executor
@ -787,7 +729,7 @@ mod tests {
async fn receiver_receives_given_try_send_async() { async fn receiver_receives_given_try_send_async() {
let executor = ThreadPool::new().unwrap(); let executor = ThreadPool::new().unwrap();
static CHANNEL: Forever<Channel<WithCriticalSections, u32, 3>> = Forever::new(); static CHANNEL: Forever<Channel<CriticalSection, u32, 3>> = Forever::new();
let c = CHANNEL.put(Channel::new()); let c = CHANNEL.put(Channel::new());
let (s, mut r) = split(c); let (s, mut r) = split(c);
assert!(executor assert!(executor
@ -800,7 +742,7 @@ mod tests {
#[futures_test::test] #[futures_test::test]
async fn sender_send_completes_if_capacity() { async fn sender_send_completes_if_capacity() {
let mut c = Channel::<WithCriticalSections, u32, 1>::new(); let mut c = Channel::<CriticalSection, u32, 1>::new();
let (s, mut r) = split(&mut c); let (s, mut r) = split(&mut c);
assert!(s.send(1).await.is_ok()); assert!(s.send(1).await.is_ok());
assert_eq!(r.recv().await, Some(1)); assert_eq!(r.recv().await, Some(1));
@ -808,7 +750,7 @@ mod tests {
#[futures_test::test] #[futures_test::test]
async fn sender_send_completes_if_closed() { async fn sender_send_completes_if_closed() {
static CHANNEL: Forever<Channel<WithCriticalSections, u32, 1>> = Forever::new(); static CHANNEL: Forever<Channel<CriticalSection, u32, 1>> = Forever::new();
let c = CHANNEL.put(Channel::new()); let c = CHANNEL.put(Channel::new());
let (s, r) = split(c); let (s, r) = split(c);
drop(r); drop(r);
@ -822,7 +764,7 @@ mod tests {
async fn senders_sends_wait_until_capacity() { async fn senders_sends_wait_until_capacity() {
let executor = ThreadPool::new().unwrap(); let executor = ThreadPool::new().unwrap();
static CHANNEL: Forever<Channel<WithCriticalSections, u32, 1>> = Forever::new(); static CHANNEL: Forever<Channel<CriticalSection, u32, 1>> = Forever::new();
let c = CHANNEL.put(Channel::new()); let c = CHANNEL.put(Channel::new());
let (s0, mut r) = split(c); let (s0, mut r) = split(c);
assert!(s0.try_send(1).is_ok()); assert!(s0.try_send(1).is_ok());
@ -842,7 +784,7 @@ mod tests {
#[futures_test::test] #[futures_test::test]
async fn sender_close_completes_if_closing() { async fn sender_close_completes_if_closing() {
static CHANNEL: Forever<Channel<WithCriticalSections, u32, 1>> = Forever::new(); static CHANNEL: Forever<Channel<CriticalSection, u32, 1>> = Forever::new();
let c = CHANNEL.put(Channel::new()); let c = CHANNEL.put(Channel::new());
let (s, mut r) = split(c); let (s, mut r) = split(c);
r.close(); r.close();
@ -851,7 +793,7 @@ mod tests {
#[futures_test::test] #[futures_test::test]
async fn sender_close_completes_if_closed() { async fn sender_close_completes_if_closed() {
static CHANNEL: Forever<Channel<WithCriticalSections, u32, 1>> = Forever::new(); static CHANNEL: Forever<Channel<CriticalSection, u32, 1>> = Forever::new();
let c = CHANNEL.put(Channel::new()); let c = CHANNEL.put(Channel::new());
let (s, r) = split(c); let (s, r) = split(c);
drop(r); drop(r);

View File

@ -6,7 +6,8 @@
mod example_common; mod example_common;
use defmt::unwrap; use defmt::unwrap;
use embassy::channel::mpsc::{self, Channel, Sender, TryRecvError, WithNoThreads}; use embassy::blocking_mutex::kind::Noop;
use embassy::channel::mpsc::{self, Channel, Sender, TryRecvError};
use embassy::executor::Spawner; use embassy::executor::Spawner;
use embassy::time::{Duration, Timer}; use embassy::time::{Duration, Timer};
use embassy::util::Forever; use embassy::util::Forever;
@ -19,10 +20,10 @@ enum LedState {
Off, Off,
} }
static CHANNEL: Forever<Channel<WithNoThreads, LedState, 1>> = Forever::new(); static CHANNEL: Forever<Channel<Noop, LedState, 1>> = Forever::new();
#[embassy::task(pool_size = 1)] #[embassy::task(pool_size = 1)]
async fn my_task(sender: Sender<'static, WithNoThreads, LedState, 1>) { async fn my_task(sender: Sender<'static, Noop, LedState, 1>) {
loop { loop {
let _ = sender.send(LedState::On).await; let _ = sender.send(LedState::On).await;
Timer::after(Duration::from_secs(1)).await; Timer::after(Duration::from_secs(1)).await;