refactor: Multicast method modifiers on stack to public

revert: udp.rs
This commit is contained in:
Leon Camus 2023-03-08 12:37:00 +01:00
parent 993875e11f
commit e484cb1b87
2 changed files with 18 additions and 35 deletions

View File

@ -306,7 +306,7 @@ impl<D: Driver + 'static> Stack<D> {
#[cfg(feature = "igmp")] #[cfg(feature = "igmp")]
impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> { impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
pub(crate) fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
where where
T: Into<IpAddress>, T: Into<IpAddress>,
{ {
@ -318,7 +318,7 @@ impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
}) })
} }
pub(crate) fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
where where
T: Into<IpAddress>, T: Into<IpAddress>,
{ {
@ -330,7 +330,7 @@ impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
}) })
} }
pub(crate) fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool { pub fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
self.socket.borrow().iface.has_multicast_group(addr) self.socket.borrow().iface.has_multicast_group(addr)
} }
} }

View File

@ -1,3 +1,4 @@
use core::cell::RefCell;
use core::future::poll_fn; use core::future::poll_fn;
use core::mem; use core::mem;
use core::task::Poll; use core::task::Poll;
@ -7,7 +8,7 @@ use smoltcp::iface::{Interface, SocketHandle};
use smoltcp::socket::udp::{self, PacketMetadata}; use smoltcp::socket::udp::{self, PacketMetadata};
use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
use crate::Stack; use crate::{SocketStack, Stack};
#[derive(PartialEq, Eq, Clone, Copy, Debug)] #[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))] #[cfg_attr(feature = "defmt", derive(defmt::Format))]
@ -25,13 +26,13 @@ pub enum Error {
NoRoute, NoRoute,
} }
pub struct UdpSocket<'a, D: Driver> { pub struct UdpSocket<'a> {
stack: &'a Stack<D>, stack: &'a RefCell<SocketStack>,
handle: SocketHandle, handle: SocketHandle,
} }
impl<'a, D: Driver> UdpSocket<'a, D> { impl<'a> UdpSocket<'a> {
pub fn new( pub fn new<D: Driver>(
stack: &'a Stack<D>, stack: &'a Stack<D>,
rx_meta: &'a mut [PacketMetadata], rx_meta: &'a mut [PacketMetadata],
rx_buffer: &'a mut [u8], rx_buffer: &'a mut [u8],
@ -49,7 +50,10 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
udp::PacketBuffer::new(tx_meta, tx_buffer), udp::PacketBuffer::new(tx_meta, tx_buffer),
)); ));
Self { stack, handle } Self {
stack: &stack.socket,
handle,
}
} }
pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError> pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError>
@ -60,7 +64,7 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
if endpoint.port == 0 { if endpoint.port == 0 {
// If user didn't specify port allocate a dynamic port. // If user didn't specify port allocate a dynamic port.
endpoint.port = self.stack.socket.borrow_mut().get_local_port(); endpoint.port = self.stack.borrow_mut().get_local_port();
} }
match self.with_mut(|s, _| s.bind(endpoint)) { match self.with_mut(|s, _| s.bind(endpoint)) {
@ -71,13 +75,13 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
} }
fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
let s = &*self.stack.socket.borrow(); let s = &*self.stack.borrow();
let socket = s.sockets.get::<udp::Socket>(self.handle); let socket = s.sockets.get::<udp::Socket>(self.handle);
f(socket, &s.iface) f(socket, &s.iface)
} }
fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
let s = &mut *self.stack.socket.borrow_mut(); let s = &mut *self.stack.borrow_mut();
let socket = s.sockets.get_mut::<udp::Socket>(self.handle); let socket = s.sockets.get_mut::<udp::Socket>(self.handle);
let res = f(socket, &mut s.iface); let res = f(socket, &mut s.iface);
s.waker.wake(); s.waker.wake();
@ -139,29 +143,8 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
} }
} }
#[cfg(feature = "igmp")] impl Drop for UdpSocket<'_> {
impl<'a, D: Driver + smoltcp::phy::Device + 'static> UdpSocket<'a, D> {
pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
where
T: Into<smoltcp::wire::IpAddress>,
{
self.stack.join_multicast_group(addr)
}
pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
where
T: Into<smoltcp::wire::IpAddress>,
{
self.stack.leave_multicast_group(addr)
}
pub fn has_multicast_group<T: Into<smoltcp::wire::IpAddress>>(&self, addr: T) -> bool {
self.stack.has_multicast_group(addr)
}
}
impl<D: Driver> Drop for UdpSocket<'_, D> {
fn drop(&mut self) { fn drop(&mut self) {
self.stack.socket.borrow_mut().sockets.remove(self.handle); self.stack.borrow_mut().sockets.remove(self.handle);
} }
} }