Simplify hid output report handling

This commit is contained in:
alexmoon 2022-04-02 11:58:01 -04:00 committed by Dario Nieuwenhuis
parent c8ad82057d
commit 99f95a33c3
5 changed files with 78 additions and 187 deletions

View File

@ -443,23 +443,8 @@ unsafe fn write_dma<T: Instance>(i: usize, buf: &[u8]) {
impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> {
type ReadFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a; type ReadFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a;
type DataReadyFuture<'a> = impl Future<Output = ()> + 'a where Self: 'a;
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
async move {
let i = self.info.addr.index();
assert!(i != 0);
self.wait_data_ready().await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
unsafe { read_dma::<T>(i, buf) }
}
}
fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a> {
async move { async move {
let i = self.info.addr.index(); let i = self.info.addr.index();
assert!(i != 0); assert!(i != 0);
@ -475,6 +460,11 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> {
} }
}) })
.await; .await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
unsafe { read_dma::<T>(i, buf) }
} }
} }
} }

View File

@ -1,90 +0,0 @@
use core::cell::Cell;
use core::future::Future;
use core::task::{Poll, Waker};
enum AsyncLeaseState {
Empty,
Waiting(*mut u8, usize, Waker),
Done(usize),
}
impl Default for AsyncLeaseState {
fn default() -> Self {
AsyncLeaseState::Empty
}
}
#[derive(Default)]
pub struct AsyncLease {
state: Cell<AsyncLeaseState>,
}
pub struct AsyncLeaseFuture<'a> {
buf: &'a mut [u8],
state: &'a Cell<AsyncLeaseState>,
}
impl<'a> Drop for AsyncLeaseFuture<'a> {
fn drop(&mut self) {
self.state.set(AsyncLeaseState::Empty);
}
}
impl<'a> Future for AsyncLeaseFuture<'a> {
type Output = usize;
fn poll(
mut self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> Poll<Self::Output> {
match self.state.take() {
AsyncLeaseState::Done(len) => Poll::Ready(len),
state => {
if let AsyncLeaseState::Waiting(ptr, _, _) = state {
assert_eq!(
ptr,
self.buf.as_mut_ptr(),
"lend() called on a busy AsyncLease."
);
}
self.state.set(AsyncLeaseState::Waiting(
self.buf.as_mut_ptr(),
self.buf.len(),
cx.waker().clone(),
));
Poll::Pending
}
}
}
}
pub struct AsyncLeaseNotReady {}
impl AsyncLease {
pub fn new() -> Self {
Default::default()
}
pub fn try_borrow_mut<F: FnOnce(&mut [u8]) -> usize>(
&self,
f: F,
) -> Result<(), AsyncLeaseNotReady> {
if let AsyncLeaseState::Waiting(data, len, waker) = self.state.take() {
let buf = unsafe { core::slice::from_raw_parts_mut(data, len) };
let len = f(buf);
self.state.set(AsyncLeaseState::Done(len));
waker.wake();
Ok(())
} else {
Err(AsyncLeaseNotReady {})
}
}
pub fn lend<'a>(&'a self, buf: &'a mut [u8]) -> AsyncLeaseFuture<'a> {
AsyncLeaseFuture {
buf,
state: &self.state,
}
}
}

View File

@ -8,24 +8,21 @@
pub(crate) mod fmt; pub(crate) mod fmt;
use core::mem::MaybeUninit; use core::mem::MaybeUninit;
use core::ops::Range;
use async_lease::AsyncLease;
use embassy::time::Duration; use embassy::time::Duration;
use embassy_usb::driver::{EndpointOut, ReadError}; use embassy_usb::driver::EndpointOut;
use embassy_usb::{ use embassy_usb::{
control::{ControlHandler, InResponse, OutResponse, Request, RequestType}, control::{ControlHandler, InResponse, OutResponse, Request, RequestType},
driver::{Driver, Endpoint, EndpointIn, WriteError}, driver::{Driver, Endpoint, EndpointIn, WriteError},
UsbDeviceBuilder, UsbDeviceBuilder,
}; };
use futures_util::future::{select, Either};
use futures_util::pin_mut;
#[cfg(feature = "usbd-hid")] #[cfg(feature = "usbd-hid")]
use ssmarshal::serialize; use ssmarshal::serialize;
#[cfg(feature = "usbd-hid")] #[cfg(feature = "usbd-hid")]
use usbd_hid::descriptor::AsInputReport; use usbd_hid::descriptor::AsInputReport;
mod async_lease;
const USB_CLASS_HID: u8 = 0x03; const USB_CLASS_HID: u8 = 0x03;
const USB_SUBCLASS_NONE: u8 = 0x00; const USB_SUBCLASS_NONE: u8 = 0x00;
const USB_PROTOCOL_NONE: u8 = 0x00; const USB_PROTOCOL_NONE: u8 = 0x00;
@ -64,14 +61,12 @@ impl ReportId {
pub struct State<'a, const IN_N: usize, const OUT_N: usize> { pub struct State<'a, const IN_N: usize, const OUT_N: usize> {
control: MaybeUninit<Control<'a>>, control: MaybeUninit<Control<'a>>,
lease: AsyncLease,
} }
impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> {
pub fn new() -> Self { pub fn new() -> Self {
State { State {
control: MaybeUninit::uninit(), control: MaybeUninit::uninit(),
lease: AsyncLease::new(),
} }
} }
} }
@ -90,9 +85,9 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> {
/// high performance uses, and a value of 255 is good for best-effort usecases. /// high performance uses, and a value of 255 is good for best-effort usecases.
/// ///
/// This allocates an IN endpoint only. /// This allocates an IN endpoint only.
pub fn new( pub fn new<const OUT_N: usize>(
builder: &mut UsbDeviceBuilder<'d, D>, builder: &mut UsbDeviceBuilder<'d, D>,
state: &'d mut State<'d, IN_N, 0>, state: &'d mut State<'d, IN_N, OUT_N>,
report_descriptor: &'static [u8], report_descriptor: &'static [u8],
request_handler: Option<&'d dyn RequestHandler>, request_handler: Option<&'d dyn RequestHandler>,
poll_ms: u8, poll_ms: u8,
@ -101,8 +96,7 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> {
let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
let control = state let control = state
.control .control
.write(Control::new(report_descriptor, None, request_handler)); .write(Control::new(report_descriptor, request_handler));
control.build(builder, None, &ep_in); control.build(builder, None, &ep_in);
Self { Self {
@ -144,20 +138,14 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize>
let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms); let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms);
let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
let control = state.control.write(Control::new( let control = state
report_descriptor, .control
Some(&state.lease), .write(Control::new(report_descriptor, request_handler));
request_handler,
));
control.build(builder, Some(&ep_out), &ep_in); control.build(builder, Some(&ep_out), &ep_in);
Self { Self {
input: ReportWriter { ep_in }, input: ReportWriter { ep_in },
output: ReportReader { output: ReportReader { ep_out, offset: 0 },
ep_out,
lease: &state.lease,
},
} }
} }
@ -178,7 +166,21 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> {
pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { pub struct ReportReader<'d, D: Driver<'d>, const N: usize> {
ep_out: D::EndpointOut, ep_out: D::EndpointOut,
lease: &'d AsyncLease, offset: usize,
}
pub enum ReadError {
BufferOverflow,
Sync(Range<usize>),
}
impl From<embassy_usb::driver::ReadError> for ReadError {
fn from(val: embassy_usb::driver::ReadError) -> Self {
use embassy_usb::driver::ReadError::*;
match val {
BufferOverflow => ReadError::BufferOverflow,
}
}
} }
impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
@ -216,32 +218,56 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
} }
impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> {
/// Starts a task to deliver output reports from the Interrupt Out pipe to
/// `handler`.
pub async fn run<T: RequestHandler>(mut self, handler: &T) -> ! {
assert!(self.offset == 0);
let mut buf = [0; N];
loop {
match self.read(&mut buf).await {
Ok(len) => { handler.set_report(ReportId::Out(0), &buf[0..len]); }
Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N),
Err(ReadError::Sync(_)) => unreachable!(),
}
}
}
/// Reads an output report from the Interrupt Out pipe.
///
/// **Note:** Any reports sent from the host over the control pipe will be
/// passed to [`RequestHandler::set_report()`] for handling. The application
/// is responsible for ensuring output reports from both pipes are handled
/// correctly.
///
/// **Note:** If `N` > the maximum packet size of the endpoint (i.e. output
/// reports may be split across multiple packets) and this method's future
/// is dropped after some packets have been read, the next call to `read()`
/// will return a [`ReadError::SyncError()`]. The range in the sync error
/// indicates the portion `buf` that was filled by the current call to
/// `read()`. If the dropped future used the same `buf`, then `buf` will
/// contain the full report.
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> { pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
assert!(buf.len() >= N); assert!(buf.len() >= N);
// Wait until a packet is ready to read from the endpoint or a SET_REPORT control request is received
{
let data_ready = self.ep_out.wait_data_ready();
pin_mut!(data_ready);
match select(data_ready, self.lease.lend(buf)).await {
Either::Left(_) => (),
Either::Right((len, _)) => return Ok(len),
}
}
// Read packets from the endpoint // Read packets from the endpoint
let max_packet_size = usize::from(self.ep_out.info().max_packet_size); let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
let mut total = 0; let starting_offset = self.offset;
for chunk in buf.chunks_mut(max_packet_size) { for chunk in buf[starting_offset..].chunks_mut(max_packet_size) {
let size = self.ep_out.read(chunk).await?; let size = self.ep_out.read(chunk).await?;
total += size; self.offset += size;
if size < max_packet_size || total == N { if size < max_packet_size || self.offset == N {
break; break;
} }
} }
let total = self.offset;
self.offset = 0;
if starting_offset > 0 {
Err(ReadError::Sync(starting_offset..total))
} else {
Ok(total) Ok(total)
} }
}
} }
pub trait RequestHandler { pub trait RequestHandler {
@ -254,10 +280,6 @@ pub trait RequestHandler {
} }
/// Sets the value of report `id` to `data`. /// Sets the value of report `id` to `data`.
///
/// If an output endpoint has been allocated, output reports
/// are routed through [`HidClass::output()`]. Otherwise they
/// are sent here, along with input and feature reports.
fn set_report(&self, id: ReportId, data: &[u8]) -> OutResponse { fn set_report(&self, id: ReportId, data: &[u8]) -> OutResponse {
let _ = (id, data); let _ = (id, data);
OutResponse::Rejected OutResponse::Rejected
@ -266,8 +288,8 @@ pub trait RequestHandler {
/// Get the idle rate for `id`. /// Get the idle rate for `id`.
/// ///
/// If `id` is `None`, get the idle rate for all reports. Returning `None` /// If `id` is `None`, get the idle rate for all reports. Returning `None`
/// will reject the control request. Any duration above 1.020 seconds or 0 /// will reject the control request. Any duration at or above 1.024 seconds
/// will be returned as an indefinite idle rate. /// or below 4ms will be returned as an indefinite idle rate.
fn get_idle(&self, id: Option<ReportId>) -> Option<Duration> { fn get_idle(&self, id: Option<ReportId>) -> Option<Duration> {
let _ = id; let _ = id;
None None
@ -284,7 +306,6 @@ pub trait RequestHandler {
struct Control<'d> { struct Control<'d> {
report_descriptor: &'static [u8], report_descriptor: &'static [u8],
out_lease: Option<&'d AsyncLease>,
request_handler: Option<&'d dyn RequestHandler>, request_handler: Option<&'d dyn RequestHandler>,
hid_descriptor: [u8; 9], hid_descriptor: [u8; 9],
} }
@ -292,12 +313,10 @@ struct Control<'d> {
impl<'a> Control<'a> { impl<'a> Control<'a> {
fn new( fn new(
report_descriptor: &'static [u8], report_descriptor: &'static [u8],
out_lease: Option<&'a AsyncLease>,
request_handler: Option<&'a dyn RequestHandler>, request_handler: Option<&'a dyn RequestHandler>,
) -> Self { ) -> Self {
Control { Control {
report_descriptor, report_descriptor,
out_lease,
request_handler, request_handler,
hid_descriptor: [ hid_descriptor: [
// Length of buf inclusive of size prefix // Length of buf inclusive of size prefix
@ -370,7 +389,7 @@ impl<'d> ControlHandler for Control<'d> {
if let RequestType::Class = req.request_type { if let RequestType::Class = req.request_type {
match req.request { match req.request {
HID_REQ_SET_IDLE => { HID_REQ_SET_IDLE => {
if let Some(handler) = self.request_handler.as_ref() { if let Some(handler) = self.request_handler {
let id = req.value as u8; let id = req.value as u8;
let id = (id != 0).then(|| ReportId::In(id)); let id = (id != 0).then(|| ReportId::In(id));
let dur = u64::from(req.value >> 8); let dur = u64::from(req.value >> 8);
@ -383,25 +402,8 @@ impl<'d> ControlHandler for Control<'d> {
} }
OutResponse::Accepted OutResponse::Accepted
} }
HID_REQ_SET_REPORT => match ( HID_REQ_SET_REPORT => match (ReportId::try_from(req.value), self.request_handler) {
ReportId::try_from(req.value), (Ok(id), Some(handler)) => handler.set_report(id, data),
self.out_lease,
self.request_handler.as_ref(),
) {
(Ok(ReportId::Out(_)), Some(lease), _) => {
match lease.try_borrow_mut(|buf| {
let len = buf.len().min(data.len());
buf[0..len].copy_from_slice(&data[0..len]);
len
}) {
Ok(()) => OutResponse::Accepted,
Err(_) => {
warn!("SET_REPORT received for output report with no reader listening.");
OutResponse::Rejected
}
}
}
(Ok(id), _, Some(handler)) => handler.set_report(id, data),
_ => OutResponse::Rejected, _ => OutResponse::Rejected,
}, },
HID_REQ_SET_PROTOCOL => { HID_REQ_SET_PROTOCOL => {
@ -429,10 +431,7 @@ impl<'d> ControlHandler for Control<'d> {
}, },
(RequestType::Class, HID_REQ_GET_REPORT) => { (RequestType::Class, HID_REQ_GET_REPORT) => {
let size = match ReportId::try_from(req.value) { let size = match ReportId::try_from(req.value) {
Ok(id) => self Ok(id) => self.request_handler.and_then(|x| x.get_report(id, buf)),
.request_handler
.as_ref()
.and_then(|x| x.get_report(id, buf)),
Err(_) => None, Err(_) => None,
}; };
@ -443,7 +442,7 @@ impl<'d> ControlHandler for Control<'d> {
} }
} }
(RequestType::Class, HID_REQ_GET_IDLE) => { (RequestType::Class, HID_REQ_GET_IDLE) => {
if let Some(handler) = self.request_handler.as_ref() { if let Some(handler) = self.request_handler {
let id = req.value as u8; let id = req.value as u8;
let id = (id != 0).then(|| ReportId::In(id)); let id = (id != 0).then(|| ReportId::In(id));
if let Some(dur) = handler.get_idle(id) { if let Some(dur) = handler.get_idle(id) {

View File

@ -120,9 +120,6 @@ pub trait Endpoint {
pub trait EndpointOut: Endpoint { pub trait EndpointOut: Endpoint {
type ReadFuture<'a>: Future<Output = Result<usize, ReadError>> + 'a type ReadFuture<'a>: Future<Output = Result<usize, ReadError>> + 'a
where
Self: 'a;
type DataReadyFuture<'a>: Future<Output = ()> + 'a
where where
Self: 'a; Self: 'a;
@ -131,11 +128,6 @@ pub trait EndpointOut: Endpoint {
/// ///
/// This should also clear any NAK flags and prepare the endpoint to receive the next packet. /// This should also clear any NAK flags and prepare the endpoint to receive the next packet.
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>; fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>;
/// Waits until a packet of data is ready to be read from the endpoint.
///
/// A call to[`read()`](Self::read()) after this future completes should not block.
fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a>;
} }
pub trait ControlPipe { pub trait ControlPipe {

View File

@ -52,7 +52,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
let mut control_buf = [0; 16]; let mut control_buf = [0; 16];
let request_handler = MyRequestHandler {}; let request_handler = MyRequestHandler {};
let mut state = State::<5, 0>::new(); let mut control = State::<5, 0>::new();
let mut builder = UsbDeviceBuilder::new( let mut builder = UsbDeviceBuilder::new(
driver, driver,
@ -66,7 +66,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
// Create classes on the builder. // Create classes on the builder.
let mut hid = HidClass::new( let mut hid = HidClass::new(
&mut builder, &mut builder,
&mut state, &mut control,
MouseReport::desc(), MouseReport::desc(),
Some(&request_handler), Some(&request_handler),
60, 60,