Skip to content

Commit cd0ffac

Browse files
authored
Merge pull request #745 from paupino/borsh-validation
add validation to borsh deser
2 parents 7823092 + cd1f748 commit cd0ffac

File tree

4 files changed

+83
-5
lines changed

4 files changed

+83
-5
lines changed

src/borsh.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use std::io;
2+
3+
use borsh::BorshDeserialize;
4+
5+
use crate::{
6+
Decimal, Error,
7+
constants::{SCALE_MASK, SCALE_SHIFT, SIGN_MASK},
8+
};
9+
10+
impl borsh::BorshDeserialize for Decimal {
11+
/// An implementation of [`BorshDeserialize`] that checks the received data to ensure it's a
12+
/// valid instance of [`Self`].
13+
fn deserialize_reader<__R: io::Read>(reader: &mut __R) -> Result<Self, io::Error> {
14+
const FLAG_MASK: u32 = SCALE_MASK | SIGN_MASK;
15+
16+
let flags: u32 = BorshDeserialize::deserialize_reader(reader)?;
17+
if flags & FLAG_MASK != flags {
18+
return Err(io::Error::new(
19+
io::ErrorKind::InvalidData,
20+
"Invalid flag representation",
21+
));
22+
}
23+
24+
let negative = flags & SIGN_MASK != 0;
25+
26+
let scale = (flags & SCALE_MASK) >> SCALE_SHIFT;
27+
if scale > Self::MAX_SCALE {
28+
return Err(io::Error::new(
29+
io::ErrorKind::InvalidData,
30+
Error::ScaleExceedsMaximumPrecision(scale),
31+
));
32+
}
33+
34+
let hi = BorshDeserialize::deserialize_reader(reader)?;
35+
let lo = BorshDeserialize::deserialize_reader(reader)?;
36+
let mid = BorshDeserialize::deserialize_reader(reader)?;
37+
38+
Ok(Self::from_parts(lo, mid, hi, negative, scale))
39+
}
40+
}

src/decimal.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,9 @@ impl From<UnpackedDecimal> for Decimal {
117117
#[cfg_attr(feature = "diesel", derive(FromSqlRow, AsExpression), diesel(sql_type = Numeric))]
118118
#[cfg_attr(feature = "c-repr", repr(C))]
119119
#[cfg_attr(feature = "align16", repr(align(16)))]
120-
#[cfg_attr(
121-
feature = "borsh",
122-
derive(borsh::BorshDeserialize, borsh::BorshSerialize, borsh::BorshSchema)
123-
)]
120+
// [`borsh::BorshDeserialize`] is implemented manually so that the result can be checked to be a
121+
// valid instance of [`Self`].
122+
#[cfg_attr(feature = "borsh", derive(borsh::BorshSerialize, borsh::BorshSchema))]
124123
#[cfg_attr(
125124
feature = "rkyv",
126125
derive(Archive, Deserialize, Serialize),

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ pub mod str;
1414
// We purposely place this here for documentation ordering
1515
mod arithmetic_impls;
1616

17+
#[cfg(feature = "borsh")]
18+
mod borsh;
1719
#[cfg(feature = "rust-fuzz")]
1820
mod fuzz;
1921
#[cfg(feature = "maths")]

tests/decimal_tests.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ fn it_can_serialize_deserialize() {
138138

139139
#[cfg(feature = "borsh")]
140140
mod borsh_tests {
141-
use rust_decimal::Decimal;
142141
use std::str::FromStr;
143142

143+
use rust_decimal::Decimal;
144+
144145
#[test]
145146
fn it_can_serialize_deserialize_borsh() {
146147
let tests = [
@@ -163,6 +164,42 @@ mod borsh_tests {
163164
assert_eq!(test.to_string(), b.to_string());
164165
}
165166
}
167+
168+
#[test]
169+
fn invalid_flags_errors() {
170+
let mut bytes: Vec<u8> = Vec::new();
171+
// Invalid flags
172+
borsh::BorshSerialize::serialize(&u32::MAX, &mut bytes).unwrap();
173+
// high
174+
borsh::BorshSerialize::serialize(&u32::MAX, &mut bytes).unwrap();
175+
// lo
176+
borsh::BorshSerialize::serialize(&u32::MAX, &mut bytes).unwrap();
177+
// mid
178+
borsh::BorshSerialize::serialize(&u32::MAX, &mut bytes).unwrap();
179+
180+
let _err =
181+
<Decimal as borsh::BorshDeserialize>::deserialize(&mut bytes.as_slice()).expect_err("Invalid flags passed");
182+
}
183+
184+
#[test]
185+
fn invalid_scale_errors() {
186+
let mut bytes: Vec<u8> = Vec::new();
187+
// Invalid scale
188+
borsh::BorshSerialize::serialize(&0x00FF_0000_u32, &mut bytes).unwrap();
189+
// high
190+
borsh::BorshSerialize::serialize(&u32::MAX, &mut bytes).unwrap();
191+
// lo
192+
borsh::BorshSerialize::serialize(&u32::MAX, &mut bytes).unwrap();
193+
// mid
194+
borsh::BorshSerialize::serialize(&u32::MAX, &mut bytes).unwrap();
195+
196+
let err =
197+
<Decimal as borsh::BorshDeserialize>::deserialize(&mut bytes.as_slice()).expect_err("Invalid scale passed");
198+
assert_eq!(
199+
err.downcast::<rust_decimal::Error>().expect("Expected str flags error"),
200+
rust_decimal::Error::ScaleExceedsMaximumPrecision(0xFF)
201+
);
202+
}
166203
}
167204

168205
#[cfg(feature = "ndarray")]

0 commit comments

Comments
 (0)