embassy/embassy-net/src/tcp_socket.rs
Dario Nieuwenhuis a2e7c24e00 Clippy fixes
2021-10-18 01:05:29 +02:00

182 lines
5.1 KiB
Rust

use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::task::{Context, Poll};
use embassy::io;
use embassy::io::{AsyncBufRead, AsyncWrite};
use smoltcp::socket::SocketHandle;
use smoltcp::socket::TcpSocket as SyncTcpSocket;
use smoltcp::socket::{TcpSocketBuffer, TcpState};
use smoltcp::time::Duration;
use smoltcp::wire::IpEndpoint;
use super::stack::Stack;
use crate::{Error, Result};
pub struct TcpSocket<'a> {
handle: SocketHandle,
ghost: PhantomData<&'a mut [u8]>,
}
impl<'a> Unpin for TcpSocket<'a> {}
impl<'a> TcpSocket<'a> {
pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
let handle = Stack::with(|stack| {
let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
stack.sockets.add(SyncTcpSocket::new(
TcpSocketBuffer::new(rx_buffer),
TcpSocketBuffer::new(tx_buffer),
))
});
Self {
handle,
ghost: PhantomData,
}
}
pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<()>
where
T: Into<IpEndpoint>,
{
let local_port = Stack::with(|stack| stack.get_local_port());
self.with(|s| s.connect(remote_endpoint, local_port))?;
futures::future::poll_fn(|cx| {
self.with(|s| match s.state() {
TcpState::Closed | TcpState::TimeWait => Poll::Ready(Err(Error::Unaddressable)),
TcpState::Listen => Poll::Ready(Err(Error::Illegal)),
TcpState::SynSent | TcpState::SynReceived => {
s.register_send_waker(cx.waker());
Poll::Pending
}
_ => Poll::Ready(Ok(())),
})
})
.await
}
pub fn set_timeout(&mut self, duration: Option<Duration>) {
self.with(|s| s.set_timeout(duration))
}
pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
self.with(|s| s.set_keep_alive(interval))
}
pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
self.with(|s| s.set_hop_limit(hop_limit))
}
pub fn local_endpoint(&self) -> IpEndpoint {
self.with(|s| s.local_endpoint())
}
pub fn remote_endpoint(&self) -> IpEndpoint {
self.with(|s| s.remote_endpoint())
}
pub fn state(&self) -> TcpState {
self.with(|s| s.state())
}
pub fn close(&mut self) {
self.with(|s| s.close())
}
pub fn abort(&mut self) {
self.with(|s| s.abort())
}
pub fn may_send(&self) -> bool {
self.with(|s| s.may_send())
}
pub fn may_recv(&self) -> bool {
self.with(|s| s.may_recv())
}
fn with<R>(&self, f: impl FnOnce(&mut SyncTcpSocket) -> R) -> R {
Stack::with(|stack| {
let res = {
let mut s = stack.sockets.get::<SyncTcpSocket>(self.handle);
f(&mut *s)
};
stack.wake();
res
})
}
}
fn to_ioerr(_err: Error) -> io::Error {
// todo
io::Error::Other
}
impl<'a> Drop for TcpSocket<'a> {
fn drop(&mut self) {
Stack::with(|stack| {
stack.sockets.remove(self.handle);
})
}
}
impl<'a> AsyncBufRead for TcpSocket<'a> {
fn poll_fill_buf<'z>(
self: Pin<&'z mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&'z [u8]>> {
self.with(|socket| match socket.peek(1 << 30) {
// No data ready
Ok(buf) if buf.is_empty() => {
socket.register_recv_waker(cx.waker());
Poll::Pending
}
// Data ready!
Ok(buf) => {
// Safety:
// - User can't touch the inner TcpSocket directly at all.
// - The socket itself won't touch these bytes until consume() is called, which
// requires the user to release this borrow.
let buf: &'z [u8] = unsafe { core::mem::transmute(&*buf) };
Poll::Ready(Ok(buf))
}
// EOF
Err(Error::Finished) => Poll::Ready(Ok(&[][..])),
// Error
Err(e) => Poll::Ready(Err(to_ioerr(e))),
})
}
fn consume(self: Pin<&mut Self>, amt: usize) {
if amt == 0 {
// smoltcp's recv returns Finished if we're at EOF,
// even if we're "reading" 0 bytes.
return;
}
self.with(|s| s.recv(|_| (amt, ()))).unwrap()
}
}
impl<'a> AsyncWrite for TcpSocket<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.with(|s| match s.send_slice(buf) {
// Not ready to send (no space in the tx buffer)
Ok(0) => {
s.register_send_waker(cx.waker());
Poll::Pending
}
// Some data sent
Ok(n) => Poll::Ready(Ok(n)),
// Error
Err(e) => Poll::Ready(Err(to_ioerr(e))),
})
}
}