Avoid a race condition by reducing the locks to one

This commit is contained in:
huntc 2021-07-11 13:01:36 +10:00
parent 5a5795ef2b
commit baab52d40c

View File

@ -145,14 +145,11 @@ where
futures::future::poll_fn(|cx| self.recv_poll(cx)).await futures::future::poll_fn(|cx| self.recv_poll(cx)).await
} }
fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> { fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match self.try_recv() { match self.channel.get().try_recv_with_context(Some(cx)) {
Ok(v) => Poll::Ready(Some(v)), Ok(v) => Poll::Ready(Some(v)),
Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Closed) => Poll::Ready(None),
Err(TryRecvError::Empty) => { Err(TryRecvError::Empty) => Poll::Pending,
self.channel.get().set_receiver_waker(&cx.waker());
Poll::Pending
}
} }
} }
@ -279,11 +276,15 @@ where
type Output = Result<(), SendError<T>>; type Output = Result<(), SendError<T>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.sender.try_send(unsafe { self.message.get().read() }) { match self
.sender
.channel
.get()
.try_send_with_context(unsafe { self.message.get().read() }, Some(cx))
{
Ok(..) => Poll::Ready(Ok(())), Ok(..) => Poll::Ready(Ok(())),
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
Err(TrySendError::Full(..)) => { Err(TrySendError::Full(..)) => {
self.sender.channel.get().set_senders_waker(&cx.waker());
Poll::Pending Poll::Pending
// Note we leave the existing UnsafeCell contents - they still // Note we leave the existing UnsafeCell contents - they still
// contain the original message. We could create another UnsafeCell // contain the original message. We could create another UnsafeCell
@ -307,10 +308,9 @@ where
type Output = (); type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.sender.is_closed() { if self.sender.channel.get().is_closed_with_context(Some(cx)) {
Poll::Ready(()) Poll::Ready(())
} else { } else {
self.sender.channel.get().set_senders_waker(&cx.waker());
Poll::Pending Poll::Pending
} }
} }
@ -513,7 +513,11 @@ where
} }
fn try_recv(&mut self) -> Result<T, TryRecvError> { fn try_recv(&mut self) -> Result<T, TryRecvError> {
let state = &mut self.state; self.try_recv_with_context(None)
}
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
let mut state = &mut self.state;
self.mutex.lock(|_| { self.mutex.lock(|_| {
if !state.closed { if !state.closed {
if state.read_pos != state.write_pos || state.full { if state.read_pos != state.write_pos || state.full {
@ -526,6 +530,8 @@ where
state.read_pos = (state.read_pos + 1) % state.buf.len(); state.read_pos = (state.read_pos + 1) % state.buf.len();
Ok(message) Ok(message)
} else if !state.closing { } else if !state.closing {
cx.into_iter()
.for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
Err(TryRecvError::Empty) Err(TryRecvError::Empty)
} else { } else {
state.closed = true; state.closed = true;
@ -539,7 +545,15 @@ where
} }
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
let state = &mut self.state; self.try_send_with_context(message, None)
}
fn try_send_with_context(
&mut self,
message: T,
cx: Option<&mut Context<'_>>,
) -> Result<(), TrySendError<T>> {
let mut state = &mut self.state;
self.mutex.lock(|_| { self.mutex.lock(|_| {
if !state.closed { if !state.closed {
if !state.full { if !state.full {
@ -551,6 +565,8 @@ where
state.receiver_waker.wake(); state.receiver_waker.wake();
Ok(()) Ok(())
} else { } else {
cx.into_iter()
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
Err(TrySendError::Full(message)) Err(TrySendError::Full(message))
} }
} else { } else {
@ -568,8 +584,20 @@ where
} }
fn is_closed(&mut self) -> bool { fn is_closed(&mut self) -> bool {
let state = &self.state; self.is_closed_with_context(None)
self.mutex.lock(|_| state.closing || state.closed) }
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
let mut state = &mut self.state;
self.mutex.lock(|_| {
if state.closing || state.closed {
cx.into_iter()
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
true
} else {
false
}
})
} }
fn register_receiver(&mut self) { fn register_receiver(&mut self) {
@ -610,16 +638,11 @@ where
}) })
} }
fn set_receiver_waker(&mut self, receiver_waker: &Waker) { fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
let state = &mut self.state;
self.mutex.lock(|_| {
state.receiver_waker.register(receiver_waker); state.receiver_waker.register(receiver_waker);
})
} }
fn set_senders_waker(&mut self, senders_waker: &Waker) { fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
let state = &mut self.state;
self.mutex.lock(|_| {
// Dispose of any existing sender causing them to be polled again. // Dispose of any existing sender causing them to be polled again.
// This could cause a spin given multiple concurrent senders, however given that // This could cause a spin given multiple concurrent senders, however given that
// most sends only block waiting for the receiver to become active, this should // most sends only block waiting for the receiver to become active, this should
@ -628,7 +651,6 @@ where
// pointers. // pointers.
state.senders_waker.wake(); state.senders_waker.wake();
state.senders_waker.register(senders_waker); state.senders_waker.register(senders_waker);
})
} }
} }