Skip to content

Commit

Permalink
Do not assume base64 input/output and use bytes everywhere (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatMaul authored Sep 13, 2024
1 parent 4836a8d commit 84b4341
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ crate-type = ["cdylib"]
[dependencies]
paste = "1.0.15"
thiserror = "1.0.63"
vodozemac = "0.7.0"
vodozemac = { git = "https://github.com/matrix-org/vodozemac.git", rev = "12f9036bf7f2536c172273602afcdc9aeddf8cf7" }

[package.metadata.maturin]
name = "vodozemac"
Expand Down
12 changes: 8 additions & 4 deletions src/account.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::collections::HashMap;

use pyo3::{prelude::*, types::PyType};
use pyo3::{
prelude::*,
types::{PyBytes, PyType},
};
use vodozemac::olm::SessionConfig;

use crate::{
convert_to_pybytes,
error::{LibolmPickleError, PickleError, SessionError},
types::{Curve25519PublicKey, Ed25519PublicKey, Ed25519Signature, PreKeyMessage},
};
Expand Down Expand Up @@ -70,7 +74,7 @@ impl Account {
self.inner.curve25519_key().into()
}

fn sign(&self, message: &str) -> Ed25519Signature {
fn sign(&self, message: &[u8]) -> Ed25519Signature {
self.inner.sign(message).into()
}

Expand Down Expand Up @@ -127,7 +131,7 @@ impl Account {
&mut self,
identity_key: &Curve25519PublicKey,
message: &PreKeyMessage,
) -> Result<(Session, String), SessionError> {
) -> Result<(Session, Py<PyBytes>), SessionError> {
let result = self
.inner
.create_inbound_session(identity_key.inner, &message.inner)?;
Expand All @@ -136,7 +140,7 @@ impl Account {
Session {
inner: result.session,
},
String::from_utf8(result.plaintext)?,
convert_to_pybytes(result.plaintext.as_slice()),
))
}
}
27 changes: 20 additions & 7 deletions src/group_sessions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use pyo3::{prelude::*, types::PyType};
use pyo3::{
prelude::*,
types::{PyBytes, PyType},
};
use vodozemac::megolm::SessionConfig;

use crate::{
convert_to_pybytes,
error::{LibolmPickleError, MegolmDecryptionError, PickleError, SessionKeyDecodeError},
types::{ExportedSessionKey, MegolmMessage, SessionKey},
};
Expand Down Expand Up @@ -35,7 +39,7 @@ impl GroupSession {
self.inner.session_key().into()
}

