Prevent accidental revert when using firmware updater

This change prevents accidentally overwriting the previous firmware before
the new one has been marked as booted.
This commit is contained in:
Ulf Lilleengen 2023-06-19 22:37:23 +02:00
parent 3c70f799a2
commit 76659d9003
4 changed files with 72 additions and 9 deletions

View File

@ -56,6 +56,16 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> {
} }
} }
// Make sure we are running a booted firmware to avoid reverting to a bad state.
async fn verify_booted(&mut self, aligned: &mut [u8]) -> Result<(), FirmwareUpdaterError> {
assert_eq!(aligned.len(), STATE::WRITE_SIZE);
if self.get_state(aligned).await? == State::Boot {
Ok(())
} else {
Err(FirmwareUpdaterError::BadState)
}
}
/// Obtain the current state. /// Obtain the current state.
/// ///
/// This is useful to check if the bootloader has just done a swap, in order /// This is useful to check if the bootloader has just done a swap, in order
@ -98,6 +108,8 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> {
assert_eq!(_aligned.len(), STATE::WRITE_SIZE); assert_eq!(_aligned.len(), STATE::WRITE_SIZE);
assert!(_update_len <= self.dfu.capacity() as u32); assert!(_update_len <= self.dfu.capacity() as u32);
self.verify_booted(_aligned).await?;
#[cfg(feature = "ed25519-dalek")] #[cfg(feature = "ed25519-dalek")]
{ {
use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier}; use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier};
@ -217,8 +229,16 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> {
/// # Safety /// # Safety
/// ///
/// Failing to meet alignment and size requirements may result in a panic. /// Failing to meet alignment and size requirements may result in a panic.
pub async fn write_firmware(&mut self, offset: usize, data: &[u8]) -> Result<(), FirmwareUpdaterError> { pub async fn write_firmware(
&mut self,
aligned: &mut [u8],
offset: usize,
data: &[u8],
) -> Result<(), FirmwareUpdaterError> {
assert!(data.len() >= DFU::ERASE_SIZE); assert!(data.len() >= DFU::ERASE_SIZE);
assert_eq!(aligned.len(), STATE::WRITE_SIZE);
self.verify_booted(aligned).await?;
self.dfu.erase(offset as u32, (offset + data.len()) as u32).await?; self.dfu.erase(offset as u32, (offset + data.len()) as u32).await?;
@ -232,7 +252,14 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> {
/// ///
/// Using this instead of `write_firmware` allows for an optimized API in /// Using this instead of `write_firmware` allows for an optimized API in
/// exchange for added complexity. /// exchange for added complexity.
pub async fn prepare_update(&mut self) -> Result<&mut DFU, FirmwareUpdaterError> { ///
/// # Safety
///
/// The `aligned` buffer must have a size of STATE::WRITE_SIZE, and follow the alignment rules for the flash being written to.
pub async fn prepare_update(&mut self, aligned: &mut [u8]) -> Result<&mut DFU, FirmwareUpdaterError> {
assert_eq!(aligned.len(), STATE::WRITE_SIZE);
self.verify_booted(aligned).await?;
self.dfu.erase(0, self.dfu.capacity() as u32).await?; self.dfu.erase(0, self.dfu.capacity() as u32).await?;
Ok(&mut self.dfu) Ok(&mut self.dfu)
@ -255,13 +282,14 @@ mod tests {
let flash = Mutex::<NoopRawMutex, _>::new(MemFlash::<131072, 4096, 8>::default()); let flash = Mutex::<NoopRawMutex, _>::new(MemFlash::<131072, 4096, 8>::default());
let state = Partition::new(&flash, 0, 4096); let state = Partition::new(&flash, 0, 4096);
let dfu = Partition::new(&flash, 65536, 65536); let dfu = Partition::new(&flash, 65536, 65536);
let mut aligned = [0; 8];
let update = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66]; let update = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66];
let mut to_write = [0; 4096]; let mut to_write = [0; 4096];
to_write[..7].copy_from_slice(update.as_slice()); to_write[..7].copy_from_slice(update.as_slice());
let mut updater = FirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state }); let mut updater = FirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state });
block_on(updater.write_firmware(0, to_write.as_slice())).unwrap(); block_on(updater.write_firmware(&mut aligned, 0, to_write.as_slice())).unwrap();
let mut chunk_buf = [0; 2]; let mut chunk_buf = [0; 2];
let mut hash = [0; 20]; let mut hash = [0; 20];
block_on(updater.hash::<Sha1>(update.len() as u32, &mut chunk_buf, &mut hash)).unwrap(); block_on(updater.hash::<Sha1>(update.len() as u32, &mut chunk_buf, &mut hash)).unwrap();

View File

@ -58,6 +58,16 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> {
} }
} }
// Make sure we are running a booted firmware to avoid reverting to a bad state.
fn verify_booted(&mut self, aligned: &mut [u8]) -> Result<(), FirmwareUpdaterError> {
assert_eq!(aligned.len(), STATE::WRITE_SIZE);
if self.get_state(aligned)? == State::Boot {
Ok(())
} else {
Err(FirmwareUpdaterError::BadState)
}
}
/// Obtain the current state. /// Obtain the current state.
/// ///
/// This is useful to check if the bootloader has just done a swap, in order /// This is useful to check if the bootloader has just done a swap, in order
@ -100,6 +110,8 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> {
assert_eq!(_aligned.len(), STATE::WRITE_SIZE); assert_eq!(_aligned.len(), STATE::WRITE_SIZE);
assert!(_update_len <= self.dfu.capacity() as u32); assert!(_update_len <= self.dfu.capacity() as u32);
self.verify_booted(_aligned)?;
#[cfg(feature = "ed25519-dalek")] #[cfg(feature = "ed25519-dalek")]
{ {
use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier}; use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier};
@ -219,8 +231,15 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> {
/// # Safety /// # Safety
/// ///
/// Failing to meet alignment and size requirements may result in a panic. /// Failing to meet alignment and size requirements may result in a panic.
pub fn write_firmware(&mut self, offset: usize, data: &[u8]) -> Result<(), FirmwareUpdaterError> { pub fn write_firmware(
&mut self,
aligned: &mut [u8],
offset: usize,
data: &[u8],
) -> Result<(), FirmwareUpdaterError> {
assert!(data.len() >= DFU::ERASE_SIZE); assert!(data.len() >= DFU::ERASE_SIZE);
assert_eq!(aligned.len(), STATE::WRITE_SIZE);
self.verify_booted(aligned)?;
self.dfu.erase(offset as u32, (offset + data.len()) as u32)?; self.dfu.erase(offset as u32, (offset + data.len()) as u32)?;
@ -234,7 +253,13 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> {
/// ///
/// Using this instead of `write_firmware` allows for an optimized API in /// Using this instead of `write_firmware` allows for an optimized API in
/// exchange for added complexity. /// exchange for added complexity.
pub fn prepare_update(&mut self) -> Result<&mut DFU, FirmwareUpdaterError> { ///
/// # Safety
///
/// The `aligned` buffer must have a size of STATE::WRITE_SIZE, and follow the alignment rules for the flash being written to.
pub fn prepare_update(&mut self, aligned: &mut [u8]) -> Result<&mut DFU, FirmwareUpdaterError> {
assert_eq!(aligned.len(), STATE::WRITE_SIZE);
self.verify_booted(aligned)?;
self.dfu.erase(0, self.dfu.capacity() as u32)?; self.dfu.erase(0, self.dfu.capacity() as u32)?;
Ok(&mut self.dfu) Ok(&mut self.dfu)
@ -264,7 +289,8 @@ mod tests {
to_write[..7].copy_from_slice(update.as_slice()); to_write[..7].copy_from_slice(update.as_slice());
let mut updater = BlockingFirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state }); let mut updater = BlockingFirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state });
updater.write_firmware(0, to_write.as_slice()).unwrap(); let mut aligned = [0; 8];
updater.write_firmware(&mut aligned, 0, to_write.as_slice()).unwrap();
let mut chunk_buf = [0; 2]; let mut chunk_buf = [0; 2];
let mut hash = [0; 20]; let mut hash = [0; 20];
updater updater

View File

@ -26,6 +26,8 @@ pub enum FirmwareUpdaterError {
Flash(NorFlashErrorKind), Flash(NorFlashErrorKind),
/// Signature errors. /// Signature errors.
Signature(signature::Error), Signature(signature::Error),
/// Bad state.
BadState,
} }
#[cfg(feature = "defmt")] #[cfg(feature = "defmt")]
@ -34,6 +36,7 @@ impl defmt::Format for FirmwareUpdaterError {
match self { match self {
FirmwareUpdaterError::Flash(_) => defmt::write!(fmt, "FirmwareUpdaterError::Flash(_)"), FirmwareUpdaterError::Flash(_) => defmt::write!(fmt, "FirmwareUpdaterError::Flash(_)"),
FirmwareUpdaterError::Signature(_) => defmt::write!(fmt, "FirmwareUpdaterError::Signature(_)"), FirmwareUpdaterError::Signature(_) => defmt::write!(fmt, "FirmwareUpdaterError::Signature(_)"),
FirmwareUpdaterError::BadState => defmt::write!(fmt, "FirmwareUpdaterError::BadState"),
} }
} }
} }

