diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index 124316a2..df0efa51 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -443,8 +443,23 @@ unsafe fn write_dma(i: usize, buf: &[u8]) { impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { type ReadFuture<'a> = impl Future> + 'a where Self: 'a; + type DataReadyFuture<'a> = impl Future + 'a where Self: 'a; fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { + async move { + let i = self.info.addr.index(); + assert!(i != 0); + + self.wait_data_ready().await; + + // Mark as not ready + READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); + + unsafe { read_dma::(i, buf) } + } + } + + fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a> { async move { let i = self.info.addr.index(); assert!(i != 0); @@ -460,11 +475,6 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { } }) .await; - - // Mark as not ready - READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); - - unsafe { read_dma::(i, buf) } } } } diff --git a/embassy-usb-hid/src/async_lease.rs b/embassy-usb-hid/src/async_lease.rs new file mode 100644 index 00000000..0971daa2 --- /dev/null +++ b/embassy-usb-hid/src/async_lease.rs @@ -0,0 +1,90 @@ +use core::cell::Cell; +use core::future::Future; +use core::task::{Poll, Waker}; + +enum AsyncLeaseState { + Empty, + Waiting(*mut u8, usize, Waker), + Done(usize), +} + +impl Default for AsyncLeaseState { + fn default() -> Self { + AsyncLeaseState::Empty + } +} + +#[derive(Default)] +pub struct AsyncLease { + state: Cell, +} + +pub struct AsyncLeaseFuture<'a> { + buf: &'a mut [u8], + state: &'a Cell, +} + +impl<'a> Drop for AsyncLeaseFuture<'a> { + fn drop(&mut self) { + self.state.set(AsyncLeaseState::Empty); + } +} + +impl<'a> Future for AsyncLeaseFuture<'a> { + type Output = usize; + + fn poll( + mut self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> Poll { + match self.state.take() { + AsyncLeaseState::Done(len) => Poll::Ready(len), + state => { + if let AsyncLeaseState::Waiting(ptr, _, _) = state { + assert_eq!( + ptr, + self.buf.as_mut_ptr(), + "lend() called on a busy AsyncLease." + ); + } + + self.state.set(AsyncLeaseState::Waiting( + self.buf.as_mut_ptr(), + self.buf.len(), + cx.waker().clone(), + )); + Poll::Pending + } + } + } +} + +pub struct AsyncLeaseNotReady {} + +impl AsyncLease { + pub fn new() -> Self { + Default::default() + } + + pub fn try_borrow_mut usize>( + &self, + f: F, + ) -> Result<(), AsyncLeaseNotReady> { + if let AsyncLeaseState::Waiting(data, len, waker) = self.state.take() { + let buf = unsafe { core::slice::from_raw_parts_mut(data, len) }; + let len = f(buf); + self.state.set(AsyncLeaseState::Done(len)); + waker.wake(); + Ok(()) + } else { + Err(AsyncLeaseNotReady {}) + } + } + + pub fn lend<'a>(&'a self, buf: &'a mut [u8]) -> AsyncLeaseFuture<'a> { + AsyncLeaseFuture { + buf, + state: &self.state, + } + } +} diff --git a/embassy-usb-hid/src/lib.rs b/embassy-usb-hid/src/lib.rs index 8bc9efdb..43e67880 100644 --- a/embassy-usb-hid/src/lib.rs +++ b/embassy-usb-hid/src/lib.rs @@ -9,7 +9,7 @@ pub(crate) mod fmt; use core::mem::MaybeUninit; -use embassy::channel::signal::Signal; +use async_lease::AsyncLease; use embassy::time::Duration; use embassy_usb::driver::{EndpointOut, ReadError}; use embassy_usb::{ @@ -24,6 +24,8 @@ use ssmarshal::serialize; #[cfg(feature = "usbd-hid")] use usbd_hid::descriptor::AsInputReport; +mod async_lease; + const USB_CLASS_HID: u8 = 0x03; const USB_SUBCLASS_NONE: u8 = 0x00; const USB_PROTOCOL_NONE: u8 = 0x00; @@ -61,15 +63,15 @@ impl ReportId { } pub struct State<'a, const IN_N: usize, const OUT_N: usize> { - control: MaybeUninit>, - out_signal: Signal<(usize, [u8; OUT_N])>, + control: MaybeUninit>, + lease: AsyncLease, } impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { pub fn new() -> Self { State { control: MaybeUninit::uninit(), - out_signal: Signal::new(), + lease: AsyncLease::new(), } } } @@ -139,22 +141,22 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize> poll_ms: u8, max_packet_size: u16, ) -> Self { - let ep_out = Some(builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms)); + let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); let control = state.control.write(Control::new( report_descriptor, - Some(&state.out_signal), + Some(&state.lease), request_handler, )); - control.build(builder, ep_out.as_ref(), &ep_in); + control.build(builder, Some(&ep_out), &ep_in); Self { input: ReportWriter { ep_in }, output: ReportReader { ep_out, - receiver: &state.out_signal, + lease: &state.lease, }, } } @@ -175,8 +177,8 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> { } pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { - ep_out: Option, - receiver: &'d Signal<(usize, [u8; N])>, + ep_out: D::EndpointOut, + lease: &'d AsyncLease, } impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { @@ -216,41 +218,29 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { pub async fn read(&mut self, buf: &mut [u8]) -> Result { assert!(buf.len() >= N); - if let Some(ep) = &mut self.ep_out { - let max_packet_size = usize::from(ep.info().max_packet_size); - let mut chunks = buf.chunks_mut(max_packet_size); - - // Wait until we've received a chunk from the endpoint or a report from a SET_REPORT control request - let (mut total, data) = { - let chunk = unwrap!(chunks.next()); - let fut1 = ep.read(chunk); - pin_mut!(fut1); - match select(fut1, self.receiver.wait()).await { - Either::Left((Ok(size), _)) => (size, None), - Either::Left((Err(err), _)) => return Err(err), - Either::Right(((size, data), _)) => (size, Some(data)), - } - }; - - if let Some(data) = data { - buf[0..total].copy_from_slice(&data[0..total]); - Ok(total) - } else { - for chunk in chunks { - let size = ep.read(chunk).await?; - total += size; - if size < max_packet_size || total == N { - break; - } - } - Ok(total) + // Wait until a packet is ready to read from the endpoint or a SET_REPORT control request is received + { + let data_ready = self.ep_out.wait_data_ready(); + pin_mut!(data_ready); + match select(data_ready, self.lease.lend(buf)).await { + Either::Left(_) => (), + Either::Right((len, _)) => return Ok(len), } - } else { - let (total, data) = self.receiver.wait().await; - buf[0..total].copy_from_slice(&data[0..total]); - Ok(total) } + + // Read packets from the endpoint + let max_packet_size = usize::from(self.ep_out.info().max_packet_size); + let mut total = 0; + for chunk in buf.chunks_mut(max_packet_size) { + let size = self.ep_out.read(chunk).await?; + total += size; + if size < max_packet_size || total == N { + break; + } + } + + Ok(total) } } @@ -292,22 +282,22 @@ pub trait RequestHandler { } } -struct Control<'d, const OUT_N: usize> { +struct Control<'d> { report_descriptor: &'static [u8], - out_signal: Option<&'d Signal<(usize, [u8; OUT_N])>>, + out_lease: Option<&'d AsyncLease>, request_handler: Option<&'d dyn RequestHandler>, hid_descriptor: [u8; 9], } -impl<'a, const OUT_N: usize> Control<'a, OUT_N> { +impl<'a> Control<'a> { fn new( report_descriptor: &'static [u8], - out_signal: Option<&'a Signal<(usize, [u8; OUT_N])>>, + out_lease: Option<&'a AsyncLease>, request_handler: Option<&'a dyn RequestHandler>, ) -> Self { Control { report_descriptor, - out_signal, + out_lease, request_handler, hid_descriptor: [ // Length of buf inclusive of size prefix @@ -372,7 +362,7 @@ impl<'a, const OUT_N: usize> Control<'a, OUT_N> { } } -impl<'d, const OUT_N: usize> ControlHandler for Control<'d, OUT_N> { +impl<'d> ControlHandler for Control<'d> { fn reset(&mut self) {} fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse { @@ -395,17 +385,21 @@ impl<'d, const OUT_N: usize> ControlHandler for Control<'d, OUT_N> { } HID_REQ_SET_REPORT => match ( ReportId::try_from(req.value), - self.out_signal, + self.out_lease, self.request_handler.as_ref(), ) { - (Ok(ReportId::Out(_)), Some(signal), _) => { - let mut buf = [0; OUT_N]; - buf[0..data.len()].copy_from_slice(data); - if signal.signaled() { - warn!("Output report dropped before being read!"); + (Ok(ReportId::Out(_)), Some(lease), _) => { + match lease.try_borrow_mut(|buf| { + let len = buf.len().min(data.len()); + buf[0..len].copy_from_slice(&data[0..len]); + len + }) { + Ok(()) => OutResponse::Accepted, + Err(_) => { + warn!("SET_REPORT received for output report with no reader listening."); + OutResponse::Rejected + } } - signal.signal((data.len(), buf)); - OutResponse::Accepted } (Ok(id), _, Some(handler)) => handler.set_report(id, data), _ => OutResponse::Rejected, diff --git a/embassy-usb/src/driver.rs b/embassy-usb/src/driver.rs index 82b59bd1..03e39b8c 100644 --- a/embassy-usb/src/driver.rs +++ b/embassy-usb/src/driver.rs @@ -120,6 +120,9 @@ pub trait Endpoint { pub trait EndpointOut: Endpoint { type ReadFuture<'a>: Future> + 'a + where + Self: 'a; + type DataReadyFuture<'a>: Future + 'a where Self: 'a; @@ -128,6 +131,11 @@ pub trait EndpointOut: Endpoint { /// /// This should also clear any NAK flags and prepare the endpoint to receive the next packet. fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>; + + /// Waits until a packet of data is ready to be read from the endpoint. + /// + /// A call to[`read()`](Self::read()) after this future completes should not block. + fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a>; } pub trait ControlPipe {