diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs index 96a6dfe2..814e7ab6 100644 --- a/embassy-net/src/tcp.rs +++ b/embassy-net/src/tcp.rs @@ -339,15 +339,16 @@ pub mod client { use super::*; + /// TCP client capable of creating up to N multiple connections with tx and rx buffers according to TX_SZ and RX_SZ. pub struct TcpClient<'d, D: Device, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> { stack: &'d Stack, - tx: &'d BufferPool, - rx: &'d BufferPool, + state: &'d TcpClientState, } impl<'d, D: Device, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> { - pub fn new(stack: &'d Stack, tx: &'d BufferPool, rx: &'d BufferPool) -> Self { - Self { stack, tx, rx } + /// Create a new TcpClient + pub fn new(stack: &'d Stack, state: &'d TcpClientState) -> Self { + Self { stack, state } } } @@ -370,7 +371,7 @@ pub mod client { IpAddr::V6(_) => panic!("ipv6 support not enabled"), }; let remote_endpoint = (addr, remote.port()); - let mut socket = TcpConnection::new(&self.stack, self.tx, self.rx)?; + let mut socket = TcpConnection::new(&self.stack, self.state)?; socket .socket .connect(remote_endpoint) @@ -383,26 +384,20 @@ pub mod client { pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> { socket: TcpSocket<'d>, - tx: &'d BufferPool, - rx: &'d BufferPool, - txb: NonNull<[u8; TX_SZ]>, - rxb: NonNull<[u8; RX_SZ]>, + state: &'d TcpClientState, + bufs: NonNull<([u8; TX_SZ], [u8; RX_SZ])>, } impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> { fn new( stack: &'d Stack, - tx: &'d BufferPool, - rx: &'d BufferPool, + state: &'d TcpClientState, ) -> Result { - let mut txb = tx.alloc().ok_or(Error::ConnectionReset)?; - let mut rxb = rx.alloc().ok_or(Error::ConnectionReset)?; + let mut bufs = state.pool.alloc().ok_or(Error::ConnectionReset)?; Ok(Self { - socket: unsafe { TcpSocket::new(stack, rxb.as_mut(), txb.as_mut()) }, - tx, - rx, - txb, - rxb, + socket: unsafe { TcpSocket::new(stack, &mut bufs.as_mut().0, &mut bufs.as_mut().1) }, + state, + bufs, }) } } @@ -411,8 +406,7 @@ pub mod client { fn drop(&mut self) { unsafe { self.socket.close(); - self.rx.free(self.rxb); - self.tx.free(self.txb); + self.state.pool.free(self.bufs); } } } @@ -455,9 +449,22 @@ pub mod client { } } - pub type BufferPool = Pool<[u8; BUFSZ], N>; + /// State for TcpClient + pub struct TcpClientState { + pool: Pool<([u8; TX_SZ], [u8; RX_SZ]), N>, + } - pub struct Pool { + impl TcpClientState { + pub const fn new() -> Self { + Self { + pool: Pool::new() + } + } + } + + unsafe impl Sync for TcpClientState {} + + struct Pool { used: [AtomicBool; N], data: [UnsafeCell>; N], } @@ -466,7 +473,7 @@ pub mod client { const VALUE: AtomicBool = AtomicBool::new(false); const UNINIT: UnsafeCell> = UnsafeCell::new(MaybeUninit::uninit()); - pub const fn new() -> Self { + const fn new() -> Self { Self { used: [Self::VALUE; N], data: [Self::UNINIT; N], diff --git a/examples/stm32h7/Cargo.toml b/examples/stm32h7/Cargo.toml index 07b7e493..89604675 100644 --- a/examples/stm32h7/Cargo.toml +++ b/examples/stm32h7/Cargo.toml @@ -18,6 +18,7 @@ cortex-m-rt = "0.7.0" embedded-hal = "0.2.6" embedded-hal-1 = { package = "embedded-hal", version = "1.0.0-alpha.8" } embedded-hal-async = { version = "0.1.0-alpha.1" } +embedded-nal-async = "0.2.0" panic-probe = { version = "0.3", features = ["print-defmt"] } futures = { version = "0.3.17", default-features = false, features = ["async-await"] } heapless = { version = "0.7.5", default-features = false }