From 67319480568a4548f2c784282c8844cd08cda7aa Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Wed, 6 Apr 2022 01:23:42 +0200 Subject: [PATCH] Add async Mutex. --- embassy/src/lib.rs | 4 +- embassy/src/mutex.rs | 167 ++++++++++++++++++++++++++++++++++ examples/nrf/src/bin/mutex.rs | 44 +++++++++ 3 files changed, 213 insertions(+), 2 deletions(-) create mode 100644 embassy/src/mutex.rs create mode 100644 examples/nrf/src/bin/mutex.rs diff --git a/embassy/src/lib.rs b/embassy/src/lib.rs index 6b24b598..ec697b40 100644 --- a/embassy/src/lib.rs +++ b/embassy/src/lib.rs @@ -10,15 +10,15 @@ pub(crate) mod fmt; pub mod blocking_mutex; pub mod channel; -pub mod waitqueue; - pub mod executor; #[cfg(cortex_m)] pub mod interrupt; pub mod io; +pub mod mutex; #[cfg(feature = "time")] pub mod time; pub mod util; +pub mod waitqueue; #[cfg(feature = "nightly")] pub use embassy_macros::{main, task}; diff --git a/embassy/src/mutex.rs b/embassy/src/mutex.rs new file mode 100644 index 00000000..27353bd4 --- /dev/null +++ b/embassy/src/mutex.rs @@ -0,0 +1,167 @@ +/// Async mutex. +/// +/// The mutex is generic over a blocking [`RawMutex`](crate::blocking_mutex::raw::RawMutex). +/// The raw mutex is used to guard access to the internal "is locked" flag. It +/// is held for very short periods only, while locking and unlocking. It is *not* held +/// for the entire time the async Mutex is locked. +use core::cell::{RefCell, UnsafeCell}; +use core::ops::{Deref, DerefMut}; +use core::task::Poll; +use futures::future::poll_fn; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex as BlockingMutex; +use crate::waitqueue::WakerRegistration; + +/// Error returned by [`Mutex::try_lock`] +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct TryLockError; + +struct State { + locked: bool, + waker: WakerRegistration, +} + +pub struct Mutex +where + M: RawMutex, + T: ?Sized, +{ + state: BlockingMutex>, + inner: UnsafeCell, +} + +unsafe impl Send for Mutex {} +unsafe impl Sync for Mutex {} + +/// Async mutex. +impl Mutex +where + M: RawMutex, +{ + /// Create a new mutex with the given value. + #[cfg(feature = "nightly")] + pub const fn new(value: T) -> Self { + Self { + inner: UnsafeCell::new(value), + state: BlockingMutex::new(RefCell::new(State { + locked: false, + waker: WakerRegistration::new(), + })), + } + } + + /// Create a new mutex with the given value. + #[cfg(not(feature = "nightly"))] + pub fn new(value: T) -> Self { + Self { + inner: UnsafeCell::new(value), + state: BlockingMutex::new(RefCell::new(State { + locked: false, + waker: WakerRegistration::new(), + })), + } + } +} + +impl Mutex +where + M: RawMutex, + T: ?Sized, +{ + /// Lock the mutex. + /// + /// This will wait for the mutex to be unlocked if it's already locked. + pub async fn lock(&self) -> MutexGuard<'_, M, T> { + poll_fn(|cx| { + let ready = self.state.lock(|s| { + let mut s = s.borrow_mut(); + if s.locked { + s.waker.register(cx.waker()); + false + } else { + s.locked = true; + true + } + }); + + if ready { + Poll::Ready(MutexGuard { mutex: self }) + } else { + Poll::Pending + } + }) + .await + } + + /// Attempt to immediately lock the mutex. + /// + /// If the mutex is already locked, this will return an error instead of waiting. + pub fn try_lock(&self) -> Result, TryLockError> { + self.state.lock(|s| { + let mut s = s.borrow_mut(); + if s.locked { + Err(TryLockError) + } else { + s.locked = true; + Ok(()) + } + })?; + + Ok(MutexGuard { mutex: self }) + } +} + +/// Async mutex guard. +/// +/// Owning an instance of this type indicates having +/// successfully locked the mutex, and grants access to the contents. +/// +/// Dropping it unlocks the mutex. +pub struct MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + mutex: &'a Mutex, +} + +impl<'a, M, T> Drop for MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn drop(&mut self) { + self.mutex.state.lock(|s| { + let mut s = s.borrow_mut(); + s.locked = false; + s.waker.wake(); + }) + } +} + +impl<'a, M, T> Deref for MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + type Target = T; + fn deref(&self) -> &Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &*(self.mutex.inner.get() as *const T) } + } +} + +impl<'a, M, T> DerefMut for MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &mut *(self.mutex.inner.get()) } + } +} diff --git a/examples/nrf/src/bin/mutex.rs b/examples/nrf/src/bin/mutex.rs new file mode 100644 index 00000000..db1b72f6 --- /dev/null +++ b/examples/nrf/src/bin/mutex.rs @@ -0,0 +1,44 @@ +#![no_std] +#![no_main] +#![feature(type_alias_impl_trait)] + +use defmt::{info, unwrap}; +use embassy::blocking_mutex::raw::ThreadModeRawMutex; +use embassy::executor::Spawner; +use embassy::mutex::Mutex; +use embassy::time::{Duration, Timer}; +use embassy_nrf::Peripherals; + +use defmt_rtt as _; // global logger +use panic_probe as _; + +static MUTEX: Mutex = Mutex::new(0); + +#[embassy::task] +async fn my_task() { + loop { + { + let mut m = MUTEX.lock().await; + info!("start long operation"); + *m += 1000; + + // Hold the mutex for a long time. + Timer::after(Duration::from_secs(1)).await; + info!("end long operation: count = {}", *m); + } + + Timer::after(Duration::from_secs(1)).await; + } +} + +#[embassy::main] +async fn main(spawner: Spawner, _p: Peripherals) { + unwrap!(spawner.spawn(my_task())); + + loop { + Timer::after(Duration::from_millis(300)).await; + let mut m = MUTEX.lock().await; + *m += 1; + info!("short operation: count = {}", *m); + } +}