STD driver needs a reentrant mutex; logic fixed to be reentrancy-safe

This commit is contained in:
ivmarkov 2023-01-26 20:14:53 +00:00
parent ffa75e1e39
commit 34b67fe137

View File

@ -1,10 +1,12 @@
use std::cell::UnsafeCell; use std::cell::{RefCell, UnsafeCell};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::{Condvar, Mutex, Once}; use std::sync::{Condvar, Mutex, Once};
use std::time::{Duration as StdDuration, Instant as StdInstant}; use std::time::{Duration as StdDuration, Instant as StdInstant};
use std::{mem, ptr, thread}; use std::{mem, ptr, thread};
use atomic_polyfill::{AtomicU8, Ordering}; use atomic_polyfill::{AtomicU8, Ordering};
use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
use embassy_sync::blocking_mutex::Mutex as EmbassyMutex;
use crate::driver::{AlarmHandle, Driver}; use crate::driver::{AlarmHandle, Driver};
@ -35,7 +37,10 @@ struct TimeDriver {
alarm_count: AtomicU8, alarm_count: AtomicU8,
once: Once, once: Once,
alarms: UninitCell<Mutex<[AlarmState; ALARM_COUNT]>>, // The STD Driver implementation requires the alarms' mutex to be reentrant, which the STD Mutex isn't
// Fortunately, mutexes based on the `critical-section` crate are reentrant, because the critical sections
// themselves are reentrant
alarms: UninitCell<EmbassyMutex<CriticalSectionRawMutex, RefCell<[AlarmState; ALARM_COUNT]>>>,
zero_instant: UninitCell<StdInstant>, zero_instant: UninitCell<StdInstant>,
signaler: UninitCell<Signaler>, signaler: UninitCell<Signaler>,
} }
@ -53,7 +58,8 @@ crate::time_driver_impl!(static DRIVER: TimeDriver = TimeDriver {
impl TimeDriver { impl TimeDriver {
fn init(&self) { fn init(&self) {
self.once.call_once(|| unsafe { self.once.call_once(|| unsafe {
self.alarms.write(Mutex::new([ALARM_NEW; ALARM_COUNT])); self.alarms
.write(EmbassyMutex::new(RefCell::new([ALARM_NEW; ALARM_COUNT])));
self.zero_instant.write(StdInstant::now()); self.zero_instant.write(StdInstant::now());
self.signaler.write(Signaler::new()); self.signaler.write(Signaler::new());
@ -66,26 +72,38 @@ impl TimeDriver {
loop { loop {
let now = DRIVER.now(); let now = DRIVER.now();
let mut next_alarm = u64::MAX; let next_alarm = unsafe { DRIVER.alarms.as_ref() }.lock(|alarms| {
{ loop {
let alarms = &mut *unsafe { DRIVER.alarms.as_ref() }.lock().unwrap(); let pending = alarms
for alarm in alarms { .borrow_mut()
if alarm.timestamp <= now { .iter_mut()
.find(|alarm| alarm.timestamp <= now)
.map(|alarm| {
alarm.timestamp = u64::MAX; alarm.timestamp = u64::MAX;
// Call after clearing alarm, so the callback can set another alarm. (alarm.callback, alarm.ctx)
});
if let Some((callback, ctx)) = pending {
// safety: // safety:
// - we can ignore the possiblity of `f` being unset (null) because of the safety contract of `allocate_alarm`. // - we can ignore the possiblity of `f` being unset (null) because of the safety contract of `allocate_alarm`.
// - other than that we only store valid function pointers into alarm.callback // - other than that we only store valid function pointers into alarm.callback
let f: fn(*mut ()) = unsafe { mem::transmute(alarm.callback) }; let f: fn(*mut ()) = unsafe { mem::transmute(callback) };
f(alarm.ctx); f(ctx);
} else { } else {
next_alarm = next_alarm.min(alarm.timestamp); // No alarm due
} break;
} }
} }
alarms
.borrow()
.iter()
.map(|alarm| alarm.timestamp)
.min()
.unwrap_or(u64::MAX)
});
// Ensure we don't overflow // Ensure we don't overflow
let until = zero let until = zero
.checked_add(StdDuration::from_micros(next_alarm)) .checked_add(StdDuration::from_micros(next_alarm))
@ -121,18 +139,23 @@ impl Driver for TimeDriver {
fn set_alarm_callback(&self, alarm: AlarmHandle, callback: fn(*mut ()), ctx: *mut ()) { fn set_alarm_callback(&self, alarm: AlarmHandle, callback: fn(*mut ()), ctx: *mut ()) {
self.init(); self.init();
let mut alarms = unsafe { self.alarms.as_ref() }.lock().unwrap(); unsafe { self.alarms.as_ref() }.lock(|alarms| {
let mut alarms = alarms.borrow_mut();
let alarm = &mut alarms[alarm.id() as usize]; let alarm = &mut alarms[alarm.id() as usize];
alarm.callback = callback as *const (); alarm.callback = callback as *const ();
alarm.ctx = ctx; alarm.ctx = ctx;
});
} }
fn set_alarm(&self, alarm: AlarmHandle, timestamp: u64) -> bool { fn set_alarm(&self, alarm: AlarmHandle, timestamp: u64) -> bool {
self.init(); self.init();
let mut alarms = unsafe { self.alarms.as_ref() }.lock().unwrap(); unsafe { self.alarms.as_ref() }.lock(|alarms| {
let mut alarms = alarms.borrow_mut();
let alarm = &mut alarms[alarm.id() as usize]; let alarm = &mut alarms[alarm.id() as usize];
alarm.timestamp = timestamp; alarm.timestamp = timestamp;
unsafe { self.signaler.as_ref() }.signal(); unsafe { self.signaler.as_ref() }.signal();
});
true true
} }