embassy/embassy-net/src/tcp_socket.rs

204 lines
5.9 KiB
Rust
Raw Normal View History

2021-02-03 05:09:37 +01:00
use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::task::{Context, Poll};
use embassy::io;
use embassy::io::{AsyncBufRead, AsyncWrite};
2021-11-26 04:12:14 +01:00
use smoltcp::iface::{Context as SmolContext, SocketHandle};
2021-02-03 05:09:37 +01:00
use smoltcp::socket::TcpSocket as SyncTcpSocket;
use smoltcp::socket::{TcpSocketBuffer, TcpState};
use smoltcp::time::Duration;
use smoltcp::wire::IpEndpoint;
use super::stack::Stack;
2021-02-12 01:48:21 +01:00
use crate::{Error, Result};
2021-02-03 05:09:37 +01:00
pub struct TcpSocket<'a> {
handle: SocketHandle,
ghost: PhantomData<&'a mut [u8]>,
}
impl<'a> Unpin for TcpSocket<'a> {}
impl<'a> TcpSocket<'a> {
pub fn new(rx_buffer: &'a mut [u8], tx_buffer: &'a mut [u8]) -> Self {
let handle = Stack::with(|stack| {
let rx_buffer: &'static mut [u8] = unsafe { mem::transmute(rx_buffer) };
let tx_buffer: &'static mut [u8] = unsafe { mem::transmute(tx_buffer) };
2021-11-26 04:12:14 +01:00
stack.iface.add_socket(SyncTcpSocket::new(
2021-02-03 05:09:37 +01:00
TcpSocketBuffer::new(rx_buffer),
TcpSocketBuffer::new(tx_buffer),
))
});
Self {
handle,
ghost: PhantomData,
}
}
pub async fn connect<T>(&mut self, remote_endpoint: T) -> Result<()>
where
T: Into<IpEndpoint>,
{
let local_port = Stack::with(|stack| stack.get_local_port());
2021-11-26 04:12:14 +01:00
self.with(|s, cx| s.connect(cx, remote_endpoint, local_port))?;
2021-02-03 05:09:37 +01:00
futures::future::poll_fn(|cx| {
2021-11-26 04:12:14 +01:00
self.with(|s, _| match s.state() {
2021-02-03 05:09:37 +01:00
TcpState::Closed | TcpState::TimeWait => Poll::Ready(Err(Error::Unaddressable)),
TcpState::Listen => Poll::Ready(Err(Error::Illegal)),
TcpState::SynSent | TcpState::SynReceived => {
s.register_send_waker(cx.waker());
Poll::Pending
}
_ => Poll::Ready(Ok(())),
})
})
.await
}
2021-11-04 13:34:13 +01:00
pub async fn listen<T>(&mut self, local_endpoint: T) -> Result<()>
where
T: Into<IpEndpoint>,
{
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.listen(local_endpoint))?;
2021-11-04 13:34:13 +01:00
futures::future::poll_fn(|cx| {
2021-11-26 04:12:14 +01:00
self.with(|s, _| match s.state() {
2021-11-04 13:34:13 +01:00
TcpState::Closed | TcpState::TimeWait => Poll::Ready(Err(Error::Unaddressable)),
TcpState::Listen => Poll::Ready(Ok(())),
TcpState::SynSent | TcpState::SynReceived => {
s.register_send_waker(cx.waker());
Poll::Pending
}
_ => Poll::Ready(Ok(())),
})
})
.await
}
2021-02-03 05:09:37 +01:00
pub fn set_timeout(&mut self, duration: Option<Duration>) {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.set_timeout(duration))
2021-02-03 05:09:37 +01:00
}
pub fn set_keep_alive(&mut self, interval: Option<Duration>) {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.set_keep_alive(interval))
2021-02-03 05:09:37 +01:00
}
pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.set_hop_limit(hop_limit))
2021-02-03 05:09:37 +01:00
}
pub fn local_endpoint(&self) -> IpEndpoint {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.local_endpoint())
2021-02-03 05:09:37 +01:00
}
pub fn remote_endpoint(&self) -> IpEndpoint {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.remote_endpoint())
2021-02-03 05:09:37 +01:00
}
pub fn state(&self) -> TcpState {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.state())
2021-02-03 05:09:37 +01:00
}
pub fn close(&mut self) {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.close())
2021-02-03 05:09:37 +01:00
}
pub fn abort(&mut self) {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.abort())
2021-02-03 05:09:37 +01:00
}
pub fn may_send(&self) -> bool {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.may_send())
2021-02-03 05:09:37 +01:00
}
pub fn may_recv(&self) -> bool {
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.may_recv())
2021-02-03 05:09:37 +01:00
}
2021-11-26 04:12:14 +01:00
fn with<R>(&self, f: impl FnOnce(&mut SyncTcpSocket, &mut SmolContext) -> R) -> R {
2021-02-03 05:09:37 +01:00
Stack::with(|stack| {
let res = {
2021-11-26 04:12:14 +01:00
let (s, cx) = stack
.iface
.get_socket_and_context::<SyncTcpSocket>(self.handle);
f(s, cx)
2021-02-03 05:09:37 +01:00
};
stack.wake();
res
})
}
}
2021-03-02 21:20:00 +01:00
fn to_ioerr(_err: Error) -> io::Error {
2021-02-03 05:09:37 +01:00
// todo
io::Error::Other
}
impl<'a> Drop for TcpSocket<'a> {
fn drop(&mut self) {
Stack::with(|stack| {
2021-11-26 04:12:14 +01:00
stack.iface.remove_socket(self.handle);
2021-02-03 05:09:37 +01:00
})
}
}
impl<'a> AsyncBufRead for TcpSocket<'a> {
fn poll_fill_buf<'z>(
self: Pin<&'z mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&'z [u8]>> {
2021-11-26 04:12:14 +01:00
self.with(|s, _| match s.peek(1 << 30) {
2021-02-03 05:09:37 +01:00
// No data ready
2021-10-18 00:55:43 +02:00
Ok(buf) if buf.is_empty() => {
2021-11-26 04:12:14 +01:00
s.register_recv_waker(cx.waker());
2021-02-03 05:09:37 +01:00
Poll::Pending
}
// Data ready!
Ok(buf) => {
// Safety:
// - User can't touch the inner TcpSocket directly at all.
// - The socket itself won't touch these bytes until consume() is called, which
// requires the user to release this borrow.
let buf: &'z [u8] = unsafe { core::mem::transmute(&*buf) };
Poll::Ready(Ok(buf))
}
// EOF
Err(Error::Finished) => Poll::Ready(Ok(&[][..])),
// Error
Err(e) => Poll::Ready(Err(to_ioerr(e))),
})
}
fn consume(self: Pin<&mut Self>, amt: usize) {
if amt == 0 {
// smoltcp's recv returns Finished if we're at EOF,
// even if we're "reading" 0 bytes.
return;
}
2021-11-26 04:12:14 +01:00
self.with(|s, _| s.recv(|_| (amt, ()))).unwrap()
2021-02-03 05:09:37 +01:00
}
}
impl<'a> AsyncWrite for TcpSocket<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
2021-11-26 04:12:14 +01:00
self.with(|s, _| match s.send_slice(buf) {
2021-02-03 05:09:37 +01:00
// Not ready to send (no space in the tx buffer)
Ok(0) => {
s.register_send_waker(cx.waker());
Poll::Pending
}
// Some data sent
Ok(n) => Poll::Ready(Ok(n)),
// Error
Err(e) => Poll::Ready(Err(to_ioerr(e))),
})
}
}