fn encrypt(&mut self, plaintext: &str) -> MegolmMessage {
fn encrypt(&mut self, plaintext: &[u8]) -> MegolmMessage {
self.inner.encrypt(plaintext).into()
}

Expand Down Expand Up @@ -67,11 +71,20 @@ impl GroupSession {
#[pyclass]
pub struct DecryptedMessage {
#[pyo3(get)]
plaintext: String,
plaintext: Py<PyBytes>,
#[pyo3(get)]
message_index: u32,
}

impl DecryptedMessage {
fn new(plaintext: &[u8], message_index: u32) -> Self {
DecryptedMessage {
plaintext: convert_to_pybytes(plaintext),
message_index,
}
}
}

#[pyclass]
pub struct InboundGroupSession {
pub(super) inner: vodozemac::megolm::InboundGroupSession,
Expand Down Expand Up @@ -122,10 +135,10 @@ impl InboundGroupSession {
) -> Result<DecryptedMessage, MegolmDecryptionError> {
let ret = self.inner.decrypt(&message.inner)?;

Ok(DecryptedMessage {
plaintext: String::from_utf8(ret.plaintext)?,
message_index: ret.message_index,
})
Ok(DecryptedMessage::new(
ret.plaintext.as_slice(),
ret.message_index,
))
}

fn pickle(&self, pickle_key: &[u8]) -> Result<String, PickleError> {
Expand Down
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod session;
mod types;

use error::*;
use pyo3::prelude::*;
use pyo3::{prelude::*, types::PyBytes};

#[pymodule(name = "vodozemac")]
fn my_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down Expand Up @@ -58,3 +58,7 @@ fn my_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {

Ok(())
}

pub(crate) fn convert_to_pybytes(bytes: &[u8]) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new_bound(py, bytes).into())
}
14 changes: 10 additions & 4 deletions src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use pyo3::{prelude::*, types::PyType};
use pyo3::{
prelude::*,
types::{PyBytes, PyType},
};

use crate::{
convert_to_pybytes,
types::{AnyOlmMessage, PreKeyMessage},
LibolmPickleError, PickleError, SessionError,
};
Expand Down Expand Up @@ -56,12 +60,14 @@ impl Session {
Ok(Self { inner: session })
}

fn encrypt(&mut self, plaintext: &str) -> AnyOlmMessage {
fn encrypt(&mut self, plaintext: &[u8]) -> AnyOlmMessage {
let message = self.inner.encrypt(plaintext);
AnyOlmMessage { inner: message }
}

fn decrypt(&mut self, message: &AnyOlmMessage) -> Result<String, SessionError> {
Ok(String::from_utf8(self.inner.decrypt(&message.inner)?)?)
fn decrypt(&mut self, message: &AnyOlmMessage) -> Result<Py<PyBytes>, SessionError> {
Ok(convert_to_pybytes(
self.inner.decrypt(&message.inner)?.as_slice(),
))
}
}
4 changes: 2 additions & 2 deletions src/types/ed25519.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ impl Ed25519PublicKey {

pub fn verify_signature(
&self,
message: &str,
message: &[u8],
signature: &Ed25519Signature,
) -> Result<(), SignatureError> {
self.inner.verify(message.as_bytes(), &signature.inner)?;
self.inner.verify(message, &signature.inner)?;

Ok(())
}
Expand Down
22 changes: 13 additions & 9 deletions src/types/messages.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::error::*;
use pyo3::{prelude::*, types::PyType};
use crate::{convert_to_pybytes, error::*};
use pyo3::{
prelude::*,
types::{PyBytes, PyType},
};

#[pyclass]
pub struct AnyOlmMessage {
Expand All @@ -9,16 +12,16 @@ pub struct AnyOlmMessage {
#[pymethods]
impl AnyOlmMessage {
#[classmethod]
pub fn pre_key(_cls: &Bound<'_, PyType>, message: &str) -> Result<Self, SessionError> {
pub fn pre_key(_cls: &Bound<'_, PyType>, message: &[u8]) -> Result<Self, SessionError> {
Ok(Self {
inner: vodozemac::olm::PreKeyMessage::from_base64(message)?.into(),
inner: vodozemac::olm::PreKeyMessage::from_bytes(message)?.into(),
})
}

#[classmethod]
pub fn normal(_cls: &Bound<'_, PyType>, message: &str) -> Result<Self, SessionError> {
pub fn normal(_cls: &Bound<'_, PyType>, message: &[u8]) -> Result<Self, SessionError> {
Ok(Self {
inner: vodozemac::olm::Message::from_base64(message)?.into(),
inner: vodozemac::olm::Message::from_bytes(message)?.into(),
})
}

Expand All @@ -36,15 +39,16 @@ impl AnyOlmMessage {
pub fn from_parts(
_cls: &Bound<'_, PyType>,
message_type: usize,
ciphertext: &str,
ciphertext: &[u8],
) -> Result<Self, DecodeError> {
Ok(Self {
inner: vodozemac::olm::OlmMessage::from_parts(message_type, ciphertext)?,
})
}

pub fn to_parts(&self) -> (usize, String) {
self.inner.clone().to_parts()
pub fn to_parts(&self) -> (usize, Py<PyBytes>) {
let (message_type, ciphertext) = self.inner.clone().to_parts();
(message_type, convert_to_pybytes(ciphertext.as_slice()))
}
}

Expand Down
6 changes: 3 additions & 3 deletions tests/account_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def test_publish_one_time_keys(self):

def test_signing(self):
alice = Account()
signature = alice.sign("This is a test")
signature = alice.sign(b"This is a test")

alice.ed25519_key.verify_signature("This is a test", signature)
alice.ed25519_key.verify_signature(b"This is a test", signature)
with pytest.raises(SignatureException):
alice.ed25519_key.verify_signature("This should fail", signature)
alice.ed25519_key.verify_signature(b"This should fail", signature)
16 changes: 8 additions & 8 deletions tests/group_session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_inbound_export(self):
imported = InboundGroupSession.import_session(
inbound.export_at(inbound.first_known_index)
)
message = imported.decrypt(outbound.encrypt("Test"))
assert message.plaintext == "Test"
message = imported.decrypt(outbound.encrypt(b"Test"))
assert message.plaintext == b"Test"
assert message.message_index == 0

def test_first_index(self):
Expand All @@ -68,24 +68,24 @@ def test_first_index(self):
def test_encrypt(self):
outbound = GroupSession()
inbound = InboundGroupSession(outbound.session_key)
message = inbound.decrypt(outbound.encrypt("Test"))
assert "Test", 0 == inbound.decrypt(outbound.encrypt("Test"))
message = inbound.decrypt(outbound.encrypt(b"Test"))
assert b"Test", 0 == inbound.decrypt(outbound.encrypt(b"Test"))

def test_decrypt_twice(self):
outbound = GroupSession()
inbound = InboundGroupSession(outbound.session_key)
outbound.encrypt("Test 1")
message = inbound.decrypt(outbound.encrypt("Test 2"))
outbound.encrypt(b"Test 1")
message = inbound.decrypt(outbound.encrypt(b"Test 2"))
assert isinstance(message.message_index, int)
assert message.message_index == 1
assert message.plaintext == "Test 2"
assert message.plaintext == b"Test 2"

def test_decrypt_failure(self):
outbound = GroupSession()
inbound = InboundGroupSession(outbound.session_key)
eve_outbound = GroupSession()
with pytest.raises(MegolmDecryptionException):
inbound.decrypt(eve_outbound.encrypt("Test"))
inbound.decrypt(eve_outbound.encrypt(b"Test"))

def test_id(self):
outbound = GroupSession()
Expand Down
18 changes: 9 additions & 9 deletions tests/session_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_wrong_passphrase_pickle(self):
Session.from_pickle(pickle, PICKLE_KEY)

def test_encrypt(self):
plaintext = "It's a secret to everybody"
plaintext = b"It's a secret to everybody"
alice, bob, session = self._create_session()
message = session.encrypt(plaintext)

Expand All @@ -69,10 +69,10 @@ def test_encrypt(self):

def test_empty_message(self):
with pytest.raises(DecodeException):
AnyOlmMessage.from_parts(0, "x")
AnyOlmMessage.from_parts(0, b"x")

def test_two_messages(self):
plaintext = "It's a secret to everybody"
plaintext = b"It's a secret to everybody"
alice, bob, session = self._create_session()
message = session.encrypt(plaintext)
message = message.to_pre_key()
Expand All @@ -82,13 +82,13 @@ def test_two_messages(self):
)
assert plaintext == decrypted

bob_plaintext = "Grumble, Grumble"
bob_plaintext = b"Grumble, Grumble"
bob_message = bob_session.encrypt(bob_plaintext)

assert bob_plaintext == session.decrypt(bob_message)

def test_matches(self):
plaintext = "It's a secret to everybody"
plaintext = b"It's a secret to everybody"
alice, bob, session = self._create_session()
message = session.encrypt(plaintext)
message = message.to_pre_key()
Expand All @@ -98,7 +98,7 @@ def test_matches(self):
)
assert plaintext == decrypted

message2 = session.encrypt("Hey! Listen!")
message2 = session.encrypt(b"Hey! Listen!")
message2 = message2.to_pre_key()

assert bob_session.session_matches(message2) is True
Expand All @@ -107,13 +107,13 @@ def test_invalid(self):
alice, bob, session = self._create_session()
_, _, another_session = self._create_session()

message = another_session.encrypt("It's a secret to everybody")
message = another_session.encrypt(b"It's a secret to everybody")
message = message.to_pre_key()

assert not session.session_matches(message)

def test_does_not_match(self):
plaintext = "It's a secret to everybody"
plaintext = b"It's a secret to everybody"
alice, bob, session = self._create_session()
message = session.encrypt(plaintext)
message = message.to_pre_key()
Expand All @@ -129,7 +129,7 @@ def test_does_not_match(self):
assert bob_session.session_matches(new_message) is False

def test_message_to_parts(self):
plaintext = "It's a secret to everybody"
plaintext = b"It's a secret to everybody"
alice, bob, session = self._create_session()
message = session.encrypt(plaintext)

Expand Down

0 comments on commit 84b4341

Please sign in to comment.