Reduce memory overhead and simplify logic for merging endpoint and control request output reports.

This commit is contained in:
alexmoon 2022-04-01 15:48:37 -04:00 committed by Dario Nieuwenhuis
parent c309531874
commit c8ad82057d
4 changed files with 163 additions and 61 deletions

View File

@ -443,8 +443,23 @@ unsafe fn write_dma<T: Instance>(i: usize, buf: &[u8]) {
impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> {
type ReadFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a; type ReadFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a;
type DataReadyFuture<'a> = impl Future<Output = ()> + 'a where Self: 'a;
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
async move {
let i = self.info.addr.index();
assert!(i != 0);
self.wait_data_ready().await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
unsafe { read_dma::<T>(i, buf) }
}
}
fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a> {
async move { async move {
let i = self.info.addr.index(); let i = self.info.addr.index();
assert!(i != 0); assert!(i != 0);
@ -460,11 +475,6 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> {
} }
}) })
.await; .await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
unsafe { read_dma::<T>(i, buf) }
} }
} }
} }

View File

@ -0,0 +1,90 @@
use core::cell::Cell;
use core::future::Future;
use core::task::{Poll, Waker};
enum AsyncLeaseState {
Empty,
Waiting(*mut u8, usize, Waker),
Done(usize),
}
impl Default for AsyncLeaseState {
fn default() -> Self {
AsyncLeaseState::Empty
}
}
#[derive(Default)]
pub struct AsyncLease {
state: Cell<AsyncLeaseState>,
}
pub struct AsyncLeaseFuture<'a> {
buf: &'a mut [u8],
state: &'a Cell<AsyncLeaseState>,
}
impl<'a> Drop for AsyncLeaseFuture<'a> {
fn drop(&mut self) {
self.state.set(AsyncLeaseState::Empty);
}
}
impl<'a> Future for AsyncLeaseFuture<'a> {
type Output = usize;
fn poll(
mut self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> Poll<Self::Output> {
match self.state.take() {
AsyncLeaseState::Done(len) => Poll::Ready(len),
state => {
if let AsyncLeaseState::Waiting(ptr, _, _) = state {
assert_eq!(
ptr,
self.buf.as_mut_ptr(),
"lend() called on a busy AsyncLease."
);
}
self.state.set(AsyncLeaseState::Waiting(
self.buf.as_mut_ptr(),
self.buf.len(),
cx.waker().clone(),
));
Poll::Pending
}
}
}
}
pub struct AsyncLeaseNotReady {}
impl AsyncLease {
pub fn new() -> Self {
Default::default()
}
pub fn try_borrow_mut<F: FnOnce(&mut [u8]) -> usize>(
&self,
f: F,
) -> Result<(), AsyncLeaseNotReady> {
if let AsyncLeaseState::Waiting(data, len, waker) = self.state.take() {
let buf = unsafe { core::slice::from_raw_parts_mut(data, len) };
let len = f(buf);
self.state.set(AsyncLeaseState::Done(len));
waker.wake();
Ok(())
} else {
Err(AsyncLeaseNotReady {})
}
}
pub fn lend<'a>(&'a self, buf: &'a mut [u8]) -> AsyncLeaseFuture<'a> {
AsyncLeaseFuture {
buf,
state: &self.state,
}
}
}

View File

@ -9,7 +9,7 @@ pub(crate) mod fmt;
use core::mem::MaybeUninit; use core::mem::MaybeUninit;
use embassy::channel::signal::Signal; use async_lease::AsyncLease;
use embassy::time::Duration; use embassy::time::Duration;
use embassy_usb::driver::{EndpointOut, ReadError}; use embassy_usb::driver::{EndpointOut, ReadError};
use embassy_usb::{ use embassy_usb::{
@ -24,6 +24,8 @@ use ssmarshal::serialize;
#[cfg(feature = "usbd-hid")] #[cfg(feature = "usbd-hid")]
use usbd_hid::descriptor::AsInputReport; use usbd_hid::descriptor::AsInputReport;
mod async_lease;
const USB_CLASS_HID: u8 = 0x03; const USB_CLASS_HID: u8 = 0x03;
const USB_SUBCLASS_NONE: u8 = 0x00; const USB_SUBCLASS_NONE: u8 = 0x00;
const USB_PROTOCOL_NONE: u8 = 0x00; const USB_PROTOCOL_NONE: u8 = 0x00;
@ -61,15 +63,15 @@ impl ReportId {
} }
pub struct State<'a, const IN_N: usize, const OUT_N: usize> { pub struct State<'a, const IN_N: usize, const OUT_N: usize> {
control: MaybeUninit<Control<'a, OUT_N>>, control: MaybeUninit<Control<'a>>,
out_signal: Signal<(usize, [u8; OUT_N])>, lease: AsyncLease,
} }
impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> {
pub fn new() -> Self { pub fn new() -> Self {
State { State {
control: MaybeUninit::uninit(), control: MaybeUninit::uninit(),
out_signal: Signal::new(), lease: AsyncLease::new(),
} }
} }
} }
@ -139,22 +141,22 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize>
poll_ms: u8, poll_ms: u8,
max_packet_size: u16, max_packet_size: u16,
) -> Self { ) -> Self {
let ep_out = Some(builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms)); let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms);
let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
let control = state.control.write(Control::new( let control = state.control.write(Control::new(
report_descriptor, report_descriptor,
Some(&state.out_signal), Some(&state.lease),
request_handler, request_handler,
)); ));
control.build(builder, ep_out.as_ref(), &ep_in); control.build(builder, Some(&ep_out), &ep_in);
Self { Self {
input: ReportWriter { ep_in }, input: ReportWriter { ep_in },
output: ReportReader { output: ReportReader {
ep_out, ep_out,
receiver: &state.out_signal, lease: &state.lease,
}, },
} }
} }
@ -175,8 +177,8 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> {
} }
pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { pub struct ReportReader<'d, D: Driver<'d>, const N: usize> {
ep_out: Option<D::EndpointOut>, ep_out: D::EndpointOut,
receiver: &'d Signal<(usize, [u8; N])>, lease: &'d AsyncLease,
} }
impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
@ -216,41 +218,29 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> {
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> { pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
assert!(buf.len() >= N); assert!(buf.len() >= N);
if let Some(ep) = &mut self.ep_out {
let max_packet_size = usize::from(ep.info().max_packet_size);
let mut chunks = buf.chunks_mut(max_packet_size); // Wait until a packet is ready to read from the endpoint or a SET_REPORT control request is received
{
// Wait until we've received a chunk from the endpoint or a report from a SET_REPORT control request let data_ready = self.ep_out.wait_data_ready();
let (mut total, data) = { pin_mut!(data_ready);
let chunk = unwrap!(chunks.next()); match select(data_ready, self.lease.lend(buf)).await {
let fut1 = ep.read(chunk); Either::Left(_) => (),
pin_mut!(fut1); Either::Right((len, _)) => return Ok(len),
match select(fut1, self.receiver.wait()).await {
Either::Left((Ok(size), _)) => (size, None),
Either::Left((Err(err), _)) => return Err(err),
Either::Right(((size, data), _)) => (size, Some(data)),
}
};
if let Some(data) = data {
buf[0..total].copy_from_slice(&data[0..total]);
Ok(total)
} else {
for chunk in chunks {
let size = ep.read(chunk).await?;
total += size;
if size < max_packet_size || total == N {
break;
}
}
Ok(total)
} }
} else {
let (total, data) = self.receiver.wait().await;
buf[0..total].copy_from_slice(&data[0..total]);
Ok(total)
} }
// Read packets from the endpoint
let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
let mut total = 0;
for chunk in buf.chunks_mut(max_packet_size) {
let size = self.ep_out.read(chunk).await?;
total += size;
if size < max_packet_size || total == N {
break;
}
}
Ok(total)
} }
} }
@ -292,22 +282,22 @@ pub trait RequestHandler {
} }
} }
struct Control<'d, const OUT_N: usize> { struct Control<'d> {
report_descriptor: &'static [u8], report_descriptor: &'static [u8],
out_signal: Option<&'d Signal<(usize, [u8; OUT_N])>>, out_lease: Option<&'d AsyncLease>,
request_handler: Option<&'d dyn RequestHandler>, request_handler: Option<&'d dyn RequestHandler>,
hid_descriptor: [u8; 9], hid_descriptor: [u8; 9],
} }
impl<'a, const OUT_N: usize> Control<'a, OUT_N> { impl<'a> Control<'a> {
fn new( fn new(
report_descriptor: &'static [u8], report_descriptor: &'static [u8],
out_signal: Option<&'a Signal<(usize, [u8; OUT_N])>>, out_lease: Option<&'a AsyncLease>,
request_handler: Option<&'a dyn RequestHandler>, request_handler: Option<&'a dyn RequestHandler>,
) -> Self { ) -> Self {
Control { Control {
report_descriptor, report_descriptor,
out_signal, out_lease,
request_handler, request_handler,
hid_descriptor: [ hid_descriptor: [
// Length of buf inclusive of size prefix // Length of buf inclusive of size prefix
@ -372,7 +362,7 @@ impl<'a, const OUT_N: usize> Control<'a, OUT_N> {
} }
} }
impl<'d, const OUT_N: usize> ControlHandler for Control<'d, OUT_N> { impl<'d> ControlHandler for Control<'d> {
fn reset(&mut self) {} fn reset(&mut self) {}
fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse { fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse {
@ -395,17 +385,21 @@ impl<'d, const OUT_N: usize> ControlHandler for Control<'d, OUT_N> {
} }
HID_REQ_SET_REPORT => match ( HID_REQ_SET_REPORT => match (
ReportId::try_from(req.value), ReportId::try_from(req.value),
self.out_signal, self.out_lease,
self.request_handler.as_ref(), self.request_handler.as_ref(),
) { ) {
(Ok(ReportId::Out(_)), Some(signal), _) => { (Ok(ReportId::Out(_)), Some(lease), _) => {
let mut buf = [0; OUT_N]; match lease.try_borrow_mut(|buf| {
buf[0..data.len()].copy_from_slice(data); let len = buf.len().min(data.len());
if signal.signaled() { buf[0..len].copy_from_slice(&data[0..len]);
warn!("Output report dropped before being read!"); len
}) {
Ok(()) => OutResponse::Accepted,
Err(_) => {
warn!("SET_REPORT received for output report with no reader listening.");
OutResponse::Rejected
}
} }
signal.signal((data.len(), buf));
OutResponse::Accepted
} }
(Ok(id), _, Some(handler)) => handler.set_report(id, data), (Ok(id), _, Some(handler)) => handler.set_report(id, data),
_ => OutResponse::Rejected, _ => OutResponse::Rejected,

View File

@ -120,6 +120,9 @@ pub trait Endpoint {
pub trait EndpointOut: Endpoint { pub trait EndpointOut: Endpoint {
type ReadFuture<'a>: Future<Output = Result<usize, ReadError>> + 'a type ReadFuture<'a>: Future<Output = Result<usize, ReadError>> + 'a
where
Self: 'a;
type DataReadyFuture<'a>: Future<Output = ()> + 'a
where where
Self: 'a; Self: 'a;
@ -128,6 +131,11 @@ pub trait EndpointOut: Endpoint {
/// ///
/// This should also clear any NAK flags and prepare the endpoint to receive the next packet. /// This should also clear any NAK flags and prepare the endpoint to receive the next packet.
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>; fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>;
/// Waits until a packet of data is ready to be read from the endpoint.
///
/// A call to[`read()`](Self::read()) after this future completes should not block.
fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a>;
} }
pub trait ControlPipe { pub trait ControlPipe {