Priority channel using binary heap

This commit is contained in:
Scott Mabin 2023-11-18 14:21:43 +00:00
parent 5bc7557826
commit ca0d02933b
2 changed files with 745 additions and 0 deletions

View File

@ -15,6 +15,7 @@ pub mod blocking_mutex;
pub mod channel; pub mod channel;
pub mod mutex; pub mod mutex;
pub mod pipe; pub mod pipe;
pub mod priority_channel;
pub mod pubsub; pub mod pubsub;
pub mod signal; pub mod signal;
pub mod waitqueue; pub mod waitqueue;

View File

@ -0,0 +1,744 @@
//! A queue for sending values between asynchronous tasks.
//!
//! Similar to a [`Channel`](crate::channel::Channel), however [`PriorityChannel`] sifts higher priority items to the front of the queue.
use core::cell::RefCell;
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use heapless::BinaryHeap;
use crate::blocking_mutex::raw::RawMutex;
use crate::blocking_mutex::Mutex;
use crate::waitqueue::WakerRegistration;
/// Send-only access to a [`PriorityChannel`].
pub struct Sender<'ch, M, T, K, const N: usize>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
channel: &'ch PriorityChannel<M, T, K, N>,
}
impl<'ch, M, T, K, const N: usize> Clone for Sender<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
fn clone(&self) -> Self {
Sender { channel: self.channel }
}
}
impl<'ch, M, T, K, const N: usize> Copy for Sender<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
}
impl<'ch, M, T, K, const N: usize> Sender<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
/// Sends a value.
///
/// See [`PriorityChannel::send()`]
pub fn send(&self, message: T) -> SendFuture<'ch, M, T, K, N> {
self.channel.send(message)
}
/// Attempt to immediately send a message.
///
/// See [`PriorityChannel::send()`]
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
self.channel.try_send(message)
}
/// Allows a poll_fn to poll until the channel is ready to send
///
/// See [`PriorityChannel::poll_ready_to_send()`]
pub fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> {
self.channel.poll_ready_to_send(cx)
}
}
/// Send-only access to a [`PriorityChannel`] without knowing channel size.
pub struct DynamicSender<'ch, T> {
channel: &'ch dyn DynamicChannel<T>,
}
impl<'ch, T> Clone for DynamicSender<'ch, T> {
fn clone(&self) -> Self {
DynamicSender { channel: self.channel }
}
}
impl<'ch, T> Copy for DynamicSender<'ch, T> {}
impl<'ch, M, T, K, const N: usize> From<Sender<'ch, M, T, K, N>> for DynamicSender<'ch, T>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
fn from(s: Sender<'ch, M, T, K, N>) -> Self {
Self { channel: s.channel }
}
}
impl<'ch, T> DynamicSender<'ch, T> {
/// Sends a value.
///
/// See [`PriorityChannel::send()`]
pub fn send(&self, message: T) -> DynamicSendFuture<'ch, T> {
DynamicSendFuture {
channel: self.channel,
message: Some(message),
}
}
/// Attempt to immediately send a message.
///
/// See [`PriorityChannel::send()`]
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
self.channel.try_send_with_context(message, None)
}
/// Allows a poll_fn to poll until the channel is ready to send
///
/// See [`PriorityChannel::poll_ready_to_send()`]
pub fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> {
self.channel.poll_ready_to_send(cx)
}
}
/// Receive-only access to a [`PriorityChannel`].
pub struct Receiver<'ch, M, T, K, const N: usize>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
channel: &'ch PriorityChannel<M, T, K, N>,
}
impl<'ch, M, T, K, const N: usize> Clone for Receiver<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
fn clone(&self) -> Self {
Receiver { channel: self.channel }
}
}
impl<'ch, M, T, K, const N: usize> Copy for Receiver<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
}
impl<'ch, M, T, K, const N: usize> Receiver<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
/// Receive the next value.
///
/// See [`PriorityChannel::receive()`].
pub fn receive(&self) -> ReceiveFuture<'_, M, T, K, N> {
self.channel.receive()
}
/// Attempt to immediately receive the next value.
///
/// See [`PriorityChannel::try_receive()`]
pub fn try_receive(&self) -> Result<T, TryReceiveError> {
self.channel.try_receive()
}
/// Allows a poll_fn to poll until the channel is ready to receive
///
/// See [`PriorityChannel::poll_ready_to_receive()`]
pub fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> {
self.channel.poll_ready_to_receive(cx)
}
/// Poll the channel for the next item
///
/// See [`PriorityChannel::poll_receive()`]
pub fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<T> {
self.channel.poll_receive(cx)
}
}
/// Receive-only access to a [`PriorityChannel`] without knowing channel size.
pub struct DynamicReceiver<'ch, T> {
channel: &'ch dyn DynamicChannel<T>,
}
impl<'ch, T> Clone for DynamicReceiver<'ch, T> {
fn clone(&self) -> Self {
DynamicReceiver { channel: self.channel }
}
}
impl<'ch, T> Copy for DynamicReceiver<'ch, T> {}
impl<'ch, T> DynamicReceiver<'ch, T> {
/// Receive the next value.
///
/// See [`PriorityChannel::receive()`].
pub fn receive(&self) -> DynamicReceiveFuture<'_, T> {
DynamicReceiveFuture { channel: self.channel }
}
/// Attempt to immediately receive the next value.
///
/// See [`PriorityChannel::try_receive()`]
pub fn try_receive(&self) -> Result<T, TryReceiveError> {
self.channel.try_receive_with_context(None)
}
/// Allows a poll_fn to poll until the channel is ready to receive
///
/// See [`PriorityChannel::poll_ready_to_receive()`]
pub fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> {
self.channel.poll_ready_to_receive(cx)
}
/// Poll the channel for the next item
///
/// See [`PriorityChannel::poll_receive()`]
pub fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<T> {
self.channel.poll_receive(cx)
}
}
impl<'ch, M, T, K, const N: usize> From<Receiver<'ch, M, T, K, N>> for DynamicReceiver<'ch, T>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
fn from(s: Receiver<'ch, M, T, K, N>) -> Self {
Self { channel: s.channel }
}
}
/// Future returned by [`PriorityChannel::receive`] and [`Receiver::receive`].
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReceiveFuture<'ch, M, T, K, const N: usize>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
channel: &'ch PriorityChannel<M, T, K, N>,
}
impl<'ch, M, T, K, const N: usize> Future for ReceiveFuture<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
self.channel.poll_receive(cx)
}
}
/// Future returned by [`DynamicReceiver::receive`].
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct DynamicReceiveFuture<'ch, T> {
channel: &'ch dyn DynamicChannel<T>,
}
impl<'ch, T> Future for DynamicReceiveFuture<'ch, T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
match self.channel.try_receive_with_context(Some(cx)) {
Ok(v) => Poll::Ready(v),
Err(TryReceiveError::Empty) => Poll::Pending,
}
}
}
/// Future returned by [`PriorityChannel::send`] and [`Sender::send`].
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendFuture<'ch, M, T, K, const N: usize>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
channel: &'ch PriorityChannel<M, T, K, N>,
message: Option<T>,
}
impl<'ch, M, T, K, const N: usize> Future for SendFuture<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.message.take() {
Some(m) => match self.channel.try_send_with_context(m, Some(cx)) {
Ok(..) => Poll::Ready(()),
Err(TrySendError::Full(m)) => {
self.message = Some(m);
Poll::Pending
}
},
None => panic!("Message cannot be None"),
}
}
}
impl<'ch, M, T, K, const N: usize> Unpin for SendFuture<'ch, M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
}
/// Future returned by [`DynamicSender::send`].
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct DynamicSendFuture<'ch, T> {
channel: &'ch dyn DynamicChannel<T>,
message: Option<T>,
}
impl<'ch, T> Future for DynamicSendFuture<'ch, T> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.message.take() {
Some(m) => match self.channel.try_send_with_context(m, Some(cx)) {
Ok(..) => Poll::Ready(()),
Err(TrySendError::Full(m)) => {
self.message = Some(m);
Poll::Pending
}
},
None => panic!("Message cannot be None"),
}
}
}
impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {}
trait DynamicChannel<T> {
fn try_send_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>>;
fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryReceiveError>;
fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()>;
fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()>;
fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<T>;
}
/// Error returned by [`try_receive`](PriorityChannel::try_receive).
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TryReceiveError {
/// A message could not be received because the channel is empty.
Empty,
}
/// Error returned by [`try_send`](PriorityChannel::try_send).
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TrySendError<T> {
/// The data could not be sent on the channel because the channel is
/// currently full and sending would require blocking.
Full(T),
}
struct ChannelState<T, K, const N: usize> {
queue: BinaryHeap<T, K, N>,
receiver_waker: WakerRegistration,
senders_waker: WakerRegistration,
}
impl<T, K, const N: usize> ChannelState<T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
{
const fn new() -> Self {
ChannelState {
queue: BinaryHeap::new(),
receiver_waker: WakerRegistration::new(),
senders_waker: WakerRegistration::new(),
}
}
fn try_receive(&mut self) -> Result<T, TryReceiveError> {
self.try_receive_with_context(None)
}
fn try_receive_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryReceiveError> {
if self.queue.len() == self.queue.capacity() {
self.senders_waker.wake();
}
if let Some(message) = self.queue.pop() {
Ok(message)
} else {
if let Some(cx) = cx {
self.receiver_waker.register(cx.waker());
}
Err(TryReceiveError::Empty)
}
}
fn poll_receive(&mut self, cx: &mut Context<'_>) -> Poll<T> {
if self.queue.len() == self.queue.capacity() {
self.senders_waker.wake();
}
if let Some(message) = self.queue.pop() {
Poll::Ready(message)
} else {
self.receiver_waker.register(cx.waker());
Poll::Pending
}
}
fn poll_ready_to_receive(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.receiver_waker.register(cx.waker());
if !self.queue.is_empty() {
Poll::Ready(())
} else {
Poll::Pending
}
}
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
self.try_send_with_context(message, None)
}
fn try_send_with_context(&mut self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> {
match self.queue.push(message) {
Ok(()) => {
self.receiver_waker.wake();
Ok(())
}
Err(message) => {
if let Some(cx) = cx {
self.senders_waker.register(cx.waker());
}
Err(TrySendError::Full(message))
}
}
}
fn poll_ready_to_send(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.senders_waker.register(cx.waker());
if !self.queue.len() == self.queue.capacity() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
/// A bounded channel for communicating between asynchronous tasks
/// with backpressure.
///
/// The channel will buffer up to the provided number of messages. Once the
/// buffer is full, attempts to `send` new messages will wait until a message is
/// received from the channel.
///
/// All data sent will become available in the same order as it was sent.
pub struct PriorityChannel<M, T, K, const N: usize>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
inner: Mutex<M, RefCell<ChannelState<T, K, N>>>,
}
impl<M, T, K, const N: usize> PriorityChannel<M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
/// Establish a new bounded channel. For example, to create one with a NoopMutex:
///
/// ```
/// use embassy_sync::channel::PriorityChannel;
/// use embassy_sync::blocking_mutex::raw::NoopRawMutex;
///
/// // Declare a bounded channel of 3 u32s.
/// let mut channel = PriorityChannel::<NoopRawMutex, u32, 3>::new();
/// ```
pub const fn new() -> Self {
Self {
inner: Mutex::new(RefCell::new(ChannelState::new())),
}
}
fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, K, N>) -> R) -> R {
self.inner.lock(|rc| f(&mut *unwrap!(rc.try_borrow_mut())))
}
fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryReceiveError> {
self.lock(|c| c.try_receive_with_context(cx))
}
/// Poll the channel for the next message
pub fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<T> {
self.lock(|c| c.poll_receive(cx))
}
fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> {
self.lock(|c| c.try_send_with_context(m, cx))
}
/// Allows a poll_fn to poll until the channel is ready to receive
pub fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> {
self.lock(|c| c.poll_ready_to_receive(cx))
}
/// Allows a poll_fn to poll until the channel is ready to send
pub fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> {
self.lock(|c| c.poll_ready_to_send(cx))
}
/// Get a sender for this channel.
pub fn sender(&self) -> Sender<'_, M, T, K, N> {
Sender { channel: self }
}
/// Get a receiver for this channel.
pub fn receiver(&self) -> Receiver<'_, M, T, K, N> {
Receiver { channel: self }
}
/// Send a value, waiting until there is capacity.
///
/// Sending completes when the value has been pushed to the channel's queue.
/// This doesn't mean the value has been received yet.
pub fn send(&self, message: T) -> SendFuture<'_, M, T, K, N> {
SendFuture {
channel: self,
message: Some(message),
}
}
/// Attempt to immediately send a message.
///
/// This method differs from [`send`](PriorityChannel::send) by returning immediately if the channel's
/// buffer is full, instead of waiting.
///
/// # Errors
///
/// If the channel capacity has been reached, i.e., the channel has `n`
/// buffered values where `n` is the argument passed to [`PriorityChannel`], then an
/// error is returned.
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
self.lock(|c| c.try_send(message))
}
/// Receive the next value.
///
/// If there are no messages in the channel's buffer, this method will
/// wait until a message is sent.
pub fn receive(&self) -> ReceiveFuture<'_, M, T, K, N> {
ReceiveFuture { channel: self }
}
/// Attempt to immediately receive a message.
///
/// This method will either receive a message from the channel immediately or return an error
/// if the channel is empty.
pub fn try_receive(&self) -> Result<T, TryReceiveError> {
self.lock(|c| c.try_receive())
}
}
/// Implements the DynamicChannel to allow creating types that are unaware of the queue size with the
/// tradeoff cost of dynamic dispatch.
impl<M, T, K, const N: usize> DynamicChannel<T> for PriorityChannel<M, T, K, N>
where
T: Ord,
K: heapless::binary_heap::Kind,
M: RawMutex,
{
fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError<T>> {
PriorityChannel::try_send_with_context(self, m, cx)
}
fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result<T, TryReceiveError> {
PriorityChannel::try_receive_with_context(self, cx)
}
fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> {
PriorityChannel::poll_ready_to_send(self, cx)
}
fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> {
PriorityChannel::poll_ready_to_receive(self, cx)
}
fn poll_receive(&self, cx: &mut Context<'_>) -> Poll<T> {
PriorityChannel::poll_receive(self, cx)
}
}
#[cfg(test)]
mod tests {
use core::time::Duration;
use futures_executor::ThreadPool;
use futures_timer::Delay;
use futures_util::task::SpawnExt;
use static_cell::StaticCell;
use super::*;
use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize {
c.queue.capacity() - c.queue.len()
}
#[test]
fn sending_once() {
let mut c = ChannelState::<u32, 3>::new();
assert!(c.try_send(1).is_ok());
assert_eq!(capacity(&c), 2);
}
#[test]
fn sending_when_full() {
let mut c = ChannelState::<u32, 3>::new();
let _ = c.try_send(1);
let _ = c.try_send(1);
let _ = c.try_send(1);
match c.try_send(2) {
Err(TrySendError::Full(2)) => assert!(true),
_ => assert!(false),
}
assert_eq!(capacity(&c), 0);
}
#[test]
fn receiving_once_with_one_send() {
let mut c = ChannelState::<u32, 3>::new();
assert!(c.try_send(1).is_ok());
assert_eq!(c.try_receive().unwrap(), 1);
assert_eq!(capacity(&c), 3);
}
#[test]
fn receiving_when_empty() {
let mut c = ChannelState::<u32, 3>::new();
match c.try_receive() {
Err(TryReceiveError::Empty) => assert!(true),
_ => assert!(false),
}
assert_eq!(capacity(&c), 3);
}
#[test]
fn simple_send_and_receive() {
let c = PriorityChannel::<NoopRawMutex, u32, 3>::new();
assert!(c.try_send(1).is_ok());
assert_eq!(c.try_receive().unwrap(), 1);
}
#[test]
fn cloning() {
let c = PriorityChannel::<NoopRawMutex, u32, 3>::new();
let r1 = c.receiver();
let s1 = c.sender();
let _ = r1.clone();
let _ = s1.clone();
}
#[test]
fn dynamic_dispatch() {
let c = PriorityChannel::<NoopRawMutex, u32, 3>::new();
let s: DynamicSender<'_, u32> = c.sender().into();
let r: DynamicReceiver<'_, u32> = c.receiver().into();
assert!(s.try_send(1).is_ok());
assert_eq!(r.try_receive().unwrap(), 1);
}
#[futures_test::test]
async fn receiver_receives_given_try_send_async() {
let executor = ThreadPool::new().unwrap();
static CHANNEL: StaticCell<PriorityChannel<CriticalSectionRawMutex, u32, 3>> = StaticCell::new();
let c = &*CHANNEL.init(PriorityChannel::new());
let c2 = c;
assert!(executor
.spawn(async move {
assert!(c2.try_send(1).is_ok());
})
.is_ok());
assert_eq!(c.receive().await, 1);
}
#[futures_test::test]
async fn sender_send_completes_if_capacity() {
let c = PriorityChannel::<CriticalSectionRawMutex, u32, 1>::new();
c.send(1).await;
assert_eq!(c.receive().await, 1);
}
#[futures_test::test]
async fn senders_sends_wait_until_capacity() {
let executor = ThreadPool::new().unwrap();
static CHANNEL: StaticCell<PriorityChannel<CriticalSectionRawMutex, u32, 1>> = StaticCell::new();
let c = &*CHANNEL.init(PriorityChannel::new());
assert!(c.try_send(1).is_ok());
let c2 = c;
let send_task_1 = executor.spawn_with_handle(async move { c2.send(2).await });
let c2 = c;
let send_task_2 = executor.spawn_with_handle(async move { c2.send(3).await });
// Wish I could think of a means of determining that the async send is waiting instead.
// However, I've used the debugger to observe that the send does indeed wait.
Delay::new(Duration::from_millis(500)).await;
assert_eq!(c.receive().await, 1);
assert!(executor
.spawn(async move {
loop {
c.receive().await;
}
})
.is_ok());
send_task_1.unwrap().await;
send_task_2.unwrap().await;
}
}