diff --git a/Cargo.toml b/Cargo.toml index 154cb3f..2a07a88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] paste = "1.0.15" thiserror = "1.0.63" -vodozemac = "0.7.0" +vodozemac = { git = "https://github.com/matrix-org/vodozemac.git" } [package.metadata.maturin] name = "vodozemac" diff --git a/src/account.rs b/src/account.rs index f64536c..3426eb0 100644 --- a/src/account.rs +++ b/src/account.rs @@ -1,9 +1,13 @@ use std::collections::HashMap; -use pyo3::{prelude::*, types::PyType}; +use pyo3::{ + prelude::*, + types::{PyBytes, PyType}, +}; use vodozemac::olm::SessionConfig; use crate::{ + convert_to_pybytes, error::{LibolmPickleError, PickleError, SessionError}, types::{Curve25519PublicKey, Ed25519PublicKey, Ed25519Signature, PreKeyMessage}, }; @@ -70,7 +74,7 @@ impl Account { self.inner.curve25519_key().into() } - fn sign(&self, message: &str) -> Ed25519Signature { + fn sign(&self, message: &[u8]) -> Ed25519Signature { self.inner.sign(message).into() } @@ -127,7 +131,7 @@ impl Account { &mut self, identity_key: &Curve25519PublicKey, message: &PreKeyMessage, - ) -> Result<(Session, String), SessionError> { + ) -> Result<(Session, Py), SessionError> { let result = self .inner .create_inbound_session(identity_key.inner, &message.inner)?; @@ -136,7 +140,7 @@ impl Account { Session { inner: result.session, }, - String::from_utf8(result.plaintext)?, + convert_to_pybytes(result.plaintext), )) } } diff --git a/src/group_sessions.rs b/src/group_sessions.rs index f4a038a..ce7390e 100644 --- a/src/group_sessions.rs +++ b/src/group_sessions.rs @@ -1,7 +1,11 @@ -use pyo3::{prelude::*, types::PyType}; +use pyo3::{ + prelude::*, + types::{PyBytes, PyType}, +}; use vodozemac::megolm::SessionConfig; use crate::{ + convert_to_pybytes, error::{LibolmPickleError, MegolmDecryptionError, PickleError, SessionKeyDecodeError}, types::{ExportedSessionKey, MegolmMessage, SessionKey}, }; @@ -35,7 +39,7 @@ impl GroupSession { self.inner.session_key().into() } - fn encrypt(&mut self, plaintext: &str) -> MegolmMessage { + fn encrypt(&mut self, plaintext: &[u8]) -> MegolmMessage { self.inner.encrypt(plaintext).into() } @@ -67,11 +71,20 @@ impl GroupSession { #[pyclass] pub struct DecryptedMessage { #[pyo3(get)] - plaintext: String, + plaintext: Py, #[pyo3(get)] message_index: u32, } +impl DecryptedMessage { + fn new(plaintext: &[u8], message_index: u32) -> Self { + DecryptedMessage { + plaintext: convert_to_pybytes(plaintext), + message_index, + } + } +} + #[pyclass] pub struct InboundGroupSession { pub(super) inner: vodozemac::megolm::InboundGroupSession, @@ -122,10 +135,10 @@ impl InboundGroupSession { ) -> Result { let ret = self.inner.decrypt(&message.inner)?; - Ok(DecryptedMessage { - plaintext: String::from_utf8(ret.plaintext)?, - message_index: ret.message_index, - }) + Ok(DecryptedMessage::new( + ret.plaintext.as_slice(), + ret.message_index, + )) } fn pickle(&self, pickle_key: &[u8]) -> Result { diff --git a/src/lib.rs b/src/lib.rs index cff6206..662ed13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ mod session; mod types; use error::*; -use pyo3::prelude::*; +use pyo3::{prelude::*, types::PyBytes}; #[pymodule(name = "vodozemac")] fn my_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -58,3 +58,7 @@ fn my_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } + +pub(crate) fn convert_to_pybytes(bytes: impl AsRef<[u8]>) -> Py { + Python::with_gil(|py| PyBytes::new_bound(py, bytes.as_ref()).into()) +} diff --git a/src/session.rs b/src/session.rs index 98100b1..02ff978 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,6 +1,10 @@ -use pyo3::{prelude::*, types::PyType}; +use pyo3::{ + prelude::*, + types::{PyBytes, PyType}, +}; use crate::{ + convert_to_pybytes, types::{AnyOlmMessage, PreKeyMessage}, LibolmPickleError, PickleError, SessionError, }; @@ -56,12 +60,12 @@ impl Session { Ok(Self { inner: session }) } - fn encrypt(&mut self, plaintext: &str) -> AnyOlmMessage { + fn encrypt(&mut self, plaintext: &[u8]) -> AnyOlmMessage { let message = self.inner.encrypt(plaintext); AnyOlmMessage { inner: message } } - fn decrypt(&mut self, message: &AnyOlmMessage) -> Result { - Ok(String::from_utf8(self.inner.decrypt(&message.inner)?)?) + fn decrypt(&mut self, message: &AnyOlmMessage) -> Result, SessionError> { + Ok(convert_to_pybytes(self.inner.decrypt(&message.inner)?)) } } diff --git a/src/types/ed25519.rs b/src/types/ed25519.rs index 13a9ad2..e0553eb 100644 --- a/src/types/ed25519.rs +++ b/src/types/ed25519.rs @@ -21,10 +21,10 @@ impl Ed25519PublicKey { pub fn verify_signature( &self, - message: &str, + message: &[u8], signature: &Ed25519Signature, ) -> Result<(), SignatureError> { - self.inner.verify(message.as_bytes(), &signature.inner)?; + self.inner.verify(message, &signature.inner)?; Ok(()) } diff --git a/src/types/messages.rs b/src/types/messages.rs index 410232a..ea9f6d5 100644 --- a/src/types/messages.rs +++ b/src/types/messages.rs @@ -1,5 +1,8 @@ -use crate::error::*; -use pyo3::{prelude::*, types::PyType}; +use crate::{convert_to_pybytes, error::*}; +use pyo3::{ + prelude::*, + types::{PyBytes, PyType}, +}; #[pyclass] pub struct AnyOlmMessage { @@ -9,16 +12,16 @@ pub struct AnyOlmMessage { #[pymethods] impl AnyOlmMessage { #[classmethod] - pub fn pre_key(_cls: &Bound<'_, PyType>, message: &str) -> Result { + pub fn pre_key(_cls: &Bound<'_, PyType>, message: &[u8]) -> Result { Ok(Self { - inner: vodozemac::olm::PreKeyMessage::from_base64(message)?.into(), + inner: vodozemac::olm::PreKeyMessage::from_bytes(message)?.into(), }) } #[classmethod] - pub fn normal(_cls: &Bound<'_, PyType>, message: &str) -> Result { + pub fn normal(_cls: &Bound<'_, PyType>, message: &[u8]) -> Result { Ok(Self { - inner: vodozemac::olm::Message::from_base64(message)?.into(), + inner: vodozemac::olm::Message::from_bytes(message)?.into(), }) } @@ -36,15 +39,16 @@ impl AnyOlmMessage { pub fn from_parts( _cls: &Bound<'_, PyType>, message_type: usize, - ciphertext: &str, + ciphertext: &[u8], ) -> Result { Ok(Self { inner: vodozemac::olm::OlmMessage::from_parts(message_type, ciphertext)?, }) } - pub fn to_parts(&self) -> (usize, String) { - self.inner.clone().to_parts() + pub fn to_parts(&self) -> (usize, Py) { + let (message_type, ciphertext) = self.inner.clone().to_parts(); + (message_type, convert_to_pybytes(ciphertext)) } } diff --git a/tests/account_test.py b/tests/account_test.py index 976c482..4e9b177 100644 --- a/tests/account_test.py +++ b/tests/account_test.py @@ -57,8 +57,8 @@ def test_publish_one_time_keys(self): def test_signing(self): alice = Account() - signature = alice.sign("This is a test") + signature = alice.sign(b"This is a test") - alice.ed25519_key.verify_signature("This is a test", signature) + alice.ed25519_key.verify_signature(b"This is a test", signature) with pytest.raises(SignatureException): - alice.ed25519_key.verify_signature("This should fail", signature) + alice.ed25519_key.verify_signature(b"This should fail", signature) diff --git a/tests/group_session_test.py b/tests/group_session_test.py index 7fcd102..5fe160b 100644 --- a/tests/group_session_test.py +++ b/tests/group_session_test.py @@ -54,8 +54,8 @@ def test_inbound_export(self): imported = InboundGroupSession.import_session( inbound.export_at(inbound.first_known_index) ) - message = imported.decrypt(outbound.encrypt("Test")) - assert message.plaintext == "Test" + message = imported.decrypt(outbound.encrypt(b"Test")) + assert message.plaintext == b"Test" assert message.message_index == 0 def test_first_index(self): @@ -68,24 +68,24 @@ def test_first_index(self): def test_encrypt(self): outbound = GroupSession() inbound = InboundGroupSession(outbound.session_key) - message = inbound.decrypt(outbound.encrypt("Test")) - assert "Test", 0 == inbound.decrypt(outbound.encrypt("Test")) + message = inbound.decrypt(outbound.encrypt(b"Test")) + assert b"Test", 0 == inbound.decrypt(outbound.encrypt(b"Test")) def test_decrypt_twice(self): outbound = GroupSession() inbound = InboundGroupSession(outbound.session_key) - outbound.encrypt("Test 1") - message = inbound.decrypt(outbound.encrypt("Test 2")) + outbound.encrypt(b"Test 1") + message = inbound.decrypt(outbound.encrypt(b"Test 2")) assert isinstance(message.message_index, int) assert message.message_index == 1 - assert message.plaintext == "Test 2" + assert message.plaintext == b"Test 2" def test_decrypt_failure(self): outbound = GroupSession() inbound = InboundGroupSession(outbound.session_key) eve_outbound = GroupSession() with pytest.raises(MegolmDecryptionException): - inbound.decrypt(eve_outbound.encrypt("Test")) + inbound.decrypt(eve_outbound.encrypt(b"Test")) def test_id(self): outbound = GroupSession() diff --git a/tests/session_test.py b/tests/session_test.py index 247e1e8..a31479c 100644 --- a/tests/session_test.py +++ b/tests/session_test.py @@ -55,7 +55,7 @@ def test_wrong_passphrase_pickle(self): Session.from_pickle(pickle, PICKLE_KEY) def test_encrypt(self): - plaintext = "It's a secret to everybody" + plaintext = b"It's a secret to everybody" alice, bob, session = self._create_session() message = session.encrypt(plaintext) @@ -69,10 +69,10 @@ def test_encrypt(self): def test_empty_message(self): with pytest.raises(DecodeException): - AnyOlmMessage.from_parts(0, "x") + AnyOlmMessage.from_parts(0, b"x") def test_two_messages(self): - plaintext = "It's a secret to everybody" + plaintext = b"It's a secret to everybody" alice, bob, session = self._create_session() message = session.encrypt(plaintext) message = message.to_pre_key() @@ -82,13 +82,13 @@ def test_two_messages(self): ) assert plaintext == decrypted - bob_plaintext = "Grumble, Grumble" + bob_plaintext = b"Grumble, Grumble" bob_message = bob_session.encrypt(bob_plaintext) assert bob_plaintext == session.decrypt(bob_message) def test_matches(self): - plaintext = "It's a secret to everybody" + plaintext = b"It's a secret to everybody" alice, bob, session = self._create_session() message = session.encrypt(plaintext) message = message.to_pre_key() @@ -98,7 +98,7 @@ def test_matches(self): ) assert plaintext == decrypted - message2 = session.encrypt("Hey! Listen!") + message2 = session.encrypt(b"Hey! Listen!") message2 = message2.to_pre_key() assert bob_session.session_matches(message2) is True @@ -107,13 +107,13 @@ def test_invalid(self): alice, bob, session = self._create_session() _, _, another_session = self._create_session() - message = another_session.encrypt("It's a secret to everybody") + message = another_session.encrypt(b"It's a secret to everybody") message = message.to_pre_key() assert not session.session_matches(message) def test_does_not_match(self): - plaintext = "It's a secret to everybody" + plaintext = b"It's a secret to everybody" alice, bob, session = self._create_session() message = session.encrypt(plaintext) message = message.to_pre_key() @@ -129,7 +129,7 @@ def test_does_not_match(self): assert bob_session.session_matches(new_message) is False def test_message_to_parts(self): - plaintext = "It's a secret to everybody" + plaintext = b"It's a secret to everybody" alice, bob, session = self._create_session() message = session.encrypt(plaintext)