From 9cfea693edec5af17ba698f64b3f0a168ad92944 Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Tue, 31 Jan 2023 22:06:41 +0100 Subject: [PATCH] Add DNS socket to embassy-net --- embassy-net/Cargo.toml | 4 +- embassy-net/src/dns.rs | 114 ++++++++++++++++++++++++++++++++ embassy-net/src/lib.rs | 2 + examples/std/Cargo.toml | 2 +- examples/std/src/bin/net_dns.rs | 102 ++++++++++++++++++++++++++++ 5 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 embassy-net/src/dns.rs create mode 100644 examples/std/src/bin/net_dns.rs diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml index 4ec340b7..6b346828 100644 --- a/embassy-net/Cargo.toml +++ b/embassy-net/Cargo.toml @@ -13,7 +13,7 @@ target = "thumbv7em-none-eabi" [features] default = [] -std = [] +std = ["smoltcp/alloc", "managed/std"] defmt = ["dep:defmt", "smoltcp/defmt", "embassy-net-driver/defmt"] @@ -22,7 +22,7 @@ unstable-traits = [] udp = ["smoltcp/socket-udp"] tcp = ["smoltcp/socket-tcp"] -dns = ["smoltcp/socket-dns"] +dns = ["smoltcp/socket-dns", "smoltcp/proto-dns"] dhcpv4 = ["medium-ethernet", "smoltcp/socket-dhcpv4"] proto-ipv6 = ["smoltcp/proto-ipv6"] medium-ethernet = ["smoltcp/medium-ethernet"] diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs new file mode 100644 index 00000000..f18750cc --- /dev/null +++ b/embassy-net/src/dns.rs @@ -0,0 +1,114 @@ +//! DNS socket with async support. +use core::cell::RefCell; +use core::future::poll_fn; +use core::mem; +use core::task::Poll; + +use embassy_net_driver::Driver; +use heapless::Vec; +use managed::ManagedSlice; +use smoltcp::iface::{Interface, SocketHandle}; +pub use smoltcp::socket::dns::DnsQuery; +use smoltcp::socket::dns::{self, GetQueryResultError, StartQueryError, MAX_ADDRESS_COUNT}; +pub use smoltcp::wire::{DnsQueryType, IpAddress}; + +use crate::{SocketStack, Stack}; + +/// Errors returned by DnsSocket. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + /// No available query slot + NoFreeSlot, + /// Invalid name + InvalidName, + /// Name too long + NameTooLong, + /// Name lookup failed + Failed, +} + +impl From for Error { + fn from(_: GetQueryResultError) -> Self { + Self::Failed + } +} + +impl From for Error { + fn from(e: StartQueryError) -> Self { + match e { + StartQueryError::NoFreeSlot => Self::NoFreeSlot, + StartQueryError::InvalidName => Self::InvalidName, + StartQueryError::NameTooLong => Self::NameTooLong, + } + } +} + +/// Async socket for making DNS queries. +pub struct DnsSocket<'a> { + stack: &'a RefCell, + handle: SocketHandle, +} + +impl<'a> DnsSocket<'a> { + /// Create a new DNS socket using the provided stack and query storage. + /// + /// DNS servers are derived from the stack configuration. + /// + /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated. + pub fn new(stack: &'a Stack, queries: Q) -> Self + where + D: Driver + 'static, + Q: Into>>, + { + let servers = stack + .config() + .map(|c| { + let v: Vec = c.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect(); + v + }) + .unwrap_or(Vec::new()); + let s = &mut *stack.socket.borrow_mut(); + let queries: ManagedSlice<'static, Option> = unsafe { mem::transmute(queries.into()) }; + + let handle = s.sockets.add(dns::Socket::new(&servers[..], queries)); + Self { + stack: &stack.socket, + handle, + } + } + + fn with_mut(&mut self, f: impl FnOnce(&mut dns::Socket, &mut Interface) -> R) -> R { + let s = &mut *self.stack.borrow_mut(); + let socket = s.sockets.get_mut::(self.handle); + let res = f(socket, &mut s.iface); + s.waker.wake(); + res + } + + /// Make a query for a given name and return the corresponding IP addresses. + pub async fn query(&mut self, name: &str, qtype: DnsQueryType) -> Result, Error> { + let query = match { self.with_mut(|s, i| s.start_query(i.context(), name, qtype)) } { + Ok(handle) => handle, + Err(e) => return Err(e.into()), + }; + + poll_fn(|cx| { + self.with_mut(|s, _| match s.get_query_result(query) { + Ok(addrs) => Poll::Ready(Ok(addrs)), + Err(GetQueryResultError::Pending) => { + s.register_query_waker(query, cx.waker()); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + }) + }) + .await + } +} + +impl<'a> Drop for DnsSocket<'a> { + fn drop(&mut self) { + self.stack.borrow_mut().sockets.remove(self.handle); + } +} diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index 0f694ee7..ae447d06 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs @@ -11,6 +11,8 @@ pub(crate) mod fmt; pub use embassy_net_driver as driver; mod device; +#[cfg(feature = "dns")] +pub mod dns; #[cfg(feature = "tcp")] pub mod tcp; #[cfg(feature = "udp")] diff --git a/examples/std/Cargo.toml b/examples/std/Cargo.toml index af1481e0..8087df09 100644 --- a/examples/std/Cargo.toml +++ b/examples/std/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" embassy-sync = { version = "0.1.0", path = "../../embassy-sync", features = ["log"] } embassy-executor = { version = "0.1.0", path = "../../embassy-executor", features = ["log", "std", "nightly", "integrated-timers"] } embassy-time = { version = "0.1.0", path = "../../embassy-time", features = ["log", "std", "nightly"] } -embassy-net = { version = "0.1.0", path = "../../embassy-net", features=[ "std", "nightly", "log", "medium-ethernet", "tcp", "udp", "dhcpv4"] } +embassy-net = { version = "0.1.0", path = "../../embassy-net", features=[ "std", "nightly", "log", "medium-ethernet", "tcp", "udp", "dns", "dhcpv4", "unstable-traits", "proto-ipv6"] } embassy-net-driver = { version = "0.1.0", path = "../../embassy-net-driver" } embedded-io = { version = "0.4.0", features = ["async", "std", "futures"] } critical-section = { version = "1.1", features = ["std"] } diff --git a/examples/std/src/bin/net_dns.rs b/examples/std/src/bin/net_dns.rs new file mode 100644 index 00000000..6203f837 --- /dev/null +++ b/examples/std/src/bin/net_dns.rs @@ -0,0 +1,102 @@ +#![feature(type_alias_impl_trait)] + +use std::default::Default; + +use clap::Parser; +use embassy_executor::{Executor, Spawner}; +use embassy_net::dns::{DnsQueryType, DnsSocket}; +use embassy_net::{Config, Ipv4Address, Ipv4Cidr, Stack, StackResources}; +use heapless::Vec; +use log::*; +use rand_core::{OsRng, RngCore}; +use static_cell::StaticCell; + +#[path = "../tuntap.rs"] +mod tuntap; + +use crate::tuntap::TunTapDevice; + +macro_rules! singleton { + ($val:expr) => {{ + type T = impl Sized; + static STATIC_CELL: StaticCell = StaticCell::new(); + STATIC_CELL.init_with(move || $val) + }}; +} + +#[derive(Parser)] +#[clap(version = "1.0")] +struct Opts { + /// TAP device name + #[clap(long, default_value = "tap0")] + tap: String, + /// use a static IP instead of DHCP + #[clap(long)] + static_ip: bool, +} + +#[embassy_executor::task] +async fn net_task(stack: &'static Stack) -> ! { + stack.run().await +} + +#[embassy_executor::task] +async fn main_task(spawner: Spawner) { + let opts: Opts = Opts::parse(); + + // Init network device + let device = TunTapDevice::new(&opts.tap).unwrap(); + + // Choose between dhcp or static ip + let config = if opts.static_ip { + Config::Static(embassy_net::StaticConfig { + address: Ipv4Cidr::new(Ipv4Address::new(192, 168, 69, 1), 24), + dns_servers: Vec::from_slice(&[Ipv4Address::new(8, 8, 4, 4).into(), Ipv4Address::new(8, 8, 8, 8).into()]) + .unwrap(), + gateway: Some(Ipv4Address::new(192, 168, 69, 100)), + }) + } else { + Config::Dhcp(Default::default()) + }; + + // Generate random seed + let mut seed = [0; 8]; + OsRng.fill_bytes(&mut seed); + let seed = u64::from_le_bytes(seed); + + // Init network stack + let stack = &*singleton!(Stack::new(device, config, singleton!(StackResources::<2>::new()), seed)); + + // Launch network task + spawner.spawn(net_task(stack)).unwrap(); + + // Then we can use it! + + let mut socket = DnsSocket::new(stack, vec![]); + + let host = "example.com"; + info!("querying host {:?}...", host); + match socket.query(host, DnsQueryType::A).await { + Ok(r) => { + info!("query response: {:?}", r); + } + Err(e) => { + warn!("query error: {:?}", e); + } + }; +} + +static EXECUTOR: StaticCell = StaticCell::new(); + +fn main() { + env_logger::builder() + .filter_level(log::LevelFilter::Debug) + .filter_module("async_io", log::LevelFilter::Info) + .format_timestamp_nanos() + .init(); + + let executor = EXECUTOR.init(Executor::new()); + executor.run(|spawner| { + spawner.spawn(main_task(spawner)).unwrap(); + }); +}