From 36b363a5b71497d32d2968fcecf046773822c2e9 Mon Sep 17 00:00:00 2001 From: Dion Dokter Date: Thu, 16 Jun 2022 13:48:26 +0200 Subject: [PATCH] Changed names of subscriber methods and implemented the Stream trait for it --- embassy/src/channel/pubsub.rs | 55 +++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/embassy/src/channel/pubsub.rs b/embassy/src/channel/pubsub.rs index fb8d0ef5..5d81431e 100644 --- a/embassy/src/channel/pubsub.rs +++ b/embassy/src/channel/pubsub.rs @@ -270,14 +270,14 @@ pub struct Subscriber<'a, T: Clone> { impl<'a, T: Clone> Subscriber<'a, T> { /// Wait for a published message - pub fn wait<'s>(&'s mut self) -> SubscriberWaitFuture<'s, 'a, T> { + pub fn next<'s>(&'s mut self) -> SubscriberWaitFuture<'s, 'a, T> { SubscriberWaitFuture { subscriber: self } } /// Try to see if there's a published message we haven't received yet. /// /// This function does not peek. The message is received if there is one. - pub fn check(&mut self) -> Option> { + pub fn try_next(&mut self) -> Option> { match self.channel.get_message(self.next_message_id) { Some(WaitResult::Lagged(amount)) => { self.next_message_id += amount; @@ -300,6 +300,37 @@ impl<'a, T: Clone> Drop for Subscriber<'a, T> { } } +impl<'a, T: Clone> futures::Stream for Subscriber<'a, T> { + type Item = WaitResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = unsafe { self.get_unchecked_mut() }; + + // Check if we can read a message + match this.channel.get_message(this.next_message_id) { + // Yes, so we are done polling + Some(WaitResult::Message(message)) => { + this.next_message_id += 1; + Poll::Ready(Some(WaitResult::Message(message))) + } + // No, so we need to reregister our waker and sleep again + None => { + unsafe { + this + .channel + .register_subscriber_waker(this.subscriber_index, cx.waker()); + } + Poll::Pending + } + // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged + Some(WaitResult::Lagged(amount)) => { + this.next_message_id += amount; + Poll::Ready(Some(WaitResult::Lagged(amount))) + } + } + } +} + /// A publisher to a channel /// /// This instance carries a reference to the channel, but uses a trait object for it so that the channel's @@ -494,11 +525,11 @@ mod tests { pub0.publish(42).await; - assert_eq!(sub0.wait().await, WaitResult::Message(42)); - assert_eq!(sub1.wait().await, WaitResult::Message(42)); + assert_eq!(sub0.next().await, WaitResult::Message(42)); + assert_eq!(sub1.next().await, WaitResult::Message(42)); - assert_eq!(sub0.check(), None); - assert_eq!(sub1.check(), None); + assert_eq!(sub0.try_next(), None); + assert_eq!(sub1.try_next(), None); } #[futures_test::test] @@ -515,12 +546,12 @@ mod tests { pub0.publish_immediate(46); pub0.publish_immediate(47); - assert_eq!(sub0.check(), Some(WaitResult::Lagged(2))); - assert_eq!(sub0.wait().await, WaitResult::Message(44)); - assert_eq!(sub0.wait().await, WaitResult::Message(45)); - assert_eq!(sub0.wait().await, WaitResult::Message(46)); - assert_eq!(sub0.wait().await, WaitResult::Message(47)); - assert_eq!(sub0.check(), None); + assert_eq!(sub0.try_next(), Some(WaitResult::Lagged(2))); + assert_eq!(sub0.next().await, WaitResult::Message(44)); + assert_eq!(sub0.next().await, WaitResult::Message(45)); + assert_eq!(sub0.next().await, WaitResult::Message(46)); + assert_eq!(sub0.next().await, WaitResult::Message(47)); + assert_eq!(sub0.try_next(), None); } #[test]