From a1754ac8a820d9cae97cf214969faf3090b37c76 Mon Sep 17 00:00:00 2001 From: alexmoon Date: Tue, 5 Apr 2022 17:23:46 -0400 Subject: [PATCH] embassy-usb-hid bug fixes --- embassy-usb-hid/src/lib.rs | 87 ++++++++++++++++++------ examples/nrf/src/bin/usb_hid_keyboard.rs | 4 +- 2 files changed, 70 insertions(+), 21 deletions(-) diff --git a/embassy-usb-hid/src/lib.rs b/embassy-usb-hid/src/lib.rs index 996de6a5..f50c5f8c 100644 --- a/embassy-usb-hid/src/lib.rs +++ b/embassy-usb-hid/src/lib.rs @@ -9,6 +9,7 @@ pub(crate) mod fmt; use core::mem::MaybeUninit; use core::ops::Range; +use core::sync::atomic::{AtomicUsize, Ordering}; use embassy::time::Duration; use embassy_usb::driver::EndpointOut; @@ -61,12 +62,14 @@ impl ReportId { pub struct State<'a, const IN_N: usize, const OUT_N: usize> { control: MaybeUninit>, + out_report_offset: AtomicUsize, } impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { pub fn new() -> Self { State { control: MaybeUninit::uninit(), + out_report_offset: AtomicUsize::new(0), } } } @@ -94,9 +97,11 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> { max_packet_size: u16, ) -> Self { let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); - let control = state - .control - .write(Control::new(report_descriptor, request_handler)); + let control = state.control.write(Control::new( + report_descriptor, + request_handler, + &state.out_report_offset, + )); control.build(builder, None, &ep_in); Self { @@ -138,14 +143,19 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize> let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); - let control = state - .control - .write(Control::new(report_descriptor, request_handler)); + let control = state.control.write(Control::new( + report_descriptor, + request_handler, + &state.out_report_offset, + )); control.build(builder, Some(&ep_out), &ep_in); Self { input: ReportWriter { ep_in }, - output: ReportReader { ep_out, offset: 0 }, + output: ReportReader { + ep_out, + offset: &state.out_report_offset, + }, } } @@ -166,7 +176,7 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> { pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { ep_out: D::EndpointOut, - offset: usize, + offset: &'d AtomicUsize, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -188,6 +198,11 @@ impl From for ReadError { } impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { + /// Waits for the interrupt in endpoint to be enabled. + pub async fn ready(&mut self) -> () { + self.ep_in.wait_enabled().await + } + /// Tries to write an input report by serializing the given report structure. /// /// Panics if no endpoint is available. @@ -222,14 +237,27 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { } impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { + /// Waits for the interrupt out endpoint to be enabled. + pub async fn ready(&mut self) -> () { + self.ep_out.wait_enabled().await + } + /// Starts a task to deliver output reports from the Interrupt Out pipe to /// `handler`. - pub async fn run(mut self, handler: &T) -> ! { - assert!(self.offset == 0); + /// + /// Terminates when the interface becomes disabled. + /// + /// If `use_report_ids` is true, the first byte of the report will be used as + /// the `ReportId` value. Otherwise the `ReportId` value will be 0. + pub async fn run(mut self, use_report_ids: bool, handler: &T) -> ! { + let offset = self.offset.load(Ordering::Acquire); + assert!(offset == 0); let mut buf = [0; N]; loop { match self.read(&mut buf).await { - Ok(len) => { handler.set_report(ReportId::Out(0), &buf[0..len]); } + Ok(len) => { + let id = if use_report_ids { buf[0] } else { 0 }; + handler.set_report(ReportId::Out(id), &buf[..len]); } Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N), Err(ReadError::Disabled) => self.ep_out.wait_enabled().await, Err(ReadError::Sync(_)) => unreachable!(), @@ -257,17 +285,33 @@ impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { // Read packets from the endpoint let max_packet_size = usize::from(self.ep_out.info().max_packet_size); - let starting_offset = self.offset; - for chunk in buf[starting_offset..].chunks_mut(max_packet_size) { - let size = self.ep_out.read(chunk).await?; - self.offset += size; - if size < max_packet_size || self.offset == N { + let starting_offset = self.offset.load(Ordering::Acquire); + let mut total = starting_offset; + loop { + for chunk in buf[starting_offset..N].chunks_mut(max_packet_size) { + match self.ep_out.read(chunk).await { + Ok(size) => { + total += size; + if size < max_packet_size || total == N { + self.offset.store(0, Ordering::Release); + break; + } else { + self.offset.store(total, Ordering::Release); + } + } + Err(err) => { + self.offset.store(0, Ordering::Release); + return Err(err.into()); + } + } + } + + // Some hosts may send ZLPs even when not required by the HID spec, so we'll loop as long as total == 0. + if total > 0 { break; } } - let total = self.offset; - self.offset = 0; if starting_offset > 0 { Err(ReadError::Sync(starting_offset..total)) } else { @@ -313,6 +357,7 @@ pub trait RequestHandler { struct Control<'d> { report_descriptor: &'static [u8], request_handler: Option<&'d dyn RequestHandler>, + out_report_offset: &'d AtomicUsize, hid_descriptor: [u8; 9], } @@ -320,10 +365,12 @@ impl<'a> Control<'a> { fn new( report_descriptor: &'static [u8], request_handler: Option<&'a dyn RequestHandler>, + out_report_offset: &'a AtomicUsize, ) -> Self { Control { report_descriptor, request_handler, + out_report_offset, hid_descriptor: [ // Length of buf inclusive of size prefix 9, @@ -388,7 +435,9 @@ impl<'a> Control<'a> { } impl<'d> ControlHandler for Control<'d> { - fn reset(&mut self) {} + fn reset(&mut self) { + self.out_report_offset.store(0, Ordering::Release); + } fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse { trace!("HID control_out {:?} {=[u8]:x}", req, data); diff --git a/examples/nrf/src/bin/usb_hid_keyboard.rs b/examples/nrf/src/bin/usb_hid_keyboard.rs index af70a9a6..51136292 100644 --- a/examples/nrf/src/bin/usb_hid_keyboard.rs +++ b/examples/nrf/src/bin/usb_hid_keyboard.rs @@ -54,7 +54,7 @@ async fn main(_spawner: Spawner, p: Peripherals) { let mut control_buf = [0; 16]; let request_handler = MyRequestHandler {}; - let mut state = State::<64, 64>::new(); + let mut state = State::<8, 1>::new(); let mut builder = UsbDeviceBuilder::new( driver, @@ -117,7 +117,7 @@ async fn main(_spawner: Spawner, p: Peripherals) { }; let out_fut = async { - hid_out.run(&MyRequestHandler {}).await; + hid_out.run(false, &request_handler).await; }; // Run everything concurrently. // If we had made everything `'static` above instead, we could do this using separate tasks instead.