From 025a023563ada3c64d8c508b9c5d291670472b44 Mon Sep 17 00:00:00 2001 From: Ahmed TAHRI Date: Tue, 23 Apr 2024 18:22:30 +0200 Subject: [PATCH] :zap: migrate Buffer to Rust instead of pure Python --- .github/workflows/CI.yml | 6 +- CHANGELOG.rst | 6 + Cargo.lock | 2 +- Cargo.toml | 2 +- qh3/__init__.py | 2 +- qh3/_buffer.py | 156 --------------------- qh3/_hazmat.pyi | 26 ++++ qh3/buffer.py | 2 +- src/buffer.rs | 291 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 + tests/test_buffer.py | 23 +++- 11 files changed, 356 insertions(+), 166 deletions(-) delete mode 100644 qh3/_buffer.py create mode 100644 src/buffer.rs diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 34f20e59e..0b090d63c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -40,7 +40,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-13, windows-latest] python_version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', 'pypy-3.9', 'pypy-3.10'] exclude: # circumvent wierd issue with qh3.asyncio+windows+proactor loop... @@ -342,11 +342,11 @@ jobs: needs: - test - lint - runs-on: macos-latest + runs-on: macos-13 strategy: fail-fast: false matrix: - target: [ x86_64, aarch64 ] + target: [ x86_64, aarch64, universal2 ] python_version: [ '3.10', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9', 'pypy-3.10' ] steps: - uses: actions/checkout@3df4ab11eba7bda6032a0b82a6bb43b11571feac diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 22c8f1c50..7177a175d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,9 @@ +1.0.4 (2024-04-23) +===================== + +**Changed** +- Buffer management has been migrated over to Rust in order to improve the overall performance. + 1.0.3 (2024-04-20) ===================== diff --git a/Cargo.lock b/Cargo.lock index 9443770f7..a07197f97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1017,7 +1017,7 @@ dependencies = [ [[package]] name = "qh3" -version = "1.0.3" +version = "1.0.4" dependencies = [ "aes", "aws-lc-rs", diff --git a/Cargo.toml b/Cargo.toml index d14e79109..468f833c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "qh3" -version = "1.0.3" +version = "1.0.4" edition = "2021" rust-version = "1.75" license = "BSD-3" diff --git a/qh3/__init__.py b/qh3/__init__.py index 321169ab3..d1d023590 100644 --- a/qh3/__init__.py +++ b/qh3/__init__.py @@ -11,7 +11,7 @@ from .quic.packet import QuicProtocolVersion from .tls import CipherSuite, SessionTicket -__version__ = "1.0.3" +__version__ = "1.0.4" __all__ = ( "connect", diff --git a/qh3/_buffer.py b/qh3/_buffer.py deleted file mode 100644 index 79929bf3e..000000000 --- a/qh3/_buffer.py +++ /dev/null @@ -1,156 +0,0 @@ -from __future__ import annotations - -import struct - -uint16 = struct.Struct(">H") -uint32 = struct.Struct(">L") -uint64 = struct.Struct(">Q") - - -class BufferReadError(ValueError): - def __init__(self, message: str = "Read out of bounds") -> None: - super().__init__(message) - - -class BufferWriteError(ValueError): - def __init__(self, message: str = "Write out of bounds") -> None: - super().__init__(message) - - -class Buffer: - def __init__(self, capacity: int = 0, data: bytes | None = None): - self._pos = 0 - self._data = memoryview(bytearray(capacity if data is None else data)) - self._capacity = len(self._data) - - @property - def capacity(self) -> int: - return self._capacity - - @property - def data(self) -> bytes: - return bytes(self._data[0 : self._pos]) - - def data_slice(self, start: int, end: int) -> bytes: - if ( - start < 0 - or self._capacity < start - or end < 0 - or self._capacity < end - or end < start - ): - raise BufferReadError() - return bytes(self._data[start:end]) - - def eof(self) -> bool: - return self._pos == self._capacity - - def seek(self, pos: int) -> None: - if pos < 0 or pos > self._capacity: - raise BufferReadError("Seek out of bounds") - self._pos = pos - - def tell(self) -> int: - return self._pos - - def pull_bytes(self, length: int) -> bytes: - if length < 0 or self._capacity < self._pos + length: - raise BufferReadError() - result = bytes(self._data[self._pos : (self._pos + length)]) - self._pos += length - return result - - def pull_uint8(self) -> int: - try: - result = self._data[self._pos] - except IndexError: - raise BufferReadError() - self._pos += 1 - return result - - def pull_uint16(self) -> int: - try: - (result,) = uint16.unpack_from(self._data, self._pos) - except struct.error: - raise BufferReadError() - self._pos += 2 - return result - - def pull_uint32(self) -> int: - try: - (result,) = uint32.unpack_from(self._data, self._pos) - except struct.error: - raise BufferReadError() - self._pos += 4 - return result - - def pull_uint64(self) -> int: - try: - (result,) = uint64.unpack_from(self._data, self._pos) - except struct.error: - raise BufferReadError() - self._pos += 8 - return result - - def pull_uint_var(self) -> int: - try: - first = self._data[self._pos] - except IndexError: - raise BufferReadError() - type = first >> 6 - if type == 0: - self._pos += 1 - return first - elif type == 1: - return self.pull_uint16() & 0x3FFF - elif type == 2: - return self.pull_uint32() & 0x3FFFFFFF - else: - return self.pull_uint64() & 0x3FFFFFFFFFFFFFFF - - def push_bytes(self, value: bytes) -> None: - end_pos = self._pos + len(value) - if self._capacity < end_pos: - raise BufferWriteError() - self._data[self._pos : end_pos] = value - self._pos = end_pos - - def push_uint8(self, value: int) -> None: - try: - self._data[self._pos] = value - except IndexError: - raise BufferWriteError() - self._pos += 1 - - def push_uint16(self, value: int) -> None: - try: - uint16.pack_into(self._data, self._pos, value) - except struct.error: - raise BufferWriteError() - self._pos += 2 - - def push_uint32(self, value: int) -> None: - try: - uint32.pack_into(self._data, self._pos, value) - except struct.error: - raise BufferWriteError() - self._pos += 4 - - def push_uint64(self, value: int) -> None: - try: - uint64.pack_into(self._data, self._pos, value) - except struct.error: - raise BufferWriteError() - self._pos += 8 - - def push_uint_var(self, value: int) -> None: - if value <= 0x3F: - self.push_uint8(value) - elif value <= 0x3FFF: - self.push_uint16(value | 0x4000) - elif value <= 0x3FFFFFFF: - self.push_uint32(value | 0x80000000) - elif value <= 0x3FFFFFFFFFFFFFFF: - self.push_uint64(value | 0xC000000000000000) - else: - raise ValueError("Integer is too big for a variable-length integer") diff --git a/qh3/_hazmat.pyi b/qh3/_hazmat.pyi index f7534e36c..acc1622b0 100644 --- a/qh3/_hazmat.pyi +++ b/qh3/_hazmat.pyi @@ -207,3 +207,29 @@ class OCSPResponse: class OCSPRequest: def __init__(self, peer_certificate: bytes, issuer_certificate: bytes) -> None: ... def public_bytes(self) -> bytes: ... + +class BufferReadError(ValueError): ... +class BufferWriteError(ValueError): ... + +class Buffer: + def __init__(self, capacity: int = 0, data: bytes | None = None) -> None: ... + @property + def capacity(self) -> int: ... + @property + def data(self) -> bytes: ... + def data_slice(self, start: int, end: int) -> bytes: ... + def eof(self) -> bool: ... + def seek(self, pos: int) -> None: ... + def tell(self) -> int: ... + def pull_bytes(self, length: int) -> bytes: ... + def pull_uint8(self) -> int: ... + def pull_uint16(self) -> int: ... + def pull_uint32(self) -> int: ... + def pull_uint64(self) -> int: ... + def pull_uint_var(self) -> int: ... + def push_bytes(self, value: bytes) -> None: ... + def push_uint8(self, value: int) -> None: ... + def push_uint16(self, value: int) -> None: ... + def push_uint32(self, value: int) -> None: ... + def push_uint64(self, value: int) -> None: ... + def push_uint_var(self, value: int) -> None: ... diff --git a/qh3/buffer.py b/qh3/buffer.py index f1fd9feb3..856457f96 100644 --- a/qh3/buffer.py +++ b/qh3/buffer.py @@ -1,4 +1,4 @@ -from ._buffer import Buffer, BufferReadError, BufferWriteError # noqa +from ._hazmat import Buffer, BufferReadError, BufferWriteError # noqa UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF UINT_VAR_MAX_SIZE = 8 diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 000000000..444e6e4c8 --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,291 @@ +use pyo3::types::{PyBytes}; +use pyo3::{pymethods, PyResult, Python}; +use pyo3::pyclass; +use pyo3::exceptions::{PyValueError}; + +pyo3::create_exception!(_hazmat, BufferReadError, PyValueError); +pyo3::create_exception!(_hazmat, BufferWriteError, PyValueError); + + +#[pyclass(module = "qh3._hazmat")] +pub struct Buffer { + pos: u64, + data: Vec, + capacity: u64 +} + + +#[pymethods] +impl Buffer { + #[new] + pub fn py_new(capacity: Option, data: Option<&PyBytes>) -> PyResult { + if data.is_some() { + let payload = data.unwrap().as_bytes(); + return Ok( + Buffer { + pos: 0, + data: payload.to_vec(), + capacity: payload.len() as u64 + } + ); + } + + if !capacity.is_some() { + return Err( + PyValueError::new_err("mandatory capacity without data args") + ); + } + + return Ok( + Buffer { + pos: 0, + data: vec![0; capacity.unwrap().try_into().unwrap()], + capacity: capacity.unwrap(), + } + ); + } + + #[getter] + pub fn capacity(&self) -> u64 { + return self.capacity; + } + + #[getter] + pub fn data<'a>(&self, py: Python<'a>) -> &'a PyBytes { + if self.pos == 0 { + return PyBytes::new( + py, + &[] + ); + } + return PyBytes::new( + py, + &self.data[0 as usize..self.pos as usize] + ); + } + + pub fn data_slice<'a>(&self, py: Python<'a>, start: u64, end: u64) -> PyResult<&'a PyBytes> { + if self.capacity < start || self.capacity < end || end < start { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + return Ok( + PyBytes::new( + py, + &self.data[start as usize..end as usize] + ) + ); + } + + pub fn eof(&self) -> bool { + return self.pos == self.capacity; + } + + pub fn seek(&mut self, pos: u64) -> PyResult<()> { + if pos > self.capacity { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + self.pos = pos; + + return Ok(()); + } + + pub fn tell(&self) -> u64 { + return self.pos; + } + + pub fn pull_bytes<'a>(&mut self, py: Python<'a>, length: u64) -> PyResult<&'a PyBytes> { + if self.capacity < self.pos + length { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + let extract = PyBytes::new( + py, + &self.data[self.pos as usize..(self.pos+length) as usize] + ); + + self.pos += length; + + return Ok(extract); + } + + pub fn pull_uint8(&mut self) -> PyResult { + if self.eof() { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + let extract = self.data[self.pos as usize]; + self.pos += 1; + + return Ok(extract); + } + + pub fn pull_uint16(&mut self) -> PyResult { + if self.eof() { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + if self.capacity < self.pos + 2 { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + let extract = u16::from_be_bytes(self.data[self.pos as usize..(self.pos + 2) as usize].try_into().expect("failure")); + self.pos += 2; + + return Ok(extract); + } + + pub fn pull_uint32(&mut self) -> PyResult { + if self.eof() { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + if self.capacity < self.pos + 4 { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + let extract = u32::from_be_bytes(self.data[self.pos as usize..(self.pos + 4) as usize].try_into().expect("failure")); + self.pos += 4; + + return Ok(extract); + } + + pub fn pull_uint64(&mut self) -> PyResult { + if self.eof() { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + if self.capacity < self.pos + 8 { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + let extract = u64::from_be_bytes(self.data[self.pos as usize..(self.pos + 8) as usize].try_into().expect("failure")); + self.pos += 8; + + return Ok(extract); + } + + pub fn pull_uint_var(&mut self) -> PyResult { + if self.eof() { + return Err(BufferReadError::new_err("Read out of bounds")); + } + + let first = self.data[self.pos as usize]; + let var_type = first >> 6; + + if var_type == 0 { + self.pos += 1; + return Ok(first.into()); + } + + if var_type == 1 { + return match self.pull_uint16() { + Ok(val) => { + return Ok((val & 0x3FFF).into()); + }, + Err(exception) => Err(exception) + }; + } + + if var_type == 2 { + return match self.pull_uint32() { + Ok(val) => { + return Ok((val & 0x3FFFFFFF).into()); + }, + Err(exception) => Err(exception) + }; + } + + return match self.pull_uint64() { + Ok(val) => { + return Ok(val & 0x3FFFFFFFFFFFFFFF); + }, + Err(exception) => Err(exception) + }; + } + + pub fn push_bytes(&mut self, data: &PyBytes) -> PyResult<()> { + let data_to_be_pushed = data.as_bytes(); + let end_pos = self.pos + data_to_be_pushed.len() as u64; + + if self.capacity < end_pos { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + self.data[self.pos as usize..end_pos as usize].clone_from_slice(&data_to_be_pushed); + self.pos = end_pos; + + return Ok(()); + } + + pub fn push_uint8(&mut self, value: u8) -> PyResult<()> { + if self.eof() { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + self.data[self.pos as usize] = value; + self.pos += 1; + + return Ok(()); + } + + pub fn push_uint16(&mut self, value: u16) -> PyResult<()> { + if self.eof() { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + if self.capacity < self.pos + 2 { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + self.data[self.pos as usize..(self.pos + 2) as usize].clone_from_slice(&value.to_be_bytes()); + self.pos += 2; + + return Ok(()); + } + + pub fn push_uint32(&mut self, value: u32) -> PyResult<()> { + if self.eof() { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + if self.capacity < self.pos + 4 { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + self.data[self.pos as usize..(self.pos + 4) as usize].clone_from_slice(&value.to_be_bytes()); + self.pos += 4; + + return Ok(()); + } + + pub fn push_uint64(&mut self, value: u64) -> PyResult<()> { + if self.eof() { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + if self.capacity < self.pos + 8 { + return Err(BufferWriteError::new_err("Write out of bounds")); + } + + self.data[self.pos as usize..(self.pos + 8) as usize].clone_from_slice(&value.to_be_bytes()); + self.pos += 8; + + return Ok(()); + } + + pub fn push_uint_var(&mut self, value: u64) -> PyResult<()> { + if value <= 0x3F { + return self.push_uint8(value.try_into().unwrap()); + } else if value <= 0x3FFF { + return self.push_uint16((value | 0x4000).try_into().unwrap()); + } else if value <= 0x3FFFFFFF { + return self.push_uint32((value | 0x80000000).try_into().unwrap()); + } else if value <= 0x3FFFFFFFFFFFFFFF { + return self.push_uint64(value | 0xC000000000000000); + } + + return Err(PyValueError::new_err("Integer is too big for a variable-length integer")); + } +} diff --git a/src/lib.rs b/src/lib.rs index 54d4fa220..7cdb862d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ mod private_key; mod pkcs8; mod hpk; mod ocsp; +mod buffer; pub use self::headers::{QpackDecoder, QpackEncoder, StreamBlocked, EncoderStreamError, DecoderStreamError, DecompressionFailed}; pub use self::aead::{AeadChaCha20Poly1305, AeadAes128Gcm, AeadAes256Gcm}; @@ -20,6 +21,7 @@ pub use self::agreement::{X25519KeyExchange, ECDHP256KeyExchange, ECDHP384KeyExc pub use self::pkcs8::{PrivateKeyInfo, KeyType}; pub use self::hpk::{QUICHeaderProtection}; pub use self::ocsp::{OCSPResponse, OCSPCertStatus, OCSPResponseStatus, ReasonFlags, OCSPRequest}; +pub use self::buffer::{Buffer, BufferReadError, BufferWriteError}; pyo3::create_exception!(_hazmat, CryptoError, PyException); @@ -69,5 +71,9 @@ fn _hazmat(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Buffer + m.add("BufferReadError", py.get_type::())?; + m.add("BufferWriteError", py.get_type::())?; + m.add_class::()?; Ok(()) } diff --git a/tests/test_buffer.py b/tests/test_buffer.py index a67f73941..bf6d44803 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -9,7 +9,7 @@ def test_data_slice(self): self.assertEqual(buf.data_slice(0, 8), b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.data_slice(1, 3), b"\x07\x06") - with self.assertRaises(BufferReadError): + with self.assertRaises(OverflowError): buf.data_slice(-1, 3) with self.assertRaises(BufferReadError): buf.data_slice(0, 9) @@ -20,9 +20,26 @@ def test_pull_bytes(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.pull_bytes(3), b"\x08\x07\x06") + def test_internal_fixed_size(self): + buf = Buffer(8) + + buf.push_bytes(b"foobar") # push 6 bytes, 2 left free bytes + self.assertEqual(buf.data, b"foobar") + buf.seek(8) # setting cursor to the end of buf capacity + self.assertEqual(buf.data, b"foobar\x00\x00") # the two NULL bytes should be there + + def test_internal_push_zero_bytes(self): + buf = Buffer(6) + + buf.push_bytes(b"foobar") # push 6 bytes, 0 left free bytes + self.assertEqual(buf.data, b"foobar") + self.assertIsNone(buf.push_bytes(b"")) # this should not trigger any exception + with self.assertRaises(BufferWriteError): + buf.push_bytes(b"x") # this should! + def test_pull_bytes_negative(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") - with self.assertRaises(BufferReadError): + with self.assertRaises(OverflowError): buf.pull_bytes(-1) def test_pull_bytes_truncated(self): @@ -134,7 +151,7 @@ def test_seek(self): self.assertTrue(buf.eof()) self.assertEqual(buf.tell(), 8) - with self.assertRaises(BufferReadError): + with self.assertRaises(OverflowError): buf.seek(-1) self.assertEqual(buf.tell(), 8) with self.assertRaises(BufferReadError):