diff --git a/pod/src/slice.rs b/pod/src/slice.rs index 043f408d..ca885aba 100644 --- a/pod/src/slice.rs +++ b/pod/src/slice.rs @@ -29,7 +29,7 @@ impl<'data, T: Pod> PodSlice<'data, T> { } let (length, data) = data.split_at(LENGTH_SIZE); let length = pod_from_bytes::(length)?; - let _max_length = max_len_for_type::(data.len())?; + let _max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; let data = pod_slice_from_bytes(data)?; Ok(Self { length, data }) } @@ -70,7 +70,7 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { if init { *length = 0.into(); } - let max_length = max_len_for_type::(data.len())?; + let max_length = max_len_for_type::(data.len(), u32::from(*length) as usize)?; let data = pod_slice_from_bytes_mut(data)?; Ok(Self { length, @@ -109,22 +109,30 @@ impl<'data, T: Pod> PodSliceMut<'data, T> { } } -fn max_len_for_type(data_len: usize) -> Result { - let size: usize = std::mem::size_of::(); +fn max_len_for_type(data_len: usize, length_val: usize) -> Result { + let item_size = std::mem::size_of::(); let max_len = data_len - .checked_div(size) + .checked_div(item_size) .ok_or(PodSliceError::CalculationFailure)?; - // check that it isn't over or under allocated - if max_len.saturating_mul(size) != data_len { + + // Make sure the max length that can be stored in the buffer isn't less + // than the length value. + if max_len < length_val { + Err(PodSliceError::BufferTooSmall)? + } + + // Make sure the buffer is cleanly divisible by `size_of::`; not over or + // under allocated. + if max_len.saturating_mul(item_size) != data_len { if max_len == 0 { // Size of T is greater than buffer size - Err(PodSliceError::BufferTooSmall.into()) + Err(PodSliceError::BufferTooSmall)? } else { - Err(PodSliceError::BufferTooLarge.into()) + Err(PodSliceError::BufferTooLarge)? } - } else { - Ok(max_len) } + + Ok(max_len) } #[cfg(test)] @@ -171,9 +179,11 @@ mod tests { #[test] fn test_pod_slice_buffer_too_large() { - // 1 `TestStruct` + length = 37 bytes - // we pass 38 to trigger BufferTooLarge - let pod_slice_bytes = [1; 38]; + // Length is 1. We pass one test struct with 6 trailing bytes to + // trigger BufferTooLarge. + let data_len = LENGTH_SIZE + std::mem::size_of::() + 6; + let mut pod_slice_bytes = vec![1; data_len]; + pod_slice_bytes[0..4].copy_from_slice(&[1, 0, 0, 0]); let err = PodSlice::::unpack(&pod_slice_bytes) .err() .unwrap(); @@ -184,6 +194,32 @@ mod tests { ); } + #[test] + fn test_pod_slice_buffer_larger_than_length_value() { + // If the buffer is longer than the u32 length value declares, it + // should still unpack successfully, as long as the length of the rest + // of the buffer can be divided by `size_of::`. + let length: u32 = 12; + let length_le = length.to_le_bytes(); + + // First set up the data to have room for extra items. + let data_len = PodSlice::::size_of(length as usize + 2).unwrap(); + let mut data = vec![0; data_len]; + + // Now write the bogus length - which is smaller - into the first 4 + // bytes. + data[..LENGTH_SIZE].copy_from_slice(&length_le); + + let pod_slice = PodSlice::::unpack(&data).unwrap(); + let pod_slice_len = u32::from(*pod_slice.length); + let data = pod_slice.data(); + let data_vec = data.to_vec(); + + assert_eq!(pod_slice_len, length); + assert_eq!(data.len(), length as usize); + assert_eq!(data_vec.len(), length as usize); + } + #[test] fn test_pod_slice_buffer_too_small() { // 1 `TestStruct` + length = 37 bytes @@ -199,6 +235,31 @@ mod tests { ); } + #[test] + fn test_pod_slice_buffer_shorter_than_length_value() { + // If the buffer is shorter than the u32 length value declares, we + // should get a BufferTooSmall error. + let length: u32 = 12; + let length_le = length.to_le_bytes(); + for num_items in 0..length { + // First set up the data to have `num_elements` items. + let data_len = PodSlice::::size_of(num_items as usize).unwrap(); + let mut data = vec![0; data_len]; + + // Now write the bogus length - which is larger - into the first 4 + // bytes. + data[..LENGTH_SIZE].copy_from_slice(&length_le); + + // Expect an error on unpacking. + let err = PodSlice::::unpack(&data).err().unwrap(); + assert_eq!( + err, + PodSliceError::BufferTooSmall.into(), + "Expected an `PodSliceError::BufferTooSmall` error" + ); + } + } + #[test] fn test_pod_slice_mut() { // slice can fit 2 `TestStruct`