add waker for DNS slots
This commit is contained in:
		@@ -10,8 +10,6 @@ use crate::{Driver, Stack};
 | 
				
			|||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 | 
					#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 | 
				
			||||||
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
 | 
					#[cfg_attr(feature = "defmt", derive(defmt::Format))]
 | 
				
			||||||
pub enum Error {
 | 
					pub enum Error {
 | 
				
			||||||
    /// No available query slot
 | 
					 | 
				
			||||||
    NoFreeSlot,
 | 
					 | 
				
			||||||
    /// Invalid name
 | 
					    /// Invalid name
 | 
				
			||||||
    InvalidName,
 | 
					    InvalidName,
 | 
				
			||||||
    /// Name too long
 | 
					    /// Name too long
 | 
				
			||||||
@@ -29,7 +27,7 @@ impl From<GetQueryResultError> for Error {
 | 
				
			|||||||
impl From<StartQueryError> for Error {
 | 
					impl From<StartQueryError> for Error {
 | 
				
			||||||
    fn from(e: StartQueryError) -> Self {
 | 
					    fn from(e: StartQueryError) -> Self {
 | 
				
			||||||
        match e {
 | 
					        match e {
 | 
				
			||||||
            StartQueryError::NoFreeSlot => Self::NoFreeSlot,
 | 
					            StartQueryError::NoFreeSlot => Self::Failed,
 | 
				
			||||||
            StartQueryError::InvalidName => Self::InvalidName,
 | 
					            StartQueryError::InvalidName => Self::InvalidName,
 | 
				
			||||||
            StartQueryError::NameTooLong => Self::NameTooLong,
 | 
					            StartQueryError::NameTooLong => Self::NameTooLong,
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -119,6 +119,8 @@ struct Inner<D: Driver> {
 | 
				
			|||||||
    dhcp_socket: Option<SocketHandle>,
 | 
					    dhcp_socket: Option<SocketHandle>,
 | 
				
			||||||
    #[cfg(feature = "dns")]
 | 
					    #[cfg(feature = "dns")]
 | 
				
			||||||
    dns_socket: SocketHandle,
 | 
					    dns_socket: SocketHandle,
 | 
				
			||||||
 | 
					    #[cfg(feature = "dns")]
 | 
				
			||||||
 | 
					    dns_waker: WakerRegistration,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub(crate) struct SocketStack {
 | 
					pub(crate) struct SocketStack {
 | 
				
			||||||
@@ -157,7 +159,6 @@ impl<D: Driver + 'static> Stack<D> {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN;
 | 
					        let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        let mut socket = SocketStack {
 | 
					        let mut socket = SocketStack {
 | 
				
			||||||
            sockets,
 | 
					            sockets,
 | 
				
			||||||
            iface,
 | 
					            iface,
 | 
				
			||||||
@@ -172,7 +173,12 @@ impl<D: Driver + 'static> Stack<D> {
 | 
				
			|||||||
            #[cfg(feature = "dhcpv4")]
 | 
					            #[cfg(feature = "dhcpv4")]
 | 
				
			||||||
            dhcp_socket: None,
 | 
					            dhcp_socket: None,
 | 
				
			||||||
            #[cfg(feature = "dns")]
 | 
					            #[cfg(feature = "dns")]
 | 
				
			||||||
            dns_socket: socket.sockets.add(dns::Socket::new(&[], managed::ManagedSlice::Borrowed(&mut resources.queries))),
 | 
					            dns_socket: socket.sockets.add(dns::Socket::new(
 | 
				
			||||||
 | 
					                &[],
 | 
				
			||||||
 | 
					                managed::ManagedSlice::Borrowed(&mut resources.queries),
 | 
				
			||||||
 | 
					            )),
 | 
				
			||||||
 | 
					            #[cfg(feature = "dns")]
 | 
				
			||||||
 | 
					            dns_waker: WakerRegistration::new(),
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        match config {
 | 
					        match config {
 | 
				
			||||||
@@ -230,10 +236,20 @@ impl<D: Driver + 'static> Stack<D> {
 | 
				
			|||||||
    /// Make a query for a given name and return the corresponding IP addresses.
 | 
					    /// Make a query for a given name and return the corresponding IP addresses.
 | 
				
			||||||
    #[cfg(feature = "dns")]
 | 
					    #[cfg(feature = "dns")]
 | 
				
			||||||
    pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> {
 | 
					    pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> {
 | 
				
			||||||
        let query = self.with_mut(|s, i| {
 | 
					        let query = poll_fn(|cx| {
 | 
				
			||||||
            let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
 | 
					            self.with_mut(|s, i| {
 | 
				
			||||||
            socket.start_query(s.iface.context(), name, qtype)
 | 
					                let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
 | 
				
			||||||
        })?;
 | 
					                match socket.start_query(s.iface.context(), name, qtype) {
 | 
				
			||||||
 | 
					                    Ok(handle) => Poll::Ready(Ok(handle)),
 | 
				
			||||||
 | 
					                    Err(dns::StartQueryError::NoFreeSlot) => {
 | 
				
			||||||
 | 
					                        i.dns_waker.register(cx.waker());
 | 
				
			||||||
 | 
					                        Poll::Pending
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    Err(e) => Poll::Ready(Err(e)),
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            })
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					        .await?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        use embassy_hal_common::drop::OnDrop;
 | 
					        use embassy_hal_common::drop::OnDrop;
 | 
				
			||||||
        let drop = OnDrop::new(|| {
 | 
					        let drop = OnDrop::new(|| {
 | 
				
			||||||
@@ -241,6 +257,7 @@ impl<D: Driver + 'static> Stack<D> {
 | 
				
			|||||||
                let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
 | 
					                let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
 | 
				
			||||||
                socket.cancel_query(query);
 | 
					                socket.cancel_query(query);
 | 
				
			||||||
                s.waker.wake();
 | 
					                s.waker.wake();
 | 
				
			||||||
 | 
					                i.dns_waker.wake();
 | 
				
			||||||
            })
 | 
					            })
 | 
				
			||||||
        });
 | 
					        });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -248,12 +265,18 @@ impl<D: Driver + 'static> Stack<D> {
 | 
				
			|||||||
            self.with_mut(|s, i| {
 | 
					            self.with_mut(|s, i| {
 | 
				
			||||||
                let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
 | 
					                let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
 | 
				
			||||||
                match socket.get_query_result(query) {
 | 
					                match socket.get_query_result(query) {
 | 
				
			||||||
                    Ok(addrs) => Poll::Ready(Ok(addrs)),
 | 
					                    Ok(addrs) => {
 | 
				
			||||||
 | 
					                        i.dns_waker.wake();
 | 
				
			||||||
 | 
					                        Poll::Ready(Ok(addrs))
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
                    Err(dns::GetQueryResultError::Pending) => {
 | 
					                    Err(dns::GetQueryResultError::Pending) => {
 | 
				
			||||||
                        socket.register_query_waker(query, cx.waker());
 | 
					                        socket.register_query_waker(query, cx.waker());
 | 
				
			||||||
                        Poll::Pending
 | 
					                        Poll::Pending
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    Err(e) => Poll::Ready(Err(e.into())),
 | 
					                    Err(e) => {
 | 
				
			||||||
 | 
					                        i.dns_waker.wake();
 | 
				
			||||||
 | 
					                        Poll::Ready(Err(e.into()))
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            })
 | 
					            })
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user