Reduce duplication, fix tests

This commit is contained in:
Scott Mabin 2023-11-18 14:31:09 +00:00
parent ca0d02933b
commit 270ec324b0
2 changed files with 41 additions and 61 deletions

View File

@ -321,7 +321,7 @@ impl<'ch, T> Future for DynamicSendFuture<'ch, T> {
impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {} impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {}
trait DynamicChannel<T> { pub(crate) trait DynamicChannel<T> {
fn try_send_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>>; fn try_send_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>>;
fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryReceiveError>; fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryReceiveError>;

View File

@ -7,17 +7,19 @@ use core::future::Future;
use core::pin::Pin; use core::pin::Pin;
use core::task::{Context, Poll}; use core::task::{Context, Poll};
use heapless::binary_heap::Kind;
use heapless::BinaryHeap; use heapless::BinaryHeap;
use crate::blocking_mutex::raw::RawMutex; use crate::blocking_mutex::raw::RawMutex;
use crate::blocking_mutex::Mutex; use crate::blocking_mutex::Mutex;
use crate::channel::{DynamicChannel, TryReceiveError, TrySendError};
use crate::waitqueue::WakerRegistration; use crate::waitqueue::WakerRegistration;
/// Send-only access to a [`PriorityChannel`]. /// Send-only access to a [`PriorityChannel`].
pub struct Sender<'ch, M, T, K, const N: usize> pub struct Sender<'ch, M, T, K, const N: usize>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
channel: &'ch PriorityChannel<M, T, K, N>, channel: &'ch PriorityChannel<M, T, K, N>,
@ -26,7 +28,7 @@ where
impl<'ch, M, T, K, const N: usize> Clone for Sender<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Clone for Sender<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@ -37,7 +39,7 @@ where
impl<'ch, M, T, K, const N: usize> Copy for Sender<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Copy for Sender<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
} }
@ -45,7 +47,7 @@ where
impl<'ch, M, T, K, const N: usize> Sender<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Sender<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
/// Sends a value. /// Sends a value.
@ -86,7 +88,7 @@ impl<'ch, T> Copy for DynamicSender<'ch, T> {}
impl<'ch, M, T, K, const N: usize> From<Sender<'ch, M, T, K, N>> for DynamicSender<'ch, T> impl<'ch, M, T, K, const N: usize> From<Sender<'ch, M, T, K, N>> for DynamicSender<'ch, T>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
fn from(s: Sender<'ch, M, T, K, N>) -> Self { fn from(s: Sender<'ch, M, T, K, N>) -> Self {
@ -124,7 +126,7 @@ impl<'ch, T> DynamicSender<'ch, T> {
pub struct Receiver<'ch, M, T, K, const N: usize> pub struct Receiver<'ch, M, T, K, const N: usize>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
channel: &'ch PriorityChannel<M, T, K, N>, channel: &'ch PriorityChannel<M, T, K, N>,
@ -133,7 +135,7 @@ where
impl<'ch, M, T, K, const N: usize> Clone for Receiver<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Clone for Receiver<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
@ -144,7 +146,7 @@ where
impl<'ch, M, T, K, const N: usize> Copy for Receiver<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Copy for Receiver<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
} }
@ -152,7 +154,7 @@ where
impl<'ch, M, T, K, const N: usize> Receiver<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Receiver<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
/// Receive the next value. /// Receive the next value.
@ -230,7 +232,7 @@ impl<'ch, T> DynamicReceiver<'ch, T> {
impl<'ch, M, T, K, const N: usize> From<Receiver<'ch, M, T, K, N>> for DynamicReceiver<'ch, T> impl<'ch, M, T, K, const N: usize> From<Receiver<'ch, M, T, K, N>> for DynamicReceiver<'ch, T>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
fn from(s: Receiver<'ch, M, T, K, N>) -> Self { fn from(s: Receiver<'ch, M, T, K, N>) -> Self {
@ -243,7 +245,7 @@ where
pub struct ReceiveFuture<'ch, M, T, K, const N: usize> pub struct ReceiveFuture<'ch, M, T, K, const N: usize>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
channel: &'ch PriorityChannel<M, T, K, N>, channel: &'ch PriorityChannel<M, T, K, N>,
@ -252,7 +254,7 @@ where
impl<'ch, M, T, K, const N: usize> Future for ReceiveFuture<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Future for ReceiveFuture<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
type Output = T; type Output = T;
@ -284,7 +286,7 @@ impl<'ch, T> Future for DynamicReceiveFuture<'ch, T> {
pub struct SendFuture<'ch, M, T, K, const N: usize> pub struct SendFuture<'ch, M, T, K, const N: usize>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
channel: &'ch PriorityChannel<M, T, K, N>, channel: &'ch PriorityChannel<M, T, K, N>,
@ -294,7 +296,7 @@ where
impl<'ch, M, T, K, const N: usize> Future for SendFuture<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Future for SendFuture<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
type Output = (); type Output = ();
@ -316,7 +318,7 @@ where
impl<'ch, M, T, K, const N: usize> Unpin for SendFuture<'ch, M, T, K, N> impl<'ch, M, T, K, const N: usize> Unpin for SendFuture<'ch, M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
} }
@ -347,34 +349,6 @@ impl<'ch, T> Future for DynamicSendFuture<'ch, T> {
impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {} impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {}
trait DynamicChannel<T> {
fn try_send_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>>;
fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryReceiveError>;
fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()>;
fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()>;
fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<T>;
}
/// Error returned by [`try_receive`](PriorityChannel::try_receive).
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TryReceiveError {
/// A message could not be received because the channel is empty.
Empty,
}
/// Error returned by [`try_send`](PriorityChannel::try_send).
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TrySendError<T> {
/// The data could not be sent on the channel because the channel is
/// currently full and sending would require blocking.
Full(T),
}
struct ChannelState<T, K, const N: usize> { struct ChannelState<T, K, const N: usize> {
queue: BinaryHeap<T, K, N>, queue: BinaryHeap<T, K, N>,
receiver_waker: WakerRegistration, receiver_waker: WakerRegistration,
@ -384,7 +358,7 @@ struct ChannelState<T, K, const N: usize> {
impl<T, K, const N: usize> ChannelState<T, K, N> impl<T, K, const N: usize> ChannelState<T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
{ {
const fn new() -> Self { const fn new() -> Self {
ChannelState { ChannelState {
@ -477,7 +451,7 @@ where
pub struct PriorityChannel<M, T, K, const N: usize> pub struct PriorityChannel<M, T, K, const N: usize>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
inner: Mutex<M, RefCell<ChannelState<T, K, N>>>, inner: Mutex<M, RefCell<ChannelState<T, K, N>>>,
@ -486,17 +460,18 @@ where
impl<M, T, K, const N: usize> PriorityChannel<M, T, K, N> impl<M, T, K, const N: usize> PriorityChannel<M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
/// Establish a new bounded channel. For example, to create one with a NoopMutex: /// Establish a new bounded channel. For example, to create one with a NoopMutex:
/// ///
/// ``` /// ```
/// use embassy_sync::channel::PriorityChannel; /// # use heapless::binary_heap::Max;
/// use embassy_sync::priority_channel::PriorityChannel;
/// use embassy_sync::blocking_mutex::raw::NoopRawMutex; /// use embassy_sync::blocking_mutex::raw::NoopRawMutex;
/// ///
/// // Declare a bounded channel of 3 u32s. /// // Declare a bounded channel of 3 u32s.
/// let mut channel = PriorityChannel::<NoopRawMutex, u32, 3>::new(); /// let mut channel = PriorityChannel::<NoopRawMutex, u32, Max, 3>::new();
/// ``` /// ```
pub const fn new() -> Self { pub const fn new() -> Self {
Self { Self {
@ -588,7 +563,7 @@ where
impl<M, T, K, const N: usize> DynamicChannel<T> for PriorityChannel<M, T, K, N> impl<M, T, K, const N: usize> DynamicChannel<T> for PriorityChannel<M, T, K, N>
where where
T: Ord, T: Ord,
K: heapless::binary_heap::Kind, K: Kind,
M: RawMutex, M: RawMutex,
{ {
fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> { fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> {
@ -619,25 +594,30 @@ mod tests {
use futures_executor::ThreadPool; use futures_executor::ThreadPool;
use futures_timer::Delay; use futures_timer::Delay;
use futures_util::task::SpawnExt; use futures_util::task::SpawnExt;
use heapless::binary_heap::{Kind, Max};
use static_cell::StaticCell; use static_cell::StaticCell;
use super::*; use super::*;
use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex}; use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize { fn capacity<T, K, const N: usize>(c: &ChannelState<T, K, N>) -> usize
where
T: Ord,
K: Kind,
{
c.queue.capacity() - c.queue.len() c.queue.capacity() - c.queue.len()
} }
#[test] #[test]
fn sending_once() { fn sending_once() {
let mut c = ChannelState::<u32, 3>::new(); let mut c = ChannelState::<u32, Max, 3>::new();
assert!(c.try_send(1).is_ok()); assert!(c.try_send(1).is_ok());
assert_eq!(capacity(&c), 2); assert_eq!(capacity(&c), 2);
} }
#[test] #[test]
fn sending_when_full() { fn sending_when_full() {
let mut c = ChannelState::<u32, 3>::new(); let mut c = ChannelState::<u32, Max, 3>::new();
let _ = c.try_send(1); let _ = c.try_send(1);
let _ = c.try_send(1); let _ = c.try_send(1);
let _ = c.try_send(1); let _ = c.try_send(1);
@ -650,7 +630,7 @@ mod tests {
#[test] #[test]
fn receiving_once_with_one_send() { fn receiving_once_with_one_send() {
let mut c = ChannelState::<u32, 3>::new(); let mut c = ChannelState::<u32, Max, 3>::new();
assert!(c.try_send(1).is_ok()); assert!(c.try_send(1).is_ok());
assert_eq!(c.try_receive().unwrap(), 1); assert_eq!(c.try_receive().unwrap(), 1);
assert_eq!(capacity(&c), 3); assert_eq!(capacity(&c), 3);
@ -658,7 +638,7 @@ mod tests {
#[test] #[test]
fn receiving_when_empty() { fn receiving_when_empty() {
let mut c = ChannelState::<u32, 3>::new(); let mut c = ChannelState::<u32, Max, 3>::new();
match c.try_receive() { match c.try_receive() {
Err(TryReceiveError::Empty) => assert!(true), Err(TryReceiveError::Empty) => assert!(true),
_ => assert!(false), _ => assert!(false),
@ -668,14 +648,14 @@ mod tests {
#[test] #[test]
fn simple_send_and_receive() { fn simple_send_and_receive() {
let c = PriorityChannel::<NoopRawMutex, u32, 3>::new(); let c = PriorityChannel::<NoopRawMutex, u32, Max, 3>::new();
assert!(c.try_send(1).is_ok()); assert!(c.try_send(1).is_ok());
assert_eq!(c.try_receive().unwrap(), 1); assert_eq!(c.try_receive().unwrap(), 1);
} }
#[test] #[test]
fn cloning() { fn cloning() {
let c = PriorityChannel::<NoopRawMutex, u32, 3>::new(); let c = PriorityChannel::<NoopRawMutex, u32, Max, 3>::new();
let r1 = c.receiver(); let r1 = c.receiver();
let s1 = c.sender(); let s1 = c.sender();
@ -685,7 +665,7 @@ mod tests {
#[test] #[test]
fn dynamic_dispatch() { fn dynamic_dispatch() {
let c = PriorityChannel::<NoopRawMutex, u32, 3>::new(); let c = PriorityChannel::<NoopRawMutex, u32, Max, 3>::new();
let s: DynamicSender<'_, u32> = c.sender().into(); let s: DynamicSender<'_, u32> = c.sender().into();
let r: DynamicReceiver<'_, u32> = c.receiver().into(); let r: DynamicReceiver<'_, u32> = c.receiver().into();
@ -697,7 +677,7 @@ mod tests {
async fn receiver_receives_given_try_send_async() { async fn receiver_receives_given_try_send_async() {
let executor = ThreadPool::new().unwrap(); let executor = ThreadPool::new().unwrap();
static CHANNEL: StaticCell<PriorityChannel<CriticalSectionRawMutex, u32, 3>> = StaticCell::new(); static CHANNEL: StaticCell<PriorityChannel<CriticalSectionRawMutex, u32, Max, 3>> = StaticCell::new();
let c = &*CHANNEL.init(PriorityChannel::new()); let c = &*CHANNEL.init(PriorityChannel::new());
let c2 = c; let c2 = c;
assert!(executor assert!(executor
@ -710,7 +690,7 @@ mod tests {
#[futures_test::test] #[futures_test::test]
async fn sender_send_completes_if_capacity() { async fn sender_send_completes_if_capacity() {
let c = PriorityChannel::<CriticalSectionRawMutex, u32, 1>::new(); let c = PriorityChannel::<CriticalSectionRawMutex, u32, Max, 1>::new();
c.send(1).await; c.send(1).await;
assert_eq!(c.receive().await, 1); assert_eq!(c.receive().await, 1);
} }
@ -719,7 +699,7 @@ mod tests {
async fn senders_sends_wait_until_capacity() { async fn senders_sends_wait_until_capacity() {
let executor = ThreadPool::new().unwrap(); let executor = ThreadPool::new().unwrap();
static CHANNEL: StaticCell<PriorityChannel<CriticalSectionRawMutex, u32, 1>> = StaticCell::new(); static CHANNEL: StaticCell<PriorityChannel<CriticalSectionRawMutex, u32, Max, 1>> = StaticCell::new();
let c = &*CHANNEL.init(PriorityChannel::new()); let c = &*CHANNEL.init(PriorityChannel::new());
assert!(c.try_send(1).is_ok()); assert!(c.try_send(1).is_ok());