diff --git a/embassy-net/src/tcp/io_impl.rs b/embassy-net/src/tcp/io_impl.rs index 15573349..b30c920b 100644 --- a/embassy-net/src/tcp/io_impl.rs +++ b/embassy-net/src/tcp/io_impl.rs @@ -2,7 +2,7 @@ use core::future::Future; use core::task::Poll; use futures::future::poll_fn; -use super::{Error, TcpSocket}; +use super::{with_socket, Error, TcpReader, TcpSocket, TcpWriter}; impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { type ReadFuture<'a> = impl Future> @@ -13,7 +13,7 @@ impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { poll_fn(move |cx| { // CAUTION: smoltcp semantics around EOF are different to what you'd expect // from posix-like IO, so we have to tweak things here. - self.with(|s, _| match s.recv_slice(buf) { + with_socket(self.handle, |s, _| match s.recv_slice(buf) { // No data ready Ok(0) => { s.register_recv_waker(cx.waker()); @@ -39,7 +39,69 @@ impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { poll_fn(move |cx| { - self.with(|s, _| match s.send_slice(buf) { + with_socket(self.handle, |s, _| match s.send_slice(buf) { + // Not ready to send (no space in the tx buffer) + Ok(0) => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + // Some data sent + Ok(n) => Poll::Ready(Ok(n)), + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + }) + }) + } + + type FlushFuture<'a> = impl Future> + where + Self: 'a; + + fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { + poll_fn(move |_| { + Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? + }) + } +} + +impl<'d> embedded_io::asynch::Read for TcpReader<'d> { + type ReadFuture<'a> = impl Future> + where + Self: 'a; + + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + poll_fn(move |cx| { + // CAUTION: smoltcp semantics around EOF are different to what you'd expect + // from posix-like IO, so we have to tweak things here. + with_socket(self.handle, |s, _| match s.recv_slice(buf) { + // No data ready + Ok(0) => { + s.register_recv_waker(cx.waker()); + Poll::Pending + } + // Data ready! + Ok(n) => Poll::Ready(Ok(n)), + // EOF + Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)), + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + }) + }) + } +} + +impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { + type WriteFuture<'a> = impl Future> + where + Self: 'a; + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { + poll_fn(move |cx| { + with_socket(self.handle, |s, _| match s.send_slice(buf) { // Not ready to send (no space in the tx buffer) Ok(0) => { s.register_send_waker(cx.waker()); diff --git a/embassy-net/src/tcp/mod.rs b/embassy-net/src/tcp/mod.rs index 3bfd4c7b..425e6acb 100644 --- a/embassy-net/src/tcp/mod.rs +++ b/embassy-net/src/tcp/mod.rs @@ -49,6 +49,20 @@ pub struct TcpSocket<'a> { impl<'a> Unpin for TcpSocket<'a> {} +pub struct TcpReader<'a> { + handle: SocketHandle, + ghost: PhantomData<&'a mut [u8]>, +} + +impl<'a> Unpin for TcpReader<'a> {} + +pub struct TcpWriter<'a> { + handle: SocketHandle, + ghost: PhantomData<&'a mut [u8]>, +} + +impl<'a> Unpin for TcpWriter<'a> {} + impl<'a> TcpSocket<'a> { pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { let handle = Stack::with(|stack| { @@ -66,12 +80,27 @@ impl<'a> TcpSocket<'a> { } } + pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { + ( + TcpReader { + handle: self.handle, + ghost: PhantomData, + }, + TcpWriter { + handle: self.handle, + ghost: PhantomData, + }, + ) + } + pub async fn connect(&mut self, remote_endpoint: T) -> Result<(), ConnectError> where T: Into, { let local_port = Stack::with(|stack| stack.get_local_port()); - match self.with(|s, cx| s.connect(cx, remote_endpoint, local_port)) { + match with_socket(self.handle, |s, cx| { + s.connect(cx, remote_endpoint, local_port) + }) { Ok(()) => {} Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState), Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute), @@ -80,7 +109,7 @@ impl<'a> TcpSocket<'a> { } futures::future::poll_fn(|cx| { - self.with(|s, _| match s.state() { + with_socket(self.handle, |s, _| match s.state() { TcpState::Closed | TcpState::TimeWait => { Poll::Ready(Err(ConnectError::ConnectionReset)) } @@ -99,7 +128,7 @@ impl<'a> TcpSocket<'a> { where T: Into, { - match self.with(|s, _| s.listen(local_endpoint)) { + match with_socket(self.handle, |s, _| s.listen(local_endpoint)) { Ok(()) => {} Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState), Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort), @@ -108,7 +137,7 @@ impl<'a> TcpSocket<'a> { } futures::future::poll_fn(|cx| { - self.with(|s, _| match s.state() { + with_socket(self.handle, |s, _| match s.state() { TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => { s.register_send_waker(cx.waker()); Poll::Pending @@ -120,57 +149,58 @@ impl<'a> TcpSocket<'a> { } pub fn set_timeout(&mut self, duration: Option) { - self.with(|s, _| s.set_timeout(duration)) + with_socket(self.handle, |s, _| s.set_timeout(duration)) } pub fn set_keep_alive(&mut self, interval: Option) { - self.with(|s, _| s.set_keep_alive(interval)) + with_socket(self.handle, |s, _| s.set_keep_alive(interval)) } pub fn set_hop_limit(&mut self, hop_limit: Option) { - self.with(|s, _| s.set_hop_limit(hop_limit)) + with_socket(self.handle, |s, _| s.set_hop_limit(hop_limit)) } pub fn local_endpoint(&self) -> IpEndpoint { - self.with(|s, _| s.local_endpoint()) + with_socket(self.handle, |s, _| s.local_endpoint()) } pub fn remote_endpoint(&self) -> IpEndpoint { - self.with(|s, _| s.remote_endpoint()) + with_socket(self.handle, |s, _| s.remote_endpoint()) } pub fn state(&self) -> TcpState { - self.with(|s, _| s.state()) + with_socket(self.handle, |s, _| s.state()) } pub fn close(&mut self) { - self.with(|s, _| s.close()) + with_socket(self.handle, |s, _| s.close()) } pub fn abort(&mut self) { - self.with(|s, _| s.abort()) + with_socket(self.handle, |s, _| s.abort()) } pub fn may_send(&self) -> bool { - self.with(|s, _| s.may_send()) + with_socket(self.handle, |s, _| s.may_send()) } pub fn may_recv(&self) -> bool { - self.with(|s, _| s.may_recv()) + with_socket(self.handle, |s, _| s.may_recv()) } +} - fn with(&self, f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R) -> R { - Stack::with(|stack| { - let res = { - let (s, cx) = stack - .iface - .get_socket_and_context::(self.handle); - f(s, cx) - }; - stack.wake(); - res - }) - } +fn with_socket( + handle: SocketHandle, + f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R, +) -> R { + Stack::with(|stack| { + let res = { + let (s, cx) = stack.iface.get_socket_and_context::(handle); + f(s, cx) + }; + stack.wake(); + res + }) } impl<'a> Drop for TcpSocket<'a> { @@ -190,3 +220,11 @@ impl embedded_io::Error for Error { impl<'d> embedded_io::Io for TcpSocket<'d> { type Error = Error; } + +impl<'d> embedded_io::Io for TcpReader<'d> { + type Error = Error; +} + +impl<'d> embedded_io::Io for TcpWriter<'d> { + type Error = Error; +}