diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index 8d589aed..842abf16 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -101,37 +101,6 @@ impl<'d, T: Instance> Driver<'d, T> { } } } - - fn set_stalled(ep_addr: EndpointAddress, stalled: bool) { - let regs = T::regs(); - - unsafe { - if ep_addr.index() == 0 { - regs.tasks_ep0stall - .write(|w| w.tasks_ep0stall().bit(stalled)); - } else { - regs.epstall.write(|w| { - w.ep().bits(ep_addr.index() as u8 & 0b111); - w.io().bit(ep_addr.is_in()); - w.stall().bit(stalled) - }); - } - } - - //if stalled { - // self.busy_in_endpoints &= !(1 << ep_addr.index()); - //} - } - - fn is_stalled(ep_addr: EndpointAddress) -> bool { - let regs = T::regs(); - - let i = ep_addr.index(); - match ep_addr.direction() { - UsbDirection::Out => regs.halted.epout[i].read().getstatus().is_halted(), - UsbDirection::In => regs.halted.epin[i].read().getstatus().is_halted(), - } - } } impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> { @@ -294,11 +263,28 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { } fn endpoint_set_stalled(&mut self, ep_addr: EndpointAddress, stalled: bool) { - Driver::::set_stalled(ep_addr, stalled) + let regs = T::regs(); + unsafe { + if ep_addr.index() == 0 { + regs.tasks_ep0stall + .write(|w| w.tasks_ep0stall().bit(stalled)); + } else { + regs.epstall.write(|w| { + w.ep().bits(ep_addr.index() as u8 & 0b111); + w.io().bit(ep_addr.is_in()); + w.stall().bit(stalled) + }); + } + } } fn endpoint_is_stalled(&mut self, ep_addr: EndpointAddress) -> bool { - Driver::::is_stalled(ep_addr) + let regs = T::regs(); + let i = ep_addr.index(); + match ep_addr.direction() { + UsbDirection::Out => regs.halted.epout[i].read().getstatus().is_halted(), + UsbDirection::In => regs.halted.epin[i].read().getstatus().is_halted(), + } } fn endpoint_set_enabled(&mut self, ep_addr: EndpointAddress, enabled: bool) { @@ -464,14 +450,6 @@ impl<'d, T: Instance, Dir: EndpointDir> driver::Endpoint for Endpoint<'d, T, Dir &self.info } - fn set_stalled(&self, stalled: bool) { - Driver::::set_stalled(self.info.addr, stalled) - } - - fn is_stalled(&self) -> bool { - Driver::::is_stalled(self.info.addr) - } - type WaitEnabledFuture<'a> = impl Future + 'a where Self: 'a; fn wait_enabled(&mut self) -> Self::WaitEnabledFuture<'_> { @@ -638,6 +616,8 @@ impl<'d, T: Instance> driver::ControlPipe for ControlPipe<'d, T> { type SetupFuture<'a> = impl Future + 'a where Self: 'a; type DataOutFuture<'a> = impl Future> + 'a where Self: 'a; type DataInFuture<'a> = impl Future> + 'a where Self: 'a; + type AcceptFuture<'a> = impl Future + 'a where Self: 'a; + type RejectFuture<'a> = impl Future + 'a where Self: 'a; fn max_packet_size(&self) -> usize { usize::from(self.max_packet_size) @@ -679,7 +659,12 @@ impl<'d, T: Instance> driver::ControlPipe for ControlPipe<'d, T> { } } - fn data_out<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::DataOutFuture<'a> { + fn data_out<'a>( + &'a mut self, + buf: &'a mut [u8], + _first: bool, + _last: bool, + ) -> Self::DataOutFuture<'a> { async move { let regs = T::regs(); @@ -716,13 +701,17 @@ impl<'d, T: Instance> driver::ControlPipe for ControlPipe<'d, T> { } } - fn data_in<'a>(&'a mut self, buf: &'a [u8], last_packet: bool) -> Self::DataInFuture<'a> { + fn data_in<'a>( + &'a mut self, + buf: &'a [u8], + _first: bool, + last: bool, + ) -> Self::DataInFuture<'a> { async move { let regs = T::regs(); regs.events_ep0datadone.reset(); - regs.shorts - .write(|w| w.ep0datadone_ep0status().bit(last_packet)); + regs.shorts.write(|w| w.ep0datadone_ep0status().bit(last)); // This starts a TX on EP0. events_ep0datadone notifies when done. unsafe { write_dma::(0, buf) } @@ -753,15 +742,19 @@ impl<'d, T: Instance> driver::ControlPipe for ControlPipe<'d, T> { } } - fn accept(&mut self) { - let regs = T::regs(); - regs.tasks_ep0status - .write(|w| w.tasks_ep0status().bit(true)); + fn accept<'a>(&'a mut self) -> Self::AcceptFuture<'a> { + async move { + let regs = T::regs(); + regs.tasks_ep0status + .write(|w| w.tasks_ep0status().bit(true)); + } } - fn reject(&mut self) { - let regs = T::regs(); - regs.tasks_ep0stall.write(|w| w.tasks_ep0stall().bit(true)); + fn reject<'a>(&'a mut self) -> Self::RejectFuture<'a> { + async move { + let regs = T::regs(); + regs.tasks_ep0stall.write(|w| w.tasks_ep0stall().bit(true)); + } } } diff --git a/embassy-usb/src/builder.rs b/embassy-usb/src/builder.rs index 698a5f76..09904949 100644 --- a/embassy-usb/src/builder.rs +++ b/embassy-usb/src/builder.rs @@ -104,7 +104,7 @@ impl<'a> Config<'a> { device_class: 0x00, device_sub_class: 0x00, device_protocol: 0x00, - max_packet_size_0: 8, + max_packet_size_0: 64, vendor_id: vid, product_id: pid, device_release: 0x0010, diff --git a/embassy-usb/src/driver.rs b/embassy-usb/src/driver.rs index 8454b041..0680df7a 100644 --- a/embassy-usb/src/driver.rs +++ b/embassy-usb/src/driver.rs @@ -118,17 +118,8 @@ pub trait Endpoint { /// Get the endpoint address fn info(&self) -> &EndpointInfo; - /// Sets or clears the STALL condition for an endpoint. If the endpoint is an OUT endpoint, it - /// should be prepared to receive data again. - fn set_stalled(&self, stalled: bool); - - /// Gets whether the STALL condition is set for an endpoint. - fn is_stalled(&self) -> bool; - /// Waits for the endpoint to be enabled. fn wait_enabled(&mut self) -> Self::WaitEnabledFuture<'_>; - - // TODO enable/disable? } pub trait EndpointOut: Endpoint { @@ -151,6 +142,12 @@ pub trait ControlPipe { where Self: 'a; type DataInFuture<'a>: Future> + 'a + where + Self: 'a; + type AcceptFuture<'a>: Future + 'a + where + Self: 'a; + type RejectFuture<'a>: Future + 'a where Self: 'a; @@ -164,22 +161,28 @@ pub trait ControlPipe { /// /// Must be called after `setup()` for requests with `direction` of `Out` /// and `length` greater than zero. - fn data_out<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::DataOutFuture<'a>; + fn data_out<'a>( + &'a mut self, + buf: &'a mut [u8], + first: bool, + last: bool, + ) -> Self::DataOutFuture<'a>; /// Sends a DATA IN packet with `data` in response to a control read request. /// /// If `last_packet` is true, the STATUS packet will be ACKed following the transfer of `data`. - fn data_in<'a>(&'a mut self, data: &'a [u8], last_packet: bool) -> Self::DataInFuture<'a>; + fn data_in<'a>(&'a mut self, data: &'a [u8], first: bool, last: bool) + -> Self::DataInFuture<'a>; /// Accepts a control request. /// /// Causes the STATUS packet for the current request to be ACKed. - fn accept(&mut self); + fn accept<'a>(&'a mut self) -> Self::AcceptFuture<'a>; /// Rejects a control request. /// /// Sets a STALL condition on the pipe to indicate an error. - fn reject(&mut self); + fn reject<'a>(&'a mut self) -> Self::RejectFuture<'a>; } pub trait EndpointIn: Endpoint { diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index 7b85a288..b691bf11 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -119,7 +119,14 @@ struct Inner<'d, D: Driver<'d>> { suspended: bool, remote_wakeup_enabled: bool, self_powered: bool, - pending_address: u8, + + /// Our device address, or 0 if none. + address: u8, + /// When receiving a set addr control request, we have to apply it AFTER we've + /// finished handling the control request, as the status stage still has to be + /// handled with addr 0. + /// If true, do a set_addr after finishing the current control req. + set_address_pending: bool, interfaces: Vec, MAX_INTERFACE_COUNT>, } @@ -154,7 +161,8 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { suspended: false, remote_wakeup_enabled: false, self_powered: false, - pending_address: 0, + address: 0, + set_address_pending: false, interfaces, }, } @@ -255,6 +263,11 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { UsbDirection::In => self.handle_control_in(req).await, UsbDirection::Out => self.handle_control_out(req).await, } + + if self.inner.set_address_pending { + self.inner.bus.set_address(self.inner.address); + self.inner.set_address_pending = false; + } } async fn handle_control_in(&mut self, req: Request) { @@ -266,7 +279,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { // a full-length packet is a short packet, thinking we're done sending data. // See https://github.com/hathach/tinyusb/issues/184 const DEVICE_DESCRIPTOR_LEN: usize = 18; - if self.inner.pending_address == 0 + if self.inner.address == 0 && max_packet_size < DEVICE_DESCRIPTOR_LEN && (max_packet_size as usize) < resp_length { @@ -279,12 +292,12 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { let len = data.len().min(resp_length); let need_zlp = len != resp_length && (len % usize::from(max_packet_size)) == 0; - let mut chunks = data[0..len] + let chunks = data[0..len] .chunks(max_packet_size) .chain(need_zlp.then(|| -> &[u8] { &[] })); - while let Some(chunk) = chunks.next() { - match self.control.data_in(chunk, chunks.size_hint().0 == 0).await { + for (first, last, chunk) in first_last(chunks) { + match self.control.data_in(chunk, first, last).await { Ok(()) => {} Err(e) => { warn!("control accept_in failed: {:?}", e); @@ -293,7 +306,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } } } - InResponse::Rejected => self.control.reject(), + InResponse::Rejected => self.control.reject().await, } } @@ -302,8 +315,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { let max_packet_size = self.control.max_packet_size(); let mut total = 0; - for chunk in self.control_buf[..req_length].chunks_mut(max_packet_size) { - let size = match self.control.data_out(chunk).await { + let chunks = self.control_buf[..req_length].chunks_mut(max_packet_size); + for (first, last, chunk) in first_last(chunks) { + let size = match self.control.data_out(chunk, first, last).await { Ok(x) => x, Err(e) => { warn!("usb: failed to read CONTROL OUT data stage: {:?}", e); @@ -323,8 +337,8 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { trace!(" control out data: {:02x?}", data); match self.inner.handle_control_out(req, data) { - OutResponse::Accepted => self.control.accept(), - OutResponse::Rejected => self.control.reject(), + OutResponse::Accepted => self.control.accept().await, + OutResponse::Rejected => self.control.reject().await, } } } @@ -337,7 +351,7 @@ impl<'d, D: Driver<'d>> Inner<'d, D> { self.device_state = UsbDeviceState::Default; self.suspended = false; self.remote_wakeup_enabled = false; - self.pending_address = 0; + self.address = 0; for iface in self.interfaces.iter_mut() { iface.current_alt_setting = 0; @@ -389,11 +403,11 @@ impl<'d, D: Driver<'d>> Inner<'d, D> { OutResponse::Accepted } (Request::SET_ADDRESS, addr @ 1..=127) => { - self.pending_address = addr as u8; - self.bus.set_address(self.pending_address); + self.address = addr as u8; + self.set_address_pending = true; self.device_state = UsbDeviceState::Addressed; if let Some(h) = &self.handler { - h.addressed(self.pending_address); + h.addressed(self.address); } OutResponse::Accepted } @@ -655,3 +669,15 @@ impl<'d, D: Driver<'d>> Inner<'d, D> { } } } + +fn first_last(iter: T) -> impl Iterator { + let mut iter = iter.peekable(); + let mut first = true; + core::iter::from_fn(move || { + let val = iter.next()?; + let is_first = first; + first = false; + let is_last = iter.peek().is_none(); + Some((is_first, is_last, val)) + }) +}