diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index d1c94dbb..1cd5a9eb 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -508,6 +508,30 @@ pub struct ControlPipe<'d, T: Instance> { } impl<'d, T: Instance> ControlPipe<'d, T> { + async fn read(&mut self, buf: &mut [u8]) -> Result { + let regs = T::regs(); + + // Wait until ready + regs.intenset.write(|w| w.ep0datadone().set()); + poll_fn(|cx| { + EP_OUT_WAKERS[0].register(cx.waker()); + let regs = T::regs(); + if regs + .events_ep0datadone + .read() + .events_ep0datadone() + .bit_is_set() + { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + + unsafe { read_dma::(0, buf) } + } + async fn write(&mut self, buf: &[u8], last_chunk: bool) { let regs = T::regs(); regs.events_ep0datadone.reset(); @@ -595,29 +619,19 @@ impl<'d, T: Instance> driver::ControlPipe for ControlPipe<'d, T> { let req = self.request.unwrap(); assert_eq!(req.direction, UsbDirection::Out); assert!(req.length > 0); - assert!(buf.len() >= usize::from(req.length)); - let regs = T::regs(); - - // Wait until ready - regs.intenset.write(|w| w.ep0datadone().set()); - poll_fn(|cx| { - EP_OUT_WAKERS[0].register(cx.waker()); - let regs = T::regs(); - if regs - .events_ep0datadone - .read() - .events_ep0datadone() - .bit_is_set() - { - Poll::Ready(()) - } else { - Poll::Pending + let req_length = usize::from(req.length); + let max_packet_size = usize::from(self.max_packet_size); + let mut total = 0; + for chunk in buf.chunks_mut(max_packet_size) { + let size = self.read(chunk).await?; + total += size; + if size < max_packet_size || total == req_length { + break; } - }) - .await; + } - unsafe { read_dma::(0, buf) } + Ok(total) } } @@ -697,16 +711,27 @@ impl Allocator { // Endpoint directions are allocated individually. - let alloc_index = match ep_type { - EndpointType::Isochronous => 8, - EndpointType::Control => 0, - EndpointType::Interrupt | EndpointType::Bulk => { - // Find rightmost zero bit in 1..=7 - let ones = (self.used >> 1).trailing_ones() as usize; - if ones >= 7 { - return Err(driver::EndpointAllocError); + let alloc_index = if let Some(ep_addr) = ep_addr { + match (ep_addr.index(), ep_type) { + (0, EndpointType::Control) => {} + (8, EndpointType::Isochronous) => {} + (n, EndpointType::Bulk) | (n, EndpointType::Interrupt) if n >= 1 && n <= 7 => {} + _ => return Err(driver::EndpointAllocError), + } + + ep_addr.index() + } else { + match ep_type { + EndpointType::Isochronous => 8, + EndpointType::Control => 0, + EndpointType::Interrupt | EndpointType::Bulk => { + // Find rightmost zero bit in 1..=7 + let ones = (self.used >> 1).trailing_ones() as usize; + if ones >= 7 { + return Err(driver::EndpointAllocError); + } + ones + 1 } - ones + 1 } };