diff --git a/embassy-net-esp-hosted/src/control.rs b/embassy-net-esp-hosted/src/control.rs index 37f220da..6cd57f68 100644 --- a/embassy-net-esp-hosted/src/control.rs +++ b/embassy-net-esp-hosted/src/control.rs @@ -5,9 +5,12 @@ use heapless::String; use crate::ioctl::Shared; use crate::proto::{self, CtrlMsg}; -#[derive(Debug)] -pub struct Error { - pub status: u32, +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + Failed(u32), + Timeout, + Internal, } pub struct Control<'a> { @@ -23,58 +26,78 @@ enum WifiMode { ApSta = 3, } +macro_rules! ioctl { + ($self:ident, $req_variant:ident, $resp_variant:ident, $req:ident, $resp:ident) => { + let mut msg = proto::CtrlMsg { + msg_id: proto::CtrlMsgId::$req_variant as _, + msg_type: proto::CtrlMsgType::Req as _, + payload: Some(proto::CtrlMsgPayload::$req_variant($req)), + }; + $self.ioctl(&mut msg).await?; + let Some(proto::CtrlMsgPayload::$resp_variant($resp)) = msg.payload else { + warn!("unexpected response variant"); + return Err(Error::Internal); + }; + if $resp.resp != 0 { + return Err(Error::Failed($resp.resp)); + } + }; +} + impl<'a> Control<'a> { pub(crate) fn new(state_ch: ch::StateRunner<'a>, shared: &'a Shared) -> Self { Self { state_ch, shared } } - pub async fn init(&mut self) { + pub async fn init(&mut self) -> Result<(), Error> { debug!("wait for init event..."); self.shared.init_wait().await; - debug!("set wifi mode"); - self.set_wifi_mode(WifiMode::Sta as _).await; + debug!("set heartbeat"); + self.set_heartbeat(10).await?; - let mac_addr = self.get_mac_addr().await; + debug!("set wifi mode"); + self.set_wifi_mode(WifiMode::Sta as _).await?; + + let mac_addr = self.get_mac_addr().await?; debug!("mac addr: {:02x}", mac_addr); self.state_ch.set_ethernet_address(mac_addr); + + Ok(()) } - pub async fn join(&mut self, ssid: &str, password: &str) { - let req = proto::CtrlMsg { - msg_id: proto::CtrlMsgId::ReqConnectAp as _, - msg_type: proto::CtrlMsgType::Req as _, - payload: Some(proto::CtrlMsgPayload::ReqConnectAp(proto::CtrlMsgReqConnectAp { - ssid: String::from(ssid), - pwd: String::from(password), - bssid: String::new(), - listen_interval: 3, - is_wpa3_supported: false, - })), + pub async fn connect(&mut self, ssid: &str, password: &str) -> Result<(), Error> { + let req = proto::CtrlMsgReqConnectAp { + ssid: String::from(ssid), + pwd: String::from(password), + bssid: String::new(), + listen_interval: 3, + is_wpa3_supported: false, }; - let resp = self.ioctl(req).await; - let proto::CtrlMsgPayload::RespConnectAp(resp) = resp.payload.unwrap() else { - panic!("unexpected resp") - }; - assert_eq!(resp.resp, 0); + ioctl!(self, ReqConnectAp, RespConnectAp, req, resp); self.state_ch.set_link_state(LinkState::Up); + Ok(()) } - async fn get_mac_addr(&mut self) -> [u8; 6] { - let req = proto::CtrlMsg { - msg_id: proto::CtrlMsgId::ReqGetMacAddress as _, - msg_type: proto::CtrlMsgType::Req as _, - payload: Some(proto::CtrlMsgPayload::ReqGetMacAddress( - proto::CtrlMsgReqGetMacAddress { - mode: WifiMode::Sta as _, - }, - )), + pub async fn disconnect(&mut self) -> Result<(), Error> { + let req = proto::CtrlMsgReqGetStatus {}; + ioctl!(self, ReqDisconnectAp, RespDisconnectAp, req, resp); + self.state_ch.set_link_state(LinkState::Up); + Ok(()) + } + + /// duration in seconds, clamped to [10, 3600] + async fn set_heartbeat(&mut self, duration: u32) -> Result<(), Error> { + let req = proto::CtrlMsgReqConfigHeartbeat { enable: true, duration }; + ioctl!(self, ReqConfigHeartbeat, RespConfigHeartbeat, req, resp); + Ok(()) + } + + async fn get_mac_addr(&mut self) -> Result<[u8; 6], Error> { + let req = proto::CtrlMsgReqGetMacAddress { + mode: WifiMode::Sta as _, }; - let resp = self.ioctl(req).await; - let proto::CtrlMsgPayload::RespGetMacAddress(resp) = resp.payload.unwrap() else { - panic!("unexpected resp") - }; - assert_eq!(resp.resp, 0); + ioctl!(self, ReqGetMacAddress, RespGetMacAddress, req, resp); // WHY IS THIS A STRING? WHYYYY fn nibble_from_hex(b: u8) -> u8 { @@ -88,32 +111,32 @@ impl<'a> Control<'a> { let mac = resp.mac.as_bytes(); let mut res = [0; 6]; - assert_eq!(mac.len(), 17); + if mac.len() != 17 { + warn!("unexpected MAC respnse length"); + return Err(Error::Internal); + } for (i, b) in res.iter_mut().enumerate() { *b = (nibble_from_hex(mac[i * 3]) << 4) | nibble_from_hex(mac[i * 3 + 1]) } - res + Ok(res) } - async fn set_wifi_mode(&mut self, mode: u32) { - let req = proto::CtrlMsg { - msg_id: proto::CtrlMsgId::ReqSetWifiMode as _, - msg_type: proto::CtrlMsgType::Req as _, - payload: Some(proto::CtrlMsgPayload::ReqSetWifiMode(proto::CtrlMsgReqSetMode { mode })), - }; - let resp = self.ioctl(req).await; - let proto::CtrlMsgPayload::RespSetWifiMode(resp) = resp.payload.unwrap() else { - panic!("unexpected resp") - }; - assert_eq!(resp.resp, 0); + async fn set_wifi_mode(&mut self, mode: u32) -> Result<(), Error> { + let req = proto::CtrlMsgReqSetMode { mode }; + ioctl!(self, ReqSetWifiMode, RespSetWifiMode, req, resp); + + Ok(()) } - async fn ioctl(&mut self, req: CtrlMsg) -> CtrlMsg { - debug!("ioctl req: {:?}", &req); + async fn ioctl(&mut self, msg: &mut CtrlMsg) -> Result<(), Error> { + debug!("ioctl req: {:?}", &msg); let mut buf = [0u8; 128]; - let req_len = noproto::write(&req, &mut buf).unwrap(); + let req_len = noproto::write(msg, &mut buf).map_err(|_| { + warn!("failed to serialize control request"); + Error::Internal + })?; struct CancelOnDrop<'a>(&'a Shared); @@ -135,9 +158,12 @@ impl<'a> Control<'a> { ioctl.defuse(); - let res = noproto::read(&buf[..resp_len]).unwrap(); - debug!("ioctl resp: {:?}", &res); + *msg = noproto::read(&buf[..resp_len]).map_err(|_| { + warn!("failed to serialize control request"); + Error::Internal + })?; + debug!("ioctl resp: {:?}", msg); - res + Ok(()) } } diff --git a/embassy-net-esp-hosted/src/lib.rs b/embassy-net-esp-hosted/src/lib.rs index 96fddce5..4a318b20 100644 --- a/embassy-net-esp-hosted/src/lib.rs +++ b/embassy-net-esp-hosted/src/lib.rs @@ -1,17 +1,15 @@ #![no_std] -use control::Control; -use embassy_futures::select::{select3, Either3}; +use embassy_futures::select::{select4, Either4}; use embassy_net_driver_channel as ch; +use embassy_net_driver_channel::driver::LinkState; use embassy_time::{Duration, Instant, Timer}; use embedded_hal::digital::{InputPin, OutputPin}; use embedded_hal_async::digital::Wait; use embedded_hal_async::spi::SpiDevice; -use ioctl::Shared; -use proto::CtrlMsg; -use crate::ioctl::PendingIoctl; -use crate::proto::CtrlMsgPayload; +use crate::ioctl::{PendingIoctl, Shared}; +use crate::proto::{CtrlMsg, CtrlMsgPayload}; mod proto; @@ -21,6 +19,8 @@ mod fmt; mod control; mod ioctl; +pub use control::*; + const MTU: usize = 1514; macro_rules! impl_bytes { @@ -95,6 +95,7 @@ enum InterfaceType { } const MAX_SPI_BUFFER_SIZE: usize = 1600; +const HEARTBEAT_MAX_GAP: Duration = Duration::from_secs(20); pub struct State { shared: Shared, @@ -129,12 +130,14 @@ where let mut runner = Runner { ch: ch_runner, + state_ch, shared: &state.shared, next_seq: 1, handshake, ready, reset, spi, + heartbeat_deadline: Instant::now() + HEARTBEAT_MAX_GAP, }; runner.init().await; @@ -143,9 +146,11 @@ where pub struct Runner<'a, SPI, IN, OUT> { ch: ch::Runner<'a, MTU>, + state_ch: ch::StateRunner<'a>, shared: &'a Shared, next_seq: u16, + heartbeat_deadline: Instant, spi: SPI, handshake: IN, @@ -177,9 +182,10 @@ where let ioctl = self.shared.ioctl_wait_pending(); let tx = self.ch.tx_buf(); let ev = async { self.ready.wait_for_high().await.unwrap() }; + let hb = Timer::at(self.heartbeat_deadline); - match select3(ioctl, tx, ev).await { - Either3::First(PendingIoctl { buf, req_len }) => { + match select4(ioctl, tx, ev, hb).await { + Either4::First(PendingIoctl { buf, req_len }) => { tx_buf[12..24].copy_from_slice(b"\x01\x08\x00ctrlResp\x02"); tx_buf[24..26].copy_from_slice(&(req_len as u16).to_le_bytes()); tx_buf[26..][..req_len].copy_from_slice(&unsafe { &*buf }[..req_len]); @@ -198,7 +204,7 @@ where header.checksum = checksum(&tx_buf[..26 + req_len]); tx_buf[0..12].copy_from_slice(&header.to_bytes()); } - Either3::Second(packet) => { + Either4::Second(packet) => { tx_buf[12..][..packet.len()].copy_from_slice(packet); let mut header = PayloadHeader { @@ -217,9 +223,12 @@ where self.ch.tx_done(); } - Either3::Third(()) => { + Either4::Third(()) => { tx_buf[..PayloadHeader::SIZE].fill(0); } + Either4::Fourth(()) => { + panic!("heartbeat from esp32 stopped") + } } if tx_buf[0] != 0 { @@ -308,7 +317,7 @@ where } } - fn handle_event(&self, data: &[u8]) { + fn handle_event(&mut self, data: &[u8]) { let Ok(event) = noproto::read::(data) else { warn!("failed to parse event"); return; @@ -323,6 +332,11 @@ where match payload { CtrlMsgPayload::EventEspInit(_) => self.shared.init_done(), + CtrlMsgPayload::EventHeartbeat(_) => self.heartbeat_deadline = Instant::now() + HEARTBEAT_MAX_GAP, + CtrlMsgPayload::EventStationDisconnectFromAp(e) => { + info!("disconnected, code {}", e.resp); + self.state_ch.set_link_state(LinkState::Down); + } _ => {} } } diff --git a/examples/nrf52840/src/bin/wifi_esp_hosted.rs b/examples/nrf52840/src/bin/wifi_esp_hosted.rs index e114e50b..a60822fd 100644 --- a/examples/nrf52840/src/bin/wifi_esp_hosted.rs +++ b/examples/nrf52840/src/bin/wifi_esp_hosted.rs @@ -72,8 +72,8 @@ async fn main(spawner: Spawner) { unwrap!(spawner.spawn(wifi_task(runner))); - control.init().await; - control.join(WIFI_NETWORK, WIFI_PASSWORD).await; + unwrap!(control.init().await); + unwrap!(control.connect(WIFI_NETWORK, WIFI_PASSWORD).await); let config = embassy_net::Config::dhcpv4(Default::default()); // let config = embassy_net::Config::ipv4_static(embassy_net::StaticConfigV4 { diff --git a/tests/nrf/src/bin/wifi_esp_hosted_perf.rs b/tests/nrf/src/bin/wifi_esp_hosted_perf.rs index ee46af2a..97ebafec 100644 --- a/tests/nrf/src/bin/wifi_esp_hosted_perf.rs +++ b/tests/nrf/src/bin/wifi_esp_hosted_perf.rs @@ -73,8 +73,8 @@ async fn main(spawner: Spawner) { unwrap!(spawner.spawn(wifi_task(runner))); - control.init().await; - control.join(WIFI_NETWORK, WIFI_PASSWORD).await; + unwrap!(control.init().await); + unwrap!(control.connect(WIFI_NETWORK, WIFI_PASSWORD).await); // Generate random seed let mut rng = Rng::new(p.RNG, Irqs);