From 15cc97d794d8b4baa6c1a8f1ed6c64468701c9e7 Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Mon, 28 Mar 2022 03:19:07 +0200 Subject: [PATCH] usb: associate ControlHandlers with interfaces, automatically route requests. --- embassy-usb/src/builder.rs | 32 ++++++---- embassy-usb/src/class.rs | 27 +++------ embassy-usb/src/lib.rs | 91 +++++++++++++++-------------- examples/nrf/src/bin/usb/cdc_acm.rs | 53 ++++++----------- 4 files changed, 94 insertions(+), 109 deletions(-) diff --git a/embassy-usb/src/builder.rs b/embassy-usb/src/builder.rs index 491acf4d..98b55adf 100644 --- a/embassy-usb/src/builder.rs +++ b/embassy-usb/src/builder.rs @@ -1,11 +1,11 @@ use heapless::Vec; -use super::class::UsbClass; +use super::class::ControlHandler; use super::descriptor::{BosWriter, DescriptorWriter}; use super::driver::{Driver, EndpointAllocError}; use super::types::*; use super::UsbDevice; -use super::MAX_CLASS_COUNT; +use super::MAX_INTERFACE_COUNT; #[derive(Debug, Copy, Clone)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -119,7 +119,7 @@ impl<'a> Config<'a> { /// Used to build new [`UsbDevice`]s. pub struct UsbDeviceBuilder<'d, D: Driver<'d>> { config: Config<'d>, - classes: Vec<&'d mut dyn UsbClass, MAX_CLASS_COUNT>, + interfaces: Vec<(u8, &'d mut dyn ControlHandler), MAX_INTERFACE_COUNT>, bus: D, next_interface_number: u8, @@ -169,7 +169,7 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { UsbDeviceBuilder { bus, config, - classes: Vec::new(), + interfaces: Vec::new(), next_interface_number: 0, next_string_index: 4, @@ -190,16 +190,10 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { self.device_descriptor.into_buf(), self.config_descriptor.into_buf(), self.bos_descriptor.writer.into_buf(), - self.classes, + self.interfaces, ) } - pub fn add_class(&mut self, class: &'d mut dyn UsbClass) { - if self.classes.push(class).is_err() { - panic!("max class count reached") - } - } - /// Allocates a new interface number. pub fn alloc_interface(&mut self) -> InterfaceNumber { let number = self.next_interface_number; @@ -208,6 +202,22 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { InterfaceNumber::new(number) } + /// Allocates a new interface number, with a handler that will be called + /// for all the control requests directed to it. + pub fn alloc_interface_with_handler( + &mut self, + handler: &'d mut dyn ControlHandler, + ) -> InterfaceNumber { + let number = self.next_interface_number; + self.next_interface_number += 1; + + if self.interfaces.push((number, handler)).is_err() { + panic!("max class count reached") + } + + InterfaceNumber::new(number) + } + /// Allocates a new string index. pub fn alloc_string(&mut self) -> StringIndex { let index = self.next_string_index; diff --git a/embassy-usb/src/class.rs b/embassy-usb/src/class.rs index a0141e31..754e3a20 100644 --- a/embassy-usb/src/class.rs +++ b/embassy-usb/src/class.rs @@ -3,22 +3,15 @@ use crate::control::Request; #[derive(Copy, Clone, Eq, PartialEq, Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum RequestStatus { - Unhandled, Accepted, Rejected, } -impl Default for RequestStatus { - fn default() -> Self { - RequestStatus::Unhandled - } -} - /// A trait for implementing USB classes. /// /// All methods are optional callbacks that will be called by /// [`UsbDevice::run()`](crate::UsbDevice::run) -pub trait UsbClass { +pub trait ControlHandler { /// Called after a USB reset after the bus reset sequence is complete. fn reset(&mut self) {} @@ -35,7 +28,9 @@ pub trait UsbClass { /// /// * `req` - The request from the SETUP packet. /// * `data` - The data from the request. - fn control_out(&mut self, req: Request, data: &[u8]) -> RequestStatus; + fn control_out(&mut self, req: Request, data: &[u8]) -> RequestStatus { + RequestStatus::Rejected + } /// Called when a control request is received with direction DeviceToHost. /// @@ -56,7 +51,9 @@ pub trait UsbClass { &mut self, req: Request, control: ControlIn<'a>, - ) -> ControlInRequestStatus<'a>; + ) -> ControlInRequestStatus<'a> { + control.reject() + } } /// Handle for a control IN transfer. When implementing a class, use the methods of this object to @@ -84,14 +81,6 @@ impl<'a> ControlIn<'a> { ControlIn { buf } } - /// Ignores the request and leaves it unhandled. - pub fn ignore(self) -> ControlInRequestStatus<'a> { - ControlInRequestStatus { - status: RequestStatus::Unhandled, - data: &[], - } - } - /// Accepts the transfer with the supplied buffer. pub fn accept(self, data: &[u8]) -> ControlInRequestStatus<'a> { assert!(data.len() < self.buf.len()); @@ -108,7 +97,7 @@ impl<'a> ControlIn<'a> { /// Rejects the transfer by stalling the pipe. pub fn reject(self) -> ControlInRequestStatus<'a> { ControlInRequestStatus { - status: RequestStatus::Unhandled, + status: RequestStatus::Rejected, data: &[], } } diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index 7d631d17..9076123a 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -13,10 +13,9 @@ pub mod driver; pub mod types; mod util; -use class::ControlInRequestStatus; use heapless::Vec; -use self::class::{RequestStatus, UsbClass}; +use self::class::{ControlHandler, RequestStatus}; use self::control::*; use self::descriptor::*; use self::driver::*; @@ -54,7 +53,7 @@ pub const CONFIGURATION_VALUE: u8 = 1; /// The default value for bAlternateSetting for all interfaces. pub const DEFAULT_ALTERNATE_SETTING: u8 = 0; -pub const MAX_CLASS_COUNT: usize = 4; +pub const MAX_INTERFACE_COUNT: usize = 4; pub struct UsbDevice<'d, D: Driver<'d>> { bus: D::Bus, @@ -70,7 +69,7 @@ pub struct UsbDevice<'d, D: Driver<'d>> { self_powered: bool, pending_address: u8, - classes: Vec<&'d mut dyn UsbClass, MAX_CLASS_COUNT>, + interfaces: Vec<(u8, &'d mut dyn ControlHandler), MAX_INTERFACE_COUNT>, } impl<'d, D: Driver<'d>> UsbDevice<'d, D> { @@ -80,7 +79,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { device_descriptor: &'d [u8], config_descriptor: &'d [u8], bos_descriptor: &'d [u8], - classes: Vec<&'d mut dyn UsbClass, MAX_CLASS_COUNT>, + interfaces: Vec<(u8, &'d mut dyn ControlHandler), MAX_INTERFACE_COUNT>, ) -> Self { let control = driver .alloc_control_pipe(config.max_packet_size_0 as u16) @@ -101,7 +100,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { remote_wakeup_enabled: false, self_powered: false, pending_address: 0, - classes, + interfaces, } } @@ -118,8 +117,8 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { self.remote_wakeup_enabled = false; self.pending_address = 0; - for c in self.classes.iter_mut() { - c.reset(); + for (_, h) in self.interfaces.iter_mut() { + h.reset(); } } Event::Resume => {} @@ -153,24 +152,6 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } async fn handle_control_out(&mut self, req: Request) { - { - let mut buf = [0; 128]; - let data = if req.length > 0 { - let size = self.control.data_out(&mut buf).await.unwrap(); - &buf[0..size] - } else { - &[] - }; - - for c in self.classes.iter_mut() { - match c.control_out(req, data) { - RequestStatus::Accepted => return self.control.accept(), - RequestStatus::Rejected => return self.control.reject(), - RequestStatus::Unhandled => (), - } - } - } - const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16; const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16; const DEFAULT_ALTERNATE_SETTING_U16: u16 = DEFAULT_ALTERNATE_SETTING as u16; @@ -224,29 +205,33 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } _ => self.control.reject(), }, + (RequestType::Class, Recipient::Interface) => { + let mut buf = [0; 128]; + let data = if req.length > 0 { + let size = self.control.data_out(&mut buf).await.unwrap(); + &buf[0..size] + } else { + &[] + }; + + let handler = self + .interfaces + .iter_mut() + .find(|(i, _)| req.index == *i as _) + .map(|(_, h)| h); + match handler { + Some(handler) => match handler.control_out(req, data) { + RequestStatus::Accepted => return self.control.accept(), + RequestStatus::Rejected => return self.control.reject(), + }, + None => self.control.reject(), + } + } _ => self.control.reject(), } } async fn handle_control_in(&mut self, req: Request) { - let mut buf = [0; 128]; - for c in self.classes.iter_mut() { - match c.control_in(req, class::ControlIn::new(&mut buf)) { - ControlInRequestStatus { - status: RequestStatus::Accepted, - data, - } => return self.control.accept_in(data).await, - ControlInRequestStatus { - status: RequestStatus::Rejected, - .. - } => return self.control.reject(), - ControlInRequestStatus { - status: RequestStatus::Unhandled, - .. - } => (), - } - } - match (req.request_type, req.recipient) { (RequestType::Standard, Recipient::Device) => match req.request { Request::GET_STATUS => { @@ -294,6 +279,24 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } _ => self.control.reject(), }, + (RequestType::Class, Recipient::Interface) => { + let mut buf = [0; 128]; + let handler = self + .interfaces + .iter_mut() + .find(|(i, _)| req.index == *i as _) + .map(|(_, h)| h); + match handler { + Some(handler) => { + let resp = handler.control_in(req, class::ControlIn::new(&mut buf)); + match resp.status { + RequestStatus::Accepted => self.control.accept_in(resp.data).await, + RequestStatus::Rejected => self.control.reject(), + } + } + None => self.control.reject(), + } + } _ => self.control.reject(), } } diff --git a/examples/nrf/src/bin/usb/cdc_acm.rs b/examples/nrf/src/bin/usb/cdc_acm.rs index 92cc16eb..f4d42979 100644 --- a/examples/nrf/src/bin/usb/cdc_acm.rs +++ b/examples/nrf/src/bin/usb/cdc_acm.rs @@ -3,7 +3,7 @@ use core::mem::{self, MaybeUninit}; use core::sync::atomic::{AtomicBool, Ordering}; use defmt::info; use embassy::blocking_mutex::CriticalSectionMutex; -use embassy_usb::class::{ControlInRequestStatus, RequestStatus, UsbClass}; +use embassy_usb::class::{ControlHandler, ControlInRequestStatus, RequestStatus}; use embassy_usb::control::{self, Request}; use embassy_usb::driver::{Endpoint, EndpointIn, EndpointOut, ReadError, WriteError}; use embassy_usb::{driver::Driver, types::*, UsbDeviceBuilder}; @@ -64,7 +64,6 @@ pub struct CdcAcmClass<'d, D: Driver<'d>> { } struct Control { - comm_if: InterfaceNumber, shared: UnsafeCell, } @@ -80,7 +79,7 @@ impl Control { unsafe { &*(self.shared.get() as *const _) } } } -impl UsbClass for Control { +impl ControlHandler for Control { fn reset(&mut self) { let shared = self.shared(); shared.line_coding.lock(|x| x.set(LineCoding::default())); @@ -89,13 +88,6 @@ impl UsbClass for Control { } fn control_out(&mut self, req: control::Request, data: &[u8]) -> RequestStatus { - if !(req.request_type == control::RequestType::Class - && req.recipient == control::Recipient::Interface - && req.index == u8::from(self.comm_if) as u16) - { - return RequestStatus::Unhandled; - } - match req.request { REQ_SEND_ENCAPSULATED_COMMAND => { // We don't actually support encapsulated commands but pretend we do for standards @@ -134,13 +126,6 @@ impl UsbClass for Control { req: Request, control: embassy_usb::class::ControlIn<'a>, ) -> ControlInRequestStatus<'a> { - if !(req.request_type == control::RequestType::Class - && req.recipient == control::Recipient::Interface - && req.index == u8::from(self.comm_if) as u16) - { - return control.ignore(); - } - match req.request { // REQ_GET_ENCAPSULATED_COMMAND is not really supported - it will be rejected below. REQ_GET_LINE_CODING if req.length == 7 => { @@ -166,7 +151,22 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { state: &'d mut State, max_packet_size: u16, ) -> Self { - let comm_if = builder.alloc_interface(); + let control = state.control.write(Control { + shared: UnsafeCell::new(ControlShared { + dtr: AtomicBool::new(false), + rts: AtomicBool::new(false), + line_coding: CriticalSectionMutex::new(Cell::new(LineCoding { + stop_bits: StopBits::One, + data_bits: 8, + parity_type: ParityType::None, + data_rate: 8_000, + })), + }), + }); + + let control_shared = unsafe { &*(control.shared.get() as *const _) }; + + let comm_if = builder.alloc_interface_with_handler(control); let comm_ep = builder.alloc_interrupt_endpoint_in(8, 255); let data_if = builder.alloc_interface(); let read_ep = builder.alloc_bulk_endpoint_out(max_packet_size); @@ -245,23 +245,6 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { builder.config_descriptor.endpoint(write_ep.info()).unwrap(); builder.config_descriptor.endpoint(read_ep.info()).unwrap(); - let control = state.control.write(Control { - comm_if, - shared: UnsafeCell::new(ControlShared { - dtr: AtomicBool::new(false), - rts: AtomicBool::new(false), - line_coding: CriticalSectionMutex::new(Cell::new(LineCoding { - stop_bits: StopBits::One, - data_bits: 8, - parity_type: ParityType::None, - data_rate: 8_000, - })), - }), - }); - - let control_shared = unsafe { &*(control.shared.get() as *const _) }; - builder.add_class(control); - CdcAcmClass { comm_ep, data_if,