From baab52d40cb8d1ede339a3a422006108a86d8efb Mon Sep 17 00:00:00 2001 From: huntc Date: Sun, 11 Jul 2021 13:01:36 +1000 Subject: [PATCH] Avoid a race condition by reducing the locks to one --- embassy/src/util/mpsc.rs | 84 +++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 68fcdf7f..8d534dc4 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -145,14 +145,11 @@ where futures::future::poll_fn(|cx| self.recv_poll(cx)).await } - fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll> { - match self.try_recv() { + fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.channel.get().try_recv_with_context(Some(cx)) { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), - Err(TryRecvError::Empty) => { - self.channel.get().set_receiver_waker(&cx.waker()); - Poll::Pending - } + Err(TryRecvError::Empty) => Poll::Pending, } } @@ -279,11 +276,15 @@ where type Output = Result<(), SendError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.sender.try_send(unsafe { self.message.get().read() }) { + match self + .sender + .channel + .get() + .try_send_with_context(unsafe { self.message.get().read() }, Some(cx)) + { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -307,10 +308,9 @@ where type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.sender.is_closed() { + if self.sender.channel.get().is_closed_with_context(Some(cx)) { Poll::Ready(()) } else { - self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending } } @@ -513,7 +513,11 @@ where } fn try_recv(&mut self) -> Result { - let state = &mut self.state; + self.try_recv_with_context(None) + } + + fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { + let mut state = &mut self.state; self.mutex.lock(|_| { if !state.closed { if state.read_pos != state.write_pos || state.full { @@ -526,6 +530,8 @@ where state.read_pos = (state.read_pos + 1) % state.buf.len(); Ok(message) } else if !state.closing { + cx.into_iter() + .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); Err(TryRecvError::Empty) } else { state.closed = true; @@ -539,7 +545,15 @@ where } fn try_send(&mut self, message: T) -> Result<(), TrySendError> { - let state = &mut self.state; + self.try_send_with_context(message, None) + } + + fn try_send_with_context( + &mut self, + message: T, + cx: Option<&mut Context<'_>>, + ) -> Result<(), TrySendError> { + let mut state = &mut self.state; self.mutex.lock(|_| { if !state.closed { if !state.full { @@ -551,6 +565,8 @@ where state.receiver_waker.wake(); Ok(()) } else { + cx.into_iter() + .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); Err(TrySendError::Full(message)) } } else { @@ -568,8 +584,20 @@ where } fn is_closed(&mut self) -> bool { - let state = &self.state; - self.mutex.lock(|_| state.closing || state.closed) + self.is_closed_with_context(None) + } + + fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { + let mut state = &mut self.state; + self.mutex.lock(|_| { + if state.closing || state.closed { + cx.into_iter() + .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); + true + } else { + false + } + }) } fn register_receiver(&mut self) { @@ -610,25 +638,19 @@ where }) } - fn set_receiver_waker(&mut self, receiver_waker: &Waker) { - let state = &mut self.state; - self.mutex.lock(|_| { - state.receiver_waker.register(receiver_waker); - }) + fn set_receiver_waker(state: &mut ChannelState, receiver_waker: &Waker) { + state.receiver_waker.register(receiver_waker); } - fn set_senders_waker(&mut self, senders_waker: &Waker) { - let state = &mut self.state; - self.mutex.lock(|_| { - // Dispose of any existing sender causing them to be polled again. - // This could cause a spin given multiple concurrent senders, however given that - // most sends only block waiting for the receiver to become active, this should - // be a short-lived activity. The upside is a greatly simplified implementation - // that avoids the need for intrusive linked-lists and unsafe operations on pinned - // pointers. - state.senders_waker.wake(); - state.senders_waker.register(senders_waker); - }) + fn set_senders_waker(state: &mut ChannelState, senders_waker: &Waker) { + // Dispose of any existing sender causing them to be polled again. + // This could cause a spin given multiple concurrent senders, however given that + // most sends only block waiting for the receiver to become active, this should + // be a short-lived activity. The upside is a greatly simplified implementation + // that avoids the need for intrusive linked-lists and unsafe operations on pinned + // pointers. + state.senders_waker.wake(); + state.senders_waker.register(senders_waker); } }