diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml index 4ec340b7..2d8c5c7b 100644 --- a/embassy-net/Cargo.toml +++ b/embassy-net/Cargo.toml @@ -22,7 +22,7 @@ unstable-traits = [] udp = ["smoltcp/socket-udp"] tcp = ["smoltcp/socket-tcp"] -dns = ["smoltcp/socket-dns"] +dns = ["smoltcp/socket-dns", "smoltcp/proto-dns"] dhcpv4 = ["medium-ethernet", "smoltcp/socket-dhcpv4"] proto-ipv6 = ["smoltcp/proto-ipv6"] medium-ethernet = ["smoltcp/medium-ethernet"] @@ -40,6 +40,7 @@ smoltcp = { version = "0.9.0", default-features = false, features = [ ]} embassy-net-driver = { version = "0.1.0", path = "../embassy-net-driver" } +embassy-hal-common = { version = "0.1.0", path = "../embassy-hal-common" } embassy-time = { version = "0.1.0", path = "../embassy-time" } embassy-sync = { version = "0.1.0", path = "../embassy-sync" } embedded-io = { version = "0.4.0", optional = true } @@ -51,5 +52,5 @@ generic-array = { version = "0.14.4", default-features = false } stable_deref_trait = { version = "1.2.0", default-features = false } futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] } atomic-pool = "1.0" -embedded-nal-async = { version = "0.3.0", optional = true } +embedded-nal-async = { version = "0.4.0", optional = true } atomic-polyfill = { version = "1.0" } diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs new file mode 100644 index 00000000..97320e44 --- /dev/null +++ b/embassy-net/src/dns.rs @@ -0,0 +1,97 @@ +//! DNS socket with async support. +use heapless::Vec; +pub use smoltcp::socket::dns::{DnsQuery, Socket}; +pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError}; +pub use smoltcp::wire::{DnsQueryType, IpAddress}; + +use crate::{Driver, Stack}; + +/// Errors returned by DnsSocket. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + /// Invalid name + InvalidName, + /// Name too long + NameTooLong, + /// Name lookup failed + Failed, +} + +impl From for Error { + fn from(_: GetQueryResultError) -> Self { + Self::Failed + } +} + +impl From for Error { + fn from(e: StartQueryError) -> Self { + match e { + StartQueryError::NoFreeSlot => Self::Failed, + StartQueryError::InvalidName => Self::InvalidName, + StartQueryError::NameTooLong => Self::NameTooLong, + } + } +} + +/// Async socket for making DNS queries. +pub struct DnsSocket<'a, D> +where + D: Driver + 'static, +{ + stack: &'a Stack, +} + +impl<'a, D> DnsSocket<'a, D> +where + D: Driver + 'static, +{ + /// Create a new DNS socket using the provided stack. + /// + /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated. + pub fn new(stack: &'a Stack) -> Self { + Self { stack } + } + + /// Make a query for a given name and return the corresponding IP addresses. + pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result, Error> { + self.stack.dns_query(name, qtype).await + } +} + +#[cfg(all(feature = "unstable-traits", feature = "nightly"))] +impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D> +where + D: Driver + 'static, +{ + type Error = Error; + + async fn get_host_by_name( + &self, + host: &str, + addr_type: embedded_nal_async::AddrType, + ) -> Result { + use embedded_nal_async::{AddrType, IpAddr}; + let qtype = match addr_type { + AddrType::IPv6 => DnsQueryType::Aaaa, + _ => DnsQueryType::A, + }; + let addrs = self.query(host, qtype).await?; + if let Some(first) = addrs.get(0) { + Ok(match first { + IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()), + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()), + }) + } else { + Err(Error::Failed) + } + } + + async fn get_host_by_address( + &self, + _addr: embedded_nal_async::IpAddr, + ) -> Result, Self::Error> { + todo!() + } +} diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index 0f694ee7..bda5f9e1 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs @@ -11,6 +11,8 @@ pub(crate) mod fmt; pub use embassy_net_driver as driver; mod device; +#[cfg(feature = "dns")] +pub mod dns; #[cfg(feature = "tcp")] pub mod tcp; #[cfg(feature = "udp")] @@ -46,15 +48,23 @@ use crate::device::DriverAdapter; const LOCAL_PORT_MIN: u16 = 1025; const LOCAL_PORT_MAX: u16 = 65535; +#[cfg(feature = "dns")] +const MAX_QUERIES: usize = 4; pub struct StackResources { sockets: [SocketStorage<'static>; SOCK], + #[cfg(feature = "dns")] + queries: [Option; MAX_QUERIES], } impl StackResources { pub fn new() -> Self { + #[cfg(feature = "dns")] + const INIT: Option = None; Self { sockets: [SocketStorage::EMPTY; SOCK], + #[cfg(feature = "dns")] + queries: [INIT; MAX_QUERIES], } } } @@ -107,6 +117,10 @@ struct Inner { config: Option, #[cfg(feature = "dhcpv4")] dhcp_socket: Option, + #[cfg(feature = "dns")] + dns_socket: SocketHandle, + #[cfg(feature = "dns")] + dns_waker: WakerRegistration, } pub(crate) struct SocketStack { @@ -145,13 +159,6 @@ impl Stack { let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; - let mut inner = Inner { - device, - link_up: false, - config: None, - #[cfg(feature = "dhcpv4")] - dhcp_socket: None, - }; let mut socket = SocketStack { sockets, iface, @@ -159,8 +166,25 @@ impl Stack { next_local_port, }; + let mut inner = Inner { + device, + link_up: false, + config: None, + #[cfg(feature = "dhcpv4")] + dhcp_socket: None, + #[cfg(feature = "dns")] + dns_socket: socket.sockets.add(dns::Socket::new( + &[], + managed::ManagedSlice::Borrowed(&mut resources.queries), + )), + #[cfg(feature = "dns")] + dns_waker: WakerRegistration::new(), + }; + match config { - Config::Static(config) => inner.apply_config(&mut socket, config), + Config::Static(config) => { + inner.apply_config(&mut socket, config); + } #[cfg(feature = "dhcpv4")] Config::Dhcp(config) => { let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new(); @@ -208,6 +232,60 @@ impl Stack { .await; unreachable!() } + + /// Make a query for a given name and return the corresponding IP addresses. + #[cfg(feature = "dns")] + pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result, dns::Error> { + let query = poll_fn(|cx| { + self.with_mut(|s, i| { + let socket = s.sockets.get_mut::(i.dns_socket); + match socket.start_query(s.iface.context(), name, qtype) { + Ok(handle) => Poll::Ready(Ok(handle)), + Err(dns::StartQueryError::NoFreeSlot) => { + i.dns_waker.register(cx.waker()); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + }) + }) + .await?; + + use embassy_hal_common::drop::OnDrop; + let drop = OnDrop::new(|| { + self.with_mut(|s, i| { + let socket = s.sockets.get_mut::(i.dns_socket); + socket.cancel_query(query); + s.waker.wake(); + i.dns_waker.wake(); + }) + }); + + let res = poll_fn(|cx| { + self.with_mut(|s, i| { + let socket = s.sockets.get_mut::(i.dns_socket); + match socket.get_query_result(query) { + Ok(addrs) => { + i.dns_waker.wake(); + Poll::Ready(Ok(addrs)) + } + Err(dns::GetQueryResultError::Pending) => { + socket.register_query_waker(query, cx.waker()); + Poll::Pending + } + Err(e) => { + i.dns_waker.wake(); + Poll::Ready(Err(e.into())) + } + } + }) + }) + .await; + + drop.defuse(); + + res + } } impl SocketStack { @@ -249,6 +327,13 @@ impl Inner { debug!(" DNS server {}: {}", i, s); } + #[cfg(feature = "dns")] + { + let socket = s.sockets.get_mut::(self.dns_socket); + let servers: Vec = config.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect(); + socket.update_servers(&servers[..]); + } + self.config = Some(config) } @@ -324,6 +409,7 @@ impl Inner { //if old_link_up || self.link_up { // self.poll_configurator(timestamp) //} + // if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) { let t = Timer::at(instant_from_smoltcp(poll_at)); diff --git a/examples/nrf52840/src/bin/usb_ethernet.rs b/examples/nrf52840/src/bin/usb_ethernet.rs index 97978089..430468ad 100644 --- a/examples/nrf52840/src/bin/usb_ethernet.rs +++ b/examples/nrf52840/src/bin/usb_ethernet.rs @@ -46,8 +46,31 @@ async fn net_task(stack: &'static Stack>) -> ! { stack.run().await } +#[inline(never)] +pub fn test_function() -> (usize, u32, [u32; 2]) { + let mut array = [3; 2]; + + let mut index = 0; + let mut result = 0; + + for x in [1, 2] { + if x == 1 { + array[1] = 99; + } else { + index = if x == 2 { 1 } else { 0 }; + + // grabs value from array[0], not array[1] + result = array[index]; + } + } + + (index, result, array) +} + #[embassy_executor::main] async fn main(spawner: Spawner) { + info!("{:?}", test_function()); + let p = embassy_nrf::init(Default::default()); let clock: pac::CLOCK = unsafe { mem::transmute(()) }; diff --git a/examples/std/Cargo.toml b/examples/std/Cargo.toml index af1481e0..8087df09 100644 --- a/examples/std/Cargo.toml +++ b/examples/std/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" embassy-sync = { version = "0.1.0", path = "../../embassy-sync", features = ["log"] } embassy-executor = { version = "0.1.0", path = "../../embassy-executor", features = ["log", "std", "nightly", "integrated-timers"] } embassy-time = { version = "0.1.0", path = "../../embassy-time", features = ["log", "std", "nightly"] } -embassy-net = { version = "0.1.0", path = "../../embassy-net", features=[ "std", "nightly", "log", "medium-ethernet", "tcp", "udp", "dhcpv4"] } +embassy-net = { version = "0.1.0", path = "../../embassy-net", features=[ "std", "nightly", "log", "medium-ethernet", "tcp", "udp", "dns", "dhcpv4", "unstable-traits", "proto-ipv6"] } embassy-net-driver = { version = "0.1.0", path = "../../embassy-net-driver" } embedded-io = { version = "0.4.0", features = ["async", "std", "futures"] } critical-section = { version = "1.1", features = ["std"] } diff --git a/examples/std/src/bin/net_dns.rs b/examples/std/src/bin/net_dns.rs new file mode 100644 index 00000000..e1cc45a3 --- /dev/null +++ b/examples/std/src/bin/net_dns.rs @@ -0,0 +1,98 @@ +#![feature(type_alias_impl_trait)] + +use std::default::Default; + +use clap::Parser; +use embassy_executor::{Executor, Spawner}; +use embassy_net::dns::DnsQueryType; +use embassy_net::{Config, Ipv4Address, Ipv4Cidr, Stack, StackResources}; +use heapless::Vec; +use log::*; +use rand_core::{OsRng, RngCore}; +use static_cell::StaticCell; + +#[path = "../tuntap.rs"] +mod tuntap; + +use crate::tuntap::TunTapDevice; + +macro_rules! singleton { + ($val:expr) => {{ + type T = impl Sized; + static STATIC_CELL: StaticCell = StaticCell::new(); + STATIC_CELL.init_with(move || $val) + }}; +} + +#[derive(Parser)] +#[clap(version = "1.0")] +struct Opts { + /// TAP device name + #[clap(long, default_value = "tap0")] + tap: String, + /// use a static IP instead of DHCP + #[clap(long)] + static_ip: bool, +} + +#[embassy_executor::task] +async fn net_task(stack: &'static Stack) -> ! { + stack.run().await +} + +#[embassy_executor::task] +async fn main_task(spawner: Spawner) { + let opts: Opts = Opts::parse(); + + // Init network device + let device = TunTapDevice::new(&opts.tap).unwrap(); + + // Choose between dhcp or static ip + let config = if opts.static_ip { + Config::Static(embassy_net::StaticConfig { + address: Ipv4Cidr::new(Ipv4Address::new(192, 168, 69, 1), 24), + dns_servers: Vec::from_slice(&[Ipv4Address::new(8, 8, 4, 4).into(), Ipv4Address::new(8, 8, 8, 8).into()]) + .unwrap(), + gateway: Some(Ipv4Address::new(192, 168, 69, 100)), + }) + } else { + Config::Dhcp(Default::default()) + }; + + // Generate random seed + let mut seed = [0; 8]; + OsRng.fill_bytes(&mut seed); + let seed = u64::from_le_bytes(seed); + + // Init network stack + let stack: &Stack<_> = &*singleton!(Stack::new(device, config, singleton!(StackResources::<2>::new()), seed)); + + // Launch network task + spawner.spawn(net_task(stack)).unwrap(); + + let host = "example.com"; + info!("querying host {:?}...", host); + match stack.dns_query(host, DnsQueryType::A).await { + Ok(r) => { + info!("query response: {:?}", r); + } + Err(e) => { + warn!("query error: {:?}", e); + } + }; +} + +static EXECUTOR: StaticCell = StaticCell::new(); + +fn main() { + env_logger::builder() + .filter_level(log::LevelFilter::Debug) + .filter_module("async_io", log::LevelFilter::Info) + .format_timestamp_nanos() + .init(); + + let executor = EXECUTOR.init(Executor::new()); + executor.run(|spawner| { + spawner.spawn(main_task(spawner)).unwrap(); + }); +} diff --git a/examples/stm32h7/Cargo.toml b/examples/stm32h7/Cargo.toml index bcf97641..a0413478 100644 --- a/examples/stm32h7/Cargo.toml +++ b/examples/stm32h7/Cargo.toml @@ -21,7 +21,7 @@ cortex-m-rt = "0.7.0" embedded-hal = "0.2.6" embedded-hal-1 = { package = "embedded-hal", version = "=1.0.0-alpha.9" } embedded-hal-async = { version = "=0.2.0-alpha.0" } -embedded-nal-async = "0.3.0" +embedded-nal-async = "0.4.0" panic-probe = { version = "0.3", features = ["print-defmt"] } futures = { version = "0.3.17", default-features = false, features = ["async-await"] } heapless = { version = "0.7.5", default-features = false }