diff --git a/kioto/streams/impl.py b/kioto/streams/impl.py index 8fcc3a7..669c878 100644 --- a/kioto/streams/impl.py +++ b/kioto/streams/impl.py @@ -15,10 +15,12 @@ from kioto.futures import pending, select, task_set from kioto.internal.queue import SlotQueue +from kioto.sink.impl import Sink T = TypeVar("T") U = TypeVar("U") +R = TypeVar("R") class _Sentinel: @@ -78,6 +80,62 @@ def switch(self, coro: Callable[[T], Awaitable[U]]) -> "Stream[U]": def debounce(self, duration: float) -> "Stream[T]": return Debounce(self, duration) + def enumerate(self) -> "Stream[tuple[int, T]]": + return Enumerate(self) + + async def unzip(self: "Stream[tuple[T, U]]") -> tuple[list[T], list[U]]: + left: list[T] = [] + right: list[U] = [] + async for l, r in self: # type: ignore[misc] + left.append(l) + right.append(r) + return left, right + + async def count(self) -> int: + n = 0 + async for _ in self: + n += 1 + return n + + def cycle(self) -> "Stream[T]": + return Cycle(self) + + async def any(self, predicate: Callable[[T], bool]) -> bool: + async for val in self: + if predicate(val): + return True + return False + + async def all(self, predicate: Callable[[T], bool]) -> bool: + async for val in self: + if not predicate(val): + return False + return True + + def scan(self, acc: U, fn: Callable[[U, T], U]) -> "Stream[U]": + return Scan(self, acc, fn) + + def skip_while(self, predicate: Callable[[T], Awaitable[bool]]) -> "Stream[T]": + return SkipWhile(self, predicate) + + def take_while(self, predicate: Callable[[T], Awaitable[bool]]) -> "Stream[T]": + return TakeWhile(self, predicate) + + def take_until(self, fut: Awaitable[R]) -> "TakeUntil[T, R]": + return TakeUntil(self, fut) + + def take(self, n: int) -> "Stream[T]": + return Take(self, n) + + def skip(self, n: int) -> "Stream[T]": + return Skip(self, n) + + async def forward(self, sink: Sink) -> None: + async for item in self: + await sink.feed(item) + await sink.flush() + await sink.close() + async def fold(self, fn: Callable[[U, T], U], acc: U) -> U: async for val in self: acc = fn(acc, val) @@ -388,6 +446,131 @@ async def __anext__(self) -> tuple[T, U]: return (await anext(self.left), await anext(self.right)) +class Enumerate(Stream[tuple[int, T]]): + def __init__(self, stream: Stream[T]): + self.stream = stream + self.index = 0 + + async def __anext__(self) -> tuple[int, T]: + val = await anext(self.stream) + idx = self.index + self.index += 1 + return idx, val + + +async def _cycle(stream: Stream[T]) -> AsyncIterator[T]: + cache: list[T] = [] + async for item in stream: + cache.append(item) + yield item + if not cache: + return + while True: + for item in cache: + yield item + + +class Cycle(Stream[T]): + def __init__(self, stream: Stream[T]): + self.stream = _cycle(stream) + + async def __anext__(self) -> T: + return await anext(self.stream) + + +class Scan(Stream[U]): + def __init__(self, stream: Stream[T], acc: U, fn: Callable[[U, T], U]): + self.stream = stream + self.acc = acc + self.fn = fn + + async def __anext__(self) -> U: + val = await anext(self.stream) + self.acc = self.fn(self.acc, val) + return self.acc + + +class SkipWhile(Stream[T]): + def __init__(self, stream: Stream[T], predicate: Callable[[T], Awaitable[bool]]): + self.stream = stream + self.predicate = predicate + self.skipping = True + + async def __anext__(self) -> T: + while True: + val = await anext(self.stream) + if self.skipping and await self.predicate(val): + continue + self.skipping = False + return val + + +class TakeWhile(Stream[T]): + def __init__(self, stream: Stream[T], predicate: Callable[[T], Awaitable[bool]]): + self.stream = stream + self.predicate = predicate + self.done = False + + async def __anext__(self) -> T: + if self.done: + raise StopAsyncIteration + val = await anext(self.stream) + if await self.predicate(val): + return val + self.done = True + raise StopAsyncIteration + + +class Take(Stream[T]): + def __init__(self, stream: Stream[T], n: int): + self.stream = stream + self.remaining = n + + async def __anext__(self) -> T: + if self.remaining <= 0: + raise StopAsyncIteration + self.remaining -= 1 + return await anext(self.stream) + + +class Skip(Stream[T]): + def __init__(self, stream: Stream[T], n: int): + self.stream = stream + self.n = n + self.skipped = False + + async def __anext__(self) -> T: + if not self.skipped: + for _ in range(self.n): + await anext(self.stream) + self.skipped = True + return await anext(self.stream) + + +class TakeUntil(Stream[T], Generic[T, R]): + def __init__(self, stream: Stream[T], stop: Awaitable[R]): + self._stream = stream + self._tasks = task_set(anext=anext(stream), stop=stop) + self._result = None + + def take_future(self): + return self._tasks._tasks.pop("stop", None) + + def take_result(self): + return self._result + + async def __anext__(self) -> T: + if self._tasks: + match await select(self._tasks): + case ("stop", result): + self._result = result + raise StopAsyncIteration + case ("anext", value): + # Re-poll the stream + self._tasks.update("anext", anext(self._stream)) + return value + + async def _switch(st: Stream[T], coro: Callable[[T], Awaitable[U]]) -> AsyncIterator[U]: # Initialize a task set, with a coroutine to fetch the next item off the stream. tasks = task_set(anext=anext(st)) diff --git a/pyproject.toml b/pyproject.toml index b8ee7fd..e26d5f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kioto" -version = "v0.1.12" +version = "v0.1.13" description = "Async utilities library inspired by Tokio" readme = "README.md" authors = [{name="Brandon Ogle", email="oglebrandon@gmail.com"}] diff --git a/tests/test_streams.py b/tests/test_streams.py index 67664a0..21c9fa6 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -4,6 +4,7 @@ from kioto import streams from kioto.channels import channel +from kioto.sink.impl import Sink @pytest.mark.asyncio @@ -523,3 +524,136 @@ async def test_stream_select_once(): # NOTE: There was a bug causing the previous task to be re-yielded assert name != "one" assert value != 1 + + +@pytest.mark.asyncio +async def test_enumerate(): + stream = streams.iter(["a", "b"]).enumerate() + assert await anext(stream) == (0, "a") + assert await stream.collect() == [(1, "b")] + + +@pytest.mark.asyncio +async def test_unzip(): + left, right = await streams.iter([(1, "a"), (2, "b")]).unzip() + assert left == [1, 2] + assert right == ["a", "b"] + + +@pytest.mark.asyncio +async def test_count(): + assert await streams.iter(range(5)).count() == 5 + + +@pytest.mark.asyncio +async def test_cycle(): + stream = streams.iter([1, 2]).cycle() + assert [await anext(stream) for _ in range(4)] == [1, 2, 1, 2] + + +@pytest.mark.asyncio +async def test_any_all(): + assert await streams.iter(range(5)).any(lambda x: x == 3) + assert await streams.iter(range(5)).all(lambda x: x < 5) + assert not await streams.iter(range(5)).all(lambda x: x < 3) + + +@pytest.mark.asyncio +async def test_scan(): + stream = streams.iter(range(1, 4)).scan(0, lambda acc, val: acc + val) + assert await stream.collect() == [1, 3, 6] + + +@pytest.mark.asyncio +async def test_skip_while(): + async def pred(x): + await asyncio.sleep(0) + return x < 3 + + stream = streams.iter(range(5)).skip_while(pred) + assert await anext(stream) == 3 + assert await stream.collect() == [4] + + +@pytest.mark.asyncio +async def test_take_while(): + async def pred(x): + await asyncio.sleep(0) + return x < 3 + + stream = streams.iter(range(5)).take_while(pred) + assert await anext(stream) == 0 + assert await stream.collect() == [1, 2] + + +@pytest.mark.asyncio +async def test_take_until(): + async def stopper(): + await asyncio.sleep(0.05) + return "done" + + @streams.async_stream + async def st(): + for i in range(10): + await asyncio.sleep(0.01) + yield i + + # Case 1: stopper interrupts the collect() + stream = st().take_until(stopper()) + assert await stream.collect() == [0, 1, 2, 3] + assert stream.take_result() == "done" + + # Future already resolved + assert stream.take_future() is None + + # Iterate the rest of the stream + assert await stream.collect() == [4, 5, 6, 7, 8, 9] + + # Case 2: remove the stopper to iterate uninterrupted + stream = st().take_until(stopper()) + assert await anext(stream) == 0 + assert await anext(stream) == 1 + stopper = stream.take_future() + assert await stream.collect() == [2, 3, 4, 5, 6, 7, 8, 9] + assert await stopper == "done" + + +@pytest.mark.asyncio +async def test_take(): + stream = streams.iter(range(5)).take(3) + assert await stream.collect() == [0, 1, 2] + + +@pytest.mark.asyncio +async def test_skip(): + stream = streams.iter(range(5)).skip(2) + assert await anext(stream) == 2 + assert await stream.collect() == [3, 4] + + +class _ForwardSink(Sink): + def __init__(self): + self.items = [] + self.flushed = False + self.closed = False + + async def feed(self, item): + self.items.append(item) + + async def send(self, item): + self.items.append(item) + + async def flush(self): + self.flushed = True + + async def close(self): + self.closed = True + + +@pytest.mark.asyncio +async def test_forward(): + sink = _ForwardSink() + await streams.iter([1, 2, 3]).forward(sink) + assert sink.items == [1, 2, 3] + assert sink.flushed + assert sink.closed diff --git a/uv.lock b/uv.lock index dcff705..1250074 100644 --- a/uv.lock +++ b/uv.lock @@ -104,7 +104,7 @@ wheels = [ [[package]] name = "kioto" -version = "0.1.12" +version = "0.1.13" source = { editable = "." } dependencies = [ { name = "aiofiles" },