Avoid a race condition by reducing the locks to one
This commit is contained in:
parent
5a5795ef2b
commit
baab52d40c
@ -145,14 +145,11 @@ where
|
||||
futures::future::poll_fn(|cx| self.recv_poll(cx)).await
|
||||
}
|
||||
|
||||
fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> {
|
||||
match self.try_recv() {
|
||||
fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
|
||||
match self.channel.get().try_recv_with_context(Some(cx)) {
|
||||
Ok(v) => Poll::Ready(Some(v)),
|
||||
Err(TryRecvError::Closed) => Poll::Ready(None),
|
||||
Err(TryRecvError::Empty) => {
|
||||
self.channel.get().set_receiver_waker(&cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
Err(TryRecvError::Empty) => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,11 +276,15 @@ where
|
||||
type Output = Result<(), SendError<T>>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.sender.try_send(unsafe { self.message.get().read() }) {
|
||||
match self
|
||||
.sender
|
||||
.channel
|
||||
.get()
|
||||
.try_send_with_context(unsafe { self.message.get().read() }, Some(cx))
|
||||
{
|
||||
Ok(..) => Poll::Ready(Ok(())),
|
||||
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
|
||||
Err(TrySendError::Full(..)) => {
|
||||
self.sender.channel.get().set_senders_waker(&cx.waker());
|
||||
Poll::Pending
|
||||
// Note we leave the existing UnsafeCell contents - they still
|
||||
// contain the original message. We could create another UnsafeCell
|
||||
@ -307,10 +308,9 @@ where
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.sender.is_closed() {
|
||||
if self.sender.channel.get().is_closed_with_context(Some(cx)) {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.sender.channel.get().set_senders_waker(&cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
@ -513,7 +513,11 @@ where
|
||||
}
|
||||
|
||||
fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||
let state = &mut self.state;
|
||||
self.try_recv_with_context(None)
|
||||
}
|
||||
|
||||
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
|
||||
let mut state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
if !state.closed {
|
||||
if state.read_pos != state.write_pos || state.full {
|
||||
@ -526,6 +530,8 @@ where
|
||||
state.read_pos = (state.read_pos + 1) % state.buf.len();
|
||||
Ok(message)
|
||||
} else if !state.closing {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
|
||||
Err(TryRecvError::Empty)
|
||||
} else {
|
||||
state.closed = true;
|
||||
@ -539,7 +545,15 @@ where
|
||||
}
|
||||
|
||||
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
|
||||
let state = &mut self.state;
|
||||
self.try_send_with_context(message, None)
|
||||
}
|
||||
|
||||
fn try_send_with_context(
|
||||
&mut self,
|
||||
message: T,
|
||||
cx: Option<&mut Context<'_>>,
|
||||
) -> Result<(), TrySendError<T>> {
|
||||
let mut state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
if !state.closed {
|
||||
if !state.full {
|
||||
@ -551,6 +565,8 @@ where
|
||||
state.receiver_waker.wake();
|
||||
Ok(())
|
||||
} else {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
|
||||
Err(TrySendError::Full(message))
|
||||
}
|
||||
} else {
|
||||
@ -568,8 +584,20 @@ where
|
||||
}
|
||||
|
||||
fn is_closed(&mut self) -> bool {
|
||||
let state = &self.state;
|
||||
self.mutex.lock(|_| state.closing || state.closed)
|
||||
self.is_closed_with_context(None)
|
||||
}
|
||||
|
||||
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
|
||||
let mut state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
if state.closing || state.closed {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn register_receiver(&mut self) {
|
||||
@ -610,25 +638,19 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn set_receiver_waker(&mut self, receiver_waker: &Waker) {
|
||||
let state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
state.receiver_waker.register(receiver_waker);
|
||||
})
|
||||
fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
|
||||
state.receiver_waker.register(receiver_waker);
|
||||
}
|
||||
|
||||
fn set_senders_waker(&mut self, senders_waker: &Waker) {
|
||||
let state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
// Dispose of any existing sender causing them to be polled again.
|
||||
// This could cause a spin given multiple concurrent senders, however given that
|
||||
// most sends only block waiting for the receiver to become active, this should
|
||||
// be a short-lived activity. The upside is a greatly simplified implementation
|
||||
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
|
||||
// pointers.
|
||||
state.senders_waker.wake();
|
||||
state.senders_waker.register(senders_waker);
|
||||
})
|
||||
fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
|
||||
// Dispose of any existing sender causing them to be polled again.
|
||||
// This could cause a spin given multiple concurrent senders, however given that
|
||||
// most sends only block waiting for the receiver to become active, this should
|
||||
// be a short-lived activity. The upside is a greatly simplified implementation
|
||||
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
|
||||
// pointers.
|
||||
state.senders_waker.wake();
|
||||
state.senders_waker.register(senders_waker);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user