diff --git a/src/control.rs b/src/control.rs index 79677b55..0dbf6d44 100644 --- a/src/control.rs +++ b/src/control.rs @@ -1,8 +1,6 @@ -use core::cell::Cell; use core::cmp::{max, min}; use ch::driver::LinkState; -use embassy_futures::yield_now; use embassy_net_driver_channel as ch; use embassy_time::{Duration, Timer}; @@ -10,21 +8,18 @@ pub use crate::bus::SpiBusCyw43; use crate::consts::*; use crate::events::{Event, EventQueue}; use crate::fmt::Bytes; +use crate::ioctl::{IoctlState, IoctlType}; use crate::structs::*; -use crate::{countries, IoctlState, IoctlType, PowerManagementMode}; +use crate::{countries, PowerManagementMode}; pub struct Control<'a> { state_ch: ch::StateRunner<'a>, event_sub: &'a EventQueue, - ioctl_state: &'a Cell, + ioctl_state: &'a IoctlState, } impl<'a> Control<'a> { - pub(crate) fn new( - state_ch: ch::StateRunner<'a>, - event_sub: &'a EventQueue, - ioctl_state: &'a Cell, - ) -> Self { + pub(crate) fn new(state_ch: ch::StateRunner<'a>, event_sub: &'a EventQueue, ioctl_state: &'a IoctlState) -> Self { Self { state_ch, event_sub, @@ -285,21 +280,8 @@ impl<'a> Control<'a> { async fn ioctl(&mut self, kind: IoctlType, cmd: u32, iface: u32, buf: &mut [u8]) -> usize { // TODO cancel ioctl on future drop. - while !matches!(self.ioctl_state.get(), IoctlState::Idle) { - yield_now().await; - } - - self.ioctl_state.set(IoctlState::Pending { kind, cmd, iface, buf }); - - let resp_len = loop { - if let IoctlState::Done { resp_len } = self.ioctl_state.get() { - break resp_len; - } - yield_now().await; - }; - - self.ioctl_state.set(IoctlState::Idle); - + self.ioctl_state.do_ioctl(kind, cmd, iface, buf).await; + let resp_len = self.ioctl_state.wait_complete().await; resp_len } } diff --git a/src/ioctl.rs b/src/ioctl.rs new file mode 100644 index 00000000..6a746559 --- /dev/null +++ b/src/ioctl.rs @@ -0,0 +1,111 @@ +use core::cell::{Cell, RefCell}; +use core::future::poll_fn; +use core::task::{Poll, Waker}; + +use embassy_sync::waitqueue::WakerRegistration; + +#[derive(Clone, Copy)] +pub enum IoctlType { + Get = 0, + Set = 2, +} + +#[derive(Clone, Copy)] +pub struct PendingIoctl { + pub buf: *mut [u8], + pub kind: IoctlType, + pub cmd: u32, + pub iface: u32, +} + +#[derive(Clone, Copy)] +enum IoctlStateInner { + Pending(PendingIoctl), + Sent { buf: *mut [u8] }, + Done { resp_len: usize }, +} + +#[derive(Default)] +struct Wakers { + control: WakerRegistration, + runner: WakerRegistration, +} + +pub struct IoctlState { + state: Cell, + wakers: RefCell, +} + +impl IoctlState { + pub fn new() -> Self { + Self { + state: Cell::new(IoctlStateInner::Done { resp_len: 0 }), + wakers: Default::default(), + } + } + + fn wake_control(&self) { + self.wakers.borrow_mut().control.wake(); + } + + fn register_control(&self, waker: &Waker) { + self.wakers.borrow_mut().control.register(waker); + } + + fn wake_runner(&self) { + self.wakers.borrow_mut().runner.wake(); + } + + fn register_runner(&self, waker: &Waker) { + self.wakers.borrow_mut().runner.register(waker); + } + + pub async fn wait_complete(&self) -> usize { + poll_fn(|cx| { + if let IoctlStateInner::Done { resp_len } = self.state.get() { + Poll::Ready(resp_len) + } else { + self.register_control(cx.waker()); + Poll::Pending + } + }) + .await + } + + pub async fn wait_pending(&self) -> PendingIoctl { + let pending = poll_fn(|cx| { + if let IoctlStateInner::Pending(pending) = self.state.get() { + warn!("found pending ioctl"); + Poll::Ready(pending) + } else { + self.register_runner(cx.waker()); + Poll::Pending + } + }) + .await; + + self.state.set(IoctlStateInner::Sent { buf: pending.buf }); + pending + } + + pub async fn do_ioctl(&self, kind: IoctlType, cmd: u32, iface: u32, buf: &mut [u8]) -> usize { + warn!("doing ioctl"); + self.state + .set(IoctlStateInner::Pending(PendingIoctl { buf, kind, cmd, iface })); + self.wake_runner(); + self.wait_complete().await + } + + pub fn ioctl_done(&self, response: &[u8]) { + if let IoctlStateInner::Sent { buf } = self.state.get() { + warn!("ioctl complete"); + // TODO fix this + (unsafe { &mut *buf }[..response.len()]).copy_from_slice(response); + + self.state.set(IoctlStateInner::Done { + resp_len: response.len(), + }); + self.wake_control(); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index af8f74a6..069ca40f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,17 +11,17 @@ mod bus; mod consts; mod countries; mod events; +mod ioctl; mod structs; mod control; mod nvram; mod runner; -use core::cell::Cell; - use embassy_net_driver_channel as ch; use embedded_hal_1::digital::OutputPin; use events::EventQueue; +use ioctl::IoctlState; use crate::bus::Bus; pub use crate::bus::SpiBusCyw43; @@ -30,12 +30,6 @@ pub use crate::runner::Runner; const MTU: usize = 1514; -#[derive(Clone, Copy)] -pub enum IoctlType { - Get = 0, - Set = 2, -} - #[allow(unused)] #[derive(Clone, Copy, PartialEq, Eq)] enum Core { @@ -106,26 +100,8 @@ const CHIP: Chip = Chip { chanspec_ctl_sb_mask: 0x0700, }; -#[derive(Clone, Copy)] -enum IoctlState { - Idle, - - Pending { - kind: IoctlType, - cmd: u32, - iface: u32, - buf: *mut [u8], - }, - Sent { - buf: *mut [u8], - }, - Done { - resp_len: usize, - }, -} - pub struct State { - ioctl_state: Cell, + ioctl_state: IoctlState, ch: ch::State, events: EventQueue, } @@ -133,7 +109,7 @@ pub struct State { impl State { pub fn new() -> Self { Self { - ioctl_state: Cell::new(IoctlState::Idle), + ioctl_state: IoctlState::new(), ch: ch::State::new(), events: EventQueue::new(), } diff --git a/src/runner.rs b/src/runner.rs index 9945af3f..4abccf48 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -1,6 +1,6 @@ -use core::cell::Cell; use core::slice; +use embassy_futures::select::{select3, Either3}; use embassy_futures::yield_now; use embassy_net_driver_channel as ch; use embassy_sync::pubsub::PubSubBehavior; @@ -12,9 +12,10 @@ pub use crate::bus::SpiBusCyw43; use crate::consts::*; use crate::events::{EventQueue, EventStatus}; use crate::fmt::Bytes; +use crate::ioctl::{IoctlState, IoctlType, PendingIoctl}; use crate::nvram::NVRAM; use crate::structs::*; -use crate::{events, Core, IoctlState, IoctlType, CHIP, MTU}; +use crate::{events, Core, CHIP, MTU}; #[cfg(feature = "firmware-logs")] struct LogState { @@ -40,7 +41,7 @@ pub struct Runner<'a, PWR, SPI> { ch: ch::Runner<'a, MTU>, bus: Bus, - ioctl_state: &'a Cell, + ioctl_state: &'a IoctlState, ioctl_id: u16, sdpcm_seq: u8, sdpcm_seq_max: u8, @@ -59,7 +60,7 @@ where pub(crate) fn new( ch: ch::Runner<'a, MTU>, bus: Bus, - ioctl_state: &'a Cell, + ioctl_state: &'a IoctlState, events: &'a EventQueue, ) -> Self { Self { @@ -226,19 +227,22 @@ where #[cfg(feature = "firmware-logs")] self.log_read().await; - // Send stuff - // TODO flow control not yet complete - if !self.has_credit() { - warn!("TX stalled"); - } else { - if let IoctlState::Pending { kind, cmd, iface, buf } = self.ioctl_state.get() { - self.send_ioctl(kind, cmd, iface, unsafe { &*buf }).await; - self.ioctl_state.set(IoctlState::Sent { buf }); - } - if !self.has_credit() { - warn!("TX stalled"); - } else { - if let Some(packet) = self.ch.try_tx_buf() { + let ev = || async { + // TODO use IRQs + yield_now().await; + }; + + if self.has_credit() { + let ioctl = self.ioctl_state.wait_pending(); + let tx = self.ch.tx_buf(); + + match select3(ioctl, tx, ev()).await { + Either3::First(PendingIoctl { buf, kind, cmd, iface }) => { + warn!("ioctl"); + self.send_ioctl(kind, cmd, iface, unsafe { &*buf }).await; + } + Either3::Second(packet) => { + warn!("packet"); trace!("tx pkt {:02x}", Bytes(&packet[..packet.len().min(48)])); let mut buf = [0; 512]; @@ -281,28 +285,46 @@ where self.bus.wlan_write(&buf[..(total_len / 4)]).await; self.ch.tx_done(); } + Either3::Third(()) => { + // Receive stuff + let irq = self.bus.read16(FUNC_BUS, REG_BUS_INTERRUPT).await; + + if irq & IRQ_F2_PACKET_AVAILABLE != 0 { + let mut status = 0xFFFF_FFFF; + while status == 0xFFFF_FFFF { + status = self.bus.read32(FUNC_BUS, REG_BUS_STATUS).await; + } + + if status & STATUS_F2_PKT_AVAILABLE != 0 { + let len = (status & STATUS_F2_PKT_LEN_MASK) >> STATUS_F2_PKT_LEN_SHIFT; + self.bus.wlan_read(&mut buf, len).await; + trace!("rx {:02x}", Bytes(&slice8_mut(&mut buf)[..(len as usize).min(48)])); + self.rx(&slice8_mut(&mut buf)[..len as usize]); + } + } + } + } + } else { + warn!("TX stalled"); + ev().await; + + // Receive stuff + let irq = self.bus.read16(FUNC_BUS, REG_BUS_INTERRUPT).await; + + if irq & IRQ_F2_PACKET_AVAILABLE != 0 { + let mut status = 0xFFFF_FFFF; + while status == 0xFFFF_FFFF { + status = self.bus.read32(FUNC_BUS, REG_BUS_STATUS).await; + } + + if status & STATUS_F2_PKT_AVAILABLE != 0 { + let len = (status & STATUS_F2_PKT_LEN_MASK) >> STATUS_F2_PKT_LEN_SHIFT; + self.bus.wlan_read(&mut buf, len).await; + trace!("rx {:02x}", Bytes(&slice8_mut(&mut buf)[..(len as usize).min(48)])); + self.rx(&slice8_mut(&mut buf)[..len as usize]); + } } } - - // Receive stuff - let irq = self.bus.read16(FUNC_BUS, REG_BUS_INTERRUPT).await; - - if irq & IRQ_F2_PACKET_AVAILABLE != 0 { - let mut status = 0xFFFF_FFFF; - while status == 0xFFFF_FFFF { - status = self.bus.read32(FUNC_BUS, REG_BUS_STATUS).await; - } - - if status & STATUS_F2_PKT_AVAILABLE != 0 { - let len = (status & STATUS_F2_PKT_LEN_MASK) >> STATUS_F2_PKT_LEN_SHIFT; - self.bus.wlan_read(&mut buf, len).await; - trace!("rx {:02x}", Bytes(&slice8_mut(&mut buf)[..(len as usize).min(48)])); - self.rx(&slice8_mut(&mut buf)[..len as usize]); - } - } - - // TODO use IRQs - yield_now().await; } } @@ -340,19 +362,17 @@ where let cdc_header = CdcHeader::from_bytes(payload[..CdcHeader::SIZE].try_into().unwrap()); trace!(" {:?}", cdc_header); - if let IoctlState::Sent { buf } = self.ioctl_state.get() { - if cdc_header.id == self.ioctl_id { - if cdc_header.status != 0 { - // TODO: propagate error instead - panic!("IOCTL error {}", cdc_header.status as i32); - } - - let resp_len = cdc_header.len as usize; - info!("IOCTL Response: {:02x}", Bytes(&payload[CdcHeader::SIZE..][..resp_len])); - - (unsafe { &mut *buf }[..resp_len]).copy_from_slice(&payload[CdcHeader::SIZE..][..resp_len]); - self.ioctl_state.set(IoctlState::Done { resp_len }); + if cdc_header.id == self.ioctl_id { + if cdc_header.status != 0 { + // TODO: propagate error instead + panic!("IOCTL error {}", cdc_header.status as i32); } + + let resp_len = cdc_header.len as usize; + let response = &payload[CdcHeader::SIZE..][..resp_len]; + info!("IOCTL Response: {:02x}", Bytes(response)); + + self.ioctl_state.ioctl_done(response); } } CHANNEL_TYPE_EVENT => {