Avoid a race condition by reducing the locks to one
This commit is contained in:
parent
5a5795ef2b
commit
baab52d40c
@ -145,14 +145,11 @@ where
|
|||||||
futures::future::poll_fn(|cx| self.recv_poll(cx)).await
|
futures::future::poll_fn(|cx| self.recv_poll(cx)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> {
|
fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
|
||||||
match self.try_recv() {
|
match self.channel.get().try_recv_with_context(Some(cx)) {
|
||||||
Ok(v) => Poll::Ready(Some(v)),
|
Ok(v) => Poll::Ready(Some(v)),
|
||||||
Err(TryRecvError::Closed) => Poll::Ready(None),
|
Err(TryRecvError::Closed) => Poll::Ready(None),
|
||||||
Err(TryRecvError::Empty) => {
|
Err(TryRecvError::Empty) => Poll::Pending,
|
||||||
self.channel.get().set_receiver_waker(&cx.waker());
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,11 +276,15 @@ where
|
|||||||
type Output = Result<(), SendError<T>>;
|
type Output = Result<(), SendError<T>>;
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
match self.sender.try_send(unsafe { self.message.get().read() }) {
|
match self
|
||||||
|
.sender
|
||||||
|
.channel
|
||||||
|
.get()
|
||||||
|
.try_send_with_context(unsafe { self.message.get().read() }, Some(cx))
|
||||||
|
{
|
||||||
Ok(..) => Poll::Ready(Ok(())),
|
Ok(..) => Poll::Ready(Ok(())),
|
||||||
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
|
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
|
||||||
Err(TrySendError::Full(..)) => {
|
Err(TrySendError::Full(..)) => {
|
||||||
self.sender.channel.get().set_senders_waker(&cx.waker());
|
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
// Note we leave the existing UnsafeCell contents - they still
|
// Note we leave the existing UnsafeCell contents - they still
|
||||||
// contain the original message. We could create another UnsafeCell
|
// contain the original message. We could create another UnsafeCell
|
||||||
@ -307,10 +308,9 @@ where
|
|||||||
type Output = ();
|
type Output = ();
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
if self.sender.is_closed() {
|
if self.sender.channel.get().is_closed_with_context(Some(cx)) {
|
||||||
Poll::Ready(())
|
Poll::Ready(())
|
||||||
} else {
|
} else {
|
||||||
self.sender.channel.get().set_senders_waker(&cx.waker());
|
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -513,7 +513,11 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||||
let state = &mut self.state;
|
self.try_recv_with_context(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
|
||||||
|
let mut state = &mut self.state;
|
||||||
self.mutex.lock(|_| {
|
self.mutex.lock(|_| {
|
||||||
if !state.closed {
|
if !state.closed {
|
||||||
if state.read_pos != state.write_pos || state.full {
|
if state.read_pos != state.write_pos || state.full {
|
||||||
@ -526,6 +530,8 @@ where
|
|||||||
state.read_pos = (state.read_pos + 1) % state.buf.len();
|
state.read_pos = (state.read_pos + 1) % state.buf.len();
|
||||||
Ok(message)
|
Ok(message)
|
||||||
} else if !state.closing {
|
} else if !state.closing {
|
||||||
|
cx.into_iter()
|
||||||
|
.for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
|
||||||
Err(TryRecvError::Empty)
|
Err(TryRecvError::Empty)
|
||||||
} else {
|
} else {
|
||||||
state.closed = true;
|
state.closed = true;
|
||||||
@ -539,7 +545,15 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
|
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
|
||||||
let state = &mut self.state;
|
self.try_send_with_context(message, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_send_with_context(
|
||||||
|
&mut self,
|
||||||
|
message: T,
|
||||||
|
cx: Option<&mut Context<'_>>,
|
||||||
|
) -> Result<(), TrySendError<T>> {
|
||||||
|
let mut state = &mut self.state;
|
||||||
self.mutex.lock(|_| {
|
self.mutex.lock(|_| {
|
||||||
if !state.closed {
|
if !state.closed {
|
||||||
if !state.full {
|
if !state.full {
|
||||||
@ -551,6 +565,8 @@ where
|
|||||||
state.receiver_waker.wake();
|
state.receiver_waker.wake();
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
|
cx.into_iter()
|
||||||
|
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
|
||||||
Err(TrySendError::Full(message))
|
Err(TrySendError::Full(message))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -568,8 +584,20 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_closed(&mut self) -> bool {
|
fn is_closed(&mut self) -> bool {
|
||||||
let state = &self.state;
|
self.is_closed_with_context(None)
|
||||||
self.mutex.lock(|_| state.closing || state.closed)
|
}
|
||||||
|
|
||||||
|
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
|
||||||
|
let mut state = &mut self.state;
|
||||||
|
self.mutex.lock(|_| {
|
||||||
|
if state.closing || state.closed {
|
||||||
|
cx.into_iter()
|
||||||
|
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_receiver(&mut self) {
|
fn register_receiver(&mut self) {
|
||||||
@ -610,16 +638,11 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_receiver_waker(&mut self, receiver_waker: &Waker) {
|
fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
|
||||||
let state = &mut self.state;
|
|
||||||
self.mutex.lock(|_| {
|
|
||||||
state.receiver_waker.register(receiver_waker);
|
state.receiver_waker.register(receiver_waker);
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_senders_waker(&mut self, senders_waker: &Waker) {
|
fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
|
||||||
let state = &mut self.state;
|
|
||||||
self.mutex.lock(|_| {
|
|
||||||
// Dispose of any existing sender causing them to be polled again.
|
// Dispose of any existing sender causing them to be polled again.
|
||||||
// This could cause a spin given multiple concurrent senders, however given that
|
// This could cause a spin given multiple concurrent senders, however given that
|
||||||
// most sends only block waiting for the receiver to become active, this should
|
// most sends only block waiting for the receiver to become active, this should
|
||||||
@ -628,7 +651,6 @@ where
|
|||||||
// pointers.
|
// pointers.
|
||||||
state.senders_waker.wake();
|
state.senders_waker.wake();
|
||||||
state.senders_waker.register(senders_waker);
|
state.senders_waker.register(senders_waker);
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user