From 890e93b4f009afe37b35e211567acfbf58be783a Mon Sep 17 00:00:00 2001 From: Thales Fragoso Date: Thu, 18 Feb 2021 21:57:35 -0300 Subject: [PATCH] Start working on usb serial --- embassy-stm32f4-examples/Cargo.toml | 5 +- .../src/bin/usb_serial.rs | 119 ++++++ embassy-stm32f4/Cargo.toml | 1 + embassy-stm32f4/src/cdc_acm.rs | 338 ++++++++++++++++++ embassy-stm32f4/src/lib.rs | 7 + embassy-stm32f4/src/usb.rs | 130 +++++++ embassy-stm32f4/src/usb_serial.rs | 290 +++++++++++++++ embassy-stm32f4/src/util/mod.rs | 12 + embassy-stm32f4/src/util/peripheral.rs | 78 ++++ embassy-stm32f4/src/util/ring_buffer.rs | 86 +++++ 10 files changed, 1064 insertions(+), 2 deletions(-) create mode 100644 embassy-stm32f4-examples/src/bin/usb_serial.rs create mode 100644 embassy-stm32f4/src/cdc_acm.rs create mode 100644 embassy-stm32f4/src/usb.rs create mode 100644 embassy-stm32f4/src/usb_serial.rs create mode 100644 embassy-stm32f4/src/util/mod.rs create mode 100644 embassy-stm32f4/src/util/peripheral.rs create mode 100644 embassy-stm32f4/src/util/ring_buffer.rs diff --git a/embassy-stm32f4-examples/Cargo.toml b/embassy-stm32f4-examples/Cargo.toml index 5bbaecc5..e4f2aa7a 100644 --- a/embassy-stm32f4-examples/Cargo.toml +++ b/embassy-stm32f4-examples/Cargo.toml @@ -46,7 +46,8 @@ cortex-m = "0.7.1" cortex-m-rt = "0.6.13" embedded-hal = { version = "0.2.4" } panic-probe = "0.1.0" -stm32f4xx-hal = { version = "0.8.3", features = ["rt"], git = "https://github.com/stm32-rs/stm32f4xx-hal.git"} +stm32f4xx-hal = { version = "0.8.3", features = ["rt", "usb_fs"], git = "https://github.com/stm32-rs/stm32f4xx-hal.git"} futures = { version = "0.3.8", default-features = false, features = ["async-await"] } rtt-target = { version = "0.3", features = ["cortex-m"] } -bxcan = "0.5.0" \ No newline at end of file +bxcan = "0.5.0" +usb-device = "0.2.7" diff --git a/embassy-stm32f4-examples/src/bin/usb_serial.rs b/embassy-stm32f4-examples/src/bin/usb_serial.rs new file mode 100644 index 00000000..d2ccb4b2 --- /dev/null +++ b/embassy-stm32f4-examples/src/bin/usb_serial.rs @@ -0,0 +1,119 @@ +#![no_std] +#![no_main] +#![feature(type_alias_impl_trait)] + +#[path = "../example_common.rs"] +mod example_common; +use example_common::*; + +use cortex_m_rt::entry; +use defmt::panic; +use embassy::executor::{task, Executor}; +use embassy::io::{AsyncBufReadExt, AsyncWriteExt}; +use embassy::time::{Duration, Timer}; +use embassy::util::Forever; +use embassy_stm32f4::interrupt::OwnedInterrupt; +use embassy_stm32f4::usb::Usb; +use embassy_stm32f4::usb_serial::UsbSerial; +use embassy_stm32f4::{interrupt, pac, rtc}; +use futures::future::{select, Either}; +use futures::pin_mut; +use stm32f4xx_hal::otg_fs::{UsbBus, USB}; +use stm32f4xx_hal::prelude::*; +use usb_device::bus::UsbBusAllocator; +use usb_device::prelude::*; + +#[task] +async fn run1(bus: &'static mut UsbBusAllocator>) { + info!("Async task"); + + let mut read_buf = [0u8; 128]; + let mut write_buf = [0u8; 128]; + let serial = UsbSerial::new(bus, &mut read_buf, &mut write_buf); + + let device = UsbDeviceBuilder::new(bus, UsbVidPid(0x16c0, 0x27dd)) + .manufacturer("Fake company") + .product("Serial port") + .serial_number("TEST") + .device_class(0x02) + .build(); + + let irq = interrupt::take!(OTG_FS); + irq.set_priority(interrupt::Priority::Level3); + + let usb = Usb::new(device, serial, irq); + pin_mut!(usb); + + let (mut read_interface, mut write_interface) = usb.as_mut().into_ref().take_serial(); + + let mut buf = [0u8; 5]; + loop { + let recv_fut = read_interface.read(&mut buf); + let timeout = Timer::after(Duration::from_ticks(32768 * 3)); + + match select(recv_fut, timeout).await { + Either::Left((recv, _)) => { + let recv = unwrap!(recv); + unwrap!(write_interface.write_all(&buf[..recv]).await); + } + Either::Right(_) => { + unwrap!(write_interface.write_all(b"Hello\r\n").await); + } + } + } +} + +static RTC: Forever> = Forever::new(); +static ALARM: Forever> = Forever::new(); +static EXECUTOR: Forever = Forever::new(); +static USB_BUS: Forever>> = Forever::new(); + +#[entry] +fn main() -> ! { + static mut EP_MEMORY: [u32; 1024] = [0; 1024]; + + info!("Hello World!"); + + let p = unwrap!(pac::Peripherals::take()); + + p.RCC.ahb1enr.modify(|_, w| w.dma1en().enabled()); + let rcc = p.RCC.constrain(); + let clocks = rcc + .cfgr + .use_hse(25.mhz()) + .sysclk(48.mhz()) + .require_pll48clk() + .freeze(); + + p.DBGMCU.cr.modify(|_, w| { + w.dbg_sleep().set_bit(); + w.dbg_standby().set_bit(); + w.dbg_stop().set_bit() + }); + + let rtc = RTC.put(rtc::RTC::new(p.TIM2, interrupt::take!(TIM2), clocks)); + rtc.start(); + + unsafe { embassy::time::set_clock(rtc) }; + + let alarm = ALARM.put(rtc.alarm1()); + let executor = EXECUTOR.put(Executor::new()); + executor.set_alarm(alarm); + + let gpioa = p.GPIOA.split(); + let usb = USB { + usb_global: p.OTG_FS_GLOBAL, + usb_device: p.OTG_FS_DEVICE, + usb_pwrclk: p.OTG_FS_PWRCLK, + pin_dm: gpioa.pa11.into_alternate_af10(), + pin_dp: gpioa.pa12.into_alternate_af10(), + hclk: clocks.hclk(), + }; + // Rust analyzer isn't recognizing the static ref magic `cortex-m` does + #[allow(unused_unsafe)] + let usb_bus = USB_BUS.put(UsbBus::new(usb, unsafe { EP_MEMORY })); + + executor.run(move |spawner| { + unwrap!(spawner.spawn(run1(usb_bus))); + }); +} diff --git a/embassy-stm32f4/Cargo.toml b/embassy-stm32f4/Cargo.toml index ae3273d6..b39a141b 100644 --- a/embassy-stm32f4/Cargo.toml +++ b/embassy-stm32f4/Cargo.toml @@ -41,3 +41,4 @@ embedded-dma = { version = "0.1.2" } stm32f4xx-hal = { version = "0.8.3", features = ["rt", "can"], git = "https://github.com/stm32-rs/stm32f4xx-hal.git"} bxcan = "0.5.0" nb = "*" +usb-device = "0.2.7" diff --git a/embassy-stm32f4/src/cdc_acm.rs b/embassy-stm32f4/src/cdc_acm.rs new file mode 100644 index 00000000..5a85b384 --- /dev/null +++ b/embassy-stm32f4/src/cdc_acm.rs @@ -0,0 +1,338 @@ +// Copied from https://github.com/mvirkkunen/usbd-serial +#![allow(dead_code)] + +use core::convert::TryInto; +use core::mem; +use usb_device::class_prelude::*; +use usb_device::Result; + +/// This should be used as `device_class` when building the `UsbDevice`. +pub const USB_CLASS_CDC: u8 = 0x02; + +const USB_CLASS_CDC_DATA: u8 = 0x0a; +const CDC_SUBCLASS_ACM: u8 = 0x02; +const CDC_PROTOCOL_NONE: u8 = 0x00; + +const CS_INTERFACE: u8 = 0x24; +const CDC_TYPE_HEADER: u8 = 0x00; +const CDC_TYPE_CALL_MANAGEMENT: u8 = 0x01; +const CDC_TYPE_ACM: u8 = 0x02; +const CDC_TYPE_UNION: u8 = 0x06; + +const REQ_SEND_ENCAPSULATED_COMMAND: u8 = 0x00; +#[allow(unused)] +const REQ_GET_ENCAPSULATED_COMMAND: u8 = 0x01; +const REQ_SET_LINE_CODING: u8 = 0x20; +const REQ_GET_LINE_CODING: u8 = 0x21; +const REQ_SET_CONTROL_LINE_STATE: u8 = 0x22; + +/// Packet level implementation of a CDC-ACM serial port. +/// +/// This class can be used directly and it has the least overhead due to directly reading and +/// writing USB packets with no intermediate buffers, but it will not act like a stream-like serial +/// port. The following constraints must be followed if you use this class directly: +/// +/// - `read_packet` must be called with a buffer large enough to hold max_packet_size bytes, and the +/// method will return a `WouldBlock` error if there is no packet to be read. +/// - `write_packet` must not be called with a buffer larger than max_packet_size bytes, and the +/// method will return a `WouldBlock` error if the previous packet has not been sent yet. +/// - If you write a packet that is exactly max_packet_size bytes long, it won't be processed by the +/// host operating system until a subsequent shorter packet is sent. A zero-length packet (ZLP) +/// can be sent if there is no other data to send. This is because USB bulk transactions must be +/// terminated with a short packet, even if the bulk endpoint is used for stream-like data. +pub struct CdcAcmClass<'a, B: UsbBus> { + comm_if: InterfaceNumber, + comm_ep: EndpointIn<'a, B>, + data_if: InterfaceNumber, + read_ep: EndpointOut<'a, B>, + write_ep: EndpointIn<'a, B>, + line_coding: LineCoding, + dtr: bool, + rts: bool, +} + +impl CdcAcmClass<'_, B> { + /// Creates a new CdcAcmClass with the provided UsbBus and max_packet_size in bytes. For + /// full-speed devices, max_packet_size has to be one of 8, 16, 32 or 64. + pub fn new(alloc: &UsbBusAllocator, max_packet_size: u16) -> CdcAcmClass<'_, B> { + CdcAcmClass { + comm_if: alloc.interface(), + comm_ep: alloc.interrupt(8, 255), + data_if: alloc.interface(), + read_ep: alloc.bulk(max_packet_size), + write_ep: alloc.bulk(max_packet_size), + line_coding: LineCoding { + stop_bits: StopBits::One, + data_bits: 8, + parity_type: ParityType::None, + data_rate: 8_000, + }, + dtr: false, + rts: false, + } + } + + /// Gets the maximum packet size in bytes. + pub fn max_packet_size(&self) -> u16 { + // The size is the same for both endpoints. + self.read_ep.max_packet_size() + } + + /// Gets the current line coding. The line coding contains information that's mainly relevant + /// for USB to UART serial port emulators, and can be ignored if not relevant. + pub fn line_coding(&self) -> &LineCoding { + &self.line_coding + } + + /// Gets the DTR (data terminal ready) state + pub fn dtr(&self) -> bool { + self.dtr + } + + /// Gets the RTS (request to send) state + pub fn rts(&self) -> bool { + self.rts + } + + /// Writes a single packet into the IN endpoint. + pub fn write_packet(&mut self, data: &[u8]) -> Result { + self.write_ep.write(data) + } + + /// Reads a single packet from the OUT endpoint. + pub fn read_packet(&mut self, data: &mut [u8]) -> Result { + self.read_ep.read(data) + } + + /// Gets the address of the IN endpoint. + pub fn write_ep_address(&self) -> EndpointAddress { + self.write_ep.address() + } + + /// Gets the address of the OUT endpoint. + pub fn read_ep_address(&self) -> EndpointAddress { + self.read_ep.address() + } +} + +impl UsbClass for CdcAcmClass<'_, B> { + fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<()> { + writer.iad( + self.comm_if, + 2, + USB_CLASS_CDC, + CDC_SUBCLASS_ACM, + CDC_PROTOCOL_NONE, + )?; + + writer.interface( + self.comm_if, + USB_CLASS_CDC, + CDC_SUBCLASS_ACM, + CDC_PROTOCOL_NONE, + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_HEADER, // bDescriptorSubtype + 0x10, + 0x01, // bcdCDC (1.10) + ], + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_ACM, // bDescriptorSubtype + 0x00, // bmCapabilities + ], + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_UNION, // bDescriptorSubtype + self.comm_if.into(), // bControlInterface + self.data_if.into(), // bSubordinateInterface + ], + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_CALL_MANAGEMENT, // bDescriptorSubtype + 0x00, // bmCapabilities + self.data_if.into(), // bDataInterface + ], + )?; + + writer.endpoint(&self.comm_ep)?; + + writer.interface(self.data_if, USB_CLASS_CDC_DATA, 0x00, 0x00)?; + + writer.endpoint(&self.write_ep)?; + writer.endpoint(&self.read_ep)?; + + Ok(()) + } + + fn reset(&mut self) { + self.line_coding = LineCoding::default(); + self.dtr = false; + self.rts = false; + } + + fn control_in(&mut self, xfer: ControlIn) { + let req = xfer.request(); + + if !(req.request_type == control::RequestType::Class + && req.recipient == control::Recipient::Interface + && req.index == u8::from(self.comm_if) as u16) + { + return; + } + + match req.request { + // REQ_GET_ENCAPSULATED_COMMAND is not really supported - it will be rejected below. + REQ_GET_LINE_CODING if req.length == 7 => { + xfer.accept(|data| { + data[0..4].copy_from_slice(&self.line_coding.data_rate.to_le_bytes()); + data[4] = self.line_coding.stop_bits as u8; + data[5] = self.line_coding.parity_type as u8; + data[6] = self.line_coding.data_bits; + + Ok(7) + }) + .ok(); + } + _ => { + xfer.reject().ok(); + } + } + } + + fn control_out(&mut self, xfer: ControlOut) { + let req = xfer.request(); + + if !(req.request_type == control::RequestType::Class + && req.recipient == control::Recipient::Interface + && req.index == u8::from(self.comm_if) as u16) + { + return; + } + + match req.request { + REQ_SEND_ENCAPSULATED_COMMAND => { + // We don't actually support encapsulated commands but pretend we do for standards + // compatibility. + xfer.accept().ok(); + } + REQ_SET_LINE_CODING if xfer.data().len() >= 7 => { + self.line_coding.data_rate = + u32::from_le_bytes(xfer.data()[0..4].try_into().unwrap()); + self.line_coding.stop_bits = xfer.data()[4].into(); + self.line_coding.parity_type = xfer.data()[5].into(); + self.line_coding.data_bits = xfer.data()[6]; + + xfer.accept().ok(); + } + REQ_SET_CONTROL_LINE_STATE => { + self.dtr = (req.value & 0x0001) != 0; + self.rts = (req.value & 0x0002) != 0; + + xfer.accept().ok(); + } + _ => { + xfer.reject().ok(); + } + }; + } +} + +/// Number of stop bits for LineCoding +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum StopBits { + /// 1 stop bit + One = 0, + + /// 1.5 stop bits + OnePointFive = 1, + + /// 2 stop bits + Two = 2, +} + +impl From for StopBits { + fn from(value: u8) -> Self { + if value <= 2 { + unsafe { mem::transmute(value) } + } else { + StopBits::One + } + } +} + +/// Parity for LineCoding +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum ParityType { + None = 0, + Odd = 1, + Event = 2, + Mark = 3, + Space = 4, +} + +impl From for ParityType { + fn from(value: u8) -> Self { + if value <= 4 { + unsafe { mem::transmute(value) } + } else { + ParityType::None + } + } +} + +/// Line coding parameters +/// +/// This is provided by the host for specifying the standard UART parameters such as baud rate. Can +/// be ignored if you don't plan to interface with a physical UART. +pub struct LineCoding { + stop_bits: StopBits, + data_bits: u8, + parity_type: ParityType, + data_rate: u32, +} + +impl LineCoding { + /// Gets the number of stop bits for UART communication. + pub fn stop_bits(&self) -> StopBits { + self.stop_bits + } + + /// Gets the number of data bits for UART communication. + pub fn data_bits(&self) -> u8 { + self.data_bits + } + + /// Gets the parity type for UART communication. + pub fn parity_type(&self) -> ParityType { + self.parity_type + } + + /// Gets the data rate in bits per second for UART communication. + pub fn data_rate(&self) -> u32 { + self.data_rate + } +} + +impl Default for LineCoding { + fn default() -> Self { + LineCoding { + stop_bits: StopBits::One, + data_bits: 8, + parity_type: ParityType::None, + data_rate: 8_000, + } + } +} diff --git a/embassy-stm32f4/src/lib.rs b/embassy-stm32f4/src/lib.rs index 0d490525..1788f5e7 100644 --- a/embassy-stm32f4/src/lib.rs +++ b/embassy-stm32f4/src/lib.rs @@ -316,3 +316,10 @@ pub mod exti; pub mod qei; pub mod rtc; pub mod serial; +pub mod usb; +pub mod usb_serial; +pub mod util; + +pub(crate) mod cdc_acm; + +pub use cortex_m_rt::interrupt; diff --git a/embassy-stm32f4/src/usb.rs b/embassy-stm32f4/src/usb.rs new file mode 100644 index 00000000..613b9ecb --- /dev/null +++ b/embassy-stm32f4/src/usb.rs @@ -0,0 +1,130 @@ +use core::cell::RefCell; +use core::marker::PhantomData; +use core::pin::Pin; + +use usb_device::bus::UsbBus; +use usb_device::class::UsbClass; +use usb_device::device::UsbDevice; + +use crate::interrupt; +use crate::usb_serial::{ReadInterface, UsbSerial, WriteInterface}; +use crate::util::peripheral::{PeripheralMutex, PeripheralState}; + +pub struct State<'bus, B: UsbBus, T: ClassSet> { + device: UsbDevice<'bus, B>, + pub(crate) classes: T, +} + +pub struct Usb<'bus, B: UsbBus, T: ClassSet> { + // Don't you dare moving out `PeripheralMutex` + inner: RefCell>>, +} + +impl<'bus, B, T> Usb<'bus, B, T> +where + B: UsbBus, + T: ClassSet, +{ + pub fn new>( + device: UsbDevice<'bus, B>, + class_set: S, + irq: interrupt::OTG_FSInterrupt, + ) -> Self { + let state = State { + device, + classes: class_set.into_class_set(), + }; + let mutex = PeripheralMutex::new(state, irq); + Self { + inner: RefCell::new(mutex), + } + } + + pub fn start(self: Pin<&mut Self>) { + let this = unsafe { self.get_unchecked_mut() }; + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + + // Use inner to register the irq + mutex.with(|_, _| {}); + } +} + +impl<'bus, 'c, B, T> Usb<'bus, B, T> +where + B: UsbBus, + T: ClassSet + SerialState<'bus, 'c, B>, +{ + pub fn take_serial<'a>( + self: Pin<&'a Self>, + ) -> ( + ReadInterface<'a, 'bus, 'c, B, T>, + WriteInterface<'a, 'bus, 'c, B, T>, + ) { + let this = self.get_ref(); + + let r = ReadInterface { + inner: &this.inner, + _buf_lifetime: PhantomData, + }; + + let w = WriteInterface { + inner: &this.inner, + _buf_lifetime: PhantomData, + }; + (r, w) + } +} + +impl<'bus, B, T> PeripheralState for State<'bus, B, T> +where + B: UsbBus, + T: ClassSet, +{ + type Interrupt = interrupt::OTG_FSInterrupt; + fn on_interrupt(&mut self) { + self.classes.poll_all(&mut self.device); + } +} + +pub trait ClassSet { + fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool; +} + +pub trait IntoClassSet> { + fn into_class_set(self) -> C; +} + +pub struct ClassSet1> { + class: T, + _bus: PhantomData, +} + +impl ClassSet for ClassSet1 +where + B: UsbBus, + T: UsbClass, +{ + fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool { + device.poll(&mut [&mut self.class]) + } +} + +impl> IntoClassSet> for T { + fn into_class_set(self) -> ClassSet1 { + ClassSet1 { + class: self, + _bus: PhantomData, + } + } +} + +pub trait SerialState<'bus, 'a, B: UsbBus> { + fn get_serial(&mut self) -> &mut UsbSerial<'bus, 'a, B>; +} + +impl<'bus, 'a, B: UsbBus> SerialState<'bus, 'a, B> for ClassSet1> { + fn get_serial(&mut self) -> &mut UsbSerial<'bus, 'a, B> { + &mut self.class + } +} diff --git a/embassy-stm32f4/src/usb_serial.rs b/embassy-stm32f4/src/usb_serial.rs new file mode 100644 index 00000000..284d7e5f --- /dev/null +++ b/embassy-stm32f4/src/usb_serial.rs @@ -0,0 +1,290 @@ +use core::cell::RefCell; +use core::marker::{PhantomData, PhantomPinned}; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use embassy::io::{self, AsyncBufRead, AsyncWrite}; +use embassy::util::WakerRegistration; +use usb_device::bus::UsbBus; +use usb_device::class_prelude::*; +use usb_device::UsbError; + +use crate::cdc_acm::CdcAcmClass; +use crate::usb::{ClassSet, SerialState, State}; +use crate::util::peripheral::PeripheralMutex; +use crate::util::ring_buffer::RingBuffer; + +pub struct ReadInterface<'a, 'bus, 'c, B: UsbBus, T: SerialState<'bus, 'c, B> + ClassSet> { + // Don't you dare moving out `PeripheralMutex` + pub(crate) inner: &'a RefCell>>, + pub(crate) _buf_lifetime: PhantomData<&'c T>, +} + +/// Write interface for USB CDC_ACM +/// +/// This interface is buffered, meaning that after the write returns the bytes might not be fully +/// on the wire just yet +pub struct WriteInterface<'a, 'bus, 'c, B: UsbBus, T: SerialState<'bus, 'c, B> + ClassSet> { + // Don't you dare moving out `PeripheralMutex` + pub(crate) inner: &'a RefCell>>, + pub(crate) _buf_lifetime: PhantomData<&'c T>, +} + +impl<'a, 'bus, 'c, B, T> AsyncBufRead for ReadInterface<'a, 'bus, 'c, B, T> +where + B: UsbBus, + T: SerialState<'bus, 'c, B> + ClassSet, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + mutex.with(|state, _irq| { + let serial = state.classes.get_serial(); + let serial = Pin::new(serial); + + match serial.poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => { + let buf: &[u8] = buf; + let buf: &[u8] = unsafe { core::mem::transmute(buf) }; + Poll::Ready(Ok(buf)) + } + Poll::Ready(Err(_)) => Poll::Ready(Err(io::Error::Other)), + Poll::Pending => Poll::Pending, + } + }) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.get_mut(); + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + mutex.with(|state, _irq| { + let serial = state.classes.get_serial(); + let serial = Pin::new(serial); + + serial.consume(amt); + }) + } +} + +impl<'a, 'bus, 'c, B, T> AsyncWrite for WriteInterface<'a, 'bus, 'c, B, T> +where + B: UsbBus, + T: SerialState<'bus, 'c, B> + ClassSet, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + mutex.with(|state, _irq| { + let serial = state.classes.get_serial(); + let serial = Pin::new(serial); + + serial.poll_write(cx, buf) + }) + } +} + +pub struct UsbSerial<'bus, 'a, B: UsbBus> { + inner: CdcAcmClass<'bus, B>, + read_buf: RingBuffer<'a>, + write_buf: RingBuffer<'a>, + read_waker: WakerRegistration, + write_waker: WakerRegistration, + write_state: WriteState, + read_error: bool, + write_error: bool, +} + +impl<'bus, 'a, B: UsbBus> AsyncBufRead for UsbSerial<'bus, 'a, B> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if this.read_error { + this.read_error = false; + return Poll::Ready(Err(io::Error::Other)); + } + + let buf = this.read_buf.pop_buf(); + if buf.is_empty() { + this.read_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(buf)) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.get_mut().read_buf.pop(amt); + } +} + +impl<'bus, 'a, B: UsbBus> AsyncWrite for UsbSerial<'bus, 'a, B> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + + if this.write_error { + this.write_error = false; + return Poll::Ready(Err(io::Error::Other)); + } + + let write_buf = this.write_buf.push_buf(); + if write_buf.is_empty() { + this.write_waker.register(cx.waker()); + return Poll::Pending; + } + + let count = write_buf.len().min(buf.len()); + write_buf[..count].copy_from_slice(&buf[..count]); + this.write_buf.push(count); + + this.flush_write(); + Poll::Ready(Ok(count)) + } +} + +/// Keeps track of the type of the last written packet. +enum WriteState { + /// No packets in-flight + Idle, + + /// Short packet currently in-flight + Short, + + /// Full packet current in-flight. A full packet must be followed by a short packet for the host + /// OS to see the transaction. The data is the number of subsequent full packets sent so far. A + /// short packet is forced every SHORT_PACKET_INTERVAL packets so that the OS sees data in a + /// timely manner. + Full(usize), +} + +impl<'bus, 'a, B: UsbBus> UsbSerial<'bus, 'a, B> { + pub fn new( + alloc: &'bus UsbBusAllocator, + read_buf: &'a mut [u8], + write_buf: &'a mut [u8], + ) -> Self { + Self { + inner: CdcAcmClass::new(alloc, 64), + read_buf: RingBuffer::new(read_buf), + write_buf: RingBuffer::new(write_buf), + read_waker: WakerRegistration::new(), + write_waker: WakerRegistration::new(), + write_state: WriteState::Idle, + read_error: false, + write_error: false, + } + } + + fn flush_write(&mut self) { + /// If this many full size packets have been sent in a row, a short packet will be sent so that the + /// host sees the data in a timely manner. + const SHORT_PACKET_INTERVAL: usize = 10; + + let full_size_packets = match self.write_state { + WriteState::Full(c) => c, + _ => 0, + }; + + let ep_size = self.inner.max_packet_size() as usize; + let max_size = if full_size_packets > SHORT_PACKET_INTERVAL { + ep_size - 1 + } else { + ep_size + }; + + let buf = { + let buf = self.write_buf.pop_buf(); + if buf.len() > max_size { + &buf[..max_size] + } else { + buf + } + }; + + if !buf.is_empty() { + let count = match self.inner.write_packet(buf) { + Ok(c) => c, + Err(UsbError::WouldBlock) => 0, + Err(_) => { + self.write_error = true; + return; + } + }; + + if buf.len() == ep_size { + self.write_state = WriteState::Full(full_size_packets + 1); + } else { + self.write_state = WriteState::Short; + } + self.write_buf.pop(count); + } else if full_size_packets > 0 { + if let Err(e) = self.inner.write_packet(&[]) { + if !matches!(e, UsbError::WouldBlock) { + self.write_error = true; + } + return; + } + self.write_state = WriteState::Idle; + } + } +} + +impl UsbClass for UsbSerial<'_, '_, B> +where + B: UsbBus, +{ + fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<(), UsbError> { + self.inner.get_configuration_descriptors(writer) + } + + fn reset(&mut self) { + self.inner.reset(); + self.read_buf.clear(); + self.write_buf.clear(); + self.write_state = WriteState::Idle; + } + + fn endpoint_in_complete(&mut self, addr: EndpointAddress) { + if addr == self.inner.write_ep_address() { + self.write_waker.wake(); + + self.flush_write(); + } + } + + fn endpoint_out(&mut self, addr: EndpointAddress) { + if addr == self.inner.read_ep_address() { + let buf = self.read_buf.push_buf(); + let count = match self.inner.read_packet(buf) { + Ok(c) => c, + Err(UsbError::WouldBlock) => 0, + Err(_) => { + self.read_error = true; + return; + } + }; + + if count > 0 { + self.read_buf.push(count); + self.read_waker.wake(); + } + } + } + + fn control_in(&mut self, xfer: ControlIn) { + self.inner.control_in(xfer); + } + + fn control_out(&mut self, xfer: ControlOut) { + self.inner.control_out(xfer); + } +} diff --git a/embassy-stm32f4/src/util/mod.rs b/embassy-stm32f4/src/util/mod.rs new file mode 100644 index 00000000..cf330654 --- /dev/null +++ b/embassy-stm32f4/src/util/mod.rs @@ -0,0 +1,12 @@ +pub mod peripheral; +pub mod ring_buffer; + +/// Low power blocking wait loop using WFE/SEV. +pub fn low_power_wait_until(mut condition: impl FnMut() -> bool) { + while !condition() { + // WFE might "eat" an event that would have otherwise woken the executor. + cortex_m::asm::wfe(); + } + // Retrigger an event to be transparent to the executor. + cortex_m::asm::sev(); +} diff --git a/embassy-stm32f4/src/util/peripheral.rs b/embassy-stm32f4/src/util/peripheral.rs new file mode 100644 index 00000000..f2c7912f --- /dev/null +++ b/embassy-stm32f4/src/util/peripheral.rs @@ -0,0 +1,78 @@ +use core::cell::UnsafeCell; +use core::marker::{PhantomData, PhantomPinned}; +use core::pin::Pin; +use core::sync::atomic::{compiler_fence, Ordering}; + +use crate::interrupt::OwnedInterrupt; + +pub trait PeripheralState { + type Interrupt: OwnedInterrupt; + fn on_interrupt(&mut self); +} + +pub struct PeripheralMutex { + inner: Option<(UnsafeCell, S::Interrupt)>, + _not_send: PhantomData<*mut ()>, + _pinned: PhantomPinned, +} + +impl PeripheralMutex { + pub fn new(state: S, irq: S::Interrupt) -> Self { + Self { + inner: Some((UnsafeCell::new(state), irq)), + _not_send: PhantomData, + _pinned: PhantomPinned, + } + } + + pub fn with(self: Pin<&mut Self>, f: impl FnOnce(&mut S, &mut S::Interrupt) -> R) -> R { + let this = unsafe { self.get_unchecked_mut() }; + let (state, irq) = unwrap!(this.inner.as_mut()); + + irq.disable(); + compiler_fence(Ordering::SeqCst); + + irq.set_handler( + |p| { + // Safety: it's OK to get a &mut to the state, since + // - We're in the IRQ, no one else can't preempt us + // - We can't have preempted a with() call because the irq is disabled during it. + let state = unsafe { &mut *(p as *mut S) }; + state.on_interrupt(); + }, + state.get() as *mut (), + ); + + // Safety: it's OK to get a &mut to the state, since the irq is disabled. + let state = unsafe { &mut *state.get() }; + + let r = f(state, irq); + + compiler_fence(Ordering::SeqCst); + irq.enable(); + + r + } + + pub fn try_free(self: Pin<&mut Self>) -> Option<(S, S::Interrupt)> { + let this = unsafe { self.get_unchecked_mut() }; + this.inner.take().map(|(state, irq)| { + irq.disable(); + irq.remove_handler(); + (state.into_inner(), irq) + }) + } + + pub fn free(self: Pin<&mut Self>) -> (S, S::Interrupt) { + unwrap!(self.try_free()) + } +} + +impl Drop for PeripheralMutex { + fn drop(&mut self) { + if let Some((_state, irq)) = &mut self.inner { + irq.disable(); + irq.remove_handler(); + } + } +} diff --git a/embassy-stm32f4/src/util/ring_buffer.rs b/embassy-stm32f4/src/util/ring_buffer.rs new file mode 100644 index 00000000..0ef66f00 --- /dev/null +++ b/embassy-stm32f4/src/util/ring_buffer.rs @@ -0,0 +1,86 @@ +use crate::fmt::{assert, *}; + +pub struct RingBuffer<'a> { + buf: &'a mut [u8], + start: usize, + end: usize, + empty: bool, +} + +impl<'a> RingBuffer<'a> { + pub fn new(buf: &'a mut [u8]) -> Self { + Self { + buf, + start: 0, + end: 0, + empty: true, + } + } + + pub fn push_buf(&mut self) -> &mut [u8] { + if self.start == self.end && !self.empty { + trace!(" ringbuf: push_buf empty"); + return &mut self.buf[..0]; + } + + let n = if self.start <= self.end { + self.buf.len() - self.end + } else { + self.start - self.end + }; + + trace!(" ringbuf: push_buf {:?}..{:?}", self.end, self.end + n); + &mut self.buf[self.end..self.end + n] + } + + pub fn push(&mut self, n: usize) { + trace!(" ringbuf: push {:?}", n); + if n == 0 { + return; + } + + self.end = self.wrap(self.end + n); + self.empty = false; + } + + pub fn pop_buf(&mut self) -> &mut [u8] { + if self.empty { + trace!(" ringbuf: pop_buf empty"); + return &mut self.buf[..0]; + } + + let n = if self.end <= self.start { + self.buf.len() - self.start + } else { + self.end - self.start + }; + + trace!(" ringbuf: pop_buf {:?}..{:?}", self.start, self.start + n); + &mut self.buf[self.start..self.start + n] + } + + pub fn pop(&mut self, n: usize) { + trace!(" ringbuf: pop {:?}", n); + if n == 0 { + return; + } + + self.start = self.wrap(self.start + n); + self.empty = self.start == self.end; + } + + pub fn clear(&mut self) { + self.start = 0; + self.end = 0; + self.empty = true; + } + + fn wrap(&self, n: usize) -> usize { + assert!(n <= self.buf.len()); + if n == self.buf.len() { + 0 + } else { + n + } + } +}