diff --git a/embassy-extras/Cargo.toml b/embassy-extras/Cargo.toml index 5d07901a..8415e32e 100644 --- a/embassy-extras/Cargo.toml +++ b/embassy-extras/Cargo.toml @@ -17,4 +17,5 @@ embassy = { version = "0.1.0", path = "../embassy" } defmt = { version = "0.2.0", optional = true } log = { version = "0.4.11", optional = true } cortex-m = "0.7.1" +critical-section = "0.2.1" usb-device = "0.2.7" diff --git a/embassy-extras/src/peripheral.rs b/embassy-extras/src/peripheral.rs index 68972c54..e3e06d69 100644 --- a/embassy-extras/src/peripheral.rs +++ b/embassy-extras/src/peripheral.rs @@ -1,15 +1,38 @@ use core::cell::UnsafeCell; use core::marker::{PhantomData, PhantomPinned}; use core::pin::Pin; +use core::ptr; use embassy::interrupt::{Interrupt, InterruptExt}; -pub trait PeripheralState { +/// # Safety +/// When types implementing this trait are used with `Peripheral` or `PeripheralMutex`, +/// their lifetime must not end without first calling `Drop` on the `Peripheral` or `PeripheralMutex`. +pub unsafe trait PeripheralStateUnchecked { type Interrupt: Interrupt; fn on_interrupt(&mut self); } -pub struct PeripheralMutex { +// `PeripheralMutex` is safe because `Pin` guarantees that the memory it references will not be invalidated or reused +// without calling `Drop`. However, it provides no guarantees about references contained within the state still being valid, +// so this `'static` bound is necessary. +pub trait PeripheralState: 'static { + type Interrupt: Interrupt; + fn on_interrupt(&mut self); +} + +// SAFETY: `T` has to live for `'static` to implement `PeripheralState`, thus its lifetime cannot end. +unsafe impl PeripheralStateUnchecked for T +where + T: PeripheralState, +{ + type Interrupt = T::Interrupt; + fn on_interrupt(&mut self) { + self.on_interrupt() + } +} + +pub struct PeripheralMutex { state: UnsafeCell, irq_setup_done: bool, @@ -19,7 +42,7 @@ pub struct PeripheralMutex { _pinned: PhantomPinned, } -impl PeripheralMutex { +impl PeripheralMutex { pub fn new(state: S, irq: S::Interrupt) -> Self { Self { irq, @@ -39,11 +62,17 @@ impl PeripheralMutex { this.irq.disable(); this.irq.set_handler(|p| { - // Safety: it's OK to get a &mut to the state, since - // - We're in the IRQ, no one else can't preempt us - // - We can't have preempted a with() call because the irq is disabled during it. - let state = unsafe { &mut *(p as *mut S) }; - state.on_interrupt(); + critical_section::with(|_| { + if p.is_null() { + // The state was dropped, so we can't operate on it. + return; + } + // Safety: it's OK to get a &mut to the state, since + // - We're in a critical section, no one can preempt us (and call with()) + // - We can't have preempted a with() call because the irq is disabled during it. + let state = unsafe { &mut *(p as *mut S) }; + state.on_interrupt(); + }) }); this.irq .set_handler_context((&mut this.state) as *mut _ as *mut ()); @@ -67,9 +96,12 @@ impl PeripheralMutex { } } -impl Drop for PeripheralMutex { +impl Drop for PeripheralMutex { fn drop(&mut self) { self.irq.disable(); self.irq.remove_handler(); + // Set the context to null so that the interrupt will know we're dropped + // if we pre-empted it before it entered a critical section. + self.irq.set_handler_context(ptr::null_mut()); } } diff --git a/embassy-extras/src/peripheral_shared.rs b/embassy-extras/src/peripheral_shared.rs index c6211339..a9fca8ca 100644 --- a/embassy-extras/src/peripheral_shared.rs +++ b/embassy-extras/src/peripheral_shared.rs @@ -1,16 +1,27 @@ -use core::cell::UnsafeCell; use core::marker::{PhantomData, PhantomPinned}; use core::pin::Pin; +use core::ptr; use embassy::interrupt::{Interrupt, InterruptExt}; -pub trait PeripheralState { +/// # Safety +/// When types implementing this trait are used with `Peripheral` or `PeripheralMutex`, +/// their lifetime must not end without first calling `Drop` on the `Peripheral` or `PeripheralMutex`. +pub unsafe trait PeripheralStateUnchecked { type Interrupt: Interrupt; fn on_interrupt(&self); } -pub struct Peripheral { - state: UnsafeCell, +// `Peripheral` is safe because `Pin` guarantees that the memory it references will not be invalidated or reused +// without calling `Drop`. However, it provides no guarantees about references contained within the state still being valid, +// so this `'static` bound is necessary. +pub trait PeripheralState: 'static { + type Interrupt: Interrupt; + fn on_interrupt(&self); +} + +pub struct Peripheral { + state: S, irq_setup_done: bool, irq: S::Interrupt, @@ -19,13 +30,13 @@ pub struct Peripheral { _pinned: PhantomPinned, } -impl Peripheral { +impl Peripheral { pub fn new(irq: S::Interrupt, state: S) -> Self { Self { irq, irq_setup_done: false, - state: UnsafeCell::new(state), + state, _not_send: PhantomData, _pinned: PhantomPinned, } @@ -39,8 +50,16 @@ impl Peripheral { this.irq.disable(); this.irq.set_handler(|p| { - let state = unsafe { &*(p as *const S) }; - state.on_interrupt(); + // We need to be in a critical section so that no one can preempt us + // and drop the state after we check whether `p.is_null()`. + critical_section::with(|_| { + if p.is_null() { + // The state was dropped, so we can't operate on it. + return; + } + let state = unsafe { &*(p as *const S) }; + state.on_interrupt(); + }); }); this.irq .set_handler_context((&this.state) as *const _ as *mut ()); @@ -49,15 +68,17 @@ impl Peripheral { this.irq_setup_done = true; } - pub fn state(self: Pin<&mut Self>) -> &S { - let this = unsafe { self.get_unchecked_mut() }; - unsafe { &*this.state.get() } + pub fn state<'a>(self: Pin<&'a mut Self>) -> &'a S { + &self.into_ref().get_ref().state } } -impl Drop for Peripheral { +impl Drop for Peripheral { fn drop(&mut self) { self.irq.disable(); self.irq.remove_handler(); + // Set the context to null so that the interrupt will know we're dropped + // if we pre-empted it before it entered a critical section. + self.irq.set_handler_context(ptr::null_mut()); } } diff --git a/embassy-extras/src/usb/mod.rs b/embassy-extras/src/usb/mod.rs index 182cd87d..330eb922 100644 --- a/embassy-extras/src/usb/mod.rs +++ b/embassy-extras/src/usb/mod.rs @@ -9,7 +9,7 @@ use usb_device::device::UsbDevice; mod cdc_acm; pub mod usb_serial; -use crate::peripheral::{PeripheralMutex, PeripheralState}; +use crate::peripheral::{PeripheralMutex, PeripheralStateUnchecked}; use embassy::interrupt::Interrupt; use usb_serial::{ReadInterface, UsbSerial, WriteInterface}; @@ -55,10 +55,12 @@ where } } - pub fn start(self: Pin<&mut Self>) { - let this = unsafe { self.get_unchecked_mut() }; + /// # Safety + /// The `UsbDevice` passed to `Self::new` must not be dropped without calling `Drop` on this `Usb` first. + pub unsafe fn start(self: Pin<&mut Self>) { + let this = self.get_unchecked_mut(); let mut mutex = this.inner.borrow_mut(); - let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + let mutex = Pin::new_unchecked(&mut *mutex); // Use inner to register the irq mutex.register_interrupt(); @@ -125,7 +127,8 @@ where } } -impl<'bus, B, T, I> PeripheralState for State<'bus, B, T, I> +// SAFETY: The safety contract of `PeripheralStateUnchecked` is forwarded to `Usb::start`. +unsafe impl<'bus, B, T, I> PeripheralStateUnchecked for State<'bus, B, T, I> where B: UsbBus, T: ClassSet, diff --git a/embassy-nrf/src/buffered_uarte.rs b/embassy-nrf/src/buffered_uarte.rs index a5a37b98..1fa98a6b 100644 --- a/embassy-nrf/src/buffered_uarte.rs +++ b/embassy-nrf/src/buffered_uarte.rs @@ -7,7 +7,7 @@ use core::task::{Context, Poll}; use embassy::interrupt::InterruptExt; use embassy::io::{AsyncBufRead, AsyncWrite, Result}; use embassy::util::{Unborrow, WakerRegistration}; -use embassy_extras::peripheral::{PeripheralMutex, PeripheralState}; +use embassy_extras::peripheral::{PeripheralMutex, PeripheralStateUnchecked}; use embassy_extras::ring_buffer::RingBuffer; use embassy_extras::{low_power_wait_until, unborrow}; @@ -283,7 +283,8 @@ impl<'a, U: UarteInstance, T: TimerInstance> Drop for State<'a, U, T> { } } -impl<'a, U: UarteInstance, T: TimerInstance> PeripheralState for State<'a, U, T> { +// SAFETY: the safety contract of `PeripheralStateUnchecked` is forwarded to `BufferedUarte::new`. +unsafe impl<'a, U: UarteInstance, T: TimerInstance> PeripheralStateUnchecked for State<'a, U, T> { type Interrupt = U::Interrupt; fn on_interrupt(&mut self) { trace!("irq: start");