View File

@ -51,6 +51,8 @@ impl<const N: usize> AsMut<[u8]> for AlignedBuffer<N> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#![allow(unused_imports)]
use embedded_storage::nor_flash::{NorFlash, ReadNorFlash}; use embedded_storage::nor_flash::{NorFlash, ReadNorFlash};
#[cfg(feature = "nightly")] #[cfg(feature = "nightly")]
use embedded_storage_async::nor_flash::NorFlash as AsyncNorFlash; use embedded_storage_async::nor_flash::NorFlash as AsyncNorFlash;
@ -120,9 +122,13 @@ mod tests {
dfu: flash.dfu(), dfu: flash.dfu(),
state: flash.state(), state: flash.state(),
}); });
block_on(updater.write_firmware(0, &UPDATE)).unwrap(); block_on(updater.write_firmware(&mut aligned, 0, &UPDATE)).unwrap();
block_on(updater.mark_updated(&mut aligned)).unwrap(); block_on(updater.mark_updated(&mut aligned)).unwrap();
// Writing after marking updated is not allowed until marked as booted.
let res: Result<(), FirmwareUpdaterError> = block_on(updater.write_firmware(&mut aligned, 0, &UPDATE));
assert!(matches!(res, Err::<(), _>(FirmwareUpdaterError::BadState)));
let flash = flash.into_blocking(); let flash = flash.into_blocking();
let mut bootloader = BootLoader::new(BootLoaderConfig { let mut bootloader = BootLoader::new(BootLoaderConfig {
active: flash.active(), active: flash.active(),
@ -188,7 +194,7 @@ mod tests {
dfu: flash.dfu(), dfu: flash.dfu(),
state: flash.state(), state: flash.state(),
}); });
block_on(updater.write_firmware(0, &UPDATE)).unwrap(); block_on(updater.write_firmware(&mut aligned, 0, &UPDATE)).unwrap();
block_on(updater.mark_updated(&mut aligned)).unwrap(); block_on(updater.mark_updated(&mut aligned)).unwrap();
let flash = flash.into_blocking(); let flash = flash.into_blocking();
@ -230,7 +236,7 @@ mod tests {
dfu: flash.dfu(), dfu: flash.dfu(),
state: flash.state(), state: flash.state(),
}); });
block_on(updater.write_firmware(0, &UPDATE)).unwrap(); block_on(updater.write_firmware(&mut aligned, 0, &UPDATE)).unwrap();
block_on(updater.mark_updated(&mut aligned)).unwrap(); block_on(updater.mark_updated(&mut aligned)).unwrap();
let flash = flash.into_blocking(); let flash = flash.into_blocking();