embassy-usb-hid bug fixes

This commit is contained in:
alexmoon 2022-04-05 17:23:46 -04:00 committed by Dario Nieuwenhuis
parent 22a47aeeb2
commit a1754ac8a8
2 changed files with 70 additions and 21 deletions

View File

@ -9,6 +9,7 @@ pub(crate) mod fmt;
use core::mem::MaybeUninit; use core::mem::MaybeUninit;
use core::ops::Range; use core::ops::Range;
use core::sync::atomic::{AtomicUsize, Ordering};
use embassy::time::Duration; use embassy::time::Duration;
use embassy_usb::driver::EndpointOut; use embassy_usb::driver::EndpointOut;
@ -61,12 +62,14 @@ impl ReportId {
pub struct State<'a, const IN_N: usize, const OUT_N: usize> { pub struct State<'a, const IN_N: usize, const OUT_N: usize> {
control: MaybeUninit<Control<'a>>, control: MaybeUninit<Control<'a>>,
out_report_offset: AtomicUsize,
} }
impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> {
pub fn new() -> Self { pub fn new() -> Self {
State { State {
control: MaybeUninit::uninit(), control: MaybeUninit::uninit(),
out_report_offset: AtomicUsize::new(0),
} }
} }
} }
@ -94,9 +97,11 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> {
max_packet_size: u16, max_packet_size: u16,
) -> Self { ) -> Self {
let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
let control = state let control = state.control.write(Control::new(
.control report_descriptor,
.write(Control::new(report_descriptor, request_handler)); request_handler,
&state.out_report_offset,
));
control.build(builder, None, &ep_in); control.build(builder, None, &ep_in);
Self { Self {
@ -138,14 +143,19 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize>
let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms); let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms);
let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
let control = state let control = state.control.write(Control::new(
.control report_descriptor,
.write(Control::new(report_descriptor, request_handler)); request_handler,
&state.out_report_offset,
));
control.build(builder, Some(&ep_out), &ep_in); control.build(builder, Some(&ep_out), &ep_in);
Self { Self {
input: ReportWriter { ep_in }, input: ReportWriter { ep_in },
output: ReportReader { ep_out, offset: 0 }, output: ReportReader {
ep_out,
offset: &state.out_report_offset,
},
} }
} }
@ -166,7 +176,7 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> {
pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { pub struct ReportReader<'d, D: Driver<'d>, const N: usize> {
ep_out: D::EndpointOut, ep_out: D::EndpointOut,
offset: usize, offset: &'d AtomicUsize,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -188,6 +198,11 @@ impl From<embassy_usb::driver::ReadError> for ReadError {
} }
impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
/// Waits for the interrupt in endpoint to be enabled.
pub async fn ready(&mut self) -> () {
self.ep_in.wait_enabled().await
}
/// Tries to write an input report by serializing the given report structure. /// Tries to write an input report by serializing the given report structure.
/// ///
/// Panics if no endpoint is available. /// Panics if no endpoint is available.
@ -222,14 +237,27 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
} }
impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> {
/// Waits for the interrupt out endpoint to be enabled.
pub async fn ready(&mut self) -> () {
self.ep_out.wait_enabled().await
}
/// Starts a task to deliver output reports from the Interrupt Out pipe to /// Starts a task to deliver output reports from the Interrupt Out pipe to
/// `handler`. /// `handler`.
pub async fn run<T: RequestHandler>(mut self, handler: &T) -> ! { ///
assert!(self.offset == 0); /// Terminates when the interface becomes disabled.
///
/// If `use_report_ids` is true, the first byte of the report will be used as
/// the `ReportId` value. Otherwise the `ReportId` value will be 0.
pub async fn run<T: RequestHandler>(mut self, use_report_ids: bool, handler: &T) -> ! {
let offset = self.offset.load(Ordering::Acquire);
assert!(offset == 0);
let mut buf = [0; N]; let mut buf = [0; N];
loop { loop {
match self.read(&mut buf).await { match self.read(&mut buf).await {
Ok(len) => { handler.set_report(ReportId::Out(0), &buf[0..len]); } Ok(len) => {
let id = if use_report_ids { buf[0] } else { 0 };
handler.set_report(ReportId::Out(id), &buf[..len]); }
Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N), Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N),
Err(ReadError::Disabled) => self.ep_out.wait_enabled().await, Err(ReadError::Disabled) => self.ep_out.wait_enabled().await,
Err(ReadError::Sync(_)) => unreachable!(), Err(ReadError::Sync(_)) => unreachable!(),
@ -257,17 +285,33 @@ impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> {
// Read packets from the endpoint // Read packets from the endpoint
let max_packet_size = usize::from(self.ep_out.info().max_packet_size); let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
let starting_offset = self.offset; let starting_offset = self.offset.load(Ordering::Acquire);
for chunk in buf[starting_offset..].chunks_mut(max_packet_size) { let mut total = starting_offset;
let size = self.ep_out.read(chunk).await?; loop {
self.offset += size; for chunk in buf[starting_offset..N].chunks_mut(max_packet_size) {
if size < max_packet_size || self.offset == N { match self.ep_out.read(chunk).await {
Ok(size) => {
total += size;
if size < max_packet_size || total == N {
self.offset.store(0, Ordering::Release);
break;
} else {
self.offset.store(total, Ordering::Release);
}
}
Err(err) => {
self.offset.store(0, Ordering::Release);
return Err(err.into());
}
}
}
// Some hosts may send ZLPs even when not required by the HID spec, so we'll loop as long as total == 0.
if total > 0 {
break; break;
} }
} }
let total = self.offset;
self.offset = 0;
if starting_offset > 0 { if starting_offset > 0 {
Err(ReadError::Sync(starting_offset..total)) Err(ReadError::Sync(starting_offset..total))
} else { } else {
@ -313,6 +357,7 @@ pub trait RequestHandler {
struct Control<'d> { struct Control<'d> {
report_descriptor: &'static [u8], report_descriptor: &'static [u8],
request_handler: Option<&'d dyn RequestHandler>, request_handler: Option<&'d dyn RequestHandler>,
out_report_offset: &'d AtomicUsize,
hid_descriptor: [u8; 9], hid_descriptor: [u8; 9],
} }
@ -320,10 +365,12 @@ impl<'a> Control<'a> {
fn new( fn new(
report_descriptor: &'static [u8], report_descriptor: &'static [u8],
request_handler: Option<&'a dyn RequestHandler>, request_handler: Option<&'a dyn RequestHandler>,
out_report_offset: &'a AtomicUsize,
) -> Self { ) -> Self {
Control { Control {
report_descriptor, report_descriptor,
request_handler, request_handler,
out_report_offset,
hid_descriptor: [ hid_descriptor: [
// Length of buf inclusive of size prefix // Length of buf inclusive of size prefix
9, 9,
@ -388,7 +435,9 @@ impl<'a> Control<'a> {
} }
impl<'d> ControlHandler for Control<'d> { impl<'d> ControlHandler for Control<'d> {
fn reset(&mut self) {} fn reset(&mut self) {
self.out_report_offset.store(0, Ordering::Release);
}
fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse { fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse {
trace!("HID control_out {:?} {=[u8]:x}", req, data); trace!("HID control_out {:?} {=[u8]:x}", req, data);

View File

@ -54,7 +54,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
let mut control_buf = [0; 16]; let mut control_buf = [0; 16];
let request_handler = MyRequestHandler {}; let request_handler = MyRequestHandler {};
let mut state = State::<64, 64>::new(); let mut state = State::<8, 1>::new();
let mut builder = UsbDeviceBuilder::new( let mut builder = UsbDeviceBuilder::new(
driver, driver,
@ -117,7 +117,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
}; };
let out_fut = async { let out_fut = async {
hid_out.run(&MyRequestHandler {}).await; hid_out.run(false, &request_handler).await;
}; };
// Run everything concurrently. // Run everything concurrently.
// If we had made everything `'static` above instead, we could do this using separate tasks instead. // If we had made everything `'static` above instead, we could do this using separate tasks instead.