1370: stm32/i2c: fix races when using dma. r=Dirbaio a=xoviat

This change addresses two races:

1. It removes the `chunks_transferred` state variable that is modified inside the interrupt. Analysis of the code reveals that the only time the waker can be woken is when `chunks_transferred` is incremented. Therefore, waking is enough to signal the `poll_fn` that the `chunks_transferred` has incremented. Moving to `remaining_len` clarifies the code, since there is no need to track how many chunks are remaining.
2. It moves the start of the transfer until after the waker is registered, which could theoretically occur if the clock speed is very low, but probably never would even if this wasn't fixed.

There is another race that I noticed: between writes the waker may not yet be registered. In that case, the code would simply be stuck and the `poll_fn` would never be woken. There is no way to resolve this without broadening the scope of the analysis, and this will likely never occur. 

Co-authored-by: xoviat <xoviat@users.noreply.github.com>
This commit is contained in:
bors[bot] 2023-04-19 21:36:04 +00:00 committed by GitHub
commit 41e90e22e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,5 @@
use core::cmp; use core::cmp;
use core::future::poll_fn; use core::future::poll_fn;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::Poll; use core::task::Poll;
use embassy_embedded_hal::SetConfig; use embassy_embedded_hal::SetConfig;
@ -35,14 +34,12 @@ impl Default for Config {
pub struct State { pub struct State {
waker: AtomicWaker, waker: AtomicWaker,
chunks_transferred: AtomicUsize,
} }
impl State { impl State {
pub(crate) const fn new() -> Self { pub(crate) const fn new() -> Self {
Self { Self {
waker: AtomicWaker::new(), waker: AtomicWaker::new(),
chunks_transferred: AtomicUsize::new(0),
} }
} }
} }
@ -130,10 +127,7 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
let isr = regs.isr().read(); let isr = regs.isr().read();
if isr.tcr() || isr.tc() { if isr.tcr() || isr.tc() {
let state = T::state(); T::state().waker.wake();
let transferred = state.chunks_transferred.load(Ordering::Relaxed);
state.chunks_transferred.store(transferred + 1, Ordering::Relaxed);
state.waker.wake();
} }
// The flag can only be cleared by writting to nbytes, we won't do that here, so disable // The flag can only be cleared by writting to nbytes, we won't do that here, so disable
// the interrupt // the interrupt
@ -457,12 +451,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
TXDMA: crate::i2c::TxDma<T>, TXDMA: crate::i2c::TxDma<T>,
{ {
let total_len = write.len(); let total_len = write.len();
let completed_chunks = total_len / 255;
let total_chunks = if completed_chunks * 255 == total_len {
completed_chunks
} else {
completed_chunks + 1
};
let dma_transfer = unsafe { let dma_transfer = unsafe {
let regs = T::regs(); let regs = T::regs();
@ -480,7 +468,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
}; };
let state = T::state(); let state = T::state();
state.chunks_transferred.store(0, Ordering::Relaxed);
let mut remaining_len = total_len; let mut remaining_len = total_len;
let on_drop = OnDrop::new(|| { let on_drop = OnDrop::new(|| {
@ -495,6 +482,11 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
} }
}); });
poll_fn(|cx| {
state.waker.register(cx.waker());
let isr = unsafe { T::regs().isr().read() };
if remaining_len == total_len {
// NOTE(unsafe) self.tx_dma does not fiddle with the i2c registers // NOTE(unsafe) self.tx_dma does not fiddle with the i2c registers
if first_slice { if first_slice {
unsafe { unsafe {
@ -502,26 +494,23 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
address, address,
total_len.min(255), total_len.min(255),
Stop::Software, Stop::Software,
(total_chunks != 1) || !last_slice, (total_len > 255) || !last_slice,
&check_timeout, &check_timeout,
)?; )?;
} }
} else { } else {
unsafe { unsafe {
Self::master_continue(total_len.min(255), (total_chunks != 1) || !last_slice, &check_timeout)?; Self::master_continue(total_len.min(255), (total_len > 255) || !last_slice, &check_timeout)?;
T::regs().cr1().modify(|w| w.set_tcie(true)); T::regs().cr1().modify(|w| w.set_tcie(true));
} }
} }
} else if !(isr.tcr() || isr.tc()) {
poll_fn(|cx| { // poll_fn was woken without an interrupt present
state.waker.register(cx.waker()); return Poll::Pending;
let chunks_transferred = state.chunks_transferred.load(Ordering::Relaxed); } else if remaining_len == 0 {
if chunks_transferred == total_chunks {
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} else if chunks_transferred != 0 { } else {
remaining_len = remaining_len.saturating_sub(255); let last_piece = (remaining_len <= 255) && last_slice;
let last_piece = (chunks_transferred + 1 == total_chunks) && last_slice;
// NOTE(unsafe) self.tx_dma does not fiddle with the i2c registers // NOTE(unsafe) self.tx_dma does not fiddle with the i2c registers
unsafe { unsafe {
@ -531,6 +520,8 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
T::regs().cr1().modify(|w| w.set_tcie(true)); T::regs().cr1().modify(|w| w.set_tcie(true));
} }
} }
remaining_len = remaining_len.saturating_sub(255);
Poll::Pending Poll::Pending
}) })
.await?; .await?;
@ -559,12 +550,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
RXDMA: crate::i2c::RxDma<T>, RXDMA: crate::i2c::RxDma<T>,
{ {
let total_len = buffer.len(); let total_len = buffer.len();
let completed_chunks = total_len / 255;
let total_chunks = if completed_chunks * 255 == total_len {
completed_chunks
} else {
completed_chunks + 1
};
let dma_transfer = unsafe { let dma_transfer = unsafe {
let regs = T::regs(); let regs = T::regs();
@ -580,7 +565,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
}; };
let state = T::state(); let state = T::state();
state.chunks_transferred.store(0, Ordering::Relaxed);
let mut remaining_len = total_len; let mut remaining_len = total_len;
let on_drop = OnDrop::new(|| { let on_drop = OnDrop::new(|| {
@ -593,27 +577,29 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
} }
}); });
poll_fn(|cx| {
state.waker.register(cx.waker());
let isr = unsafe { T::regs().isr().read() };
if remaining_len == total_len {
// NOTE(unsafe) self.rx_dma does not fiddle with the i2c registers // NOTE(unsafe) self.rx_dma does not fiddle with the i2c registers
unsafe { unsafe {
Self::master_read( Self::master_read(
address, address,
total_len.min(255), total_len.min(255),
Stop::Software, Stop::Software,
total_chunks != 1, total_len > 255,
restart, restart,
&check_timeout, &check_timeout,
)?; )?;
} }
} else if !(isr.tcr() || isr.tc()) {
poll_fn(|cx| { // poll_fn was woken without an interrupt present
state.waker.register(cx.waker()); return Poll::Pending;
let chunks_transferred = state.chunks_transferred.load(Ordering::Relaxed); } else if remaining_len == 0 {
if chunks_transferred == total_chunks {
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} else if chunks_transferred != 0 { } else {
remaining_len = remaining_len.saturating_sub(255); let last_piece = remaining_len <= 255;
let last_piece = chunks_transferred + 1 == total_chunks;
// NOTE(unsafe) self.rx_dma does not fiddle with the i2c registers // NOTE(unsafe) self.rx_dma does not fiddle with the i2c registers
unsafe { unsafe {
@ -623,6 +609,8 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
T::regs().cr1().modify(|w| w.set_tcie(true)); T::regs().cr1().modify(|w| w.set_tcie(true));
} }
} }
remaining_len = remaining_len.saturating_sub(255);
Poll::Pending Poll::Pending
}) })
.await?; .await?;