From 7b91597e9c96410be72918860eee24ef1de6c8a5 Mon Sep 17 00:00:00 2001 From: Ali Somay Date: Tue, 15 Aug 2023 19:53:08 +0200 Subject: [PATCH] Replace case specific implementation with generic descriptor buffer overwriting functionality --- embassy-usb/src/descriptor.rs | 224 ++++++++-------------------------- 1 file changed, 48 insertions(+), 176 deletions(-) diff --git a/embassy-usb/src/descriptor.rs b/embassy-usb/src/descriptor.rs index 3c62ee4e..be9f825a 100644 --- a/embassy-usb/src/descriptor.rs +++ b/embassy-usb/src/descriptor.rs @@ -36,26 +36,12 @@ pub mod capability_type { pub const PLATFORM: u8 = 5; } -/// A data structure to hold the initial descriptor position of a compound descriptor set. -/// -/// It is meant to be used in the [`DescriptorWriter`] -pub(crate) struct CompoundDescriptorSetTracker { - initial_descriptor_pos: usize, -} - -impl CompoundDescriptorSetTracker { - pub(crate) fn new(initial_descriptor_pos: usize) -> Self { - Self { initial_descriptor_pos } - } -} - /// A writer for USB descriptors. pub(crate) struct DescriptorWriter<'a> { pub buf: &'a mut [u8], position: usize, num_interfaces_mark: Option, num_endpoints_mark: Option, - tracker: Option, } impl<'a> DescriptorWriter<'a> { @@ -65,32 +51,10 @@ impl<'a> DescriptorWriter<'a> { position: 0, num_interfaces_mark: None, num_endpoints_mark: None, - tracker: None, } } - /// Starts tracking the total length of a compound descriptor set. - pub fn start_tracking_total_length_of_compound_descriptor_set(&mut self, initial_descriptor_pos: usize) { - self.tracker = Some(CompoundDescriptorSetTracker::new(initial_descriptor_pos)); - } - - /// Ends tracking the total length of a compound descriptor set and updates the initial descriptor of the set. - pub fn end_tracking_total_length_of_compound_descriptor_set_and_update_the_initial_descriptor( - &mut self, - offset: usize, - ) { - if let Some(tracker) = self.tracker.as_mut() { - let total_length = u16::try_from(self.position - tracker.initial_descriptor_pos) - .expect("\"Total Length\" fields in class-specific descriptors are always 2 bytes long."); - let total_length_bytes = total_length.to_le_bytes(); - let total_length_offset = tracker.initial_descriptor_pos + offset; - let total_length_length = tracker.initial_descriptor_pos + offset + total_length_bytes.len(); - // Write in little endian - self.buf[total_length_offset..total_length_length].copy_from_slice(&total_length_bytes) - } - self.tracker = None; - } - + /// Consumes this writer and returns the descriptor buffer. pub fn into_buf(self) -> &'a mut [u8] { &mut self.buf[..self.position] } @@ -100,8 +64,23 @@ impl<'a> DescriptorWriter<'a> { self.position } + /// Overwrites a part of the descriptor buffer starting from the provided position. + /// + /// # Panics + /// + /// Panics if writing the provided data results with exceeding the buffer size. + pub fn overwrite(&mut self, position: usize, data: &[u8]) { + let end = position + data.len(); + if end > self.buf.len() { + panic!("Descriptor buffer full"); + } + self.buf[position..end].copy_from_slice(data); + } + /// Writes an arbitrary (usually class-specific) descriptor. - pub fn write(&mut self, descriptor_type: u8, descriptor: &[u8]) { + /// + /// Returns the byte length of the descriptor which has been written. + pub fn write(&mut self, descriptor_type: u8, descriptor: &[u8]) -> usize { let length = descriptor.len(); if (self.position + 2 + length) > self.buf.len() || (length + 2) > 255 { @@ -116,6 +95,9 @@ impl<'a> DescriptorWriter<'a> { self.buf[start..start + length].copy_from_slice(descriptor); self.position = start + length; + + // Total length of the written descriptor + length + 2 } pub(crate) fn device(&mut self, config: &Config) { @@ -139,7 +121,7 @@ impl<'a> DescriptorWriter<'a> { config.serial_number.map_or(0, |_| 3), // iSerialNumber 1, // bNumConfigurations ], - ) + ); } pub(crate) fn configuration(&mut self, config: &Config) { @@ -157,7 +139,7 @@ impl<'a> DescriptorWriter<'a> { | if config.supports_remote_wakeup { 0x20 } else { 0x00 }, // bmAttributes (config.max_power / 2) as u8, // bMaxPower ], - ) + ); } #[allow(unused)] @@ -318,7 +300,7 @@ pub struct BosWriter<'a> { impl<'a> BosWriter<'a> { pub(crate) fn new(writer: DescriptorWriter<'a>) -> Self { Self { - writer: writer, + writer, num_caps_mark: None, } } @@ -376,146 +358,36 @@ mod tests { use super::*; const CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE: u8 = 0x24; - const CLASS_SPECIFIC_ENDPOINT_DESCRIPTOR_TYPE: u8 = 0x25; + #[test] - fn test_can_track_total_length() { - let mut writer_buf = [0u8; 256]; + fn test_overwriting_descriptors() { + let mut writer_buf = [0u8; 12]; + let mut writer = DescriptorWriter::new(&mut writer_buf); + + // Track the length. + let mut written_len = 0; + + // Write some imaginary descriptors to fill the buffer. + written_len += writer.write(CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, &[0x0, 0x1, 0x2, 0x3]); + written_len += writer.write(CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, &[0x4, 0x5, 0x6, 0x7]); + + writer.overwrite(3, &[0xAA_u8, 0xBB_u8, 0xCC_u8][..]); + assert_eq!(written_len, 12); + assert_eq!(&writer.buf[3..6], &[0xAA_u8, 0xBB_u8, 0xCC_u8][..]); + + writer.overwrite(4, &[0xFF_u8, 0xFF_u8][..]); + assert_eq!(&writer.buf[2..6], &[0x00_u8, 0xAA_u8, 0xFF_u8, 0xFF_u8][..]); + } + + #[test] + #[should_panic] + fn test_exceeding_descriptor_buffer_size_when_overwriting_descriptors_panics() { + let mut writer_buf = [0u8; 12]; let mut writer = DescriptorWriter::new(&mut writer_buf); // Write some imaginary descriptors to fill the buffer. writer.write(CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, &[0x0, 0x1, 0x2, 0x3]); writer.write(CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, &[0x4, 0x5, 0x6, 0x7]); - - // Start a compound descriptor set. Real example. - let position_of_the_initial_descriptor_in_the_compound_set = writer.position(); - - writer.start_tracking_total_length_of_compound_descriptor_set( - position_of_the_initial_descriptor_in_the_compound_set, - ); - - // 7 bytes - writer.write( - CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, - &[ - 0x01, // bDescriptorSubtype HEADER subtype. - 0x00, // bcdADC Revision of class specification - 1.0 - 0x01, // bcdADC - // - // We can write anything to the total length field here, it will be overwritten. - // - 0x00, // wTotalLength Total length of the class specific descriptor set. - 0x00, // wTotalLength - ], - ); - - // From here on the rest of the descriptors are part of the compound set. - // There can be many combinations. - - // Here is one example. - - // 6 bytes - writer.write( - CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, - &[ - 0x02, // bDescriptorSubtype HEADER subtype. - 0x01, // bJackType - 0x01, // bJackID - 0x00, // iJack Unused - ], - ); - - // 9 bytes - writer.write( - CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, - &[ - 0x03, // bDescriptorSubtype HEADER subtype. - 0x01, // bJackType - 0x02, // bJackID - 0x01, // bNrInputPins Number of Input Pins of this Jack. - 0x02, // BaSourceID(1) ID of the Entity to which this Pin is connected. - 0x01, // BaSourcePin(1) Output Pin number of the Entity to which this Input Pin is connected. - 0x00, // iJack Unused - ], - ); - - // 5 bytes - writer.write( - CLASS_SPECIFIC_ENDPOINT_DESCRIPTOR_TYPE, - &[ - 0x01, // bDescriptorSubtype - 0x01, // bNumEmbMIDIJack Number of embedded MIDI IN Jacks. - 0x01, // BaAssocJackID(1) ID of the Embedded MIDI IN Jack. - ], - ); - - // 5 bytes - writer.write( - CLASS_SPECIFIC_ENDPOINT_DESCRIPTOR_TYPE, - &[ - 0x01, // bDescriptorSubtype - 0x01, // bNumEmbMIDIJack Number of embedded MIDI OUT Jacks. - 0x02, // BaAssocJackID(1) ID of the Embedded MIDI OUT Jack. - ], - ); - - // 7 + 6 + 9 + 5 + 5 = 32 bytes in total. - - // Here we end the compound set. - // We need to give the offset of the total length bytes in the initial descriptor so they can be updated. - // 2 bytes of header written by our writer + 3 bytes will be our offset. - writer.end_tracking_total_length_of_compound_descriptor_set_and_update_the_initial_descriptor(5); - - let position_of_the_buffer_when_we_finished_the_compound_set = writer.position(); - - let total_length_bytes_le = &writer.buf[(position_of_the_initial_descriptor_in_the_compound_set + 5) - ..(position_of_the_initial_descriptor_in_the_compound_set + 5 + 2)]; - - let total_length_we_have_written = - u16::from_le_bytes([total_length_bytes_le[0], total_length_bytes_le[1]]) as usize; - - let actual_total_length = *&writer.buf[position_of_the_initial_descriptor_in_the_compound_set - ..position_of_the_buffer_when_we_finished_the_compound_set] - .len(); - - assert_eq!(total_length_we_have_written, 32); - assert_eq!(total_length_we_have_written, actual_total_length); - - // Now let's try writing one more compound set to see if we reset the tracker correctly. - - let position_of_the_initial_descriptor_in_the_compound_set = writer.position(); - - writer.start_tracking_total_length_of_compound_descriptor_set( - position_of_the_initial_descriptor_in_the_compound_set, - ); - - writer.write( - CLASS_SPECIFIC_INTERFACE_DESCRIPTOR_TYPE, - &[ - 0x01, // bDescriptorSubtype HEADER subtype. - 0x00, // bcdADC Revision of class specification - 1.0 - 0x01, // bcdADC - // - // We can write anything to the total length field here, it will be overwritten. - // - 0x00, // wTotalLength Total length of the class specific descriptor set. - 0x00, // wTotalLength - ], - ); - - writer.end_tracking_total_length_of_compound_descriptor_set_and_update_the_initial_descriptor(5); - - let position_of_the_buffer_when_we_finished_the_compound_set = writer.position(); - - let total_length_bytes_le = &writer.buf[(position_of_the_initial_descriptor_in_the_compound_set + 5) - ..(position_of_the_initial_descriptor_in_the_compound_set + 5 + 2)]; - - let total_length_we_have_written = - u16::from_le_bytes([total_length_bytes_le[0], total_length_bytes_le[1]]) as usize; - - let actual_total_length = *&writer.buf[position_of_the_initial_descriptor_in_the_compound_set - ..position_of_the_buffer_when_we_finished_the_compound_set] - .len(); - assert_eq!(total_length_we_have_written, 7); - assert_eq!(total_length_we_have_written, actual_total_length); + writer.overwrite(10, &[0xAA_u8, 0xBB_u8, 0xCC_u8][..]); } }