Prevent accidental revert when using firmware updater
This change prevents accidentally overwriting the previous firmware before the new one has been marked as booted.
This commit is contained in:
		
							parent
							
								
									3c70f799a2
								
							
						
					
					
						commit
						76659d9003
					
				| @ -56,6 +56,16 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     // Make sure we are running a booted firmware to avoid reverting to a bad state.
 | ||||||
|  |     async fn verify_booted(&mut self, aligned: &mut [u8]) -> Result<(), FirmwareUpdaterError> { | ||||||
|  |         assert_eq!(aligned.len(), STATE::WRITE_SIZE); | ||||||
|  |         if self.get_state(aligned).await? == State::Boot { | ||||||
|  |             Ok(()) | ||||||
|  |         } else { | ||||||
|  |             Err(FirmwareUpdaterError::BadState) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /// Obtain the current state.
 |     /// Obtain the current state.
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// This is useful to check if the bootloader has just done a swap, in order
 |     /// This is useful to check if the bootloader has just done a swap, in order
 | ||||||
| @ -98,6 +108,8 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> { | |||||||
|         assert_eq!(_aligned.len(), STATE::WRITE_SIZE); |         assert_eq!(_aligned.len(), STATE::WRITE_SIZE); | ||||||
|         assert!(_update_len <= self.dfu.capacity() as u32); |         assert!(_update_len <= self.dfu.capacity() as u32); | ||||||
| 
 | 
 | ||||||
|  |         self.verify_booted(_aligned).await?; | ||||||
|  | 
 | ||||||
|         #[cfg(feature = "ed25519-dalek")] |         #[cfg(feature = "ed25519-dalek")] | ||||||
|         { |         { | ||||||
|             use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier}; |             use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier}; | ||||||
| @ -217,8 +229,16 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> { | |||||||
|     /// # Safety
 |     /// # Safety
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// Failing to meet alignment and size requirements may result in a panic.
 |     /// Failing to meet alignment and size requirements may result in a panic.
 | ||||||
|     pub async fn write_firmware(&mut self, offset: usize, data: &[u8]) -> Result<(), FirmwareUpdaterError> { |     pub async fn write_firmware( | ||||||
|  |         &mut self, | ||||||
|  |         aligned: &mut [u8], | ||||||
|  |         offset: usize, | ||||||
|  |         data: &[u8], | ||||||
|  |     ) -> Result<(), FirmwareUpdaterError> { | ||||||
|         assert!(data.len() >= DFU::ERASE_SIZE); |         assert!(data.len() >= DFU::ERASE_SIZE); | ||||||
|  |         assert_eq!(aligned.len(), STATE::WRITE_SIZE); | ||||||
|  | 
 | ||||||
|  |         self.verify_booted(aligned).await?; | ||||||
| 
 | 
 | ||||||
|         self.dfu.erase(offset as u32, (offset + data.len()) as u32).await?; |         self.dfu.erase(offset as u32, (offset + data.len()) as u32).await?; | ||||||
| 
 | 
 | ||||||
| @ -232,7 +252,14 @@ impl<DFU: NorFlash, STATE: NorFlash> FirmwareUpdater<DFU, STATE> { | |||||||
|     ///
 |     ///
 | ||||||
|     /// Using this instead of `write_firmware` allows for an optimized API in
 |     /// Using this instead of `write_firmware` allows for an optimized API in
 | ||||||
|     /// exchange for added complexity.
 |     /// exchange for added complexity.
 | ||||||
|     pub async fn prepare_update(&mut self) -> Result<&mut DFU, FirmwareUpdaterError> { |     ///
 | ||||||
|  |     /// # Safety
 | ||||||
|  |     ///
 | ||||||
|  |     /// The `aligned` buffer must have a size of STATE::WRITE_SIZE, and follow the alignment rules for the flash being written to.
 | ||||||
|  |     pub async fn prepare_update(&mut self, aligned: &mut [u8]) -> Result<&mut DFU, FirmwareUpdaterError> { | ||||||
|  |         assert_eq!(aligned.len(), STATE::WRITE_SIZE); | ||||||
|  |         self.verify_booted(aligned).await?; | ||||||
|  | 
 | ||||||
|         self.dfu.erase(0, self.dfu.capacity() as u32).await?; |         self.dfu.erase(0, self.dfu.capacity() as u32).await?; | ||||||
| 
 | 
 | ||||||
|         Ok(&mut self.dfu) |         Ok(&mut self.dfu) | ||||||
| @ -255,13 +282,14 @@ mod tests { | |||||||
|         let flash = Mutex::<NoopRawMutex, _>::new(MemFlash::<131072, 4096, 8>::default()); |         let flash = Mutex::<NoopRawMutex, _>::new(MemFlash::<131072, 4096, 8>::default()); | ||||||
|         let state = Partition::new(&flash, 0, 4096); |         let state = Partition::new(&flash, 0, 4096); | ||||||
|         let dfu = Partition::new(&flash, 65536, 65536); |         let dfu = Partition::new(&flash, 65536, 65536); | ||||||
|  |         let mut aligned = [0; 8]; | ||||||
| 
 | 
 | ||||||
|         let update = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66]; |         let update = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66]; | ||||||
|         let mut to_write = [0; 4096]; |         let mut to_write = [0; 4096]; | ||||||
|         to_write[..7].copy_from_slice(update.as_slice()); |         to_write[..7].copy_from_slice(update.as_slice()); | ||||||
| 
 | 
 | ||||||
|         let mut updater = FirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state }); |         let mut updater = FirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state }); | ||||||
|         block_on(updater.write_firmware(0, to_write.as_slice())).unwrap(); |         block_on(updater.write_firmware(&mut aligned, 0, to_write.as_slice())).unwrap(); | ||||||
|         let mut chunk_buf = [0; 2]; |         let mut chunk_buf = [0; 2]; | ||||||
|         let mut hash = [0; 20]; |         let mut hash = [0; 20]; | ||||||
|         block_on(updater.hash::<Sha1>(update.len() as u32, &mut chunk_buf, &mut hash)).unwrap(); |         block_on(updater.hash::<Sha1>(update.len() as u32, &mut chunk_buf, &mut hash)).unwrap(); | ||||||
|  | |||||||
| @ -58,6 +58,16 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     // Make sure we are running a booted firmware to avoid reverting to a bad state.
 | ||||||
|  |     fn verify_booted(&mut self, aligned: &mut [u8]) -> Result<(), FirmwareUpdaterError> { | ||||||
|  |         assert_eq!(aligned.len(), STATE::WRITE_SIZE); | ||||||
|  |         if self.get_state(aligned)? == State::Boot { | ||||||
|  |             Ok(()) | ||||||
|  |         } else { | ||||||
|  |             Err(FirmwareUpdaterError::BadState) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /// Obtain the current state.
 |     /// Obtain the current state.
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// This is useful to check if the bootloader has just done a swap, in order
 |     /// This is useful to check if the bootloader has just done a swap, in order
 | ||||||
| @ -100,6 +110,8 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> { | |||||||
|         assert_eq!(_aligned.len(), STATE::WRITE_SIZE); |         assert_eq!(_aligned.len(), STATE::WRITE_SIZE); | ||||||
|         assert!(_update_len <= self.dfu.capacity() as u32); |         assert!(_update_len <= self.dfu.capacity() as u32); | ||||||
| 
 | 
 | ||||||
|  |         self.verify_booted(_aligned)?; | ||||||
|  | 
 | ||||||
|         #[cfg(feature = "ed25519-dalek")] |         #[cfg(feature = "ed25519-dalek")] | ||||||
|         { |         { | ||||||
|             use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier}; |             use ed25519_dalek::{PublicKey, Signature, SignatureError, Verifier}; | ||||||
| @ -219,8 +231,15 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> { | |||||||
|     /// # Safety
 |     /// # Safety
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// Failing to meet alignment and size requirements may result in a panic.
 |     /// Failing to meet alignment and size requirements may result in a panic.
 | ||||||
|     pub fn write_firmware(&mut self, offset: usize, data: &[u8]) -> Result<(), FirmwareUpdaterError> { |     pub fn write_firmware( | ||||||
|  |         &mut self, | ||||||
|  |         aligned: &mut [u8], | ||||||
|  |         offset: usize, | ||||||
|  |         data: &[u8], | ||||||
|  |     ) -> Result<(), FirmwareUpdaterError> { | ||||||
|         assert!(data.len() >= DFU::ERASE_SIZE); |         assert!(data.len() >= DFU::ERASE_SIZE); | ||||||
|  |         assert_eq!(aligned.len(), STATE::WRITE_SIZE); | ||||||
|  |         self.verify_booted(aligned)?; | ||||||
| 
 | 
 | ||||||
|         self.dfu.erase(offset as u32, (offset + data.len()) as u32)?; |         self.dfu.erase(offset as u32, (offset + data.len()) as u32)?; | ||||||
| 
 | 
 | ||||||
| @ -234,7 +253,13 @@ impl<DFU: NorFlash, STATE: NorFlash> BlockingFirmwareUpdater<DFU, STATE> { | |||||||
|     ///
 |     ///
 | ||||||
|     /// Using this instead of `write_firmware` allows for an optimized API in
 |     /// Using this instead of `write_firmware` allows for an optimized API in
 | ||||||
|     /// exchange for added complexity.
 |     /// exchange for added complexity.
 | ||||||
|     pub fn prepare_update(&mut self) -> Result<&mut DFU, FirmwareUpdaterError> { |     ///
 | ||||||
|  |     /// # Safety
 | ||||||
|  |     ///
 | ||||||
|  |     /// The `aligned` buffer must have a size of STATE::WRITE_SIZE, and follow the alignment rules for the flash being written to.
 | ||||||
|  |     pub fn prepare_update(&mut self, aligned: &mut [u8]) -> Result<&mut DFU, FirmwareUpdaterError> { | ||||||
|  |         assert_eq!(aligned.len(), STATE::WRITE_SIZE); | ||||||
|  |         self.verify_booted(aligned)?; | ||||||
|         self.dfu.erase(0, self.dfu.capacity() as u32)?; |         self.dfu.erase(0, self.dfu.capacity() as u32)?; | ||||||
| 
 | 
 | ||||||
|         Ok(&mut self.dfu) |         Ok(&mut self.dfu) | ||||||
| @ -264,7 +289,8 @@ mod tests { | |||||||
|         to_write[..7].copy_from_slice(update.as_slice()); |         to_write[..7].copy_from_slice(update.as_slice()); | ||||||
| 
 | 
 | ||||||
|         let mut updater = BlockingFirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state }); |         let mut updater = BlockingFirmwareUpdater::new(FirmwareUpdaterConfig { dfu, state }); | ||||||
|         updater.write_firmware(0, to_write.as_slice()).unwrap(); |         let mut aligned = [0; 8]; | ||||||
|  |         updater.write_firmware(&mut aligned, 0, to_write.as_slice()).unwrap(); | ||||||
|         let mut chunk_buf = [0; 2]; |         let mut chunk_buf = [0; 2]; | ||||||
|         let mut hash = [0; 20]; |         let mut hash = [0; 20]; | ||||||
|         updater |         updater | ||||||
|  | |||||||
| @ -26,6 +26,8 @@ pub enum FirmwareUpdaterError { | |||||||
|     Flash(NorFlashErrorKind), |     Flash(NorFlashErrorKind), | ||||||
|     /// Signature errors.
 |     /// Signature errors.
 | ||||||
|     Signature(signature::Error), |     Signature(signature::Error), | ||||||
|  |     /// Bad state.
 | ||||||
|  |     BadState, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "defmt")] | #[cfg(feature = "defmt")] | ||||||
| @ -34,6 +36,7 @@ impl defmt::Format for FirmwareUpdaterError { | |||||||
|         match self { |         match self { | ||||||
|             FirmwareUpdaterError::Flash(_) => defmt::write!(fmt, "FirmwareUpdaterError::Flash(_)"), |             FirmwareUpdaterError::Flash(_) => defmt::write!(fmt, "FirmwareUpdaterError::Flash(_)"), | ||||||
|             FirmwareUpdaterError::Signature(_) => defmt::write!(fmt, "FirmwareUpdaterError::Signature(_)"), |             FirmwareUpdaterError::Signature(_) => defmt::write!(fmt, "FirmwareUpdaterError::Signature(_)"), | ||||||
|  |             FirmwareUpdaterError::BadState => defmt::write!(fmt, "FirmwareUpdaterError::BadState"), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -51,6 +51,8 @@ impl<const N: usize> AsMut<[u8]> for AlignedBuffer<N> { | |||||||
| 
 | 
 | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests { | mod tests { | ||||||
|  |     #![allow(unused_imports)] | ||||||
|  | 
 | ||||||
|     use embedded_storage::nor_flash::{NorFlash, ReadNorFlash}; |     use embedded_storage::nor_flash::{NorFlash, ReadNorFlash}; | ||||||
|     #[cfg(feature = "nightly")] |     #[cfg(feature = "nightly")] | ||||||
|     use embedded_storage_async::nor_flash::NorFlash as AsyncNorFlash; |     use embedded_storage_async::nor_flash::NorFlash as AsyncNorFlash; | ||||||
| @ -120,9 +122,13 @@ mod tests { | |||||||
|             dfu: flash.dfu(), |             dfu: flash.dfu(), | ||||||
|             state: flash.state(), |             state: flash.state(), | ||||||
|         }); |         }); | ||||||
|         block_on(updater.write_firmware(0, &UPDATE)).unwrap(); |         block_on(updater.write_firmware(&mut aligned, 0, &UPDATE)).unwrap(); | ||||||
|         block_on(updater.mark_updated(&mut aligned)).unwrap(); |         block_on(updater.mark_updated(&mut aligned)).unwrap(); | ||||||
| 
 | 
 | ||||||
|  |         // Writing after marking updated is not allowed until marked as booted.
 | ||||||
|  |         let res: Result<(), FirmwareUpdaterError> = block_on(updater.write_firmware(&mut aligned, 0, &UPDATE)); | ||||||
|  |         assert!(matches!(res, Err::<(), _>(FirmwareUpdaterError::BadState))); | ||||||
|  | 
 | ||||||
|         let flash = flash.into_blocking(); |         let flash = flash.into_blocking(); | ||||||
|         let mut bootloader = BootLoader::new(BootLoaderConfig { |         let mut bootloader = BootLoader::new(BootLoaderConfig { | ||||||
|             active: flash.active(), |             active: flash.active(), | ||||||
| @ -188,7 +194,7 @@ mod tests { | |||||||
|             dfu: flash.dfu(), |             dfu: flash.dfu(), | ||||||
|             state: flash.state(), |             state: flash.state(), | ||||||
|         }); |         }); | ||||||
|         block_on(updater.write_firmware(0, &UPDATE)).unwrap(); |         block_on(updater.write_firmware(&mut aligned, 0, &UPDATE)).unwrap(); | ||||||
|         block_on(updater.mark_updated(&mut aligned)).unwrap(); |         block_on(updater.mark_updated(&mut aligned)).unwrap(); | ||||||
| 
 | 
 | ||||||
|         let flash = flash.into_blocking(); |         let flash = flash.into_blocking(); | ||||||
| @ -230,7 +236,7 @@ mod tests { | |||||||
|             dfu: flash.dfu(), |             dfu: flash.dfu(), | ||||||
|             state: flash.state(), |             state: flash.state(), | ||||||
|         }); |         }); | ||||||
|         block_on(updater.write_firmware(0, &UPDATE)).unwrap(); |         block_on(updater.write_firmware(&mut aligned, 0, &UPDATE)).unwrap(); | ||||||
|         block_on(updater.mark_updated(&mut aligned)).unwrap(); |         block_on(updater.mark_updated(&mut aligned)).unwrap(); | ||||||
| 
 | 
 | ||||||
|         let flash = flash.into_blocking(); |         let flash = flash.into_blocking(); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user