From a2656f402b1c59461cec5f5dc685b2692119b996 Mon Sep 17 00:00:00 2001 From: Ruben De Smet Date: Fri, 11 Aug 2023 12:04:30 +0200 Subject: [PATCH] Move embassy-net-driver-channel::zerocopy_channel to embassy_sync::zero_copy_channel --- embassy-net-driver-channel/src/lib.rs | 237 ++------------------------ embassy-sync/src/lib.rs | 1 + embassy-sync/src/zero_copy_channel.rs | 209 +++++++++++++++++++++++ 3 files changed, 223 insertions(+), 224 deletions(-) create mode 100644 embassy-sync/src/zero_copy_channel.rs diff --git a/embassy-net-driver-channel/src/lib.rs b/embassy-net-driver-channel/src/lib.rs index f2aa6b25..e8cd66f8 100644 --- a/embassy-net-driver-channel/src/lib.rs +++ b/embassy-net-driver-channel/src/lib.rs @@ -14,6 +14,7 @@ use embassy_net_driver::{Capabilities, LinkState, Medium}; use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_sync::blocking_mutex::Mutex; use embassy_sync::waitqueue::WakerRegistration; +use embassy_sync::zero_copy_channel; pub struct State { rx: [PacketBuf; N_RX], @@ -34,8 +35,8 @@ impl State { - rx: zerocopy_channel::Channel<'d, NoopRawMutex, PacketBuf>, - tx: zerocopy_channel::Channel<'d, NoopRawMutex, PacketBuf>, + rx: zero_copy_channel::Channel<'d, NoopRawMutex, PacketBuf>, + tx: zero_copy_channel::Channel<'d, NoopRawMutex, PacketBuf>, shared: Mutex>, } @@ -47,8 +48,8 @@ struct Shared { } pub struct Runner<'d, const MTU: usize> { - tx_chan: zerocopy_channel::Receiver<'d, NoopRawMutex, PacketBuf>, - rx_chan: zerocopy_channel::Sender<'d, NoopRawMutex, PacketBuf>, + tx_chan: zero_copy_channel::Receiver<'d, NoopRawMutex, PacketBuf>, + rx_chan: zero_copy_channel::Sender<'d, NoopRawMutex, PacketBuf>, shared: &'d Mutex>, } @@ -58,11 +59,11 @@ pub struct StateRunner<'d> { } pub struct RxRunner<'d, const MTU: usize> { - rx_chan: zerocopy_channel::Sender<'d, NoopRawMutex, PacketBuf>, + rx_chan: zero_copy_channel::Sender<'d, NoopRawMutex, PacketBuf>, } pub struct TxRunner<'d, const MTU: usize> { - tx_chan: zerocopy_channel::Receiver<'d, NoopRawMutex, PacketBuf>, + tx_chan: zero_copy_channel::Receiver<'d, NoopRawMutex, PacketBuf>, } impl<'d, const MTU: usize> Runner<'d, MTU> { @@ -243,8 +244,8 @@ pub fn new<'d, const MTU: usize, const N_RX: usize, const N_TX: usize>( let state_uninit: *mut MaybeUninit> = (&mut state.inner as *mut MaybeUninit>).cast(); let state = unsafe { &mut *state_uninit }.write(StateInner { - rx: zerocopy_channel::Channel::new(&mut state.rx[..]), - tx: zerocopy_channel::Channel::new(&mut state.tx[..]), + rx: zero_copy_channel::Channel::new(&mut state.rx[..]), + tx: zero_copy_channel::Channel::new(&mut state.tx[..]), shared: Mutex::new(RefCell::new(Shared { link_state: LinkState::Down, hardware_address, @@ -282,8 +283,8 @@ impl PacketBuf { } pub struct Device<'d, const MTU: usize> { - rx: zerocopy_channel::Receiver<'d, NoopRawMutex, PacketBuf>, - tx: zerocopy_channel::Sender<'d, NoopRawMutex, PacketBuf>, + rx: zero_copy_channel::Receiver<'d, NoopRawMutex, PacketBuf>, + tx: zero_copy_channel::Sender<'d, NoopRawMutex, PacketBuf>, shared: &'d Mutex>, caps: Capabilities, } @@ -328,7 +329,7 @@ impl<'d, const MTU: usize> embassy_net_driver::Driver for Device<'d, MTU> { } pub struct RxToken<'a, const MTU: usize> { - rx: zerocopy_channel::Receiver<'a, NoopRawMutex, PacketBuf>, + rx: zero_copy_channel::Receiver<'a, NoopRawMutex, PacketBuf>, } impl<'a, const MTU: usize> embassy_net_driver::RxToken for RxToken<'a, MTU> { @@ -345,7 +346,7 @@ impl<'a, const MTU: usize> embassy_net_driver::RxToken for RxToken<'a, MTU> { } pub struct TxToken<'a, const MTU: usize> { - tx: zerocopy_channel::Sender<'a, NoopRawMutex, PacketBuf>, + tx: zero_copy_channel::Sender<'a, NoopRawMutex, PacketBuf>, } impl<'a, const MTU: usize> embassy_net_driver::TxToken for TxToken<'a, MTU> { @@ -361,215 +362,3 @@ impl<'a, const MTU: usize> embassy_net_driver::TxToken for TxToken<'a, MTU> { r } } - -mod zerocopy_channel { - use core::cell::RefCell; - use core::future::poll_fn; - use core::marker::PhantomData; - use core::task::{Context, Poll}; - - use embassy_sync::blocking_mutex::raw::RawMutex; - use embassy_sync::blocking_mutex::Mutex; - use embassy_sync::waitqueue::WakerRegistration; - - pub struct Channel<'a, M: RawMutex, T> { - buf: *mut T, - phantom: PhantomData<&'a mut T>, - state: Mutex>, - } - - impl<'a, M: RawMutex, T> Channel<'a, M, T> { - pub fn new(buf: &'a mut [T]) -> Self { - let len = buf.len(); - assert!(len != 0); - - Self { - buf: buf.as_mut_ptr(), - phantom: PhantomData, - state: Mutex::new(RefCell::new(State { - len, - front: 0, - back: 0, - full: false, - send_waker: WakerRegistration::new(), - recv_waker: WakerRegistration::new(), - })), - } - } - - pub fn split(&mut self) -> (Sender<'_, M, T>, Receiver<'_, M, T>) { - (Sender { channel: self }, Receiver { channel: self }) - } - } - - pub struct Sender<'a, M: RawMutex, T> { - channel: &'a Channel<'a, M, T>, - } - - impl<'a, M: RawMutex, T> Sender<'a, M, T> { - pub fn borrow(&mut self) -> Sender<'_, M, T> { - Sender { channel: self.channel } - } - - pub fn try_send(&mut self) -> Option<&mut T> { - self.channel.state.lock(|s| { - let s = &mut *s.borrow_mut(); - match s.push_index() { - Some(i) => Some(unsafe { &mut *self.channel.buf.add(i) }), - None => None, - } - }) - } - - pub fn poll_send(&mut self, cx: &mut Context) -> Poll<&mut T> { - self.channel.state.lock(|s| { - let s = &mut *s.borrow_mut(); - match s.push_index() { - Some(i) => Poll::Ready(unsafe { &mut *self.channel.buf.add(i) }), - None => { - s.recv_waker.register(cx.waker()); - Poll::Pending - } - } - }) - } - - pub async fn send(&mut self) -> &mut T { - let i = poll_fn(|cx| { - self.channel.state.lock(|s| { - let s = &mut *s.borrow_mut(); - match s.push_index() { - Some(i) => Poll::Ready(i), - None => { - s.recv_waker.register(cx.waker()); - Poll::Pending - } - } - }) - }) - .await; - unsafe { &mut *self.channel.buf.add(i) } - } - - pub fn send_done(&mut self) { - self.channel.state.lock(|s| s.borrow_mut().push_done()) - } - } - pub struct Receiver<'a, M: RawMutex, T> { - channel: &'a Channel<'a, M, T>, - } - - impl<'a, M: RawMutex, T> Receiver<'a, M, T> { - pub fn borrow(&mut self) -> Receiver<'_, M, T> { - Receiver { channel: self.channel } - } - - pub fn try_recv(&mut self) -> Option<&mut T> { - self.channel.state.lock(|s| { - let s = &mut *s.borrow_mut(); - match s.pop_index() { - Some(i) => Some(unsafe { &mut *self.channel.buf.add(i) }), - None => None, - } - }) - } - - pub fn poll_recv(&mut self, cx: &mut Context) -> Poll<&mut T> { - self.channel.state.lock(|s| { - let s = &mut *s.borrow_mut(); - match s.pop_index() { - Some(i) => Poll::Ready(unsafe { &mut *self.channel.buf.add(i) }), - None => { - s.send_waker.register(cx.waker()); - Poll::Pending - } - } - }) - } - - pub async fn recv(&mut self) -> &mut T { - let i = poll_fn(|cx| { - self.channel.state.lock(|s| { - let s = &mut *s.borrow_mut(); - match s.pop_index() { - Some(i) => Poll::Ready(i), - None => { - s.send_waker.register(cx.waker()); - Poll::Pending - } - } - }) - }) - .await; - unsafe { &mut *self.channel.buf.add(i) } - } - - pub fn recv_done(&mut self) { - self.channel.state.lock(|s| s.borrow_mut().pop_done()) - } - } - - struct State { - len: usize, - - /// Front index. Always 0..=(N-1) - front: usize, - /// Back index. Always 0..=(N-1). - back: usize, - - /// Used to distinguish "empty" and "full" cases when `front == back`. - /// May only be `true` if `front == back`, always `false` otherwise. - full: bool, - - send_waker: WakerRegistration, - recv_waker: WakerRegistration, - } - - impl State { - fn increment(&self, i: usize) -> usize { - if i + 1 == self.len { - 0 - } else { - i + 1 - } - } - - fn is_full(&self) -> bool { - self.full - } - - fn is_empty(&self) -> bool { - self.front == self.back && !self.full - } - - fn push_index(&mut self) -> Option { - match self.is_full() { - true => None, - false => Some(self.back), - } - } - - fn push_done(&mut self) { - assert!(!self.is_full()); - self.back = self.increment(self.back); - if self.back == self.front { - self.full = true; - } - self.send_waker.wake(); - } - - fn pop_index(&mut self) -> Option { - match self.is_empty() { - true => None, - false => Some(self.front), - } - } - - fn pop_done(&mut self) { - assert!(!self.is_empty()); - self.front = self.increment(self.front); - self.full = false; - self.recv_waker.wake(); - } - } -} diff --git a/embassy-sync/src/lib.rs b/embassy-sync/src/lib.rs index 53d95d08..48a7b13f 100644 --- a/embassy-sync/src/lib.rs +++ b/embassy-sync/src/lib.rs @@ -17,3 +17,4 @@ pub mod pipe; pub mod pubsub; pub mod signal; pub mod waitqueue; +pub mod zero_copy_channel; diff --git a/embassy-sync/src/zero_copy_channel.rs b/embassy-sync/src/zero_copy_channel.rs new file mode 100644 index 00000000..3701ccf1 --- /dev/null +++ b/embassy-sync/src/zero_copy_channel.rs @@ -0,0 +1,209 @@ +use core::cell::RefCell; +use core::future::poll_fn; +use core::marker::PhantomData; +use core::task::{Context, Poll}; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex; +use crate::waitqueue::WakerRegistration; + +pub struct Channel<'a, M: RawMutex, T> { + buf: *mut T, + phantom: PhantomData<&'a mut T>, + state: Mutex>, +} + +impl<'a, M: RawMutex, T> Channel<'a, M, T> { + pub fn new(buf: &'a mut [T]) -> Self { + let len = buf.len(); + assert!(len != 0); + + Self { + buf: buf.as_mut_ptr(), + phantom: PhantomData, + state: Mutex::new(RefCell::new(State { + len, + front: 0, + back: 0, + full: false, + send_waker: WakerRegistration::new(), + recv_waker: WakerRegistration::new(), + })), + } + } + + pub fn split(&mut self) -> (Sender<'_, M, T>, Receiver<'_, M, T>) { + (Sender { channel: self }, Receiver { channel: self }) + } +} + +pub struct Sender<'a, M: RawMutex, T> { + channel: &'a Channel<'a, M, T>, +} + +impl<'a, M: RawMutex, T> Sender<'a, M, T> { + pub fn borrow(&mut self) -> Sender<'_, M, T> { + Sender { channel: self.channel } + } + + pub fn try_send(&mut self) -> Option<&mut T> { + self.channel.state.lock(|s| { + let s = &mut *s.borrow_mut(); + match s.push_index() { + Some(i) => Some(unsafe { &mut *self.channel.buf.add(i) }), + None => None, + } + }) + } + + pub fn poll_send(&mut self, cx: &mut Context) -> Poll<&mut T> { + self.channel.state.lock(|s| { + let s = &mut *s.borrow_mut(); + match s.push_index() { + Some(i) => Poll::Ready(unsafe { &mut *self.channel.buf.add(i) }), + None => { + s.recv_waker.register(cx.waker()); + Poll::Pending + } + } + }) + } + + pub async fn send(&mut self) -> &mut T { + let i = poll_fn(|cx| { + self.channel.state.lock(|s| { + let s = &mut *s.borrow_mut(); + match s.push_index() { + Some(i) => Poll::Ready(i), + None => { + s.recv_waker.register(cx.waker()); + Poll::Pending + } + } + }) + }) + .await; + unsafe { &mut *self.channel.buf.add(i) } + } + + pub fn send_done(&mut self) { + self.channel.state.lock(|s| s.borrow_mut().push_done()) + } +} +pub struct Receiver<'a, M: RawMutex, T> { + channel: &'a Channel<'a, M, T>, +} + +impl<'a, M: RawMutex, T> Receiver<'a, M, T> { + pub fn borrow(&mut self) -> Receiver<'_, M, T> { + Receiver { channel: self.channel } + } + + pub fn try_recv(&mut self) -> Option<&mut T> { + self.channel.state.lock(|s| { + let s = &mut *s.borrow_mut(); + match s.pop_index() { + Some(i) => Some(unsafe { &mut *self.channel.buf.add(i) }), + None => None, + } + }) + } + + pub fn poll_recv(&mut self, cx: &mut Context) -> Poll<&mut T> { + self.channel.state.lock(|s| { + let s = &mut *s.borrow_mut(); + match s.pop_index() { + Some(i) => Poll::Ready(unsafe { &mut *self.channel.buf.add(i) }), + None => { + s.send_waker.register(cx.waker()); + Poll::Pending + } + } + }) + } + + pub async fn recv(&mut self) -> &mut T { + let i = poll_fn(|cx| { + self.channel.state.lock(|s| { + let s = &mut *s.borrow_mut(); + match s.pop_index() { + Some(i) => Poll::Ready(i), + None => { + s.send_waker.register(cx.waker()); + Poll::Pending + } + } + }) + }) + .await; + unsafe { &mut *self.channel.buf.add(i) } + } + + pub fn recv_done(&mut self) { + self.channel.state.lock(|s| s.borrow_mut().pop_done()) + } +} + +struct State { + len: usize, + + /// Front index. Always 0..=(N-1) + front: usize, + /// Back index. Always 0..=(N-1). + back: usize, + + /// Used to distinguish "empty" and "full" cases when `front == back`. + /// May only be `true` if `front == back`, always `false` otherwise. + full: bool, + + send_waker: WakerRegistration, + recv_waker: WakerRegistration, +} + +impl State { + fn increment(&self, i: usize) -> usize { + if i + 1 == self.len { + 0 + } else { + i + 1 + } + } + + fn is_full(&self) -> bool { + self.full + } + + fn is_empty(&self) -> bool { + self.front == self.back && !self.full + } + + fn push_index(&mut self) -> Option { + match self.is_full() { + true => None, + false => Some(self.back), + } + } + + fn push_done(&mut self) { + assert!(!self.is_full()); + self.back = self.increment(self.back); + if self.back == self.front { + self.full = true; + } + self.send_waker.wake(); + } + + fn pop_index(&mut self) -> Option { + match self.is_empty() { + true => None, + false => Some(self.front), + } + } + + fn pop_done(&mut self) { + assert!(!self.is_empty()); + self.front = self.increment(self.front); + self.full = false; + self.recv_waker.wake(); + } +}