diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs index 9b1b936c..2dd44a4e 100644 --- a/embassy-net/src/dns.rs +++ b/embassy-net/src/dns.rs @@ -10,8 +10,6 @@ use crate::{Driver, Stack}; #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Error { - /// No available query slot - NoFreeSlot, /// Invalid name InvalidName, /// Name too long @@ -29,7 +27,7 @@ impl From for Error { impl From for Error { fn from(e: StartQueryError) -> Self { match e { - StartQueryError::NoFreeSlot => Self::NoFreeSlot, + StartQueryError::NoFreeSlot => Self::Failed, StartQueryError::InvalidName => Self::InvalidName, StartQueryError::NameTooLong => Self::NameTooLong, } diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index e5f7479c..9d9de913 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs @@ -119,6 +119,8 @@ struct Inner { dhcp_socket: Option, #[cfg(feature = "dns")] dns_socket: SocketHandle, + #[cfg(feature = "dns")] + dns_waker: WakerRegistration, } pub(crate) struct SocketStack { @@ -157,7 +159,6 @@ impl Stack { let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; - let mut socket = SocketStack { sockets, iface, @@ -172,7 +173,12 @@ impl Stack { #[cfg(feature = "dhcpv4")] dhcp_socket: None, #[cfg(feature = "dns")] - dns_socket: socket.sockets.add(dns::Socket::new(&[], managed::ManagedSlice::Borrowed(&mut resources.queries))), + dns_socket: socket.sockets.add(dns::Socket::new( + &[], + managed::ManagedSlice::Borrowed(&mut resources.queries), + )), + #[cfg(feature = "dns")] + dns_waker: WakerRegistration::new(), }; match config { @@ -230,10 +236,20 @@ impl Stack { /// Make a query for a given name and return the corresponding IP addresses. #[cfg(feature = "dns")] pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result, dns::Error> { - let query = self.with_mut(|s, i| { - let socket = s.sockets.get_mut::(i.dns_socket); - socket.start_query(s.iface.context(), name, qtype) - })?; + let query = poll_fn(|cx| { + self.with_mut(|s, i| { + let socket = s.sockets.get_mut::(i.dns_socket); + match socket.start_query(s.iface.context(), name, qtype) { + Ok(handle) => Poll::Ready(Ok(handle)), + Err(dns::StartQueryError::NoFreeSlot) => { + i.dns_waker.register(cx.waker()); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + }) + }) + .await?; use embassy_hal_common::drop::OnDrop; let drop = OnDrop::new(|| { @@ -241,6 +257,7 @@ impl Stack { let socket = s.sockets.get_mut::(i.dns_socket); socket.cancel_query(query); s.waker.wake(); + i.dns_waker.wake(); }) }); @@ -248,12 +265,18 @@ impl Stack { self.with_mut(|s, i| { let socket = s.sockets.get_mut::(i.dns_socket); match socket.get_query_result(query) { - Ok(addrs) => Poll::Ready(Ok(addrs)), + Ok(addrs) => { + i.dns_waker.wake(); + Poll::Ready(Ok(addrs)) + } Err(dns::GetQueryResultError::Pending) => { socket.register_query_waker(query, cx.waker()); Poll::Pending } - Err(e) => Poll::Ready(Err(e.into())), + Err(e) => { + i.dns_waker.wake(); + Poll::Ready(Err(e.into())) + } } }) })