Eliminates unsoundness by using an UnsafeCell for sharing the channel

This commit is contained in:
huntc 2021-07-14 16:34:32 +10:00
parent babee7f32a
commit d711e8a82c

View File

@ -122,11 +122,10 @@ where
{ {
let sender = Sender { channel: &channel }; let sender = Sender { channel: &channel };
let receiver = Receiver { channel: &channel }; let receiver = Receiver { channel: &channel };
{ channel.lock(|c| {
let c = channel.get();
c.register_receiver(); c.register_receiver();
c.register_sender(); c.register_sender();
} });
(sender, receiver) (sender, receiver)
} }
@ -155,11 +154,12 @@ where
} }
fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match self.channel.get().try_recv_with_context(Some(cx)) { self.channel
.lock(|c| match c.try_recv_with_context(Some(cx)) {
Ok(v) => Poll::Ready(Some(v)), Ok(v) => Poll::Ready(Some(v)),
Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Closed) => Poll::Ready(None),
Err(TryRecvError::Empty) => Poll::Pending, Err(TryRecvError::Empty) => Poll::Pending,
} })
} }
/// Attempts to immediately receive a message on this `Receiver` /// Attempts to immediately receive a message on this `Receiver`
@ -167,7 +167,7 @@ where
/// This method will either receive a message from the channel immediately or return an error /// This method will either receive a message from the channel immediately or return an error
/// if the channel is empty. /// if the channel is empty.
pub fn try_recv(&self) -> Result<T, TryRecvError> { pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.channel.get().try_recv() self.channel.lock(|c| c.try_recv())
} }
/// Closes the receiving half of a channel without dropping it. /// Closes the receiving half of a channel without dropping it.
@ -181,7 +181,7 @@ where
/// until those are released. /// until those are released.
/// ///
pub fn close(&mut self) { pub fn close(&mut self) {
self.channel.get().close() self.channel.lock(|c| c.close())
} }
} }
@ -190,7 +190,7 @@ where
M: Mutex<Data = ()>, M: Mutex<Data = ()>,
{ {
fn drop(&mut self) { fn drop(&mut self) {
self.channel.get().deregister_receiver() self.channel.lock(|c| c.deregister_receiver())
} }
} }
@ -245,7 +245,7 @@ where
/// [`channel`]: channel /// [`channel`]: channel
/// [`close`]: Receiver::close /// [`close`]: Receiver::close
pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
self.channel.get().try_send(message) self.channel.lock(|c| c.try_send(message))
} }
/// Completes when the receiver has dropped. /// Completes when the receiver has dropped.
@ -266,7 +266,7 @@ where
/// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver`]: crate::sync::mpsc::Receiver
/// [`Receiver::close`]: crate::sync::mpsc::Receiver::close /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
pub fn is_closed(&self) -> bool { pub fn is_closed(&self) -> bool {
self.channel.get().is_closed() self.channel.lock(|c| c.is_closed())
} }
} }
@ -286,7 +286,11 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.message.take() { match self.message.take() {
Some(m) => match self.sender.channel.get().try_send_with_context(m, Some(cx)) { Some(m) => match self
.sender
.channel
.lock(|c| c.try_send_with_context(m, Some(cx)))
{
Ok(..) => Poll::Ready(Ok(())), Ok(..) => Poll::Ready(Ok(())),
Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
Err(TrySendError::Full(m)) => { Err(TrySendError::Full(m)) => {
@ -315,7 +319,11 @@ where
type Output = (); type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.sender.channel.get().is_closed_with_context(Some(cx)) { if self
.sender
.channel
.lock(|c| c.is_closed_with_context(Some(cx)))
{
Poll::Ready(()) Poll::Ready(())
} else { } else {
Poll::Pending Poll::Pending
@ -328,7 +336,7 @@ where
M: Mutex<Data = ()>, M: Mutex<Data = ()>,
{ {
fn drop(&mut self) { fn drop(&mut self) {
self.channel.get().deregister_sender() self.channel.lock(|c| c.deregister_sender())
} }
} }
@ -338,7 +346,7 @@ where
{ {
#[allow(clippy::clone_double_ref)] #[allow(clippy::clone_double_ref)]
fn clone(&self) -> Self { fn clone(&self) -> Self {
self.channel.get().register_sender(); self.channel.lock(|c| c.register_sender());
Sender { Sender {
channel: self.channel.clone(), channel: self.channel.clone(),
} }
@ -421,6 +429,116 @@ impl<T, const N: usize> ChannelState<T, N> {
senders_waker: WakerRegistration::new(), senders_waker: WakerRegistration::new(),
} }
} }
fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.try_recv_with_context(None)
}
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
if self.read_pos != self.write_pos || self.full {
if self.full {
self.full = false;
self.senders_waker.wake();
}
let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() };
self.read_pos = (self.read_pos + 1) % self.buf.len();
Ok(message)
} else if !self.closed {
cx.into_iter()
.for_each(|cx| self.set_receiver_waker(&cx.waker()));
Err(TryRecvError::Empty)
} else {
Err(TryRecvError::Closed)
}
}
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
self.try_send_with_context(message, None)
}
fn try_send_with_context(
&mut self,
message: T,
cx: Option<&mut Context<'_>>,
) -> Result<(), TrySendError<T>> {
if !self.closed {
if !self.full {
self.buf[self.write_pos] = MaybeUninit::new(message.into());
self.write_pos = (self.write_pos + 1) % self.buf.len();
if self.write_pos == self.read_pos {
self.full = true;
}
self.receiver_waker.wake();
Ok(())
} else {
cx.into_iter()
.for_each(|cx| self.set_senders_waker(&cx.waker()));
Err(TrySendError::Full(message))
}
} else {
Err(TrySendError::Closed(message))
}
}
fn close(&mut self) {
self.receiver_waker.wake();
self.closed = true;
}
fn is_closed(&mut self) -> bool {
self.is_closed_with_context(None)
}
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
if self.closed {
cx.into_iter()
.for_each(|cx| self.set_senders_waker(&cx.waker()));
true
} else {
false
}
}
fn register_receiver(&mut self) {
assert!(!self.receiver_registered);
self.receiver_registered = true;
}
fn deregister_receiver(&mut self) {
if self.receiver_registered {
self.closed = true;
self.senders_waker.wake();
}
self.receiver_registered = false;
}
fn register_sender(&mut self) {
self.senders_registered += 1;
}
fn deregister_sender(&mut self) {
assert!(self.senders_registered > 0);
self.senders_registered -= 1;
if self.senders_registered == 0 {
self.receiver_waker.wake();
self.closed = true;
}
}
fn set_receiver_waker(&mut self, receiver_waker: &Waker) {
self.receiver_waker.register(receiver_waker);
}
fn set_senders_waker(&mut self, senders_waker: &Waker) {
// Dispose of any existing sender causing them to be polled again.
// This could cause a spin given multiple concurrent senders, however given that
// most sends only block waiting for the receiver to become active, this should
// be a short-lived activity. The upside is a greatly simplified implementation
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
// pointers.
self.senders_waker.wake();
self.senders_waker.register(senders_waker);
}
} }
impl<T, const N: usize> Drop for ChannelState<T, N> { impl<T, const N: usize> Drop for ChannelState<T, N> {
@ -442,6 +560,13 @@ impl<T, const N: usize> Drop for ChannelState<T, N> {
/// ///
/// All data sent will become available in the same order as it was sent. /// All data sent will become available in the same order as it was sent.
pub struct Channel<M, T, const N: usize> pub struct Channel<M, T, const N: usize>
where
M: Mutex<Data = ()>,
{
sync_channel: UnsafeCell<ChannelCell<M, T, N>>,
}
struct ChannelCell<M, T, const N: usize>
where where
M: Mutex<Data = ()>, M: Mutex<Data = ()>,
{ {
@ -468,7 +593,10 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> {
pub const fn with_critical_sections() -> Self { pub const fn with_critical_sections() -> Self {
let mutex = CriticalSectionMutex::new(()); let mutex = CriticalSectionMutex::new(());
let state = ChannelState::new(); let state = ChannelState::new();
Channel { mutex, state } let sync_channel = ChannelCell { mutex, state };
Channel {
sync_channel: UnsafeCell::new(sync_channel),
}
} }
} }
@ -492,7 +620,10 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> {
pub const fn with_thread_mode_only() -> Self { pub const fn with_thread_mode_only() -> Self {
let mutex = ThreadModeMutex::new(()); let mutex = ThreadModeMutex::new(());
let state = ChannelState::new(); let state = ChannelState::new();
Channel { mutex, state } let sync_channel = ChannelCell { mutex, state };
Channel {
sync_channel: UnsafeCell::new(sync_channel),
}
} }
} }
@ -513,7 +644,10 @@ impl<T, const N: usize> Channel<WithNoThreads, T, N> {
pub const fn with_no_threads() -> Self { pub const fn with_no_threads() -> Self {
let mutex = NoopMutex::new(()); let mutex = NoopMutex::new(());
let state = ChannelState::new(); let state = ChannelState::new();
Channel { mutex, state } let sync_channel = ChannelCell { mutex, state };
Channel {
sync_channel: UnsafeCell::new(sync_channel),
}
} }
} }
@ -521,144 +655,13 @@ impl<M, T, const N: usize> Channel<M, T, N>
where where
M: Mutex<Data = ()>, M: Mutex<Data = ()>,
{ {
fn get(&self) -> &mut Self { fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, N>) -> R) -> R {
let const_ptr = self as *const Self; unsafe {
let mut_ptr = const_ptr as *mut Self; let sync_channel = &mut *(self.sync_channel.get());
unsafe { &mut *mut_ptr } let mutex = &mut sync_channel.mutex;
let mut state = &mut sync_channel.state;
mutex.lock(|_| f(&mut state))
} }
fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.try_recv_with_context(None)
}
fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
let mut state = &mut self.state;
self.mutex.lock(|_| {
if state.read_pos != state.write_pos || state.full {
if state.full {
state.full = false;
state.senders_waker.wake();
}
let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() };
state.read_pos = (state.read_pos + 1) % state.buf.len();
Ok(message)
} else if !state.closed {
cx.into_iter()
.for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
Err(TryRecvError::Empty)
} else {
Err(TryRecvError::Closed)
}
})
}
fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
self.try_send_with_context(message, None)
}
fn try_send_with_context(
&mut self,
message: T,
cx: Option<&mut Context<'_>>,
) -> Result<(), TrySendError<T>> {
let mut state = &mut self.state;
self.mutex.lock(|_| {
if !state.closed {
if !state.full {
state.buf[state.write_pos] = MaybeUninit::new(message.into());
state.write_pos = (state.write_pos + 1) % state.buf.len();
if state.write_pos == state.read_pos {
state.full = true;
}
state.receiver_waker.wake();
Ok(())
} else {
cx.into_iter()
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
Err(TrySendError::Full(message))
}
} else {
Err(TrySendError::Closed(message))
}
})
}
fn close(&mut self) {
let state = &mut self.state;
self.mutex.lock(|_| {
state.receiver_waker.wake();
state.closed = true;
});
}
fn is_closed(&mut self) -> bool {
self.is_closed_with_context(None)
}
fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
let mut state = &mut self.state;
self.mutex.lock(|_| {
if state.closed {
cx.into_iter()
.for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
true
} else {
false
}
})
}
fn register_receiver(&mut self) {
let state = &mut self.state;
self.mutex.lock(|_| {
assert!(!state.receiver_registered);
state.receiver_registered = true;
});
}
fn deregister_receiver(&mut self) {
let state = &mut self.state;
self.mutex.lock(|_| {
if state.receiver_registered {
state.closed = true;
state.senders_waker.wake();
}
state.receiver_registered = false;
})
}
fn register_sender(&mut self) {
let state = &mut self.state;
self.mutex.lock(|_| {
state.senders_registered += 1;
})
}
fn deregister_sender(&mut self) {
let state = &mut self.state;
self.mutex.lock(|_| {
assert!(state.senders_registered > 0);
state.senders_registered -= 1;
if state.senders_registered == 0 {
state.receiver_waker.wake();
state.closed = true;
}
})
}
fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
state.receiver_waker.register(receiver_waker);
}
fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
// Dispose of any existing sender causing them to be polled again.
// This could cause a spin given multiple concurrent senders, however given that
// most sends only block waiting for the receiver to become active, this should
// be a short-lived activity. The upside is a greatly simplified implementation
// that avoids the need for intrusive linked-lists and unsafe operations on pinned
// pointers.
state.senders_waker.wake();
state.senders_waker.register(senders_waker);
} }
} }
@ -672,15 +675,12 @@ mod tests {
use super::*; use super::*;
fn capacity<M, T, const N: usize>(c: &Channel<M, T, N>) -> usize fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize {
where if !c.full {
M: Mutex<Data = ()>, if c.write_pos > c.read_pos {
{ (c.buf.len() - c.write_pos) + c.read_pos
if !c.state.full {
if c.state.write_pos > c.state.read_pos {
(c.state.buf.len() - c.state.write_pos) + c.state.read_pos
} else { } else {
(c.state.buf.len() - c.state.read_pos) + c.state.write_pos (c.buf.len() - c.read_pos) + c.write_pos
} }
} else { } else {
0 0
@ -689,14 +689,14 @@ mod tests {
#[test] #[test]
fn sending_once() { fn sending_once() {
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); let mut c = ChannelState::<u32, 3>::new();
assert!(c.try_send(1).is_ok()); assert!(c.try_send(1).is_ok());
assert_eq!(capacity(&c), 2); assert_eq!(capacity(&c), 2);
} }
#[test] #[test]
fn sending_when_full() { fn sending_when_full() {
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); let mut c = ChannelState::<u32, 3>::new();
let _ = c.try_send(1); let _ = c.try_send(1);
let _ = c.try_send(1); let _ = c.try_send(1);
let _ = c.try_send(1); let _ = c.try_send(1);
@ -709,8 +709,8 @@ mod tests {
#[test] #[test]
fn sending_when_closed() { fn sending_when_closed() {
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); let mut c = ChannelState::<u32, 3>::new();
c.state.closed = true; c.closed = true;
match c.try_send(2) { match c.try_send(2) {
Err(TrySendError::Closed(2)) => assert!(true), Err(TrySendError::Closed(2)) => assert!(true),
_ => assert!(false), _ => assert!(false),
@ -719,7 +719,7 @@ mod tests {
#[test] #[test]
fn receiving_once_with_one_send() { fn receiving_once_with_one_send() {
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); let mut c = ChannelState::<u32, 3>::new();
assert!(c.try_send(1).is_ok()); assert!(c.try_send(1).is_ok());
assert_eq!(c.try_recv().unwrap(), 1); assert_eq!(c.try_recv().unwrap(), 1);
assert_eq!(capacity(&c), 3); assert_eq!(capacity(&c), 3);
@ -727,7 +727,7 @@ mod tests {
#[test] #[test]
fn receiving_when_empty() { fn receiving_when_empty() {
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); let mut c = ChannelState::<u32, 3>::new();
match c.try_recv() { match c.try_recv() {
Err(TryRecvError::Empty) => assert!(true), Err(TryRecvError::Empty) => assert!(true),
_ => assert!(false), _ => assert!(false),
@ -737,8 +737,8 @@ mod tests {
#[test] #[test]
fn receiving_when_closed() { fn receiving_when_closed() {
let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); let mut c = ChannelState::<u32, 3>::new();
c.state.closed = true; c.closed = true;
match c.try_recv() { match c.try_recv() {
Err(TryRecvError::Closed) => assert!(true), Err(TryRecvError::Closed) => assert!(true),
_ => assert!(false), _ => assert!(false),