diff --git a/utils/key-utils/src/lib.rs b/utils/key-utils/src/lib.rs index 2c743012ce..2d0a4ff04d 100644 --- a/utils/key-utils/src/lib.rs +++ b/utils/key-utils/src/lib.rs @@ -2,18 +2,29 @@ use bs58::{decode, decode::Error as Bs58DecodeError}; use core::convert::TryFrom; use secp256k1::{SecretKey, XOnlyPublicKey}; use serde::{Deserialize, Serialize}; -use std::fmt::Display; +use std::fmt::{Display, Write}; +use std::str::FromStr; #[derive(Debug)] pub enum Error { Bs58Decode(Bs58DecodeError), Secp256k1(secp256k1::Error), - Custom, + KeyVersion(u16), + KeyLength, + Custom(String), } impl Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Key Utils Error") + match self { + Self::Bs58Decode(error) => write!(f, "Base58 code error: {error}"), + Self::Secp256k1(error) => write!(f, "Secp256k1 error: {error}"), + Self::KeyVersion(obtained) => { + write!(f, "Unknown public key version. version found: {obtained}") + } + Self::KeyLength => write!(f, "Bad key length"), + Self::Custom(error) => write!(f, "Custom error: {error}"), + } } } @@ -29,7 +40,7 @@ impl From for Error { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] #[serde(into = "String", try_from = "String")] pub struct Secp256k1SecretKey(pub SecretKey); @@ -37,6 +48,14 @@ impl TryFrom for Secp256k1SecretKey { type Error = Error; fn try_from(value: String) -> Result { + value.parse() + } +} + +impl FromStr for Secp256k1SecretKey { + type Err = Error; + + fn from_str(value: &str) -> Result { let decoded = decode(value).with_check(None).into_vec()?; let secret = SecretKey::from_slice(&decoded)?; Ok(Secp256k1SecretKey(secret)) @@ -45,12 +64,18 @@ impl TryFrom for Secp256k1SecretKey { impl From for String { fn from(secret: Secp256k1SecretKey) -> Self { - let bytes = secret.0.secret_bytes(); - bs58::encode(bytes).with_check().into_string() + secret.to_string() } } -#[derive(Debug, Clone, Serialize, Deserialize)] +impl Display for Secp256k1SecretKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let bytes = self.0.secret_bytes(); + f.write_str(&bs58::encode(bytes).with_check().into_string()) + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] #[serde(into = "String", try_from = "String")] pub struct Secp256k1PublicKey(pub XOnlyPublicKey); @@ -58,14 +83,24 @@ impl TryFrom for Secp256k1PublicKey { type Error = Error; fn try_from(value: String) -> Result { + value.parse() + } +} + +impl FromStr for Secp256k1PublicKey { + type Err = Error; + + fn from_str(value: &str) -> Result { let decoded = decode(value).with_check(None).into_vec()?; if decoded.len() < 34 { - return Err(Error::Custom); + return Err(Error::KeyLength); } - if decoded[..2] != [1, 0] { - return Err(Error::Custom); + let key_version = + u16::from_le_bytes(decoded[..2].try_into().expect("Invalid array length")); + if key_version != 1 { + return Err(Error::KeyVersion(key_version)); } - let public = XOnlyPublicKey::from_slice(&decoded[2..]).expect("Invalid public key"); + let public = XOnlyPublicKey::from_slice(&decoded[2..]).map_err(Error::Secp256k1)?; Ok(Secp256k1PublicKey(public)) } } @@ -97,28 +132,63 @@ impl Secp256k1SecretKey { } } +impl From for Secp256k1PublicKey { + fn from(value: Secp256k1SecretKey) -> Self { + let context = secp256k1::Secp256k1::new(); + let (x_coordinate, _) = value.0.public_key(&context).x_only_public_key(); + Self(x_coordinate) + } +} + #[cfg(test)] mod test { use super::*; - - #[derive(Serialize, Deserialize)] - struct Test { - public_key: Secp256k1PublicKey, - secret_key: Secp256k1SecretKey, - } + use secp256k1::rand; + use secp256k1::rand::Rng; #[test] fn deserialize_serialize_toml() { - let pub_ = "3VANfft6ei6jQq1At7d8nmiZzVhBFS4CiQujdgim1ign"; - let secr = "7qbpUjScc865jyX2kiB4NVJANoC7GA7TAJupdzXWkc62"; - let string = r#" - public_key = "3VANfft6ei6jQq1At7d8nmiZzVhBFS4CiQujdgim1ign" - secret_key = "7qbpUjScc865jyX2kiB4NVJANoC7GA7TAJupdzXWkc62" - "#; - let test: Test = toml::from_str(&string).unwrap(); - let ser_p: String = test.public_key.try_into().unwrap(); - let ser_s: String = test.secret_key.try_into().unwrap(); - assert_eq!(ser_p, pub_); - assert_eq!(ser_s, secr); + let secret_key = "zmBEmPhqo3A92FkiLVvyCz6htc3e53ph3ZbD4ASqGaLjwnFLi"; + let public_key = "9bDuixKmZqAJnrmP746n8zU1wyAQRrus7th9dxnkPg6RzQvCnan"; + let bad_public_key1 = "9bDuixKmZqAJnrmP746n8zU1wyAQRrus7th9dxnkPg6RzQvCnam"; // invalid checksum (swapped char) + let bad_public_key2 = "2myPhc5vkPzuC5FXNK5tee79WmP7uoLh55SxezoF8iqwF3E3rnPY"; // invalid version (version 12) + let bad_public_key3 = "2wmHTKZkLg2QzXyEXGMBXzKP7JXDUt8yy9SA5hoQwERc92qR6c"; // invalid length (1 B missing) + + let error = bad_public_key1 + .parse::() + .expect_err("Bad bud public key failed to raise error"); + assert!( + matches!(error, Error::Bs58Decode(_)), + "expected failed checksum error, got {}", + error + ); + let error = bad_public_key2 + .parse::() + .expect_err("Bad bud public key failed to raise error"); + assert!( + matches!(error, Error::KeyVersion(_)), + "expected invalid key version error, got {}", + error + ); + let error = bad_public_key3 + .parse::() + .expect_err("Bad bud public key failed to raise error"); + assert!( + matches!(error, Error::KeyLength), + "expected invalid key length error, got {}", + error + ); + + let parsed_key = secret_key + .parse::() + .expect("Invalid test key"); + + let calculated_public_key = Secp256k1PublicKey::from(parsed_key); + assert_eq!(calculated_public_key.to_string(), public_key); + + let parsed_public_key = public_key + .parse::() + .expect("Invalid test pubkey"); + assert_eq!(calculated_public_key.0, parsed_public_key.0); } }