diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml index 8484aebc..b58b52f1 100644 --- a/embassy-net/Cargo.toml +++ b/embassy-net/Cargo.toml @@ -31,15 +31,13 @@ pool-32 = [] pool-64 = [] pool-128 = [] -nightly = ["embedded-io/async"] - [dependencies] defmt = { version = "0.3", optional = true } log = { version = "0.4.14", optional = true } embassy = { version = "0.1.0", path = "../embassy" } -embedded-io = "0.3.0" +embedded-io = { version = "0.3.0", features = [ "async" ] } managed = { version = "0.8.0", default-features = false, features = [ "map" ] } heapless = { version = "0.7.5", default-features = false } diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index ded84190..18dc1ef6 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs @@ -1,9 +1,6 @@ #![cfg_attr(not(feature = "std"), no_std)] #![allow(clippy::new_without_default)] -#![cfg_attr( - feature = "nightly", - feature(generic_associated_types, type_alias_impl_trait) -)] +#![feature(generic_associated_types, type_alias_impl_trait)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs new file mode 100644 index 00000000..c18651b9 --- /dev/null +++ b/embassy-net/src/tcp.rs @@ -0,0 +1,353 @@ +use core::future::Future; +use core::marker::PhantomData; +use core::mem; +use core::task::Poll; +use futures::future::poll_fn; +use smoltcp::iface::{Context as SmolContext, SocketHandle}; +use smoltcp::socket::TcpSocket as SyncTcpSocket; +use smoltcp::socket::{TcpSocketBuffer, TcpState}; +use smoltcp::time::Duration; +use smoltcp::wire::IpEndpoint; + +use super::stack::Stack; + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + ConnectionReset, +} + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ConnectError { + /// The socket is already connected or listening. + InvalidState, + /// The remote host rejected the connection with a RST packet. + ConnectionReset, + /// Connect timed out. + TimedOut, + /// No route to host. + NoRoute, +} + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum AcceptError { + /// The socket is already connected or listening. + InvalidState, + /// Invalid listen port + InvalidPort, + /// The remote host rejected the connection with a RST packet. + ConnectionReset, +} + +pub struct TcpSocket<'a> { + handle: SocketHandle, + ghost: PhantomData<&'a mut [u8]>, +} + +impl<'a> Unpin for TcpSocket<'a> {} + +pub struct TcpReader<'a> { + handle: SocketHandle, + ghost: PhantomData<&'a mut [u8]>, +} + +impl<'a> Unpin for TcpReader<'a> {} + +pub struct TcpWriter<'a> { + handle: SocketHandle, + ghost: PhantomData<&'a mut [u8]>, +} + +impl<'a> Unpin for TcpWriter<'a> {} + +impl<'a> TcpSocket<'a> { + pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { + let handle = Stack::with(|stack| { + let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; + let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; + stack.iface.add_socket(SyncTcpSocket::new( + TcpSocketBuffer::new(rx_buffer), + TcpSocketBuffer::new(tx_buffer), + )) + }); + + Self { + handle, + ghost: PhantomData, + } + } + + pub fn split(&mut self) -> (TcpReader<'_>, TcpWriter<'_>) { + ( + TcpReader { + handle: self.handle, + ghost: PhantomData, + }, + TcpWriter { + handle: self.handle, + ghost: PhantomData, + }, + ) + } + + pub async fn connect(&mut self, remote_endpoint: T) -> Result<(), ConnectError> + where + T: Into, + { + let local_port = Stack::with(|stack| stack.get_local_port()); + match with_socket(self.handle, |s, cx| { + s.connect(cx, remote_endpoint, local_port) + }) { + Ok(()) => {} + Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState), + Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + } + + futures::future::poll_fn(|cx| { + with_socket(self.handle, |s, _| match s.state() { + TcpState::Closed | TcpState::TimeWait => { + Poll::Ready(Err(ConnectError::ConnectionReset)) + } + TcpState::Listen => unreachable!(), + TcpState::SynSent | TcpState::SynReceived => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + _ => Poll::Ready(Ok(())), + }) + }) + .await + } + + pub async fn accept(&mut self, local_endpoint: T) -> Result<(), AcceptError> + where + T: Into, + { + match with_socket(self.handle, |s, _| s.listen(local_endpoint)) { + Ok(()) => {} + Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState), + Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + } + + futures::future::poll_fn(|cx| { + with_socket(self.handle, |s, _| match s.state() { + TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + _ => Poll::Ready(Ok(())), + }) + }) + .await + } + + pub fn set_timeout(&mut self, duration: Option) { + with_socket(self.handle, |s, _| s.set_timeout(duration)) + } + + pub fn set_keep_alive(&mut self, interval: Option) { + with_socket(self.handle, |s, _| s.set_keep_alive(interval)) + } + + pub fn set_hop_limit(&mut self, hop_limit: Option) { + with_socket(self.handle, |s, _| s.set_hop_limit(hop_limit)) + } + + pub fn local_endpoint(&self) -> IpEndpoint { + with_socket(self.handle, |s, _| s.local_endpoint()) + } + + pub fn remote_endpoint(&self) -> IpEndpoint { + with_socket(self.handle, |s, _| s.remote_endpoint()) + } + + pub fn state(&self) -> TcpState { + with_socket(self.handle, |s, _| s.state()) + } + + pub fn close(&mut self) { + with_socket(self.handle, |s, _| s.close()) + } + + pub fn abort(&mut self) { + with_socket(self.handle, |s, _| s.abort()) + } + + pub fn may_send(&self) -> bool { + with_socket(self.handle, |s, _| s.may_send()) + } + + pub fn may_recv(&self) -> bool { + with_socket(self.handle, |s, _| s.may_recv()) + } +} + +fn with_socket( + handle: SocketHandle, + f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R, +) -> R { + Stack::with(|stack| { + let res = { + let (s, cx) = stack.iface.get_socket_and_context::(handle); + f(s, cx) + }; + stack.wake(); + res + }) +} + +impl<'a> Drop for TcpSocket<'a> { + fn drop(&mut self) { + Stack::with(|stack| { + stack.iface.remove_socket(self.handle); + }) + } +} + +impl embedded_io::Error for Error { + fn kind(&self) -> embedded_io::ErrorKind { + embedded_io::ErrorKind::Other + } +} + +impl<'d> embedded_io::Io for TcpSocket<'d> { + type Error = Error; +} + +impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { + type ReadFuture<'a> = impl Future> + where + Self: 'a; + + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + poll_fn(move |cx| { + // CAUTION: smoltcp semantics around EOF are different to what you'd expect + // from posix-like IO, so we have to tweak things here. + with_socket(self.handle, |s, _| match s.recv_slice(buf) { + // No data ready + Ok(0) => { + s.register_recv_waker(cx.waker()); + Poll::Pending + } + // Data ready! + Ok(n) => Poll::Ready(Ok(n)), + // EOF + Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)), + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + }) + }) + } +} + +impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { + type WriteFuture<'a> = impl Future> + where + Self: 'a; + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { + poll_fn(move |cx| { + with_socket(self.handle, |s, _| match s.send_slice(buf) { + // Not ready to send (no space in the tx buffer) + Ok(0) => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + // Some data sent + Ok(n) => Poll::Ready(Ok(n)), + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + }) + }) + } + + type FlushFuture<'a> = impl Future> + where + Self: 'a; + + fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { + poll_fn(move |_| { + Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? + }) + } +} + +impl<'d> embedded_io::Io for TcpReader<'d> { + type Error = Error; +} + +impl<'d> embedded_io::asynch::Read for TcpReader<'d> { + type ReadFuture<'a> = impl Future> + where + Self: 'a; + + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + poll_fn(move |cx| { + // CAUTION: smoltcp semantics around EOF are different to what you'd expect + // from posix-like IO, so we have to tweak things here. + with_socket(self.handle, |s, _| match s.recv_slice(buf) { + // No data ready + Ok(0) => { + s.register_recv_waker(cx.waker()); + Poll::Pending + } + // Data ready! + Ok(n) => Poll::Ready(Ok(n)), + // EOF + Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)), + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + }) + }) + } +} + +impl<'d> embedded_io::Io for TcpWriter<'d> { + type Error = Error; +} + +impl<'d> embedded_io::asynch::Write for TcpWriter<'d> { + type WriteFuture<'a> = impl Future> + where + Self: 'a; + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { + poll_fn(move |cx| { + with_socket(self.handle, |s, _| match s.send_slice(buf) { + // Not ready to send (no space in the tx buffer) + Ok(0) => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + // Some data sent + Ok(n) => Poll::Ready(Ok(n)), + // Connection reset. TODO: this can also be timeouts etc, investigate. + Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), + // smoltcp returns no errors other than the above. + Err(_) => unreachable!(), + }) + }) + } + + type FlushFuture<'a> = impl Future> + where + Self: 'a; + + fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { + poll_fn(move |_| { + Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? + }) + } +} diff --git a/embassy-net/src/tcp/io_impl.rs b/embassy-net/src/tcp/io_impl.rs deleted file mode 100644 index 15573349..00000000 --- a/embassy-net/src/tcp/io_impl.rs +++ /dev/null @@ -1,67 +0,0 @@ -use core::future::Future; -use core::task::Poll; -use futures::future::poll_fn; - -use super::{Error, TcpSocket}; - -impl<'d> embedded_io::asynch::Read for TcpSocket<'d> { - type ReadFuture<'a> = impl Future> - where - Self: 'a; - - fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { - poll_fn(move |cx| { - // CAUTION: smoltcp semantics around EOF are different to what you'd expect - // from posix-like IO, so we have to tweak things here. - self.with(|s, _| match s.recv_slice(buf) { - // No data ready - Ok(0) => { - s.register_recv_waker(cx.waker()); - Poll::Pending - } - // Data ready! - Ok(n) => Poll::Ready(Ok(n)), - // EOF - Err(smoltcp::Error::Finished) => Poll::Ready(Ok(0)), - // Connection reset. TODO: this can also be timeouts etc, investigate. - Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), - }) - }) - } -} - -impl<'d> embedded_io::asynch::Write for TcpSocket<'d> { - type WriteFuture<'a> = impl Future> - where - Self: 'a; - - fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { - poll_fn(move |cx| { - self.with(|s, _| match s.send_slice(buf) { - // Not ready to send (no space in the tx buffer) - Ok(0) => { - s.register_send_waker(cx.waker()); - Poll::Pending - } - // Some data sent - Ok(n) => Poll::Ready(Ok(n)), - // Connection reset. TODO: this can also be timeouts etc, investigate. - Err(smoltcp::Error::Illegal) => Poll::Ready(Err(Error::ConnectionReset)), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), - }) - }) - } - - type FlushFuture<'a> = impl Future> - where - Self: 'a; - - fn flush<'a>(&'a mut self) -> Self::FlushFuture<'a> { - poll_fn(move |_| { - Poll::Ready(Ok(())) // TODO: Is there a better implementation for this? - }) - } -} diff --git a/embassy-net/src/tcp/mod.rs b/embassy-net/src/tcp/mod.rs deleted file mode 100644 index 3bfd4c7b..00000000 --- a/embassy-net/src/tcp/mod.rs +++ /dev/null @@ -1,192 +0,0 @@ -use core::marker::PhantomData; -use core::mem; -use core::task::Poll; -use smoltcp::iface::{Context as SmolContext, SocketHandle}; -use smoltcp::socket::TcpSocket as SyncTcpSocket; -use smoltcp::socket::{TcpSocketBuffer, TcpState}; -use smoltcp::time::Duration; -use smoltcp::wire::IpEndpoint; - -#[cfg(feature = "nightly")] -mod io_impl; - -use super::stack::Stack; - -#[derive(PartialEq, Eq, Clone, Copy, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum Error { - ConnectionReset, -} - -#[derive(PartialEq, Eq, Clone, Copy, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum ConnectError { - /// The socket is already connected or listening. - InvalidState, - /// The remote host rejected the connection with a RST packet. - ConnectionReset, - /// Connect timed out. - TimedOut, - /// No route to host. - NoRoute, -} - -#[derive(PartialEq, Eq, Clone, Copy, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum AcceptError { - /// The socket is already connected or listening. - InvalidState, - /// Invalid listen port - InvalidPort, - /// The remote host rejected the connection with a RST packet. - ConnectionReset, -} - -pub struct TcpSocket<'a> { - handle: SocketHandle, - ghost: PhantomData<&'a mut [u8]>, -} - -impl<'a> Unpin for TcpSocket<'a> {} - -impl<'a> TcpSocket<'a> { - pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self { - let handle = Stack::with(|stack| { - let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) }; - let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) }; - stack.iface.add_socket(SyncTcpSocket::new( - TcpSocketBuffer::new(rx_buffer), - TcpSocketBuffer::new(tx_buffer), - )) - }); - - Self { - handle, - ghost: PhantomData, - } - } - - pub async fn connect(&mut self, remote_endpoint: T) -> Result<(), ConnectError> - where - T: Into, - { - let local_port = Stack::with(|stack| stack.get_local_port()); - match self.with(|s, cx| s.connect(cx, remote_endpoint, local_port)) { - Ok(()) => {} - Err(smoltcp::Error::Illegal) => return Err(ConnectError::InvalidState), - Err(smoltcp::Error::Unaddressable) => return Err(ConnectError::NoRoute), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), - } - - futures::future::poll_fn(|cx| { - self.with(|s, _| match s.state() { - TcpState::Closed | TcpState::TimeWait => { - Poll::Ready(Err(ConnectError::ConnectionReset)) - } - TcpState::Listen => unreachable!(), - TcpState::SynSent | TcpState::SynReceived => { - s.register_send_waker(cx.waker()); - Poll::Pending - } - _ => Poll::Ready(Ok(())), - }) - }) - .await - } - - pub async fn accept(&mut self, local_endpoint: T) -> Result<(), AcceptError> - where - T: Into, - { - match self.with(|s, _| s.listen(local_endpoint)) { - Ok(()) => {} - Err(smoltcp::Error::Illegal) => return Err(AcceptError::InvalidState), - Err(smoltcp::Error::Unaddressable) => return Err(AcceptError::InvalidPort), - // smoltcp returns no errors other than the above. - Err(_) => unreachable!(), - } - - futures::future::poll_fn(|cx| { - self.with(|s, _| match s.state() { - TcpState::Listen | TcpState::SynSent | TcpState::SynReceived => { - s.register_send_waker(cx.waker()); - Poll::Pending - } - _ => Poll::Ready(Ok(())), - }) - }) - .await - } - - pub fn set_timeout(&mut self, duration: Option) { - self.with(|s, _| s.set_timeout(duration)) - } - - pub fn set_keep_alive(&mut self, interval: Option) { - self.with(|s, _| s.set_keep_alive(interval)) - } - - pub fn set_hop_limit(&mut self, hop_limit: Option) { - self.with(|s, _| s.set_hop_limit(hop_limit)) - } - - pub fn local_endpoint(&self) -> IpEndpoint { - self.with(|s, _| s.local_endpoint()) - } - - pub fn remote_endpoint(&self) -> IpEndpoint { - self.with(|s, _| s.remote_endpoint()) - } - - pub fn state(&self) -> TcpState { - self.with(|s, _| s.state()) - } - - pub fn close(&mut self) { - self.with(|s, _| s.close()) - } - - pub fn abort(&mut self) { - self.with(|s, _| s.abort()) - } - - pub fn may_send(&self) -> bool { - self.with(|s, _| s.may_send()) - } - - pub fn may_recv(&self) -> bool { - self.with(|s, _| s.may_recv()) - } - - fn with(&self, f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R) -> R { - Stack::with(|stack| { - let res = { - let (s, cx) = stack - .iface - .get_socket_and_context::(self.handle); - f(s, cx) - }; - stack.wake(); - res - }) - } -} - -impl<'a> Drop for TcpSocket<'a> { - fn drop(&mut self) { - Stack::with(|stack| { - stack.iface.remove_socket(self.handle); - }) - } -} - -impl embedded_io::Error for Error { - fn kind(&self) -> embedded_io::ErrorKind { - embedded_io::ErrorKind::Other - } -} - -impl<'d> embedded_io::Io for TcpSocket<'d> { - type Error = Error; -} diff --git a/embassy-stm32/Cargo.toml b/embassy-stm32/Cargo.toml index ce36c7da..e310d25f 100644 --- a/embassy-stm32/Cargo.toml +++ b/embassy-stm32/Cargo.toml @@ -90,7 +90,7 @@ time-driver-tim12 = ["_time-driver"] time-driver-tim15 = ["_time-driver"] # Enable nightly-only features -nightly = ["embassy/nightly", "embassy-net?/nightly", "embedded-hal-1", "embedded-hal-async", "embedded-storage-async", "dep:embedded-io"] +nightly = ["embassy/nightly", "embedded-hal-1", "embedded-hal-async", "embedded-storage-async", "dep:embedded-io"] # Reexport stm32-metapac at `embassy_stm32::pac`. # This is unstable because semver-minor (non-breaking) releases of embassy-stm32 may major-bump (breaking) the stm32-metapac version. diff --git a/examples/nrf/Cargo.toml b/examples/nrf/Cargo.toml index d96eedf9..124725f9 100644 --- a/examples/nrf/Cargo.toml +++ b/examples/nrf/Cargo.toml @@ -6,12 +6,12 @@ version = "0.1.0" [features] default = ["nightly"] -nightly = ["embassy-nrf/nightly", "embassy-nrf/unstable-traits", "embassy-usb", "embassy-usb-serial", "embassy-usb-hid", "embassy-usb-ncm", "embedded-io/async", "embassy-net/nightly"] +nightly = ["embassy-nrf/nightly", "embassy-nrf/unstable-traits", "embassy-usb", "embassy-usb-serial", "embassy-usb-hid", "embassy-usb-ncm", "embedded-io/async", "embassy-net"] [dependencies] embassy = { version = "0.1.0", path = "../../embassy", features = ["defmt", "defmt-timestamp-uptime"] } embassy-nrf = { version = "0.1.0", path = "../../embassy-nrf", features = ["defmt", "nrf52840", "time-driver-rtc1", "gpiote", "unstable-pac"] } -embassy-net = { version = "0.1.0", path = "../../embassy-net", features = ["defmt", "tcp", "dhcpv4", "medium-ethernet", "pool-16"] } +embassy-net = { version = "0.1.0", path = "../../embassy-net", features = ["defmt", "tcp", "dhcpv4", "medium-ethernet", "pool-16"], optional = true } embassy-usb = { version = "0.1.0", path = "../../embassy-usb", features = ["defmt"], optional = true } embassy-usb-serial = { version = "0.1.0", path = "../../embassy-usb-serial", features = ["defmt"], optional = true } embassy-usb-hid = { version = "0.1.0", path = "../../embassy-usb-hid", features = ["defmt"], optional = true } diff --git a/examples/std/Cargo.toml b/examples/std/Cargo.toml index 863760a4..7e1c2e4b 100644 --- a/examples/std/Cargo.toml +++ b/examples/std/Cargo.toml @@ -6,7 +6,7 @@ version = "0.1.0" [dependencies] embassy = { version = "0.1.0", path = "../../embassy", features = ["log", "std", "time", "nightly"] } -embassy-net = { version = "0.1.0", path = "../../embassy-net", features=["nightly", "std", "log", "medium-ethernet", "tcp", "dhcpv4", "pool-16"] } +embassy-net = { version = "0.1.0", path = "../../embassy-net", features=[ "std", "log", "medium-ethernet", "tcp", "dhcpv4", "pool-16"] } embedded-io = { version = "0.3.0", features = ["async", "std", "futures"] } async-io = "1.6.0"