Avoid a race condition by reducing the locks to one

This commit is contained in:
huntc 2021-07-11 13:01:36 +10:00
parent 5a5795ef2b
commit baab52d40c

View File

@ -145,14 +145,11 @@ where
futures::future::poll_fn(|cx| self.recv_poll(cx)).await
}
fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match self.try_recv() {
fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match self.channel.get().try_recv_with_context(Some(cx)) {
Ok(v) => Poll::Ready(Some(v)),
Err(TryRecvError::Closed) => Poll::Ready(None),
Err(TryRecvError::Empty) => {
self.channel.get().set_receiver_waker(&cx.waker());
Poll::Pending
}
Err(TryRecvError::Empty) => Poll::Pending,
}
}
@ -279,11 +276,15 @@ where
type Output = Result<(), SendError<T>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.sender.try_send(unsafe { self.message.get().read() }) {
match self
.sender
.channel
.get()
.try_send_with_context(unsafe { self.message.get().read() }, Some(cx))
{
Ok(..) => Poll::Ready(Ok(())),
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
Err(TrySendError::Full(..)) => {
self.sender.channel.get().set_senders_waker(&cx.waker());
Poll::Pending
// Note we leave the existing UnsafeCell contents - they still
// contain the original message. We could create another UnsafeCell
@ -307,10 +308,9 @@ where
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.sender.is_closed() {
if self.sender.channel.get().is_closed_with_context(Some(cx)) {
Poll::Ready(())
} else {
self.sender.channel.get().set_senders_waker(&cx.waker());
Poll::Pending
}
}
@ -513,7 +513,11 @@ where
}
fn try_recv(&mut self) -> Result<T, TryRecvError> {
let state = &mut self.state;
self.try_recv_with_context(None)
}
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
let mut state = &mut self.state;
self.mutex.lock(|_| {
if !state.closed {
if state.read_pos != state.write_pos || state.full {
@ -526,6 +530,8 @@ where
state.read_pos = (state.read_pos + 1) % state.buf.len();
Ok(message)
} else if !state.closing {
cx.into_iter()
.for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
Err(TryRecvError::Empty)
} else {
state.closed = true;
@ -539,7 +545,15 @@ where
}
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
let state = &mut self.state;
self.try_send_with_context(message, None)
}
fn try_send_with_context(
&mut self,
message: T,
cx: Option<&mut Context<'_>>,
) -> Result<(), TrySendError<T>> {
let mut state = &mut self.state;
self.mutex.lock(|_| {
if !state.closed {
if !state.full {
@ -551,6 +565,8 @@ where
state.receiver_waker.wake();
Ok(())
} else {
cx.into_iter()
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
Err(TrySendError::Full(message))
}
} else {
@ -568,8 +584,20 @@ where
}
fn is_closed(&mut self) -> bool {
let state = &self.state;
self.mutex.lock(|_| state.closing || state.closed)
self.is_closed_with_context(None)
}
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
let mut state = &mut self.state;
self.mutex.lock(|_| {
if state.closing || state.closed {
cx.into_iter()
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
true
} else {
false
}
})
}
fn register_receiver(&mut self) {
@ -610,25 +638,19 @@ where
})
}
fn set_receiver_waker(&mut self, receiver_waker: &Waker) {
let state = &mut self.state;
self.mutex.lock(|_| {
state.receiver_waker.register(receiver_waker);
})
fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
state.receiver_waker.register(receiver_waker);
}
fn set_senders_waker(&mut self, senders_waker: &Waker) {
let state = &mut self.state;
self.mutex.lock(|_| {
// Dispose of any existing sender causing them to be polled again.
// This could cause a spin given multiple concurrent senders, however given that
// most sends only block waiting for the receiver to become active, this should
// be a short-lived activity. The upside is a greatly simplified implementation
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
// pointers.
state.senders_waker.wake();
state.senders_waker.register(senders_waker);
})
fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
// Dispose of any existing sender causing them to be polled again.
// This could cause a spin given multiple concurrent senders, however given that
// most sends only block waiting for the receiver to become active, this should
// be a short-lived activity. The upside is a greatly simplified implementation
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
// pointers.
state.senders_waker.wake();
state.senders_waker.register(senders_waker);
}
}