diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs index c903fb24..b4ce4094 100644 --- a/embassy-net/src/tcp.rs +++ b/embassy-net/src/tcp.rs @@ -82,6 +82,17 @@ impl<'a> TcpReader<'a> { pub async fn read(&mut self, buf: &mut [u8]) -> Result { self.io.read(buf).await } + + /// Call `f` with the largest contiguous slice of octets in the receive buffer, + /// and dequeue the amount of elements returned by `f`. + /// + /// If no data is available, it waits until there is at least one byte available. + pub async fn read_with(&mut self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + self.io.read_with(f).await + } } impl<'a> TcpWriter<'a> { @@ -100,6 +111,17 @@ impl<'a> TcpWriter<'a> { pub async fn flush(&mut self) -> Result<(), Error> { self.io.flush().await } + + /// Call `f` with the largest contiguous slice of octets in the transmit buffer, + /// and enqueue the amount of elements returned by `f`. + /// + /// If the socket is not ready to accept data, it waits until it is. + pub async fn write_with(&mut self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + self.io.write_with(f).await + } } impl<'a> TcpSocket<'a> { @@ -121,6 +143,28 @@ impl<'a> TcpSocket<'a> { } } + /// Call `f` with the largest contiguous slice of octets in the transmit buffer, + /// and enqueue the amount of elements returned by `f`. + /// + /// If the socket is not ready to accept data, it waits until it is. + pub async fn write_with(&mut self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + self.io.write_with(f).await + } + + /// Call `f` with the largest contiguous slice of octets in the receive buffer, + /// and dequeue the amount of elements returned by `f`. + /// + /// If no data is available, it waits until there is at least one byte available. + pub async fn read_with(&mut self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + self.io.read_with(f).await + } + /// Split the socket into reader and a writer halves. pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { (TcpReader { io: self.io }, TcpWriter { io: self.io }) @@ -359,6 +403,64 @@ impl<'d> TcpIo<'d> { .await } + async fn write_with(&mut self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + let mut f = Some(f); + poll_fn(move |cx| { + self.with_mut(|s, _| { + if !s.can_send() { + if s.may_send() { + // socket buffer is full wait until it has atleast one byte free + s.register_send_waker(cx.waker()); + Poll::Pending + } else { + // if we can't transmit because the transmit half of the duplex connection is closed then return an error + Poll::Ready(Err(Error::ConnectionReset)) + } + } else { + Poll::Ready(match s.send(f.take().unwrap()) { + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(tcp::SendError::InvalidState) => Err(Error::ConnectionReset), + Ok(r) => Ok(r), + }) + } + }) + }) + .await + } + + async fn read_with(&mut self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> (usize, R), + { + let mut f = Some(f); + poll_fn(move |cx| { + self.with_mut(|s, _| { + if !s.can_recv() { + if s.may_recv() { + // socket buffer is empty wait until it has atleast one byte has arrived + s.register_recv_waker(cx.waker()); + Poll::Pending + } else { + // if we can't receive because the recieve half of the duplex connection is closed then return an error + Poll::Ready(Err(Error::ConnectionReset)) + } + } else { + Poll::Ready(match s.recv(f.take().unwrap()) { + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(tcp::RecvError::Finished) | Err(tcp::RecvError::InvalidState) => { + Err(Error::ConnectionReset) + } + Ok(r) => Ok(r), + }) + } + }) + }) + .await + } + async fn flush(&mut self) -> Result<(), Error> { poll_fn(move |cx| { self.with_mut(|s, _| {