Skip to content

Commit

Permalink
fix(minor-router): unambiguous message hash computation (#542)
Browse files Browse the repository at this point in the history
  • Loading branch information
milapsheth authored Jul 25, 2024
1 parent 3359fc3 commit 852e68b
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 23 deletions.
4 changes: 2 additions & 2 deletions contracts/gateway/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ mod test {

let message = Message {
cc_id: CrossChainId::new("chain", "id").unwrap(),
source_address: "source_address".parse().unwrap(),
source_address: "source-address".parse().unwrap(),
destination_chain: "destination".parse().unwrap(),
destination_address: "destination_address".parse().unwrap(),
destination_address: "destination-address".parse().unwrap(),
payload_hash: [1; 32],
};

Expand Down
8 changes: 4 additions & 4 deletions contracts/router/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ mod test {
use permission_control::Permission;
use router_api::error::Error;
use router_api::{
ChainEndpoint, ChainName, CrossChainId, GatewayDirection, Message, CHAIN_NAME_DELIMITER,
ChainEndpoint, ChainName, CrossChainId, GatewayDirection, Message, FIELD_DELIMITER,
};

use super::*;
Expand Down Expand Up @@ -638,7 +638,7 @@ mod test {

register_chain(deps.as_mut(), &eth);
register_chain(deps.as_mut(), &polygon);
let new_gateway = Addr::unchecked("new_gateway");
let new_gateway = Addr::unchecked("new-gateway");

let _ = execute(
deps.as_mut(),
Expand Down Expand Up @@ -676,7 +676,7 @@ mod test {

register_chain(deps.as_mut(), &eth);
register_chain(deps.as_mut(), &polygon);
let new_gateway = Addr::unchecked("new_gateway");
let new_gateway = Addr::unchecked("new-gateway");

let _ = execute(
deps.as_mut(),
Expand Down Expand Up @@ -791,7 +791,7 @@ mod test {
#[test]
fn invalid_chain_name() {
assert_contract_err_string_contains(
ChainName::from_str(format!("bad{}", CHAIN_NAME_DELIMITER).as_str()).unwrap_err(),
ChainName::from_str(format!("bad{}", FIELD_DELIMITER).as_str()).unwrap_err(),
Error::InvalidChainName,
);

Expand Down
12 changes: 6 additions & 6 deletions contracts/voting-verifier/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ mod test {
.map(|i| Message {
cc_id: CrossChainId::new(source_chain(), message_id("id", i, msg_id_format))
.unwrap(),
source_address: format!("source_address{i}").parse().unwrap(),
source_address: format!("source-address{i}").parse().unwrap(),
destination_chain: format!("destination-chain{i}").parse().unwrap(),
destination_address: format!("destination_address{i}").parse().unwrap(),
destination_address: format!("destination-address{i}").parse().unwrap(),
payload_hash: [0; 32],
})
.collect()
Expand Down Expand Up @@ -269,17 +269,17 @@ mod test {
Message {
cc_id: CrossChainId::new(source_chain(), message_id("id", 1, &msg_id_format))
.unwrap(),
source_address: "source_address1".parse().unwrap(),
source_address: "source-address1".parse().unwrap(),
destination_chain: "destination-chain1".parse().unwrap(),
destination_address: "destination_address1".parse().unwrap(),
destination_address: "destination-address1".parse().unwrap(),
payload_hash: [0; 32],
},
Message {
cc_id: CrossChainId::new("other-chain", message_id("id", 2, &msg_id_format))
.unwrap(),
source_address: "source_address2".parse().unwrap(),
source_address: "source-address2".parse().unwrap(),
destination_chain: "destination-chain2".parse().unwrap(),
destination_address: "destination_address2".parse().unwrap(),
destination_address: "destination-address2".parse().unwrap(),
payload_hash: [0; 32],
},
]);
Expand Down
4 changes: 2 additions & 2 deletions contracts/voting-verifier/src/contract/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ mod tests {
.as_str(),
)
.unwrap(),
source_address: format!("source_address{id}").parse().unwrap(),
source_address: format!("source-address{id}").parse().unwrap(),
destination_chain: format!("destination-chain{id}").parse().unwrap(),
destination_address: format!("destination_address{id}").parse().unwrap(),
destination_address: format!("destination-address{id}").parse().unwrap(),
payload_hash: [0; 32],
}
}
Expand Down
2 changes: 1 addition & 1 deletion contracts/voting-verifier/src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ mod test {
fn generate_msg(msg_id: nonempty::String) -> Message {
Message {
cc_id: CrossChainId::new("source-chain", msg_id).unwrap(),
source_address: "source_address".parse().unwrap(),
source_address: "source-address".parse().unwrap(),
destination_chain: "destination-chain".parse().unwrap(),
destination_address: "destination-address".parse().unwrap(),
payload_hash: [0; 32],
Expand Down
119 changes: 111 additions & 8 deletions packages/router-api/src/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use valuable::Valuable;

use crate::error::*;

pub const CHAIN_NAME_DELIMITER: char = '_';
/// Delimiter used when concatenating fields to prevent ambiguous encodings.
/// The delimiter must be prevented from being contained in values that are used as fields.
pub const FIELD_DELIMITER: char = '_';

#[cw_serde]
#[derive(Eq, Hash)]
Expand All @@ -41,11 +43,18 @@ pub struct Message {
impl Message {
pub fn hash(&self) -> Hash {
let mut hasher = Keccak256::new();
let delimiter_bytes = &[FIELD_DELIMITER as u8];

hasher.update(self.cc_id.to_string());
hasher.update(delimiter_bytes);
hasher.update(self.source_address.as_str());
hasher.update(delimiter_bytes);
hasher.update(self.destination_chain.as_ref());
hasher.update(delimiter_bytes);
hasher.update(self.destination_address.as_str());
hasher.update(delimiter_bytes);
hasher.update(self.payload_hash);

hasher.finalize().into()
}
}
Expand All @@ -68,6 +77,7 @@ impl From<Message> for Vec<Attribute> {
}

#[cw_serde]
#[serde(try_from = "String")]
#[derive(Eq, Hash)]
pub struct Address(nonempty::String);

Expand All @@ -91,6 +101,10 @@ impl TryFrom<String> for Address {
type Error = Report<Error>;

fn try_from(value: String) -> Result<Self, Self::Error> {
if value.contains(FIELD_DELIMITER) {
return Err(Report::new(Error::InvalidAddress));
}

Ok(Address(
value
.parse::<nonempty::String>()
Expand Down Expand Up @@ -150,7 +164,7 @@ impl KeyDeserialize for CrossChainId {
}
impl Display for CrossChainId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{}{}", self.chain, CHAIN_NAME_DELIMITER, *self.id)
write!(f, "{}{}{}", self.chain, FIELD_DELIMITER, *self.id)
}
}

Expand Down Expand Up @@ -274,7 +288,7 @@ impl FromStr for ChainNameRaw {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.contains(CHAIN_NAME_DELIMITER) || s.is_empty() {
if s.contains(FIELD_DELIMITER) || s.is_empty() {
return Err(Error::InvalidChainName);
}

Expand Down Expand Up @@ -430,7 +444,7 @@ mod tests {
// will cause this test to fail, indicating that a migration is needed.
fn test_message_struct_unchanged() {
let expected_message_hash =
"e8052da3a89c90468cc6e4e242a827f8579fb0ea8e298b1650d73a0f7e81abc3";
"b0c6ee811cf4c205b08e36dbbad956212c4e291aedae44ab700265477bfea526";

let msg = dummy_message();

Expand All @@ -444,7 +458,7 @@ mod tests {
#[test]
fn hash_id_unchanged() {
let expected_message_hash =
"d30a374a795454706b43259998aafa741267ecbc8b6d5771be8d7b8c9a9db263";
"e6b9cc9b6962c997b44ded605ebfb4f861e2db2ddff7e8be84a7a79728cea61e";

let msg = dummy_message();

Expand All @@ -462,7 +476,7 @@ mod tests {

assert_eq!(
"chain name is invalid",
serde_json::from_str::<ChainName>(format!("\"chain{CHAIN_NAME_DELIMITER}\"").as_str())
serde_json::from_str::<ChainName>(format!("\"chain{FIELD_DELIMITER}\"").as_str())
.unwrap_err()
.to_string()
);
Expand Down Expand Up @@ -561,6 +575,87 @@ mod tests {
}
}

#[test]
fn should_not_deserialize_invalid_address() {
assert_eq!(
"address is invalid",
serde_json::from_str::<Address>("\"\"")
.unwrap_err()
.to_string()
);

assert_eq!(
"address is invalid",
serde_json::from_str::<Address>(format!("\"address{FIELD_DELIMITER}\"").as_str())
.unwrap_err()
.to_string()
);
}

#[test]
fn ensure_address_parsing_respect_restrictions() {
struct TestCase<'a> {
input: &'a str,
can_parse: bool,
}
let random_lower = random_address().to_lowercase();
let random_upper = random_address().to_uppercase();

let test_cases = [
TestCase {
input: "",
can_parse: false,
},
TestCase {
input: "address_with_prohibited_symbols",
can_parse: false,
},
TestCase {
input: "!@#$%^&*()+=-1234567890",
can_parse: true,
},
TestCase {
input: "0x4F4495243837681061C4743b74B3eEdf548D56A5",
can_parse: true,
},
TestCase {
input: "0x4f4495243837681061c4743b74b3eedf548d56a5",
can_parse: true,
},
TestCase {
input: "GARRAOPAA5MNY3Y5V2OOYXUMBC54UDHHJTUMLRQBY2DIZKT62G5WSJP4Copy",
can_parse: true,
},
TestCase {
input: "ETHEREUM-1",
can_parse: true,
},
TestCase {
input: random_lower.as_str(),
can_parse: true,
},
TestCase {
input: random_upper.as_str(),
can_parse: true,
},
];

let conversions: [fn(&str) -> Result<Address, _>; 2] = [
|input: &str| Address::from_str(input),
|input: &str| Address::try_from(input.to_string()),
];

for case in test_cases.into_iter() {
for conversion in conversions.into_iter() {
let result = conversion(case.input);
assert_eq!(result.is_ok(), case.can_parse, "input: {}", case.input);
if case.can_parse {
assert_eq!(result.unwrap().to_string(), case.input);
}
}
}
}

#[test]
fn json_schema_for_gateway_direction_flag_set_does_not_panic() {
let gen = &mut SchemaGenerator::default();
Expand All @@ -574,9 +669,9 @@ mod tests {
fn dummy_message() -> Message {
Message {
cc_id: CrossChainId::new("chain", "hash-index").unwrap(),
source_address: "source_address".parse().unwrap(),
source_address: "source-address".parse().unwrap(),
destination_chain: "destination-chain".parse().unwrap(),
destination_address: "destination_address".parse().unwrap(),
destination_address: "destination-address".parse().unwrap(),
payload_hash: [1; 32],
}
}
Expand All @@ -588,4 +683,12 @@ mod tests {
.map(char::from)
.collect()
}

fn random_address() -> String {
thread_rng()
.sample_iter(&Alphanumeric)
.take(10)
.map(char::from)
.collect()
}
}

0 comments on commit 852e68b

Please sign in to comment.