Add support for USB classes handling control requests.

This commit is contained in:
alexmoon
2022-03-25 16:46:14 -04:00
committed by Dario Nieuwenhuis
parent 5c0db627fe
commit bdc6e0481c
8 changed files with 703 additions and 291 deletions

View File

@ -9,6 +9,7 @@ use embassy::time::{with_timeout, Duration};
use embassy::util::Unborrow;
use embassy::waitqueue::AtomicWaker;
use embassy_hal_common::unborrow;
use embassy_usb::control::Request;
use embassy_usb::driver::{self, Event, ReadError, WriteError};
use embassy_usb::types::{EndpointAddress, EndpointInfo, EndpointType, UsbDirection};
use futures::future::poll_fn;
@ -134,6 +135,7 @@ impl<'d, T: Instance> Driver<'d, T> {
impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> {
type EndpointOut = Endpoint<'d, T, Out>;
type EndpointIn = Endpoint<'d, T, In>;
type ControlPipe = ControlPipe<'d, T>;
type Bus = Bus<'d, T>;
fn alloc_endpoint_in(
@ -174,6 +176,19 @@ impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> {
}))
}
fn alloc_control_pipe(
&mut self,
max_packet_size: u16,
) -> Result<Self::ControlPipe, driver::EndpointAllocError> {
self.alloc_endpoint_out(Some(0x00.into()), EndpointType::Control, max_packet_size, 0)?;
self.alloc_endpoint_in(Some(0x80.into()), EndpointType::Control, max_packet_size, 0)?;
Ok(ControlPipe {
_phantom: PhantomData,
max_packet_size,
request: None,
})
}
fn enable(self) -> Self::Bus {
let regs = T::regs();
@ -344,99 +359,110 @@ impl<'d, T: Instance, Dir> driver::Endpoint for Endpoint<'d, T, Dir> {
}
}
unsafe fn read_dma<T: Instance>(i: usize, buf: &mut [u8]) -> Result<usize, ReadError> {
let regs = T::regs();
// Check that the packet fits into the buffer
let size = regs.size.epout[0].read().bits() as usize;
if size > buf.len() {
return Err(ReadError::BufferOverflow);
}
if i == 0 {
regs.events_ep0datadone.reset();
}
let epout = [
&regs.epout0,
&regs.epout1,
&regs.epout2,
&regs.epout3,
&regs.epout4,
&regs.epout5,
&regs.epout6,
&regs.epout7,
];
epout[i].ptr.write(|w| w.bits(buf.as_ptr() as u32));
// MAXCNT must match SIZE
epout[i].maxcnt.write(|w| w.bits(size as u32));
dma_start();
regs.events_endepout[i].reset();
regs.tasks_startepout[i].write(|w| w.tasks_startepout().set_bit());
while regs.events_endepout[i]
.read()
.events_endepout()
.bit_is_clear()
{}
regs.events_endepout[i].reset();
dma_end();
regs.size.epout[i].reset();
Ok(size)
}
unsafe fn write_dma<T: Instance>(i: usize, buf: &[u8]) -> Result<(), WriteError> {
let regs = T::regs();
if buf.len() > 64 {
return Err(WriteError::BufferOverflow);
}
// EasyDMA can't read FLASH, so we copy through RAM
let mut ram_buf: MaybeUninit<[u8; 64]> = MaybeUninit::uninit();
let ptr = ram_buf.as_mut_ptr() as *mut u8;
core::ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len());
let epin = [
&regs.epin0,
&regs.epin1,
&regs.epin2,
&regs.epin3,
&regs.epin4,
&regs.epin5,
&regs.epin6,
&regs.epin7,
];
// Set the buffer length so the right number of bytes are transmitted.
// Safety: `buf.len()` has been checked to be <= the max buffer length.
epin[i].ptr.write(|w| w.bits(ptr as u32));
epin[i].maxcnt.write(|w| w.maxcnt().bits(buf.len() as u8));
regs.events_endepin[i].reset();
dma_start();
regs.tasks_startepin[i].write(|w| w.bits(1));
while regs.events_endepin[i].read().bits() == 0 {}
dma_end();
Ok(())
}
impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> {
type ReadFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a;
fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
async move {
let regs = T::regs();
let i = self.info.addr.index();
assert!(i != 0);
if i == 0 {
if buf.len() == 0 {
regs.tasks_ep0status.write(|w| unsafe { w.bits(1) });
return Ok(0);
// Wait until ready
poll_fn(|cx| {
EP_OUT_WAKERS[i].register(cx.waker());
let r = READY_ENDPOINTS.load(Ordering::Acquire);
if r & (1 << (i + 16)) != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
// Wait for SETUP packet
regs.events_ep0setup.reset();
regs.intenset.write(|w| w.ep0setup().set());
poll_fn(|cx| {
EP_OUT_WAKERS[0].register(cx.waker());
let regs = T::regs();
if regs.events_ep0setup.read().bits() != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
if buf.len() < 8 {
return Err(ReadError::BufferOverflow);
}
buf[0] = regs.bmrequesttype.read().bits() as u8;
buf[1] = regs.brequest.read().brequest().bits();
buf[2] = regs.wvaluel.read().wvaluel().bits();
buf[3] = regs.wvalueh.read().wvalueh().bits();
buf[4] = regs.windexl.read().windexl().bits();
buf[5] = regs.windexh.read().windexh().bits();
buf[6] = regs.wlengthl.read().wlengthl().bits();
buf[7] = regs.wlengthh.read().wlengthh().bits();
Ok(8)
} else {
// Wait until ready
poll_fn(|cx| {
EP_OUT_WAKERS[i].register(cx.waker());
let r = READY_ENDPOINTS.load(Ordering::Acquire);
if r & (1 << (i + 16)) != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
// Check that the packet fits into the buffer
let size = regs.size.epout[i].read().bits();
if size as usize > buf.len() {
return Err(ReadError::BufferOverflow);
}
let epout = [
&regs.epout0,
&regs.epout1,
&regs.epout2,
&regs.epout3,
&regs.epout4,
&regs.epout5,
&regs.epout6,
&regs.epout7,
];
epout[i]
.ptr
.write(|w| unsafe { w.bits(buf.as_ptr() as u32) });
// MAXCNT must match SIZE
epout[i].maxcnt.write(|w| unsafe { w.bits(size) });
dma_start();
regs.events_endepout[i].reset();
regs.tasks_startepout[i].write(|w| w.tasks_startepout().set_bit());
while regs.events_endepout[i]
.read()
.events_endepout()
.bit_is_clear()
{}
regs.events_endepout[i].reset();
dma_end();
Ok(size as usize)
}
unsafe { read_dma::<T>(i, buf) }
}
}
}
@ -446,87 +472,181 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> {
fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> {
async move {
let regs = T::regs();
let i = self.info.addr.index();
assert!(i != 0);
// Wait until ready.
if i != 0 {
poll_fn(|cx| {
EP_IN_WAKERS[i].register(cx.waker());
let r = READY_ENDPOINTS.load(Ordering::Acquire);
if r & (1 << i) != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << i), Ordering::AcqRel);
}
if i == 0 {
regs.events_ep0datadone.reset();
}
assert!(buf.len() <= 64);
// EasyDMA can't read FLASH, so we copy through RAM
let mut ram_buf: MaybeUninit<[u8; 64]> = MaybeUninit::uninit();
let ptr = ram_buf.as_mut_ptr() as *mut u8;
unsafe { core::ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len()) };
let epin = [
&regs.epin0,
&regs.epin1,
&regs.epin2,
&regs.epin3,
&regs.epin4,
&regs.epin5,
&regs.epin6,
&regs.epin7,
];
// Set the buffer length so the right number of bytes are transmitted.
// Safety: `buf.len()` has been checked to be <= the max buffer length.
unsafe {
epin[i].ptr.write(|w| w.bits(ptr as u32));
epin[i].maxcnt.write(|w| w.maxcnt().bits(buf.len() as u8));
}
regs.events_endepin[i].reset();
dma_start();
regs.tasks_startepin[i].write(|w| unsafe { w.bits(1) });
while regs.events_endepin[i].read().bits() == 0 {}
dma_end();
if i == 0 {
regs.intenset.write(|w| w.ep0datadone().set());
let res = with_timeout(
Duration::from_millis(10),
poll_fn(|cx| {
EP_IN_WAKERS[0].register(cx.waker());
let regs = T::regs();
if regs.events_ep0datadone.read().bits() != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
}),
)
.await;
if res.is_err() {
// todo wrong error
return Err(driver::WriteError::BufferOverflow);
poll_fn(|cx| {
EP_IN_WAKERS[i].register(cx.waker());
let r = READY_ENDPOINTS.load(Ordering::Acquire);
if r & (1 << i) != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
// Mark as not ready
READY_ENDPOINTS.fetch_and(!(1 << i), Ordering::AcqRel);
unsafe { write_dma::<T>(i, buf) }
}
}
}
pub struct ControlPipe<'d, T: Instance> {
_phantom: PhantomData<&'d mut T>,
max_packet_size: u16,
request: Option<Request>,
}
impl<'d, T: Instance> ControlPipe<'d, T> {
async fn write(&mut self, buf: &[u8], last_chunk: bool) {
let regs = T::regs();
regs.events_ep0datadone.reset();
unsafe {
write_dma::<T>(0, buf).unwrap();
}
regs.shorts
.modify(|_, w| w.ep0datadone_ep0status().bit(last_chunk));
regs.intenset.write(|w| w.ep0datadone().set());
let res = with_timeout(
Duration::from_millis(10),
poll_fn(|cx| {
EP_IN_WAKERS[0].register(cx.waker());
let regs = T::regs();
if regs.events_ep0datadone.read().bits() != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
}),
)
.await;
if res.is_err() {
error!("ControlPipe::write timed out.");
}
}
}
impl<'d, T: Instance> driver::ControlPipe for ControlPipe<'d, T> {
type SetupFuture<'a> = impl Future<Output = Request> + 'a where Self: 'a;
type DataOutFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a;
type AcceptInFuture<'a> = impl Future<Output = ()> + 'a where Self: 'a;
fn setup<'a>(&'a mut self) -> Self::SetupFuture<'a> {
async move {
assert!(self.request.is_none());
let regs = T::regs();
// Wait for SETUP packet
regs.intenset.write(|w| w.ep0setup().set());
poll_fn(|cx| {
EP_OUT_WAKERS[0].register(cx.waker());
let regs = T::regs();
if regs.events_ep0setup.read().bits() != 0 {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
// Reset shorts
regs.shorts
.modify(|_, w| w.ep0datadone_ep0status().clear_bit());
regs.events_ep0setup.reset();
let mut buf = [0; 8];
buf[0] = regs.bmrequesttype.read().bits() as u8;
buf[1] = regs.brequest.read().brequest().bits();
buf[2] = regs.wvaluel.read().wvaluel().bits();
buf[3] = regs.wvalueh.read().wvalueh().bits();
buf[4] = regs.windexl.read().windexl().bits();
buf[5] = regs.windexh.read().windexh().bits();
buf[6] = regs.wlengthl.read().wlengthl().bits();
buf[7] = regs.wlengthh.read().wlengthh().bits();
let req = Request::parse(&buf);
if req.direction == UsbDirection::Out {
regs.tasks_ep0rcvout
.write(|w| w.tasks_ep0rcvout().set_bit());
}
Ok(())
self.request = Some(req);
req
}
}
fn data_out<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::DataOutFuture<'a> {
async move {
let req = self.request.unwrap();
assert_eq!(req.direction, UsbDirection::Out);
assert!(req.length > 0);
assert!(buf.len() >= usize::from(req.length));
let regs = T::regs();
// Wait until ready
regs.intenset.write(|w| w.ep0datadone().set());
poll_fn(|cx| {
EP_OUT_WAKERS[0].register(cx.waker());
let regs = T::regs();
if regs
.events_ep0datadone
.read()
.events_ep0datadone()
.bit_is_set()
{
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
unsafe { read_dma::<T>(0, buf) }
}
}
fn accept(&mut self) {
let regs = T::regs();
regs.tasks_ep0status
.write(|w| w.tasks_ep0status().bit(true));
self.request = None;
}
fn accept_in<'a>(&'a mut self, buf: &'a [u8]) -> Self::AcceptInFuture<'a> {
async move {
info!("control accept {=[u8]:x}", buf);
let req = self.request.unwrap();
assert_eq!(req.direction, UsbDirection::In);
let req_len = usize::from(req.length);
let len = buf.len().min(req_len);
let need_zlp = len != req_len && (len % usize::from(self.max_packet_size)) == 0;
let mut chunks = buf[0..len]
.chunks(usize::from(self.max_packet_size))
.chain(need_zlp.then(|| -> &[u8] { &[] }));
while let Some(chunk) = chunks.next() {
self.write(chunk, chunks.size_hint().0 == 0).await;
}
self.request = None;
}
}
fn reject(&mut self) {
let regs = T::regs();
regs.tasks_ep0stall.write(|w| w.tasks_ep0stall().bit(true));
self.request = None;
}
}
fn dma_start() {