add waker for DNS slots

This commit is contained in:
Ulf Lilleengen 2023-02-10 18:44:51 +01:00
parent 48dff04d64
commit 32c3725631
Failed to extract signature
2 changed files with 32 additions and 11 deletions

View File

@ -10,8 +10,6 @@ use crate::{Driver, Stack};
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))] #[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Error { pub enum Error {
/// No available query slot
NoFreeSlot,
/// Invalid name /// Invalid name
InvalidName, InvalidName,
/// Name too long /// Name too long
@ -29,7 +27,7 @@ impl From<GetQueryResultError> for Error {
impl From<StartQueryError> for Error { impl From<StartQueryError> for Error {
fn from(e: StartQueryError) -> Self { fn from(e: StartQueryError) -> Self {
match e { match e {
StartQueryError::NoFreeSlot => Self::NoFreeSlot, StartQueryError::NoFreeSlot => Self::Failed,
StartQueryError::InvalidName => Self::InvalidName, StartQueryError::InvalidName => Self::InvalidName,
StartQueryError::NameTooLong => Self::NameTooLong, StartQueryError::NameTooLong => Self::NameTooLong,
} }

View File

@ -119,6 +119,8 @@ struct Inner<D: Driver> {
dhcp_socket: Option<SocketHandle>, dhcp_socket: Option<SocketHandle>,
#[cfg(feature = "dns")] #[cfg(feature = "dns")]
dns_socket: SocketHandle, dns_socket: SocketHandle,
#[cfg(feature = "dns")]
dns_waker: WakerRegistration,
} }
pub(crate) struct SocketStack { pub(crate) struct SocketStack {
@ -157,7 +159,6 @@ impl<D: Driver + 'static> Stack<D> {
let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN;
let mut socket = SocketStack { let mut socket = SocketStack {
sockets, sockets,
iface, iface,
@ -172,7 +173,12 @@ impl<D: Driver + 'static> Stack<D> {
#[cfg(feature = "dhcpv4")] #[cfg(feature = "dhcpv4")]
dhcp_socket: None, dhcp_socket: None,
#[cfg(feature = "dns")] #[cfg(feature = "dns")]
dns_socket: socket.sockets.add(dns::Socket::new(&[], managed::ManagedSlice::Borrowed(&mut resources.queries))), dns_socket: socket.sockets.add(dns::Socket::new(
&[],
managed::ManagedSlice::Borrowed(&mut resources.queries),
)),
#[cfg(feature = "dns")]
dns_waker: WakerRegistration::new(),
}; };
match config { match config {
@ -230,10 +236,20 @@ impl<D: Driver + 'static> Stack<D> {
/// Make a query for a given name and return the corresponding IP addresses. /// Make a query for a given name and return the corresponding IP addresses.
#[cfg(feature = "dns")] #[cfg(feature = "dns")]
pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> { pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> {
let query = self.with_mut(|s, i| { let query = poll_fn(|cx| {
self.with_mut(|s, i| {
let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
socket.start_query(s.iface.context(), name, qtype) match socket.start_query(s.iface.context(), name, qtype) {
})?; Ok(handle) => Poll::Ready(Ok(handle)),
Err(dns::StartQueryError::NoFreeSlot) => {
i.dns_waker.register(cx.waker());
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
})
})
.await?;
use embassy_hal_common::drop::OnDrop; use embassy_hal_common::drop::OnDrop;
let drop = OnDrop::new(|| { let drop = OnDrop::new(|| {
@ -241,6 +257,7 @@ impl<D: Driver + 'static> Stack<D> {
let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
socket.cancel_query(query); socket.cancel_query(query);
s.waker.wake(); s.waker.wake();
i.dns_waker.wake();
}) })
}); });
@ -248,12 +265,18 @@ impl<D: Driver + 'static> Stack<D> {
self.with_mut(|s, i| { self.with_mut(|s, i| {
let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
match socket.get_query_result(query) { match socket.get_query_result(query) {
Ok(addrs) => Poll::Ready(Ok(addrs)), Ok(addrs) => {
i.dns_waker.wake();
Poll::Ready(Ok(addrs))
}
Err(dns::GetQueryResultError::Pending) => { Err(dns::GetQueryResultError::Pending) => {
socket.register_query_waker(query, cx.waker()); socket.register_query_waker(query, cx.waker());
Poll::Pending Poll::Pending
} }
Err(e) => Poll::Ready(Err(e.into())), Err(e) => {
i.dns_waker.wake();
Poll::Ready(Err(e.into()))
}
} }
}) })
}) })