Use atomics to share state instead of a RefCell
				
					
				
			This commit is contained in:
		| @@ -1,14 +1,14 @@ | ||||
| use core::cell::RefCell; | ||||
| use core::convert::Infallible; | ||||
| use core::future::Future; | ||||
| use core::marker::PhantomData; | ||||
| use core::ptr::NonNull; | ||||
| use core::ptr; | ||||
| use core::sync::atomic::AtomicPtr; | ||||
| use core::sync::atomic::Ordering; | ||||
| use core::task::Poll; | ||||
| use core::task::Waker; | ||||
|  | ||||
| use embassy::interrupt::InterruptExt; | ||||
| use embassy::traits; | ||||
| use embassy::util::CriticalSectionMutex; | ||||
| use embassy::util::AtomicWaker; | ||||
| use embassy::util::OnDrop; | ||||
| use embassy::util::Unborrow; | ||||
| use embassy_extras::unborrow; | ||||
| @@ -25,25 +25,18 @@ impl RNG { | ||||
|     } | ||||
| } | ||||
|  | ||||
| static STATE: CriticalSectionMutex<RefCell<State>> = | ||||
|     CriticalSectionMutex::new(RefCell::new(State { | ||||
|         buffer: None, | ||||
|         waker: None, | ||||
|         index: 0, | ||||
|     })); | ||||
| static STATE: State = State { | ||||
|     ptr: AtomicPtr::new(ptr::null_mut()), | ||||
|     end: AtomicPtr::new(ptr::null_mut()), | ||||
|     waker: AtomicWaker::new(), | ||||
| }; | ||||
|  | ||||
| struct State { | ||||
|     buffer: Option<NonNull<[u8]>>, | ||||
|     waker: Option<Waker>, | ||||
|     index: usize, | ||||
|     ptr: AtomicPtr<u8>, | ||||
|     end: AtomicPtr<u8>, | ||||
|     waker: AtomicWaker, | ||||
| } | ||||
|  | ||||
| // SAFETY: `NonNull` is `!Send` because of the possibility of it being aliased. | ||||
| // However, `buffer` is only used within `on_interrupt`, | ||||
| // and the original `&mut` passed to `fill_bytes` cannot be used because the safety contract of `Rng::new` | ||||
| // means that it must still be borrowed by `RngFuture`, and so `rustc` will not let it be accessed. | ||||
| unsafe impl Send for State {} | ||||
|  | ||||
| /// A wrapper around an nRF RNG peripheral. | ||||
| /// | ||||
| /// It has a non-blocking API, through `embassy::traits::Rng`, and a blocking api through `rand`. | ||||
| @@ -70,7 +63,7 @@ impl<'d> Rng<'d> { | ||||
|             phantom: PhantomData, | ||||
|         }; | ||||
|  | ||||
|         Self::stop(); | ||||
|         this.stop(); | ||||
|         this.disable_irq(); | ||||
|  | ||||
|         this.irq.set_handler(Self::on_interrupt); | ||||
| @@ -81,25 +74,54 @@ impl<'d> Rng<'d> { | ||||
|     } | ||||
|  | ||||
|     fn on_interrupt(_: *mut ()) { | ||||
|         critical_section::with(|cs| { | ||||
|             let mut state = STATE.borrow(cs).borrow_mut(); | ||||
|             // SAFETY: the safety requirements on `Rng::new` make sure that the original `&mut`'s lifetime is still valid, | ||||
|             // meaning it can't be aliased and is a valid pointer. | ||||
|             let buffer = unsafe { state.buffer.unwrap().as_mut() }; | ||||
|             buffer[state.index] = RNG::regs().value.read().value().bits(); | ||||
|             state.index += 1; | ||||
|             if state.index == buffer.len() { | ||||
|                 // Stop the RNG within the interrupt so that it doesn't get triggered again on the way to waking the future. | ||||
|                 Self::stop(); | ||||
|                 if let Some(waker) = state.waker.take() { | ||||
|                     waker.wake(); | ||||
|         // Clear the event. | ||||
|         RNG::regs().events_valrdy.reset(); | ||||
|  | ||||
|         // Mutate the slice within a critical section, | ||||
|         // so that the future isn't dropped in between us loading the pointer and actually dereferencing it. | ||||
|         let (ptr, end) = critical_section::with(|_| { | ||||
|             let ptr = STATE.ptr.load(Ordering::Relaxed); | ||||
|             // We need to make sure we haven't already filled the whole slice, | ||||
|             // in case the interrupt fired again before the executor got back to the future. | ||||
|             let end = STATE.end.load(Ordering::Relaxed); | ||||
|             if !ptr.is_null() && ptr != end { | ||||
|                 // If the future was dropped, the pointer would have been set to null, | ||||
|                 // so we're still good to mutate the slice. | ||||
|                 // The safety contract of `Rng::new` means that the future can't have been dropped | ||||
|                 // without calling its destructor. | ||||
|                 unsafe { | ||||
|                     *ptr = RNG::regs().value.read().value().bits(); | ||||
|                 } | ||||
|             } | ||||
|             RNG::regs().events_valrdy.reset(); | ||||
|             (ptr, end) | ||||
|         }); | ||||
|  | ||||
|         if ptr.is_null() || ptr == end { | ||||
|             // If the future was dropped, there's nothing to do. | ||||
|             // If `ptr == end`, we were called by mistake, so return. | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         let new_ptr = unsafe { ptr.add(1) }; | ||||
|         match STATE | ||||
|             .ptr | ||||
|             .compare_exchange(ptr, new_ptr, Ordering::Relaxed, Ordering::Relaxed) | ||||
|         { | ||||
|             Ok(ptr) => { | ||||
|                 let end = STATE.end.load(Ordering::Relaxed); | ||||
|                 // It doesn't matter if `end` was changed under our feet, because then this will just be false. | ||||
|                 if ptr == end { | ||||
|                     STATE.waker.wake(); | ||||
|                 } | ||||
|             } | ||||
|             Err(_) => { | ||||
|                 // If the future was dropped or finished, there's no point trying to wake it. | ||||
|                 // It will have already stopped the RNG, so there's no need to do that either. | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn stop() { | ||||
|     fn stop(&self) { | ||||
|         RNG::regs().tasks_stop.write(|w| unsafe { w.bits(1) }) | ||||
|     } | ||||
|  | ||||
| @@ -140,37 +162,41 @@ impl<'d> traits::rng::Rng for Rng<'d> { | ||||
|  | ||||
|     fn fill_bytes<'a>(&'a mut self, dest: &'a mut [u8]) -> Self::RngFuture<'a> { | ||||
|         async move { | ||||
|             critical_section::with(|cs| { | ||||
|                 let mut state = STATE.borrow(cs).borrow_mut(); | ||||
|                 state.buffer = Some(dest.into()); | ||||
|             }); | ||||
|             if dest.len() == 0 { | ||||
|                 return Ok(()); // Nothing to fill | ||||
|             } | ||||
|  | ||||
|             let range = dest.as_mut_ptr_range(); | ||||
|             // Even if we've preempted the interrupt, it can't preempt us again, | ||||
|             // so we don't need to worry about the order we write these in. | ||||
|             STATE.ptr.store(range.start, Ordering::Relaxed); | ||||
|             STATE.end.store(range.end, Ordering::Relaxed); | ||||
|  | ||||
|             self.enable_irq(); | ||||
|             self.start(); | ||||
|  | ||||
|             let on_drop = OnDrop::new(|| { | ||||
|                 Self::stop(); | ||||
|                 self.stop(); | ||||
|                 self.disable_irq(); | ||||
|  | ||||
|                 // The interrupt is now disabled and can't preempt us anymore, so the order doesn't matter here. | ||||
|                 STATE.ptr.store(ptr::null_mut(), Ordering::Relaxed); | ||||
|                 STATE.end.store(ptr::null_mut(), Ordering::Relaxed); | ||||
|             }); | ||||
|  | ||||
|             poll_fn(|cx| { | ||||
|                 critical_section::with(|cs| { | ||||
|                     let mut state = STATE.borrow(cs).borrow_mut(); | ||||
|                     state.waker = Some(cx.waker().clone()); | ||||
|                     // SAFETY: see safety message in interrupt handler. | ||||
|                     // Also, both here and in the interrupt handler, we're in a critical section, | ||||
|                     // so they can't interfere with each other. | ||||
|                     let buffer = unsafe { state.buffer.unwrap().as_ref() }; | ||||
|                 STATE.waker.register(cx.waker()); | ||||
|  | ||||
|                     if state.index == buffer.len() { | ||||
|                         // Reset the state for next time | ||||
|                         state.buffer = None; | ||||
|                         state.index = 0; | ||||
|                         Poll::Ready(()) | ||||
|                     } else { | ||||
|                         Poll::Pending | ||||
|                     } | ||||
|                 }) | ||||
|                 // The interrupt will never modify `end`, so load it first and then get the most up-to-date `ptr`. | ||||
|                 let end = STATE.end.load(Ordering::Relaxed); | ||||
|                 let ptr = STATE.ptr.load(Ordering::Relaxed); | ||||
|  | ||||
|                 if ptr == end { | ||||
|                     // We're done. | ||||
|                     Poll::Ready(()) | ||||
|                 } else { | ||||
|                     Poll::Pending | ||||
|                 } | ||||
|             }) | ||||
|             .await; | ||||
|  | ||||
| @@ -193,7 +219,7 @@ impl<'d> RngCore for Rng<'d> { | ||||
|             *byte = regs.value.read().value().bits(); | ||||
|         } | ||||
|  | ||||
|         Self::stop(); | ||||
|         self.stop(); | ||||
|     } | ||||
|  | ||||
|     fn next_u32(&mut self) -> u32 { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user