Eliminates unsoundness by using an UnsafeCell for sharing the channel
This commit is contained in:
parent
babee7f32a
commit
d711e8a82c
@ -122,11 +122,10 @@ where
|
||||
{
|
||||
let sender = Sender { channel: &channel };
|
||||
let receiver = Receiver { channel: &channel };
|
||||
{
|
||||
let c = channel.get();
|
||||
channel.lock(|c| {
|
||||
c.register_receiver();
|
||||
c.register_sender();
|
||||
}
|
||||
});
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
@ -155,11 +154,12 @@ where
|
||||
}
|
||||
|
||||
fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
|
||||
match self.channel.get().try_recv_with_context(Some(cx)) {
|
||||
Ok(v) => Poll::Ready(Some(v)),
|
||||
Err(TryRecvError::Closed) => Poll::Ready(None),
|
||||
Err(TryRecvError::Empty) => Poll::Pending,
|
||||
}
|
||||
self.channel
|
||||
.lock(|c| match c.try_recv_with_context(Some(cx)) {
|
||||
Ok(v) => Poll::Ready(Some(v)),
|
||||
Err(TryRecvError::Closed) => Poll::Ready(None),
|
||||
Err(TryRecvError::Empty) => Poll::Pending,
|
||||
})
|
||||
}
|
||||
|
||||
/// Attempts to immediately receive a message on this `Receiver`
|
||||
@ -167,7 +167,7 @@ where
|
||||
/// This method will either receive a message from the channel immediately or return an error
|
||||
/// if the channel is empty.
|
||||
pub fn try_recv(&self) -> Result<T, TryRecvError> {
|
||||
self.channel.get().try_recv()
|
||||
self.channel.lock(|c| c.try_recv())
|
||||
}
|
||||
|
||||
/// Closes the receiving half of a channel without dropping it.
|
||||
@ -181,7 +181,7 @@ where
|
||||
/// until those are released.
|
||||
///
|
||||
pub fn close(&mut self) {
|
||||
self.channel.get().close()
|
||||
self.channel.lock(|c| c.close())
|
||||
}
|
||||
}
|
||||
|
||||
@ -190,7 +190,7 @@ where
|
||||
M: Mutex<Data = ()>,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
self.channel.get().deregister_receiver()
|
||||
self.channel.lock(|c| c.deregister_receiver())
|
||||
}
|
||||
}
|
||||
|
||||
@ -245,7 +245,7 @@ where
|
||||
/// [`channel`]: channel
|
||||
/// [`close`]: Receiver::close
|
||||
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
|
||||
self.channel.get().try_send(message)
|
||||
self.channel.lock(|c| c.try_send(message))
|
||||
}
|
||||
|
||||
/// Completes when the receiver has dropped.
|
||||
@ -266,7 +266,7 @@ where
|
||||
/// [`Receiver`]: crate::sync::mpsc::Receiver
|
||||
/// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.channel.get().is_closed()
|
||||
self.channel.lock(|c| c.is_closed())
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,7 +286,11 @@ where
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.message.take() {
|
||||
Some(m) => match self.sender.channel.get().try_send_with_context(m, Some(cx)) {
|
||||
Some(m) => match self
|
||||
.sender
|
||||
.channel
|
||||
.lock(|c| c.try_send_with_context(m, Some(cx)))
|
||||
{
|
||||
Ok(..) => Poll::Ready(Ok(())),
|
||||
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
|
||||
Err(TrySendError::Full(m)) => {
|
||||
@ -315,7 +319,11 @@ where
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.sender.channel.get().is_closed_with_context(Some(cx)) {
|
||||
if self
|
||||
.sender
|
||||
.channel
|
||||
.lock(|c| c.is_closed_with_context(Some(cx)))
|
||||
{
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
Poll::Pending
|
||||
@ -328,7 +336,7 @@ where
|
||||
M: Mutex<Data = ()>,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
self.channel.get().deregister_sender()
|
||||
self.channel.lock(|c| c.deregister_sender())
|
||||
}
|
||||
}
|
||||
|
||||
@ -338,7 +346,7 @@ where
|
||||
{
|
||||
#[allow(clippy::clone_double_ref)]
|
||||
fn clone(&self) -> Self {
|
||||
self.channel.get().register_sender();
|
||||
self.channel.lock(|c| c.register_sender());
|
||||
Sender {
|
||||
channel: self.channel.clone(),
|
||||
}
|
||||
@ -421,6 +429,116 @@ impl<T, const N: usize> ChannelState<T, N> {
|
||||
senders_waker: WakerRegistration::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||
self.try_recv_with_context(None)
|
||||
}
|
||||
|
||||
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
|
||||
if self.read_pos != self.write_pos || self.full {
|
||||
if self.full {
|
||||
self.full = false;
|
||||
self.senders_waker.wake();
|
||||
}
|
||||
let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() };
|
||||
self.read_pos = (self.read_pos + 1) % self.buf.len();
|
||||
Ok(message)
|
||||
} else if !self.closed {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| self.set_receiver_waker(&cx.waker()));
|
||||
Err(TryRecvError::Empty)
|
||||
} else {
|
||||
Err(TryRecvError::Closed)
|
||||
}
|
||||
}
|
||||
|
||||
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
|
||||
self.try_send_with_context(message, None)
|
||||
}
|
||||
|
||||
fn try_send_with_context(
|
||||
&mut self,
|
||||
message: T,
|
||||
cx: Option<&mut Context<'_>>,
|
||||
) -> Result<(), TrySendError<T>> {
|
||||
if !self.closed {
|
||||
if !self.full {
|
||||
self.buf[self.write_pos] = MaybeUninit::new(message.into());
|
||||
self.write_pos = (self.write_pos + 1) % self.buf.len();
|
||||
if self.write_pos == self.read_pos {
|
||||
self.full = true;
|
||||
}
|
||||
self.receiver_waker.wake();
|
||||
Ok(())
|
||||
} else {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| self.set_senders_waker(&cx.waker()));
|
||||
Err(TrySendError::Full(message))
|
||||
}
|
||||
} else {
|
||||
Err(TrySendError::Closed(message))
|
||||
}
|
||||
}
|
||||
|
||||
fn close(&mut self) {
|
||||
self.receiver_waker.wake();
|
||||
self.closed = true;
|
||||
}
|
||||
|
||||
fn is_closed(&mut self) -> bool {
|
||||
self.is_closed_with_context(None)
|
||||
}
|
||||
|
||||
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
|
||||
if self.closed {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| self.set_senders_waker(&cx.waker()));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn register_receiver(&mut self) {
|
||||
assert!(!self.receiver_registered);
|
||||
self.receiver_registered = true;
|
||||
}
|
||||
|
||||
fn deregister_receiver(&mut self) {
|
||||
if self.receiver_registered {
|
||||
self.closed = true;
|
||||
self.senders_waker.wake();
|
||||
}
|
||||
self.receiver_registered = false;
|
||||
}
|
||||
|
||||
fn register_sender(&mut self) {
|
||||
self.senders_registered += 1;
|
||||
}
|
||||
|
||||
fn deregister_sender(&mut self) {
|
||||
assert!(self.senders_registered > 0);
|
||||
self.senders_registered -= 1;
|
||||
if self.senders_registered == 0 {
|
||||
self.receiver_waker.wake();
|
||||
self.closed = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_receiver_waker(&mut self, receiver_waker: &Waker) {
|
||||
self.receiver_waker.register(receiver_waker);
|
||||
}
|
||||
|
||||
fn set_senders_waker(&mut self, senders_waker: &Waker) {
|
||||
// Dispose of any existing sender causing them to be polled again.
|
||||
// This could cause a spin given multiple concurrent senders, however given that
|
||||
// most sends only block waiting for the receiver to become active, this should
|
||||
// be a short-lived activity. The upside is a greatly simplified implementation
|
||||
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
|
||||
// pointers.
|
||||
self.senders_waker.wake();
|
||||
self.senders_waker.register(senders_waker);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const N: usize> Drop for ChannelState<T, N> {
|
||||
@ -442,6 +560,13 @@ impl<T, const N: usize> Drop for ChannelState<T, N> {
|
||||
///
|
||||
/// All data sent will become available in the same order as it was sent.
|
||||
pub struct Channel<M, T, const N: usize>
|
||||
where
|
||||
M: Mutex<Data = ()>,
|
||||
{
|
||||
sync_channel: UnsafeCell<ChannelCell<M, T, N>>,
|
||||
}
|
||||
|
||||
struct ChannelCell<M, T, const N: usize>
|
||||
where
|
||||
M: Mutex<Data = ()>,
|
||||
{
|
||||
@ -468,7 +593,10 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> {
|
||||
pub const fn with_critical_sections() -> Self {
|
||||
let mutex = CriticalSectionMutex::new(());
|
||||
let state = ChannelState::new();
|
||||
Channel { mutex, state }
|
||||
let sync_channel = ChannelCell { mutex, state };
|
||||
Channel {
|
||||
sync_channel: UnsafeCell::new(sync_channel),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -492,7 +620,10 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> {
|
||||
pub const fn with_thread_mode_only() -> Self {
|
||||
let mutex = ThreadModeMutex::new(());
|
||||
let state = ChannelState::new();
|
||||
Channel { mutex, state }
|
||||
let sync_channel = ChannelCell { mutex, state };
|
||||
Channel {
|
||||
sync_channel: UnsafeCell::new(sync_channel),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -513,7 +644,10 @@ impl<T, const N: usize> Channel<WithNoThreads, T, N> {
|
||||
pub const fn with_no_threads() -> Self {
|
||||
let mutex = NoopMutex::new(());
|
||||
let state = ChannelState::new();
|
||||
Channel { mutex, state }
|
||||
let sync_channel = ChannelCell { mutex, state };
|
||||
Channel {
|
||||
sync_channel: UnsafeCell::new(sync_channel),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -521,144 +655,13 @@ impl<M, T, const N: usize> Channel<M, T, N>
|
||||
where
|
||||
M: Mutex<Data = ()>,
|
||||
{
|
||||
fn get(&self) -> &mut Self {
|
||||
let const_ptr = self as *const Self;
|
||||
let mut_ptr = const_ptr as *mut Self;
|
||||
unsafe { &mut *mut_ptr }
|
||||
}
|
||||
|
||||
fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||
self.try_recv_with_context(None)
|
||||
}
|
||||
|
||||
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
|
||||
let mut state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
if state.read_pos != state.write_pos || state.full {
|
||||
if state.full {
|
||||
state.full = false;
|
||||
state.senders_waker.wake();
|
||||
}
|
||||
let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() };
|
||||
state.read_pos = (state.read_pos + 1) % state.buf.len();
|
||||
Ok(message)
|
||||
} else if !state.closed {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
|
||||
Err(TryRecvError::Empty)
|
||||
} else {
|
||||
Err(TryRecvError::Closed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
|
||||
self.try_send_with_context(message, None)
|
||||
}
|
||||
|
||||
fn try_send_with_context(
|
||||
&mut self,
|
||||
message: T,
|
||||
cx: Option<&mut Context<'_>>,
|
||||
) -> Result<(), TrySendError<T>> {
|
||||
let mut state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
if !state.closed {
|
||||
if !state.full {
|
||||
state.buf[state.write_pos] = MaybeUninit::new(message.into());
|
||||
state.write_pos = (state.write_pos + 1) % state.buf.len();
|
||||
if state.write_pos == state.read_pos {
|
||||
state.full = true;
|
||||
}
|
||||
state.receiver_waker.wake();
|
||||
Ok(())
|
||||
} else {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
|
||||
Err(TrySendError::Full(message))
|
||||
}
|
||||
} else {
|
||||
Err(TrySendError::Closed(message))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn close(&mut self) {
|
||||
let state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
state.receiver_waker.wake();
|
||||
state.closed = true;
|
||||
});
|
||||
}
|
||||
|
||||
fn is_closed(&mut self) -> bool {
|
||||
self.is_closed_with_context(None)
|
||||
}
|
||||
|
||||
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
|
||||
let mut state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
if state.closed {
|
||||
cx.into_iter()
|
||||
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn register_receiver(&mut self) {
|
||||
let state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
assert!(!state.receiver_registered);
|
||||
state.receiver_registered = true;
|
||||
});
|
||||
}
|
||||
|
||||
fn deregister_receiver(&mut self) {
|
||||
let state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
if state.receiver_registered {
|
||||
state.closed = true;
|
||||
state.senders_waker.wake();
|
||||
}
|
||||
state.receiver_registered = false;
|
||||
})
|
||||
}
|
||||
|
||||
fn register_sender(&mut self) {
|
||||
let state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
state.senders_registered += 1;
|
||||
})
|
||||
}
|
||||
|
||||
fn deregister_sender(&mut self) {
|
||||
let state = &mut self.state;
|
||||
self.mutex.lock(|_| {
|
||||
assert!(state.senders_registered > 0);
|
||||
state.senders_registered -= 1;
|
||||
if state.senders_registered == 0 {
|
||||
state.receiver_waker.wake();
|
||||
state.closed = true;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
|
||||
state.receiver_waker.register(receiver_waker);
|
||||
}
|
||||
|
||||
fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
|
||||
// Dispose of any existing sender causing them to be polled again.
|
||||
// This could cause a spin given multiple concurrent senders, however given that
|
||||
// most sends only block waiting for the receiver to become active, this should
|
||||
// be a short-lived activity. The upside is a greatly simplified implementation
|
||||
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
|
||||
// pointers.
|
||||
state.senders_waker.wake();
|
||||
state.senders_waker.register(senders_waker);
|
||||
fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, N>) -> R) -> R {
|
||||
unsafe {
|
||||
let sync_channel = &mut *(self.sync_channel.get());
|
||||
let mutex = &mut sync_channel.mutex;
|
||||
let mut state = &mut sync_channel.state;
|
||||
mutex.lock(|_| f(&mut state))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -672,15 +675,12 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn capacity<M, T, const N: usize>(c: &Channel<M, T, N>) -> usize
|
||||
where
|
||||
M: Mutex<Data = ()>,
|
||||
{
|
||||
if !c.state.full {
|
||||
if c.state.write_pos > c.state.read_pos {
|
||||
(c.state.buf.len() - c.state.write_pos) + c.state.read_pos
|
||||
fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize {
|
||||
if !c.full {
|
||||
if c.write_pos > c.read_pos {
|
||||
(c.buf.len() - c.write_pos) + c.read_pos
|
||||
} else {
|
||||
(c.state.buf.len() - c.state.read_pos) + c.state.write_pos
|
||||
(c.buf.len() - c.read_pos) + c.write_pos
|
||||
}
|
||||
} else {
|
||||
0
|
||||
@ -689,14 +689,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn sending_once() {
|
||||
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
|
||||
let mut c = ChannelState::<u32, 3>::new();
|
||||
assert!(c.try_send(1).is_ok());
|
||||
assert_eq!(capacity(&c), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sending_when_full() {
|
||||
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
|
||||
let mut c = ChannelState::<u32, 3>::new();
|
||||
let _ = c.try_send(1);
|
||||
let _ = c.try_send(1);
|
||||
let _ = c.try_send(1);
|
||||
@ -709,8 +709,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn sending_when_closed() {
|
||||
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
|
||||
c.state.closed = true;
|
||||
let mut c = ChannelState::<u32, 3>::new();
|
||||
c.closed = true;
|
||||
match c.try_send(2) {
|
||||
Err(TrySendError::Closed(2)) => assert!(true),
|
||||
_ => assert!(false),
|
||||
@ -719,7 +719,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn receiving_once_with_one_send() {
|
||||
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
|
||||
let mut c = ChannelState::<u32, 3>::new();
|
||||
assert!(c.try_send(1).is_ok());
|
||||
assert_eq!(c.try_recv().unwrap(), 1);
|
||||
assert_eq!(capacity(&c), 3);
|
||||
@ -727,7 +727,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn receiving_when_empty() {
|
||||
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
|
||||
let mut c = ChannelState::<u32, 3>::new();
|
||||
match c.try_recv() {
|
||||
Err(TryRecvError::Empty) => assert!(true),
|
||||
_ => assert!(false),
|
||||
@ -737,8 +737,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn receiving_when_closed() {
|
||||
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
|
||||
c.state.closed = true;
|
||||
let mut c = ChannelState::<u32, 3>::new();
|
||||
c.closed = true;
|
||||
match c.try_recv() {
|
||||
Err(TryRecvError::Closed) => assert!(true),
|
||||
_ => assert!(false),
|
||||
|
Loading…
Reference in New Issue
Block a user