Avoid a race condition by reducing the locks to one
This commit is contained in:
		| @@ -145,14 +145,11 @@ where | ||||
|         futures::future::poll_fn(|cx| self.recv_poll(cx)).await | ||||
|     } | ||||
|  | ||||
|     fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> { | ||||
|         match self.try_recv() { | ||||
|     fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { | ||||
|         match self.channel.get().try_recv_with_context(Some(cx)) { | ||||
|             Ok(v) => Poll::Ready(Some(v)), | ||||
|             Err(TryRecvError::Closed) => Poll::Ready(None), | ||||
|             Err(TryRecvError::Empty) => { | ||||
|                 self.channel.get().set_receiver_waker(&cx.waker()); | ||||
|                 Poll::Pending | ||||
|             } | ||||
|             Err(TryRecvError::Empty) => Poll::Pending, | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -279,11 +276,15 @@ where | ||||
|     type Output = Result<(), SendError<T>>; | ||||
|  | ||||
|     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||
|         match self.sender.try_send(unsafe { self.message.get().read() }) { | ||||
|         match self | ||||
|             .sender | ||||
|             .channel | ||||
|             .get() | ||||
|             .try_send_with_context(unsafe { self.message.get().read() }, Some(cx)) | ||||
|         { | ||||
|             Ok(..) => Poll::Ready(Ok(())), | ||||
|             Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), | ||||
|             Err(TrySendError::Full(..)) => { | ||||
|                 self.sender.channel.get().set_senders_waker(&cx.waker()); | ||||
|                 Poll::Pending | ||||
|                 // Note we leave the existing UnsafeCell contents - they still | ||||
|                 // contain the original message. We could create another UnsafeCell | ||||
| @@ -307,10 +308,9 @@ where | ||||
|     type Output = (); | ||||
|  | ||||
|     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||
|         if self.sender.is_closed() { | ||||
|         if self.sender.channel.get().is_closed_with_context(Some(cx)) { | ||||
|             Poll::Ready(()) | ||||
|         } else { | ||||
|             self.sender.channel.get().set_senders_waker(&cx.waker()); | ||||
|             Poll::Pending | ||||
|         } | ||||
|     } | ||||
| @@ -513,7 +513,11 @@ where | ||||
|     } | ||||
|  | ||||
|     fn try_recv(&mut self) -> Result<T, TryRecvError> { | ||||
|         let state = &mut self.state; | ||||
|         self.try_recv_with_context(None) | ||||
|     } | ||||
|  | ||||
|     fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> { | ||||
|         let mut state = &mut self.state; | ||||
|         self.mutex.lock(|_| { | ||||
|             if !state.closed { | ||||
|                 if state.read_pos != state.write_pos || state.full { | ||||
| @@ -526,6 +530,8 @@ where | ||||
|                     state.read_pos = (state.read_pos + 1) % state.buf.len(); | ||||
|                     Ok(message) | ||||
|                 } else if !state.closing { | ||||
|                     cx.into_iter() | ||||
|                         .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); | ||||
|                     Err(TryRecvError::Empty) | ||||
|                 } else { | ||||
|                     state.closed = true; | ||||
| @@ -539,7 +545,15 @@ where | ||||
|     } | ||||
|  | ||||
|     fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { | ||||
|         let state = &mut self.state; | ||||
|         self.try_send_with_context(message, None) | ||||
|     } | ||||
|  | ||||
|     fn try_send_with_context( | ||||
|         &mut self, | ||||
|         message: T, | ||||
|         cx: Option<&mut Context<'_>>, | ||||
|     ) -> Result<(), TrySendError<T>> { | ||||
|         let mut state = &mut self.state; | ||||
|         self.mutex.lock(|_| { | ||||
|             if !state.closed { | ||||
|                 if !state.full { | ||||
| @@ -551,6 +565,8 @@ where | ||||
|                     state.receiver_waker.wake(); | ||||
|                     Ok(()) | ||||
|                 } else { | ||||
|                     cx.into_iter() | ||||
|                         .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); | ||||
|                     Err(TrySendError::Full(message)) | ||||
|                 } | ||||
|             } else { | ||||
| @@ -568,8 +584,20 @@ where | ||||
|     } | ||||
|  | ||||
|     fn is_closed(&mut self) -> bool { | ||||
|         let state = &self.state; | ||||
|         self.mutex.lock(|_| state.closing || state.closed) | ||||
|         self.is_closed_with_context(None) | ||||
|     } | ||||
|  | ||||
|     fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { | ||||
|         let mut state = &mut self.state; | ||||
|         self.mutex.lock(|_| { | ||||
|             if state.closing || state.closed { | ||||
|                 cx.into_iter() | ||||
|                     .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); | ||||
|                 true | ||||
|             } else { | ||||
|                 false | ||||
|             } | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     fn register_receiver(&mut self) { | ||||
| @@ -610,25 +638,19 @@ where | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     fn set_receiver_waker(&mut self, receiver_waker: &Waker) { | ||||
|         let state = &mut self.state; | ||||
|         self.mutex.lock(|_| { | ||||
|             state.receiver_waker.register(receiver_waker); | ||||
|         }) | ||||
|     fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) { | ||||
|         state.receiver_waker.register(receiver_waker); | ||||
|     } | ||||
|  | ||||
|     fn set_senders_waker(&mut self, senders_waker: &Waker) { | ||||
|         let state = &mut self.state; | ||||
|         self.mutex.lock(|_| { | ||||
|             // Dispose of any existing sender causing them to be polled again. | ||||
|             // This could cause a spin given multiple concurrent senders, however given that | ||||
|             // most sends only block waiting for the receiver to become active, this should | ||||
|             // be a short-lived activity. The upside is a greatly simplified implementation | ||||
|             // that avoids the need for intrusive linked-lists and unsafe operations on pinned | ||||
|             // pointers. | ||||
|             state.senders_waker.wake(); | ||||
|             state.senders_waker.register(senders_waker); | ||||
|         }) | ||||
|     fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) { | ||||
|         // Dispose of any existing sender causing them to be polled again. | ||||
|         // This could cause a spin given multiple concurrent senders, however given that | ||||
|         // most sends only block waiting for the receiver to become active, this should | ||||
|         // be a short-lived activity. The upside is a greatly simplified implementation | ||||
|         // that avoids the need for intrusive linked-lists and unsafe operations on pinned | ||||
|         // pointers. | ||||
|         state.senders_waker.wake(); | ||||
|         state.senders_waker.register(senders_waker); | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user