diff --git a/embassy-usb/Cargo.toml b/embassy-usb/Cargo.toml index 5a5a6d7a..af5986c1 100644 --- a/embassy-usb/Cargo.toml +++ b/embassy-usb/Cargo.toml @@ -10,3 +10,4 @@ defmt = { version = "0.3", optional = true } log = { version = "0.4.14", optional = true } cortex-m = "0.7.3" num-traits = { version = "0.2.14", default-features = false } +heapless = "0.7.10" \ No newline at end of file diff --git a/embassy-usb/src/builder.rs b/embassy-usb/src/builder.rs index bcb838ff..491acf4d 100644 --- a/embassy-usb/src/builder.rs +++ b/embassy-usb/src/builder.rs @@ -1,8 +1,11 @@ +use heapless::Vec; + use super::class::UsbClass; use super::descriptor::{BosWriter, DescriptorWriter}; use super::driver::{Driver, EndpointAllocError}; use super::types::*; use super::UsbDevice; +use super::MAX_CLASS_COUNT; #[derive(Debug, Copy, Clone)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -116,6 +119,7 @@ impl<'a> Config<'a> { /// Used to build new [`UsbDevice`]s. pub struct UsbDeviceBuilder<'d, D: Driver<'d>> { config: Config<'d>, + classes: Vec<&'d mut dyn UsbClass, MAX_CLASS_COUNT>, bus: D, next_interface_number: u8, @@ -165,6 +169,7 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { UsbDeviceBuilder { bus, config, + classes: Vec::new(), next_interface_number: 0, next_string_index: 4, @@ -175,7 +180,7 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { } /// Creates the [`UsbDevice`] instance with the configuration in this builder. - pub fn build(mut self, classes: &'d mut [&'d mut dyn UsbClass]) -> UsbDevice<'d, D> { + pub fn build(mut self) -> UsbDevice<'d, D> { self.config_descriptor.end_configuration(); self.bos_descriptor.end_bos(); @@ -185,10 +190,16 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { self.device_descriptor.into_buf(), self.config_descriptor.into_buf(), self.bos_descriptor.writer.into_buf(), - classes, + self.classes, ) } + pub fn add_class(&mut self, class: &'d mut dyn UsbClass) { + if self.classes.push(class).is_err() { + panic!("max class count reached") + } + } + /// Allocates a new interface number. pub fn alloc_interface(&mut self) -> InterfaceNumber { let number = self.next_interface_number; diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index ff3930af..2c00e881 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -14,6 +14,7 @@ pub mod types; mod util; use class::ControlInRequestStatus; +use heapless::Vec; use self::class::{RequestStatus, UsbClass}; use self::control::*; @@ -53,6 +54,8 @@ pub const CONFIGURATION_VALUE: u8 = 1; /// The default value for bAlternateSetting for all interfaces. pub const DEFAULT_ALTERNATE_SETTING: u8 = 0; +pub const MAX_CLASS_COUNT: usize = 4; + pub struct UsbDevice<'d, D: Driver<'d>> { bus: D::Bus, control: D::ControlPipe, @@ -67,7 +70,7 @@ pub struct UsbDevice<'d, D: Driver<'d>> { self_powered: bool, pending_address: u8, - classes: &'d mut [&'d mut dyn UsbClass], + classes: Vec<&'d mut dyn UsbClass, MAX_CLASS_COUNT>, } impl<'d, D: Driver<'d>> UsbDevice<'d, D> { @@ -77,7 +80,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { device_descriptor: &'d [u8], config_descriptor: &'d [u8], bos_descriptor: &'d [u8], - classes: &'d mut [&'d mut dyn UsbClass], + classes: Vec<&'d mut dyn UsbClass, MAX_CLASS_COUNT>, ) -> Self { let control = driver .alloc_control_pipe(config.max_packet_size_0 as u16) diff --git a/examples/nrf/src/bin/usb/cdc_acm.rs b/examples/nrf/src/bin/usb/cdc_acm.rs index 5e4abfea..92cc16eb 100644 --- a/examples/nrf/src/bin/usb/cdc_acm.rs +++ b/examples/nrf/src/bin/usb/cdc_acm.rs @@ -1,5 +1,8 @@ -use core::mem; +use core::cell::{Cell, UnsafeCell}; +use core::mem::{self, MaybeUninit}; +use core::sync::atomic::{AtomicBool, Ordering}; use defmt::info; +use embassy::blocking_mutex::CriticalSectionMutex; use embassy_usb::class::{ControlInRequestStatus, RequestStatus, UsbClass}; use embassy_usb::control::{self, Request}; use embassy_usb::driver::{Endpoint, EndpointIn, EndpointOut, ReadError, WriteError}; @@ -25,6 +28,18 @@ const REQ_SET_LINE_CODING: u8 = 0x20; const REQ_GET_LINE_CODING: u8 = 0x21; const REQ_SET_CONTROL_LINE_STATE: u8 = 0x22; +pub struct State { + control: MaybeUninit, +} + +impl State { + pub fn new() -> Self { + Self { + control: MaybeUninit::uninit(), + } + } +} + /// Packet level implementation of a CDC-ACM serial port. /// /// This class can be used directly and it has the least overhead due to directly reading and @@ -45,21 +60,32 @@ pub struct CdcAcmClass<'d, D: Driver<'d>> { pub data_if: InterfaceNumber, pub read_ep: D::EndpointOut, pub write_ep: D::EndpointIn, - pub control: CdcAcmControl, + control: &'d ControlShared, } -pub struct CdcAcmControl { - pub comm_if: InterfaceNumber, - pub line_coding: LineCoding, - pub dtr: bool, - pub rts: bool, +struct Control { + comm_if: InterfaceNumber, + shared: UnsafeCell, } -impl UsbClass for CdcAcmControl { +/// Shared data between Control and CdcAcmClass +struct ControlShared { + line_coding: CriticalSectionMutex>, + dtr: AtomicBool, + rts: AtomicBool, +} + +impl Control { + fn shared(&mut self) -> &ControlShared { + unsafe { &*(self.shared.get() as *const _) } + } +} +impl UsbClass for Control { fn reset(&mut self) { - self.line_coding = LineCoding::default(); - self.dtr = false; - self.rts = false; + let shared = self.shared(); + shared.line_coding.lock(|x| x.set(LineCoding::default())); + shared.dtr.store(false, Ordering::Relaxed); + shared.rts.store(false, Ordering::Relaxed); } fn control_out(&mut self, req: control::Request, data: &[u8]) -> RequestStatus { @@ -77,18 +103,25 @@ impl UsbClass for CdcAcmControl { RequestStatus::Accepted } REQ_SET_LINE_CODING if data.len() >= 7 => { - self.line_coding.data_rate = u32::from_le_bytes(data[0..4].try_into().unwrap()); - self.line_coding.stop_bits = data[4].into(); - self.line_coding.parity_type = data[5].into(); - self.line_coding.data_bits = data[6]; - info!("Set line coding to: {:?}", self.line_coding); + let coding = LineCoding { + data_rate: u32::from_le_bytes(data[0..4].try_into().unwrap()), + stop_bits: data[4].into(), + parity_type: data[5].into(), + data_bits: data[6], + }; + self.shared().line_coding.lock(|x| x.set(coding)); + info!("Set line coding to: {:?}", coding); RequestStatus::Accepted } REQ_SET_CONTROL_LINE_STATE => { - self.dtr = (req.value & 0x0001) != 0; - self.rts = (req.value & 0x0002) != 0; - info!("Set dtr {}, rts {}", self.dtr, self.rts); + let dtr = (req.value & 0x0001) != 0; + let rts = (req.value & 0x0002) != 0; + + let shared = self.shared(); + shared.dtr.store(dtr, Ordering::Relaxed); + shared.rts.store(rts, Ordering::Relaxed); + info!("Set dtr {}, rts {}", dtr, rts); RequestStatus::Accepted } @@ -112,11 +145,12 @@ impl UsbClass for CdcAcmControl { // REQ_GET_ENCAPSULATED_COMMAND is not really supported - it will be rejected below. REQ_GET_LINE_CODING if req.length == 7 => { info!("Sending line coding"); + let coding = self.shared().line_coding.lock(|x| x.get()); let mut data = [0; 7]; - data[0..4].copy_from_slice(&self.line_coding.data_rate.to_le_bytes()); - data[4] = self.line_coding.stop_bits as u8; - data[5] = self.line_coding.parity_type as u8; - data[6] = self.line_coding.data_bits; + data[0..4].copy_from_slice(&coding.data_rate.to_le_bytes()); + data[4] = coding.stop_bits as u8; + data[5] = coding.parity_type as u8; + data[6] = coding.data_bits; control.accept(&data) } _ => control.reject(), @@ -127,7 +161,11 @@ impl UsbClass for CdcAcmControl { impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { /// Creates a new CdcAcmClass with the provided UsbBus and max_packet_size in bytes. For /// full-speed devices, max_packet_size has to be one of 8, 16, 32 or 64. - pub fn new(builder: &mut UsbDeviceBuilder<'d, D>, max_packet_size: u16) -> Self { + pub fn new( + builder: &mut UsbDeviceBuilder<'d, D>, + state: &'d mut State, + max_packet_size: u16, + ) -> Self { let comm_if = builder.alloc_interface(); let comm_ep = builder.alloc_interrupt_endpoint_in(8, 255); let data_if = builder.alloc_interface(); @@ -207,22 +245,29 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { builder.config_descriptor.endpoint(write_ep.info()).unwrap(); builder.config_descriptor.endpoint(read_ep.info()).unwrap(); + let control = state.control.write(Control { + comm_if, + shared: UnsafeCell::new(ControlShared { + dtr: AtomicBool::new(false), + rts: AtomicBool::new(false), + line_coding: CriticalSectionMutex::new(Cell::new(LineCoding { + stop_bits: StopBits::One, + data_bits: 8, + parity_type: ParityType::None, + data_rate: 8_000, + })), + }), + }); + + let control_shared = unsafe { &*(control.shared.get() as *const _) }; + builder.add_class(control); + CdcAcmClass { comm_ep, data_if, read_ep, write_ep, - control: CdcAcmControl { - comm_if, - dtr: false, - rts: false, - line_coding: LineCoding { - stop_bits: StopBits::One, - data_bits: 8, - parity_type: ParityType::None, - data_rate: 8_000, - }, - }, + control: control_shared, } } @@ -234,18 +279,18 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { /// Gets the current line coding. The line coding contains information that's mainly relevant /// for USB to UART serial port emulators, and can be ignored if not relevant. - pub fn line_coding(&self) -> &LineCoding { - &self.control.line_coding + pub fn line_coding(&self) -> LineCoding { + self.control.line_coding.lock(|x| x.get()) } /// Gets the DTR (data terminal ready) state pub fn dtr(&self) -> bool { - self.control.dtr + self.control.dtr.load(Ordering::Relaxed) } /// Gets the RTS (request to send) state pub fn rts(&self) -> bool { - self.control.rts + self.control.rts.load(Ordering::Relaxed) } /// Writes a single packet into the IN endpoint. @@ -264,88 +309,6 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { } } -/* -impl UsbClass for CdcAcmClass<'_, B> { - fn get_configuration_descriptors(&self, builder.config_descriptor: &mut Descriptorbuilder.config_descriptor) -> Result<()> { - - Ok(()) - } - - fn reset(&mut self) { - self.line_coding = LineCoding::default(); - self.dtr = false; - self.rts = false; - } - - fn control_in(&mut self, xfer: ControlIn) { - let req = xfer.request(); - - if !(req.request_type == control::RequestType::Class - && req.recipient == control::Recipient::Interface - && req.index == u8::from(self.comm_if) as u16) - { - return; - } - - match req.request { - // REQ_GET_ENCAPSULATED_COMMAND is not really supported - it will be rejected below. - REQ_GET_LINE_CODING if req.length == 7 => { - xfer.accept(|data| { - data[0..4].copy_from_slice(&self.line_coding.data_rate.to_le_bytes()); - data[4] = self.line_coding.stop_bits as u8; - data[5] = self.line_coding.parity_type as u8; - data[6] = self.line_coding.data_bits; - - Ok(7) - }) - .ok(); - } - _ => { - xfer.reject().ok(); - } - } - } - - fn control_out(&mut self, xfer: ControlOut) { - let req = xfer.request(); - - if !(req.request_type == control::RequestType::Class - && req.recipient == control::Recipient::Interface - && req.index == u8::from(self.comm_if) as u16) - { - return; - } - - match req.request { - REQ_SEND_ENCAPSULATED_COMMAND => { - // We don't actually support encapsulated commands but pretend we do for standards - // compatibility. - xfer.accept().ok(); - } - REQ_SET_LINE_CODING if xfer.data().len() >= 7 => { - self.line_coding.data_rate = - u32::from_le_bytes(xfer.data()[0..4].try_into().unwrap()); - self.line_coding.stop_bits = xfer.data()[4].into(); - self.line_coding.parity_type = xfer.data()[5].into(); - self.line_coding.data_bits = xfer.data()[6]; - - xfer.accept().ok(); - } - REQ_SET_CONTROL_LINE_STATE => { - self.dtr = (req.value & 0x0001) != 0; - self.rts = (req.value & 0x0002) != 0; - - xfer.accept().ok(); - } - _ => { - xfer.reject().ok(); - } - }; - } -} - - */ - /// Number of stop bits for LineCoding #[derive(Copy, Clone, PartialEq, Eq, defmt::Format)] pub enum StopBits { @@ -393,7 +356,7 @@ impl From for ParityType { /// /// This is provided by the host for specifying the standard UART parameters such as baud rate. Can /// be ignored if you don't plan to interface with a physical UART. -#[derive(defmt::Format)] +#[derive(Clone, Copy, defmt::Format)] pub struct LineCoding { stop_bits: StopBits, data_bits: u8, diff --git a/examples/nrf/src/bin/usb/main.rs b/examples/nrf/src/bin/usb/main.rs index 73ac3a21..398dd07b 100644 --- a/examples/nrf/src/bin/usb/main.rs +++ b/examples/nrf/src/bin/usb/main.rs @@ -16,12 +16,11 @@ use embassy_nrf::interrupt; use embassy_nrf::pac; use embassy_nrf::usb::Driver; use embassy_nrf::Peripherals; -use embassy_usb::class::UsbClass; use embassy_usb::driver::{EndpointIn, EndpointOut}; use embassy_usb::{Config, UsbDeviceBuilder}; use futures::future::join3; -use crate::cdc_acm::CdcAcmClass; +use crate::cdc_acm::{CdcAcmClass, State}; #[embassy::main] async fn main(_spawner: Spawner, p: Peripherals) { @@ -48,6 +47,9 @@ async fn main(_spawner: Spawner, p: Peripherals) { let mut device_descriptor = [0; 256]; let mut config_descriptor = [0; 256]; let mut bos_descriptor = [0; 256]; + + let mut state = State::new(); + let mut builder = UsbDeviceBuilder::new( driver, config, @@ -57,11 +59,10 @@ async fn main(_spawner: Spawner, p: Peripherals) { ); // Create classes on the builder. - let mut class = CdcAcmClass::new(&mut builder, 64); + let mut class = CdcAcmClass::new(&mut builder, &mut state, 64); // Build the builder. - let mut classes: [&mut dyn UsbClass; 1] = [&mut class.control]; - let mut usb = builder.build(&mut classes); + let mut usb = builder.build(); // Run the USB device. let fut1 = usb.run();