diff --git a/embassy/src/channel/pubsub.rs b/embassy/src/channel/pubsub.rs index cc00f47a..41f275c4 100644 --- a/embassy/src/channel/pubsub.rs +++ b/embassy/src/channel/pubsub.rs @@ -10,7 +10,7 @@ use heapless::Deque; use crate::blocking_mutex::raw::RawMutex; use crate::blocking_mutex::Mutex; -use crate::waitqueue::WakerRegistration; +use crate::waitqueue::MultiWakerRegistration; /// A broadcast channel implementation where multiple publishers can send messages to multiple subscribers /// @@ -42,21 +42,15 @@ impl= SUBS { + Err(Error::MaximumSubscribersReached) + } else { + s.subscriber_count += 1; + Ok(Subscriber { + next_message_id: s.next_message_id, + channel: self, + }) } - - // No spot was found, we're full - Err(Error::MaximumSubscribersReached) }) } @@ -67,20 +61,12 @@ impl= PUBS { + Err(Error::MaximumPublishersReached) + } else { + s.publisher_count += 1; + Ok(Publisher { channel: self }) } - - // No spot was found, we're full - Err(Error::MaximumPublishersReached) }) } @@ -94,12 +80,7 @@ impl PubSubBehavior for PubSubChannel { - fn get_message_with_context( - &self, - next_message_id: &mut u64, - subscriber_index: usize, - cx: Option<&mut Context<'_>>, - ) -> Poll> { + fn get_message_with_context(&self, next_message_id: &mut u64, cx: Option<&mut Context<'_>>) -> Poll> { self.inner.lock(|s| { let mut s = s.borrow_mut(); @@ -113,7 +94,7 @@ impl { if let Some(cx) = cx { - s.register_subscriber_waker(subscriber_index, cx.waker()); + s.register_subscriber_waker(cx.waker()); } Poll::Pending } @@ -126,7 +107,7 @@ impl>) -> Result<(), T> { + fn publish_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), T> { self.inner.lock(|s| { let mut s = s.borrow_mut(); // Try to publish the message @@ -136,7 +117,7 @@ impl { if let Some(cx) = cx { - s.register_publisher_waker(publisher_index, cx.waker()); + s.register_publisher_waker(cx.waker()); } Err(message) } @@ -151,17 +132,17 @@ impl; SUBS], + subscriber_wakers: MultiWakerRegistration, /// Collection of wakers for Publishers that are waiting. - /// The [Publisher::publisher_index] field indexes into this array. - publisher_wakers: [Option; PUBS], + publisher_wakers: MultiWakerRegistration, + /// The amount of subscribers that are active + subscriber_count: usize, + /// The amount of publishers that are active + publisher_count: usize, } impl PubSubState { /// Create a new internal channel state const fn new() -> Self { - const WAKER_INIT: Option = None; Self { queue: Deque::new(), next_message_id: 0, - subscriber_wakers: [WAKER_INIT; SUBS], - publisher_wakers: [WAKER_INIT; PUBS], + subscriber_wakers: MultiWakerRegistration::new(), + publisher_wakers: MultiWakerRegistration::new(), + subscriber_count: 0, + publisher_count: 0, } } fn try_publish(&mut self, message: T) -> Result<(), T> { - let active_subscriber_count = self.subscriber_wakers.iter().flatten().count(); - - if active_subscriber_count == 0 { + if self.subscriber_count == 0 { // We don't need to publish anything because there is no one to receive it return Ok(()); } @@ -206,14 +188,12 @@ impl PubSubSta return Err(message); } // We just did a check for this - self.queue.push_back((message, active_subscriber_count)).ok().unwrap(); + self.queue.push_back((message, self.subscriber_count)).ok().unwrap(); self.next_message_id += 1; // Wake all of the subscribers - for active_subscriber in self.subscriber_wakers.iter_mut().flatten() { - active_subscriber.wake() - } + self.subscriber_wakers.wake(); Ok(()) } @@ -250,26 +230,42 @@ impl PubSubSta if current_message_index == 0 && queue_item.1 == 0 { self.queue.pop_front(); - self.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake()); + self.publisher_wakers.wake(); } Some(WaitResult::Message(message)) } - fn register_subscriber_waker(&mut self, subscriber_index: usize, waker: &Waker) { - self.subscriber_wakers[subscriber_index] - .as_mut() - .unwrap() - .register(waker); + fn register_subscriber_waker(&mut self, waker: &Waker) { + match self.subscriber_wakers.register(waker) { + Ok(()) => {} + Err(_) => { + // All waker slots were full. This can only happen when there was a subscriber that now has dropped. + // We need to throw it away. It's a bit inefficient, but we can wake everything. + // Any future that is still active will simply reregister. + // This won't happen a lot, so it's ok. + self.subscriber_wakers.wake(); + self.subscriber_wakers.register(waker).unwrap(); + } + } } - fn register_publisher_waker(&mut self, publisher_index: usize, waker: &Waker) { - self.publisher_wakers[publisher_index].as_mut().unwrap().register(waker); + fn register_publisher_waker(&mut self, waker: &Waker) { + match self.publisher_wakers.register(waker) { + Ok(()) => {} + Err(_) => { + // All waker slots were full. This can only happen when there was a publisher that now has dropped. + // We need to throw it away. It's a bit inefficient, but we can wake everything. + // Any future that is still active will simply reregister. + // This won't happen a lot, so it's ok. + self.publisher_wakers.wake(); + self.publisher_wakers.register(waker).unwrap(); + } + } } - fn unregister_subscriber(&mut self, subscriber_index: usize, subscriber_next_message_id: u64) { - // Remove the subscriber from the wakers - self.subscriber_wakers[subscriber_index] = None; + fn unregister_subscriber(&mut self, subscriber_next_message_id: u64) { + self.subscriber_count -= 1; // All messages that haven't been read yet by this subscriber must have their counter decremented let start_id = self.next_message_id - self.queue.len() as u64; @@ -282,9 +278,8 @@ impl PubSubSta } } - fn unregister_publisher(&mut self, publisher_index: usize) { - // Remove the publisher from the wakers - self.publisher_wakers[publisher_index] = None; + fn unregister_publisher(&mut self) { + self.publisher_count -= 1; } } @@ -293,8 +288,6 @@ impl PubSubSta /// This instance carries a reference to the channel, but uses a trait object for it so that the channel's /// generics are erased on this subscriber pub struct Subscriber<'a, T: Clone> { - /// Our index into the channel - subscriber_index: usize, /// The message id of the next message we are yet to receive next_message_id: u64, /// The channel we are a subscriber to @@ -321,10 +314,7 @@ impl<'a, T: Clone> Subscriber<'a, T> { /// /// This function does not peek. The message is received if there is one. pub fn try_next_message(&mut self) -> Option> { - match self - .channel - .get_message_with_context(&mut self.next_message_id, self.subscriber_index, None) - { + match self.channel.get_message_with_context(&mut self.next_message_id, None) { Poll::Ready(result) => Some(result), Poll::Pending => None, } @@ -346,8 +336,7 @@ impl<'a, T: Clone> Subscriber<'a, T> { impl<'a, T: Clone> Drop for Subscriber<'a, T> { fn drop(&mut self) { - self.channel - .unregister_subscriber(self.subscriber_index, self.next_message_id) + self.channel.unregister_subscriber(self.next_message_id) } } @@ -357,10 +346,9 @@ impl<'a, T: Clone> futures::Stream for Subscriber<'a, T> { type Item = T; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let sub_index = self.subscriber_index; match self .channel - .get_message_with_context(&mut self.next_message_id, sub_index, Some(cx)) + .get_message_with_context(&mut self.next_message_id, Some(cx)) { Poll::Ready(WaitResult::Message(message)) => Poll::Ready(Some(message)), Poll::Ready(WaitResult::Lagged(_)) => { @@ -377,8 +365,6 @@ impl<'a, T: Clone> futures::Stream for Subscriber<'a, T> { /// This instance carries a reference to the channel, but uses a trait object for it so that the channel's /// generics are erased on this subscriber pub struct Publisher<'a, T: Clone> { - /// Our index into the channel - publisher_index: usize, /// The channel we are a publisher for channel: &'a dyn PubSubBehavior, } @@ -400,13 +386,13 @@ impl<'a, T: Clone> Publisher<'a, T> { /// Publish a message if there is space in the message queue pub fn try_publish(&self, message: T) -> Result<(), T> { - self.channel.publish_with_context(message, self.publisher_index, None) + self.channel.publish_with_context(message, None) } } impl<'a, T: Clone> Drop for Publisher<'a, T> { fn drop(&mut self) { - self.channel.unregister_publisher(self.publisher_index) + self.channel.unregister_publisher() } } @@ -426,7 +412,7 @@ impl<'a, T: Clone> ImmediatePublisher<'a, T> { /// Publish a message if there is space in the message queue pub fn try_publish(&self, message: T) -> Result<(), T> { - self.channel.publish_with_context(message, usize::MAX, None) + self.channel.publish_with_context(message, None) } } @@ -442,20 +428,15 @@ pub enum Error { } trait PubSubBehavior { - fn get_message_with_context( - &self, - next_message_id: &mut u64, - subscriber_index: usize, - cx: Option<&mut Context<'_>>, - ) -> Poll>; + fn get_message_with_context(&self, next_message_id: &mut u64, cx: Option<&mut Context<'_>>) -> Poll>; - fn publish_with_context(&self, message: T, publisher_index: usize, cx: Option<&mut Context<'_>>) -> Result<(), T>; + fn publish_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), T>; fn publish_immediate(&self, message: T); - fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64); + fn unregister_subscriber(&self, subscriber_next_message_id: u64); - fn unregister_publisher(&self, publisher_index: usize); + fn unregister_publisher(&self); } /// Future for the subscriber wait action @@ -467,10 +448,9 @@ impl<'s, 'a, T: Clone> Future for SubscriberWaitFuture<'s, 'a, T> { type Output = WaitResult; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let sub_index = self.subscriber.subscriber_index; self.subscriber .channel - .get_message_with_context(&mut self.subscriber.next_message_id, sub_index, Some(cx)) + .get_message_with_context(&mut self.subscriber.next_message_id, Some(cx)) } } @@ -488,11 +468,7 @@ impl<'s, 'a, T: Clone> Future for PublisherWaitFuture<'s, 'a, T> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let message = self.message.take().unwrap(); - match self - .publisher - .channel - .publish_with_context(message, self.publisher.publisher_index, Some(cx)) - { + match self.publisher.channel.publish_with_context(message, Some(cx)) { Ok(()) => Poll::Ready(()), Err(message) => { self.message = Some(message); diff --git a/embassy/src/waitqueue/mod.rs b/embassy/src/waitqueue/mod.rs index a2bafad9..5c4e1bc3 100644 --- a/embassy/src/waitqueue/mod.rs +++ b/embassy/src/waitqueue/mod.rs @@ -3,3 +3,6 @@ #[cfg_attr(feature = "executor-agnostic", path = "waker_agnostic.rs")] mod waker; pub use waker::*; + +mod multi_waker; +pub use multi_waker::*; diff --git a/embassy/src/waitqueue/multi_waker.rs b/embassy/src/waitqueue/multi_waker.rs new file mode 100644 index 00000000..6e8710cb --- /dev/null +++ b/embassy/src/waitqueue/multi_waker.rs @@ -0,0 +1,31 @@ +use core::task::Waker; + +use super::WakerRegistration; + +pub struct MultiWakerRegistration { + wakers: [WakerRegistration; N], +} + +impl MultiWakerRegistration { + pub const fn new() -> Self { + const WAKER: WakerRegistration = WakerRegistration::new(); + Self { wakers: [WAKER; N] } + } + + /// Register a waker. If the buffer is full the function returns it in the error + pub fn register<'a>(&mut self, w: &'a Waker) -> Result<(), &'a Waker> { + if let Some(waker_slot) = self.wakers.iter_mut().find(|waker_slot| !waker_slot.occupied()) { + waker_slot.register(w); + Ok(()) + } else { + Err(w) + } + } + + /// Wake all registered wakers. This clears the buffer + pub fn wake(&mut self) { + for waker_slot in self.wakers.iter_mut() { + waker_slot.wake() + } + } +} diff --git a/embassy/src/waitqueue/waker.rs b/embassy/src/waitqueue/waker.rs index da907300..a90154cc 100644 --- a/embassy/src/waitqueue/waker.rs +++ b/embassy/src/waitqueue/waker.rs @@ -50,6 +50,11 @@ impl WakerRegistration { unsafe { wake_task(w) } } } + + /// Returns true if a waker is currently registered + pub fn occupied(&self) -> bool { + self.waker.is_some() + } } // SAFETY: `WakerRegistration` effectively contains an `Option`, diff --git a/embassy/src/waitqueue/waker_agnostic.rs b/embassy/src/waitqueue/waker_agnostic.rs index 89430aa4..62e3adb7 100644 --- a/embassy/src/waitqueue/waker_agnostic.rs +++ b/embassy/src/waitqueue/waker_agnostic.rs @@ -47,6 +47,11 @@ impl WakerRegistration { w.wake() } } + + /// Returns true if a waker is currently registered + pub fn occupied(&self) -> bool { + self.waker.is_some() + } } /// Utility struct to register and wake a waker.