add waker for DNS slots
This commit is contained in:
		| @@ -10,8 +10,6 @@ use crate::{Driver, Stack}; | |||||||
| #[derive(Debug, PartialEq, Eq, Clone, Copy)] | #[derive(Debug, PartialEq, Eq, Clone, Copy)] | ||||||
| #[cfg_attr(feature = "defmt", derive(defmt::Format))] | #[cfg_attr(feature = "defmt", derive(defmt::Format))] | ||||||
| pub enum Error { | pub enum Error { | ||||||
|     /// No available query slot |  | ||||||
|     NoFreeSlot, |  | ||||||
|     /// Invalid name |     /// Invalid name | ||||||
|     InvalidName, |     InvalidName, | ||||||
|     /// Name too long |     /// Name too long | ||||||
| @@ -29,7 +27,7 @@ impl From<GetQueryResultError> for Error { | |||||||
| impl From<StartQueryError> for Error { | impl From<StartQueryError> for Error { | ||||||
|     fn from(e: StartQueryError) -> Self { |     fn from(e: StartQueryError) -> Self { | ||||||
|         match e { |         match e { | ||||||
|             StartQueryError::NoFreeSlot => Self::NoFreeSlot, |             StartQueryError::NoFreeSlot => Self::Failed, | ||||||
|             StartQueryError::InvalidName => Self::InvalidName, |             StartQueryError::InvalidName => Self::InvalidName, | ||||||
|             StartQueryError::NameTooLong => Self::NameTooLong, |             StartQueryError::NameTooLong => Self::NameTooLong, | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -119,6 +119,8 @@ struct Inner<D: Driver> { | |||||||
|     dhcp_socket: Option<SocketHandle>, |     dhcp_socket: Option<SocketHandle>, | ||||||
|     #[cfg(feature = "dns")] |     #[cfg(feature = "dns")] | ||||||
|     dns_socket: SocketHandle, |     dns_socket: SocketHandle, | ||||||
|  |     #[cfg(feature = "dns")] | ||||||
|  |     dns_waker: WakerRegistration, | ||||||
| } | } | ||||||
|  |  | ||||||
| pub(crate) struct SocketStack { | pub(crate) struct SocketStack { | ||||||
| @@ -157,7 +159,6 @@ impl<D: Driver + 'static> Stack<D> { | |||||||
|  |  | ||||||
|         let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; |         let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN; | ||||||
|  |  | ||||||
|  |  | ||||||
|         let mut socket = SocketStack { |         let mut socket = SocketStack { | ||||||
|             sockets, |             sockets, | ||||||
|             iface, |             iface, | ||||||
| @@ -172,7 +173,12 @@ impl<D: Driver + 'static> Stack<D> { | |||||||
|             #[cfg(feature = "dhcpv4")] |             #[cfg(feature = "dhcpv4")] | ||||||
|             dhcp_socket: None, |             dhcp_socket: None, | ||||||
|             #[cfg(feature = "dns")] |             #[cfg(feature = "dns")] | ||||||
|             dns_socket: socket.sockets.add(dns::Socket::new(&[], managed::ManagedSlice::Borrowed(&mut resources.queries))), |             dns_socket: socket.sockets.add(dns::Socket::new( | ||||||
|  |                 &[], | ||||||
|  |                 managed::ManagedSlice::Borrowed(&mut resources.queries), | ||||||
|  |             )), | ||||||
|  |             #[cfg(feature = "dns")] | ||||||
|  |             dns_waker: WakerRegistration::new(), | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         match config { |         match config { | ||||||
| @@ -230,10 +236,20 @@ impl<D: Driver + 'static> Stack<D> { | |||||||
|     /// Make a query for a given name and return the corresponding IP addresses. |     /// Make a query for a given name and return the corresponding IP addresses. | ||||||
|     #[cfg(feature = "dns")] |     #[cfg(feature = "dns")] | ||||||
|     pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> { |     pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> { | ||||||
|         let query = self.with_mut(|s, i| { |         let query = poll_fn(|cx| { | ||||||
|  |             self.with_mut(|s, i| { | ||||||
|                 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); |                 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); | ||||||
|             socket.start_query(s.iface.context(), name, qtype) |                 match socket.start_query(s.iface.context(), name, qtype) { | ||||||
|         })?; |                     Ok(handle) => Poll::Ready(Ok(handle)), | ||||||
|  |                     Err(dns::StartQueryError::NoFreeSlot) => { | ||||||
|  |                         i.dns_waker.register(cx.waker()); | ||||||
|  |                         Poll::Pending | ||||||
|  |                     } | ||||||
|  |                     Err(e) => Poll::Ready(Err(e)), | ||||||
|  |                 } | ||||||
|  |             }) | ||||||
|  |         }) | ||||||
|  |         .await?; | ||||||
|  |  | ||||||
|         use embassy_hal_common::drop::OnDrop; |         use embassy_hal_common::drop::OnDrop; | ||||||
|         let drop = OnDrop::new(|| { |         let drop = OnDrop::new(|| { | ||||||
| @@ -241,6 +257,7 @@ impl<D: Driver + 'static> Stack<D> { | |||||||
|                 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); |                 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); | ||||||
|                 socket.cancel_query(query); |                 socket.cancel_query(query); | ||||||
|                 s.waker.wake(); |                 s.waker.wake(); | ||||||
|  |                 i.dns_waker.wake(); | ||||||
|             }) |             }) | ||||||
|         }); |         }); | ||||||
|  |  | ||||||
| @@ -248,12 +265,18 @@ impl<D: Driver + 'static> Stack<D> { | |||||||
|             self.with_mut(|s, i| { |             self.with_mut(|s, i| { | ||||||
|                 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); |                 let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket); | ||||||
|                 match socket.get_query_result(query) { |                 match socket.get_query_result(query) { | ||||||
|                     Ok(addrs) => Poll::Ready(Ok(addrs)), |                     Ok(addrs) => { | ||||||
|  |                         i.dns_waker.wake(); | ||||||
|  |                         Poll::Ready(Ok(addrs)) | ||||||
|  |                     } | ||||||
|                     Err(dns::GetQueryResultError::Pending) => { |                     Err(dns::GetQueryResultError::Pending) => { | ||||||
|                         socket.register_query_waker(query, cx.waker()); |                         socket.register_query_waker(query, cx.waker()); | ||||||
|                         Poll::Pending |                         Poll::Pending | ||||||
|                     } |                     } | ||||||
|                     Err(e) => Poll::Ready(Err(e.into())), |                     Err(e) => { | ||||||
|  |                         i.dns_waker.wake(); | ||||||
|  |                         Poll::Ready(Err(e.into())) | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|             }) |             }) | ||||||
|         }) |         }) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user