From c10f3fecaa3c7c5d3aff574b79aeeb002ba5b3b0 Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Fri, 18 Oct 2024 12:17:55 -0400 Subject: [PATCH] Add async iterator over bytes in result of `get` (#11) * Add async iterator over bytes in result of `get` * Add synchronous iteration * Implement `aiter` and `iter` on `GetResult` * Chunk size * Add python test --- .../python/object_store_rs/_get.pyi | 68 ++++++++- object-store-rs/src/get.rs | 144 ++++++++++++++++-- pyproject.toml | 1 + tests/test_get.py | 46 ++++++ tests/test_hello.py | 2 - uv.lock | 14 ++ 6 files changed, 262 insertions(+), 13 deletions(-) create mode 100644 tests/test_get.py delete mode 100644 tests/test_hello.py diff --git a/object-store-rs/python/object_store_rs/_get.pyi b/object-store-rs/python/object_store_rs/_get.pyi index 5c875b7..7d9b7dc 100644 --- a/object-store-rs/python/object_store_rs/_get.pyi +++ b/object-store-rs/python/object_store_rs/_get.pyi @@ -82,7 +82,34 @@ class GetOptions(TypedDict): """ class GetResult: - """Result for a get request""" + """Result for a get request. + + You can materialize the entire buffer by using either `bytes` or `bytes_async`, or + you can stream the result using `stream`. `__iter__` and `__aiter__` are implemented + as aliases to `stream`, so you can alternatively call `iter()` or `aiter()` on + `GetResult` to start an iterator. + + Using as an async iterator: + ```py + resp = await obs.get_async(store, path) + # 5MB chunk size in stream + stream = resp.stream(min_chunk_size=5 * 1024 * 1024) + async for buf in stream: + print(len(buf)) + ``` + + Using as a sync iterator: + ```py + resp = obs.get(store, path) + # 20MB chunk size in stream + stream = resp.stream(min_chunk_size=20 * 1024 * 1024) + for buf in stream: + print(len(buf)) + ``` + + Note that after calling `bytes`, `bytes_async`, or `stream`, you will no longer be + able to call other methods on this object, such as the `meta` attribute. + """ def bytes(self) -> bytes: """ @@ -98,6 +125,45 @@ class GetResult: def meta(self) -> ObjectMeta: """The ObjectMeta for this object""" + def stream(self, min_chunk_size: int = 10 * 1024 * 1024) -> BytesStream: + """Return a chunked stream over the result's bytes. + + Args: + min_chunk_size: The minimum size in bytes for each chunk in the returned + `BytesStream`. All chunks except for the last chunk will be at least + this size. Defaults to 10*1024*1024 (10MB). + + Returns: + A chunked stream + """ + + def __aiter__(self) -> BytesStream: + """ + Return a chunked stream over the result's bytes with the default (10MB) chunk + size. + """ + + def __iter__(self) -> BytesStream: + """ + Return a chunked stream over the result's bytes with the default (10MB) chunk + size. + """ + +class BytesStream: + """An async stream of bytes.""" + + def __aiter__(self) -> BytesStream: + """Return `Self` as an async iterator.""" + + def __iter__(self) -> BytesStream: + """Return `Self` as an async iterator.""" + + async def __anext__(self) -> bytes: + """Return the next chunk of bytes in the stream.""" + + def __next__(self) -> bytes: + """Return the next chunk of bytes in the stream.""" + def get( store: ObjectStore, location: str, *, options: GetOptions | None = None ) -> GetResult: diff --git a/object-store-rs/src/get.rs b/object-store-rs/src/get.rs index f417659..e6c29a6 100644 --- a/object-store-rs/src/get.rs +++ b/object-store-rs/src/get.rs @@ -1,14 +1,23 @@ +use std::sync::Arc; + +use bytes::Bytes; use chrono::{DateTime, Utc}; +use futures::stream::BoxStream; +use futures::StreamExt; use object_store::{GetOptions, GetResult, ObjectStore}; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3_object_store::error::{PyObjectStoreError, PyObjectStoreResult}; use pyo3_object_store::PyObjectStore; +use tokio::sync::Mutex; use crate::list::PyObjectMeta; use crate::runtime::get_runtime; +/// 10MB default chunk size +const DEFAULT_BYTES_CHUNK_SIZE: usize = 10 * 1024 * 1024; + #[derive(FromPyObject)] pub(crate) struct PyGetOptions { if_match: Option, @@ -54,7 +63,7 @@ impl PyGetResult { let runtime = get_runtime(py)?; py.allow_threads(|| { let bytes = runtime.block_on(get_result.bytes())?; - Ok::<_, PyObjectStoreError>(PyBytesWrapper(bytes)) + Ok::<_, PyObjectStoreError>(PyBytesWrapper::new(bytes)) }) } @@ -68,7 +77,7 @@ impl PyGetResult { .bytes() .await .map_err(PyObjectStoreError::ObjectStoreError)?; - Ok(PyBytesWrapper(bytes)) + Ok(PyBytesWrapper::new(bytes)) }) } @@ -80,14 +89,129 @@ impl PyGetResult { .ok_or(PyValueError::new_err("Result has already been disposed."))?; Ok(PyObjectMeta::new(inner.meta.clone())) } + + #[pyo3(signature = (min_chunk_size = DEFAULT_BYTES_CHUNK_SIZE))] + fn stream(&mut self, min_chunk_size: usize) -> PyResult { + let get_result = self + .0 + .take() + .ok_or(PyValueError::new_err("Result has already been disposed."))?; + Ok(PyBytesStream::new(get_result.into_stream(), min_chunk_size)) + } + + fn __aiter__(&mut self) -> PyResult { + self.stream(DEFAULT_BYTES_CHUNK_SIZE) + } + + fn __iter__(&mut self) -> PyResult { + self.stream(DEFAULT_BYTES_CHUNK_SIZE) + } +} + +#[pyclass(name = "BytesStream")] +pub struct PyBytesStream { + stream: Arc>>>, + min_chunk_size: usize, +} + +impl PyBytesStream { + fn new(stream: BoxStream<'static, object_store::Result>, min_chunk_size: usize) -> Self { + Self { + stream: Arc::new(Mutex::new(stream)), + min_chunk_size, + } + } +} + +async fn next_stream( + stream: Arc>>>, + min_chunk_size: usize, + sync: bool, +) -> PyResult { + let mut stream = stream.lock().await; + let mut buffers: Vec = vec![]; + loop { + match stream.next().await { + Some(Ok(bytes)) => { + buffers.push(bytes); + let total_buffer_len = buffers.iter().fold(0, |acc, buf| acc + buf.len()); + if total_buffer_len >= min_chunk_size { + return Ok(PyBytesWrapper::new_multiple(buffers)); + } + } + Some(Err(e)) => return Err(PyObjectStoreError::from(e).into()), + None => { + if buffers.is_empty() { + // Depending on whether the iteration is sync or not, we raise either a + // StopIteration or a StopAsyncIteration + if sync { + return Err(PyStopIteration::new_err("stream exhausted")); + } else { + return Err(PyStopAsyncIteration::new_err("stream exhausted")); + } + } else { + return Ok(PyBytesWrapper::new_multiple(buffers)); + } + } + }; + } +} + +#[pymethods] +impl PyBytesStream { + fn __aiter__(slf: Py) -> Py { + slf + } + + fn __iter__(slf: Py) -> Py { + slf + } + + fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult> { + let stream = self.stream.clone(); + pyo3_async_runtimes::tokio::future_into_py( + py, + next_stream(stream, self.min_chunk_size, false), + ) + } + + fn __next__<'py>(&'py self, py: Python<'py>) -> PyResult { + let runtime = get_runtime(py)?; + let stream = self.stream.clone(); + runtime.block_on(next_stream(stream, self.min_chunk_size, true)) + } } -pub(crate) struct PyBytesWrapper(bytes::Bytes); +pub(crate) struct PyBytesWrapper(Vec); + +impl PyBytesWrapper { + pub fn new(buf: Bytes) -> Self { + Self(vec![buf]) + } -// TODO: return buffer protocol object + pub fn new_multiple(buffers: Vec) -> Self { + Self(buffers) + } +} + +// TODO: return buffer protocol object? This isn't possible on an array of Bytes, so if you want to +// support the buffer protocol in the future (e.g. for get_range) you may need to have a separate +// wrapper of Bytes impl IntoPy for PyBytesWrapper { fn into_py(self, py: Python<'_>) -> PyObject { - PyBytes::new_bound(py, &self.0).into_py(py) + let total_len = self.0.iter().fold(0, |acc, buf| acc + buf.len()); + // Copy all internal Bytes objects into a single PyBytes + // Since our inner callback is infallible, this will only panic on out of memory + PyBytes::new_bound_with(py, total_len, |target| { + let mut offset = 0; + for buf in self.0.iter() { + target[offset..offset + buf.len()].copy_from_slice(buf); + offset += buf.len(); + } + Ok(()) + }) + .unwrap() + .into_py(py) } } @@ -144,7 +268,7 @@ pub(crate) fn get_range( let range = offset..offset + length; py.allow_threads(|| { let out = runtime.block_on(store.as_ref().get_range(&location.into(), range))?; - Ok::<_, PyObjectStoreError>(PyBytesWrapper(out)) + Ok::<_, PyObjectStoreError>(PyBytesWrapper::new(out)) }) } @@ -163,7 +287,7 @@ pub(crate) fn get_range_async( .get_range(&location.into(), range) .await .map_err(PyObjectStoreError::ObjectStoreError)?; - Ok(PyBytesWrapper(out)) + Ok(PyBytesWrapper::new(out)) }) } @@ -183,7 +307,7 @@ pub(crate) fn get_ranges( .collect::>(); py.allow_threads(|| { let out = runtime.block_on(store.as_ref().get_ranges(&location.into(), &ranges))?; - Ok::<_, PyObjectStoreError>(out.into_iter().map(PyBytesWrapper).collect()) + Ok::<_, PyObjectStoreError>(out.into_iter().map(PyBytesWrapper::new).collect()) }) } @@ -206,6 +330,6 @@ pub(crate) fn get_ranges_async( .get_ranges(&location.into(), &ranges) .await .map_err(PyObjectStoreError::ObjectStoreError)?; - Ok(out.into_iter().map(PyBytesWrapper).collect::>()) + Ok(out.into_iter().map(PyBytesWrapper::new).collect::>()) }) } diff --git a/pyproject.toml b/pyproject.toml index 6027cda..326f408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dev-dependencies = [ "mkdocstrings[python]>=0.26.1", "pandas>=2.2.3", "pip>=24.2", + "pytest-asyncio>=0.24.0", "pytest>=8.3.3", ] diff --git a/tests/test_get.py b/tests/test_get.py new file mode 100644 index 0000000..a1721f7 --- /dev/null +++ b/tests/test_get.py @@ -0,0 +1,46 @@ +import object_store_rs as obs +import pytest +from object_store_rs.store import MemoryStore + + +def test_stream_sync(): + store = MemoryStore() + + data = b"the quick brown fox jumps over the lazy dog," * 5000 + path = "big-data.txt" + + obs.put_file(store, path, data) + resp = obs.get(store, path) + stream = resp.stream(min_chunk_size=0) + + # Note: it looks from manual testing that with the local store we're only getting + # one chunk and not able to test the chunk sizing. + pos = 0 + for chunk in stream: + size = len(chunk) + assert chunk == data[pos : pos + size] + pos += size + + assert pos == len(data) + + +@pytest.mark.asyncio +async def test_stream_async(): + store = MemoryStore() + + data = b"the quick brown fox jumps over the lazy dog," * 5000 + path = "big-data.txt" + + await obs.put_file_async(store, path, data) + resp = await obs.get_async(store, path) + stream = resp.stream(min_chunk_size=0) + + # Note: it looks from manual testing that with the local store we're only getting + # one chunk and not able to test the chunk sizing. + pos = 0 + async for chunk in stream: + size = len(chunk) + assert chunk == data[pos : pos + size] + pos += size + + assert pos == len(data) diff --git a/tests/test_hello.py b/tests/test_hello.py deleted file mode 100644 index 03dd85e..0000000 --- a/tests/test_hello.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_hello_world(): - pass diff --git a/uv.lock b/uv.lock index 496a2db..46dd163 100644 --- a/uv.lock +++ b/uv.lock @@ -1071,6 +1071,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/6d/c6cf50ce320cf8611df7a1254d86233b3df7cc07f9b5f5cbcb82e08aa534/pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276", size = 49855 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/31/6607dab48616902f76885dfcf62c08d929796fc3b2d2318faf9fd54dbed9/pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b", size = 18024 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1333,6 +1345,7 @@ dev = [ { name = "pandas" }, { name = "pip" }, { name = "pytest" }, + { name = "pytest-asyncio" }, ] [package.metadata] @@ -1351,6 +1364,7 @@ dev = [ { name = "pandas", specifier = ">=2.2.3" }, { name = "pip", specifier = ">=24.2" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, ] [[package]]