diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs index ee90c301..78b09a49 100644 --- a/embassy-net/src/udp.rs +++ b/embassy-net/src/udp.rs @@ -27,7 +27,8 @@ pub enum Error { } pub struct UdpSocket<'a> { - io: UdpIo<'a>, + stack: &'a UnsafeCell, + handle: SocketHandle, } impl<'a> UdpSocket<'a> { @@ -51,10 +52,8 @@ impl<'a> UdpSocket<'a> { )); Self { - io: UdpIo { - stack: &stack.socket, - handle, - }, + stack: &stack.socket, + handle, } } @@ -67,64 +66,17 @@ impl<'a> UdpSocket<'a> { // safety: not accessed reentrantly. if endpoint.port == 0 { // If user didn't specify port allocate a dynamic port. - endpoint.port = unsafe { &mut *self.io.stack.get() }.get_local_port(); + endpoint.port = unsafe { &mut *self.stack.get() }.get_local_port(); } // safety: not accessed reentrantly. - match unsafe { self.io.with_mut(|s, _| s.bind(endpoint)) } { + match unsafe { self.with_mut(|s, _| s.bind(endpoint)) } { Ok(()) => Ok(()), Err(udp::BindError::InvalidState) => Err(BindError::InvalidState), Err(udp::BindError::Unaddressable) => Err(BindError::NoRoute), } } - pub async fn send_to(&self, buf: &[u8], remote_endpoint: T) -> Result<(), Error> - where - T: Into, - { - self.io.write(buf, remote_endpoint.into()).await - } - - pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { - self.io.read(buf).await - } - - pub fn endpoint(&self) -> IpListenEndpoint { - unsafe { self.io.with(|s, _| s.endpoint()) } - } - - pub fn is_open(&self) -> bool { - unsafe { self.io.with(|s, _| s.is_open()) } - } - - pub fn close(&mut self) { - unsafe { self.io.with_mut(|s, _| s.close()) } - } - - pub fn may_send(&self) -> bool { - unsafe { self.io.with(|s, _| s.can_send()) } - } - - pub fn may_recv(&self) -> bool { - unsafe { self.io.with(|s, _| s.can_recv()) } - } -} - -impl Drop for UdpSocket<'_> { - fn drop(&mut self) { - // safety: not accessed reentrantly. - let s = unsafe { &mut *self.io.stack.get() }; - s.sockets.remove(self.io.handle); - } -} - -#[derive(Copy, Clone)] -pub struct UdpIo<'a> { - stack: &'a UnsafeCell, - handle: SocketHandle, -} - -impl UdpIo<'_> { /// SAFETY: must not call reentrantly. unsafe fn with(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { let s = &*self.stack.get(); @@ -141,7 +93,7 @@ impl UdpIo<'_> { res } - async fn read(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { + pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpEndpoint), Error> { poll_fn(move |cx| unsafe { self.with_mut(|s, _| match s.recv_slice(buf) { Ok(x) => Poll::Ready(Ok(x)), @@ -156,9 +108,13 @@ impl UdpIo<'_> { .await } - async fn write(&self, buf: &[u8], ep: IpEndpoint) -> Result<(), Error> { + pub async fn send_to(&self, buf: &[u8], remote_endpoint: T) -> Result<(), Error> + where + T: Into, + { + let remote_endpoint = remote_endpoint.into(); poll_fn(move |cx| unsafe { - self.with_mut(|s, _| match s.send_slice(buf, ep) { + self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { // Entire datagram has been sent Ok(()) => Poll::Ready(Ok(())), Err(udp::SendError::BufferFull) => { @@ -170,4 +126,32 @@ impl UdpIo<'_> { }) .await } + + pub fn endpoint(&self) -> IpListenEndpoint { + unsafe { self.with(|s, _| s.endpoint()) } + } + + pub fn is_open(&self) -> bool { + unsafe { self.with(|s, _| s.is_open()) } + } + + pub fn close(&mut self) { + unsafe { self.with_mut(|s, _| s.close()) } + } + + pub fn may_send(&self) -> bool { + unsafe { self.with(|s, _| s.can_send()) } + } + + pub fn may_recv(&self) -> bool { + unsafe { self.with(|s, _| s.can_recv()) } + } +} + +impl Drop for UdpSocket<'_> { + fn drop(&mut self) { + // safety: not accessed reentrantly. + let s = unsafe { &mut *self.stack.get() }; + s.sockets.remove(self.handle); + } }