From a614a55c7ddecc171d48c61bf9fa8c6c11ed16f4 Mon Sep 17 00:00:00 2001 From: Dion Dokter Date: Thu, 16 Jun 2022 22:11:29 +0200 Subject: [PATCH] Put most behaviour one level lower (under the mutex instead of above). Changed the PubSubBehavior to only have high level functions. --- embassy/src/channel/pubsub.rs | 347 ++++++++++++++++++---------------- 1 file changed, 181 insertions(+), 166 deletions(-) diff --git a/embassy/src/channel/pubsub.rs b/embassy/src/channel/pubsub.rs index 20878187..c5a8c01f 100644 --- a/embassy/src/channel/pubsub.rs +++ b/embassy/src/channel/pubsub.rs @@ -94,122 +94,74 @@ impl PubSubBehavior for PubSubChannel { - fn try_publish(&self, message: T) -> Result<(), T> { - self.inner.lock(|inner| { - let mut s = inner.borrow_mut(); + fn get_message_with_context( + &self, + next_message_id: &mut u64, + subscriber_index: usize, + cx: Option<&mut Context<'_>>, + ) -> Poll> { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); - let active_subscriber_count = s.subscriber_wakers.iter().flatten().count(); - - if active_subscriber_count == 0 { - // We don't need to publish anything because there is no one to receive it - return Ok(()); + // Check if we can read a message + match s.get_message(*next_message_id) { + // Yes, so we are done polling + Some(WaitResult::Message(message)) => { + *next_message_id += 1; + Poll::Ready(WaitResult::Message(message)) + } + // No, so we need to reregister our waker and sleep again + None => { + if let Some(cx) = cx { + s.register_subscriber_waker(subscriber_index, cx.waker()); + } + Poll::Pending + } + // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged + Some(WaitResult::Lagged(amount)) => { + *next_message_id += amount; + Poll::Ready(WaitResult::Lagged(amount)) + } } + }) + } - if s.queue.is_full() { - return Err(message); + fn publish_with_context(&self, message: T, publisher_index: usize, cx: Option<&mut Context<'_>>) -> Result<(), T> { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + // Try to publish the message + match s.try_publish(message) { + // We did it, we are ready + Ok(()) => Ok(()), + // The queue is full, so we need to reregister our waker and go to sleep + Err(message) => { + if let Some(cx) = cx { + s.register_publisher_waker(publisher_index, cx.waker()); + } + Err(message) + } } - // We just did a check for this - s.queue.push_back((message, active_subscriber_count)).ok().unwrap(); - - s.next_message_id += 1; - - // Wake all of the subscribers - for active_subscriber in s.subscriber_wakers.iter_mut().flatten() { - active_subscriber.wake() - } - - Ok(()) }) } fn publish_immediate(&self, message: T) { - self.inner.lock(|inner| { - let mut s = inner.borrow_mut(); - - // Make space in the queue if required - if s.queue.is_full() { - s.queue.pop_front(); - } - - // We are going to call something is Self again. - // The lock is fine, but we need to get rid of the refcell borrow - drop(s); - - // This will succeed because we made sure there is space - self.try_publish(message).ok().unwrap(); - }); - } - - fn get_message(&self, message_id: u64) -> Option> { - self.inner.lock(|inner| { - let mut s = inner.borrow_mut(); - - let start_id = s.next_message_id - s.queue.len() as u64; - - if message_id < start_id { - return Some(WaitResult::Lagged(start_id - message_id)); - } - - let current_message_index = (message_id - start_id) as usize; - - if current_message_index >= s.queue.len() { - return None; - } - - // We've checked that the index is valid - let queue_item = s.queue.iter_mut().nth(current_message_index).unwrap(); - - // We're reading this item, so decrement the counter - queue_item.1 -= 1; - let message = queue_item.0.clone(); - - if current_message_index == 0 && queue_item.1 == 0 { - s.queue.pop_front(); - s.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake()); - } - - Some(WaitResult::Message(message)) - }) - } - - fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker) { - self.inner.lock(|inner| { - let mut s = inner.borrow_mut(); - s.subscriber_wakers[subscriber_index].as_mut().unwrap().register(waker); - }) - } - - fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker) { - self.inner.lock(|inner| { - let mut s = inner.borrow_mut(); - s.publisher_wakers[publisher_index].as_mut().unwrap().register(waker); + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.publish_immediate(message) }) } fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64) { - self.inner.lock(|inner| { - let mut s = inner.borrow_mut(); - - // Remove the subscriber from the wakers - s.subscriber_wakers[subscriber_index] = None; - - // All messages that haven't been read yet by this subscriber must have their counter decremented - let start_id = s.next_message_id - s.queue.len() as u64; - if subscriber_next_message_id >= start_id { - let current_message_index = (subscriber_next_message_id - start_id) as usize; - s.queue - .iter_mut() - .skip(current_message_index) - .for_each(|(_, counter)| *counter -= 1); - } + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.unregister_subscriber(subscriber_index, subscriber_next_message_id) }) } fn unregister_publisher(&self, publisher_index: usize) { - self.inner.lock(|inner| { - let mut s = inner.borrow_mut(); - // Remove the publisher from the wakers - s.publisher_wakers[publisher_index] = None; + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.unregister_publisher(publisher_index) }) } } @@ -241,6 +193,99 @@ impl PubSubSta publisher_wakers: [WAKER_INIT; PUBS], } } + + fn try_publish(&mut self, message: T) -> Result<(), T> { + let active_subscriber_count = self.subscriber_wakers.iter().flatten().count(); + + if active_subscriber_count == 0 { + // We don't need to publish anything because there is no one to receive it + return Ok(()); + } + + if self.queue.is_full() { + return Err(message); + } + // We just did a check for this + self.queue.push_back((message, active_subscriber_count)).ok().unwrap(); + + self.next_message_id += 1; + + // Wake all of the subscribers + for active_subscriber in self.subscriber_wakers.iter_mut().flatten() { + active_subscriber.wake() + } + + Ok(()) + } + + fn publish_immediate(&mut self, message: T) { + // Make space in the queue if required + if self.queue.is_full() { + self.queue.pop_front(); + } + + // This will succeed because we made sure there is space + self.try_publish(message).ok().unwrap(); + } + + fn get_message(&mut self, message_id: u64) -> Option> { + let start_id = self.next_message_id - self.queue.len() as u64; + + if message_id < start_id { + return Some(WaitResult::Lagged(start_id - message_id)); + } + + let current_message_index = (message_id - start_id) as usize; + + if current_message_index >= self.queue.len() { + return None; + } + + // We've checked that the index is valid + let queue_item = self.queue.iter_mut().nth(current_message_index).unwrap(); + + // We're reading this item, so decrement the counter + queue_item.1 -= 1; + let message = queue_item.0.clone(); + + if current_message_index == 0 && queue_item.1 == 0 { + self.queue.pop_front(); + self.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake()); + } + + Some(WaitResult::Message(message)) + } + + fn register_subscriber_waker(&mut self, subscriber_index: usize, waker: &Waker) { + self.subscriber_wakers[subscriber_index] + .as_mut() + .unwrap() + .register(waker); + } + + fn register_publisher_waker(&mut self, publisher_index: usize, waker: &Waker) { + self.publisher_wakers[publisher_index].as_mut().unwrap().register(waker); + } + + fn unregister_subscriber(&mut self, subscriber_index: usize, subscriber_next_message_id: u64) { + // Remove the subscriber from the wakers + self.subscriber_wakers[subscriber_index] = None; + + // All messages that haven't been read yet by this subscriber must have their counter decremented + let start_id = self.next_message_id - self.queue.len() as u64; + if subscriber_next_message_id >= start_id { + let current_message_index = (subscriber_next_message_id - start_id) as usize; + self.queue + .iter_mut() + .skip(current_message_index) + .for_each(|(_, counter)| *counter -= 1); + } + } + + fn unregister_publisher(&mut self, publisher_index: usize) { + // Remove the publisher from the wakers + self.publisher_wakers[publisher_index] = None; + } } /// A subscriber to a channel @@ -276,15 +321,12 @@ impl<'a, T: Clone> Subscriber<'a, T> { /// /// This function does not peek. The message is received if there is one. pub fn try_next_message(&mut self) -> Option> { - match self.channel.get_message(self.next_message_id) { - Some(WaitResult::Lagged(amount)) => { - self.next_message_id += amount; - Some(WaitResult::Lagged(amount)) - } - result => { - self.next_message_id += 1; - result - } + match self + .channel + .get_message_with_context(&mut self.next_message_id, self.subscriber_index, None) + { + Poll::Ready(result) => Some(result), + Poll::Pending => None, } } @@ -317,26 +359,16 @@ impl<'a, T: Clone> futures::Stream for Subscriber<'a, T> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = unsafe { self.get_unchecked_mut() }; - // Check if we can read a message - match this.channel.get_message(this.next_message_id) { - // Yes, so we are done polling - Some(WaitResult::Message(message)) => { - this.next_message_id += 1; - Poll::Ready(Some(message)) - } - // No, so we need to reregister our waker and sleep again - None => { - this.channel - .register_subscriber_waker(this.subscriber_index, cx.waker()); - Poll::Pending - } - // We missed a couple of messages. We must do our internal bookkeeping. - // This stream impl doesn't return lag results, so we just ignore and start over - Some(WaitResult::Lagged(amount)) => { - this.next_message_id += amount; + match this + .channel + .get_message_with_context(&mut this.next_message_id, this.subscriber_index, Some(cx)) + { + Poll::Ready(WaitResult::Message(message)) => Poll::Ready(Some(message)), + Poll::Ready(WaitResult::Lagged(_)) => { cx.waker().wake_by_ref(); Poll::Pending } + Poll::Pending => Poll::Pending, } } } @@ -369,7 +401,7 @@ impl<'a, T: Clone> Publisher<'a, T> { /// Publish a message if there is space in the message queue pub fn try_publish(&self, message: T) -> Result<(), T> { - self.channel.try_publish(message) + self.channel.publish_with_context(message, self.publisher_index, None) } } @@ -395,7 +427,7 @@ impl<'a, T: Clone> ImmediatePublisher<'a, T> { /// Publish a message if there is space in the message queue pub fn try_publish(&self, message: T) -> Result<(), T> { - self.channel.try_publish(message) + self.channel.publish_with_context(message, usize::MAX, None) } } @@ -411,19 +443,19 @@ pub enum Error { } trait PubSubBehavior { - /// Try to publish a message. If the queue is full it won't succeed - fn try_publish(&self, message: T) -> Result<(), T>; - /// Publish a message immediately. If the queue is full, just throw out the oldest one. + fn get_message_with_context( + &self, + next_message_id: &mut u64, + subscriber_index: usize, + cx: Option<&mut Context<'_>>, + ) -> Poll>; + + fn publish_with_context(&self, message: T, publisher_index: usize, cx: Option<&mut Context<'_>>) -> Result<(), T>; + fn publish_immediate(&self, message: T); - /// Tries to read the message if available - fn get_message(&self, message_id: u64) -> Option>; - /// Register the given waker for the given subscriber. - fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker); - /// Register the given waker for the given publisher. - fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker); - /// Make the channel forget the subscriber. + fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64); - /// Make the channel forget the publisher. + fn unregister_publisher(&self, publisher_index: usize); } @@ -436,26 +468,10 @@ impl<'s, 'a, T: Clone> Future for SubscriberWaitFuture<'s, 'a, T> { type Output = WaitResult; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // Check if we can read a message - match self.subscriber.channel.get_message(self.subscriber.next_message_id) { - // Yes, so we are done polling - Some(WaitResult::Message(message)) => { - self.subscriber.next_message_id += 1; - Poll::Ready(WaitResult::Message(message)) - } - // No, so we need to reregister our waker and sleep again - None => { - self.subscriber - .channel - .register_subscriber_waker(self.subscriber.subscriber_index, cx.waker()); - Poll::Pending - } - // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged - Some(WaitResult::Lagged(amount)) => { - self.subscriber.next_message_id += amount; - Poll::Ready(WaitResult::Lagged(amount)) - } - } + let sub_index = self.subscriber.subscriber_index; + self.subscriber + .channel + .get_message_with_context(&mut self.subscriber.next_message_id, sub_index, Some(cx)) } } @@ -474,16 +490,15 @@ impl<'s, 'a, T: Clone> Future for PublisherWaitFuture<'s, 'a, T> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = unsafe { self.get_unchecked_mut() }; - // Try to publish the message - match this.publisher.channel.try_publish(this.message.take().unwrap()) { - // We did it, we are ready + let message = this.message.take().unwrap(); + match this + .publisher + .channel + .publish_with_context(message, this.publisher.publisher_index, Some(cx)) + { Ok(()) => Poll::Ready(()), - // The queue is full, so we need to reregister our waker and go to sleep Err(message) => { this.message = Some(message); - this.publisher - .channel - .register_publisher_waker(this.publisher.publisher_index, cx.waker()); Poll::Pending } }