Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions kioto/streams/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

from kioto.futures import pending, select, task_set
from kioto.internal.queue import SlotQueue
from kioto.sink.impl import Sink


T = TypeVar("T")
U = TypeVar("U")
R = TypeVar("R")


class _Sentinel:
Expand Down Expand Up @@ -78,6 +80,62 @@ def switch(self, coro: Callable[[T], Awaitable[U]]) -> "Stream[U]":
def debounce(self, duration: float) -> "Stream[T]":
return Debounce(self, duration)

def enumerate(self) -> "Stream[tuple[int, T]]":
return Enumerate(self)

async def unzip(self: "Stream[tuple[T, U]]") -> tuple[list[T], list[U]]:
left: list[T] = []
right: list[U] = []
async for l, r in self: # type: ignore[misc]
left.append(l)
right.append(r)
return left, right

async def count(self) -> int:
n = 0
async for _ in self:
n += 1
return n

def cycle(self) -> "Stream[T]":
return Cycle(self)

async def any(self, predicate: Callable[[T], bool]) -> bool:
async for val in self:
if predicate(val):
return True
return False

async def all(self, predicate: Callable[[T], bool]) -> bool:
async for val in self:
if not predicate(val):
return False
return True

def scan(self, acc: U, fn: Callable[[U, T], U]) -> "Stream[U]":
return Scan(self, acc, fn)

def skip_while(self, predicate: Callable[[T], Awaitable[bool]]) -> "Stream[T]":
return SkipWhile(self, predicate)

def take_while(self, predicate: Callable[[T], Awaitable[bool]]) -> "Stream[T]":
return TakeWhile(self, predicate)

def take_until(self, fut: Awaitable[R]) -> "TakeUntil[T, R]":
return TakeUntil(self, fut)

def take(self, n: int) -> "Stream[T]":
return Take(self, n)

def skip(self, n: int) -> "Stream[T]":
return Skip(self, n)

async def forward(self, sink: Sink) -> None:
async for item in self:
await sink.feed(item)
await sink.flush()
await sink.close()

async def fold(self, fn: Callable[[U, T], U], acc: U) -> U:
async for val in self:
acc = fn(acc, val)
Expand Down Expand Up @@ -388,6 +446,131 @@ async def __anext__(self) -> tuple[T, U]:
return (await anext(self.left), await anext(self.right))


class Enumerate(Stream[tuple[int, T]]):
def __init__(self, stream: Stream[T]):
self.stream = stream
self.index = 0

async def __anext__(self) -> tuple[int, T]:
val = await anext(self.stream)
idx = self.index
self.index += 1
return idx, val


async def _cycle(stream: Stream[T]) -> AsyncIterator[T]:
cache: list[T] = []
async for item in stream:
cache.append(item)
yield item
if not cache:
return
while True:
for item in cache:
yield item


class Cycle(Stream[T]):
def __init__(self, stream: Stream[T]):
self.stream = _cycle(stream)

async def __anext__(self) -> T:
return await anext(self.stream)


class Scan(Stream[U]):
def __init__(self, stream: Stream[T], acc: U, fn: Callable[[U, T], U]):
self.stream = stream
self.acc = acc
self.fn = fn

async def __anext__(self) -> U:
val = await anext(self.stream)
self.acc = self.fn(self.acc, val)
return self.acc


class SkipWhile(Stream[T]):
def __init__(self, stream: Stream[T], predicate: Callable[[T], Awaitable[bool]]):
self.stream = stream
self.predicate = predicate
self.skipping = True

async def __anext__(self) -> T:
while True:
val = await anext(self.stream)
if self.skipping and await self.predicate(val):
continue
self.skipping = False
return val


class TakeWhile(Stream[T]):
def __init__(self, stream: Stream[T], predicate: Callable[[T], Awaitable[bool]]):
self.stream = stream
self.predicate = predicate
self.done = False

async def __anext__(self) -> T:
if self.done:
raise StopAsyncIteration
val = await anext(self.stream)
if await self.predicate(val):
return val
self.done = True
raise StopAsyncIteration


class Take(Stream[T]):
def __init__(self, stream: Stream[T], n: int):
self.stream = stream
self.remaining = n

async def __anext__(self) -> T:
if self.remaining <= 0:
raise StopAsyncIteration
self.remaining -= 1
return await anext(self.stream)


class Skip(Stream[T]):
def __init__(self, stream: Stream[T], n: int):
self.stream = stream
self.n = n
self.skipped = False

async def __anext__(self) -> T:
if not self.skipped:
for _ in range(self.n):
await anext(self.stream)
self.skipped = True
return await anext(self.stream)


class TakeUntil(Stream[T], Generic[T, R]):
def __init__(self, stream: Stream[T], stop: Awaitable[R]):
self._stream = stream
self._tasks = task_set(anext=anext(stream), stop=stop)
self._result = None

def take_future(self):
return self._tasks._tasks.pop("stop", None)

def take_result(self):
return self._result

async def __anext__(self) -> T:
if self._tasks:
match await select(self._tasks):
case ("stop", result):
self._result = result
raise StopAsyncIteration
case ("anext", value):
# Re-poll the stream
self._tasks.update("anext", anext(self._stream))
return value


async def _switch(st: Stream[T], coro: Callable[[T], Awaitable[U]]) -> AsyncIterator[U]:
# Initialize a task set, with a coroutine to fetch the next item off the stream.
tasks = task_set(anext=anext(st))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kioto"
version = "v0.1.12"
version = "v0.1.13"
description = "Async utilities library inspired by Tokio"
readme = "README.md"
authors = [{name="Brandon Ogle", email="oglebrandon@gmail.com"}]
Expand Down
134 changes: 134 additions & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from kioto import streams
from kioto.channels import channel
from kioto.sink.impl import Sink


@pytest.mark.asyncio
Expand Down Expand Up @@ -523,3 +524,136 @@ async def test_stream_select_once():
# NOTE: There was a bug causing the previous task to be re-yielded
assert name != "one"
assert value != 1


@pytest.mark.asyncio
async def test_enumerate():
stream = streams.iter(["a", "b"]).enumerate()
assert await anext(stream) == (0, "a")
assert await stream.collect() == [(1, "b")]


@pytest.mark.asyncio
async def test_unzip():
left, right = await streams.iter([(1, "a"), (2, "b")]).unzip()
assert left == [1, 2]
assert right == ["a", "b"]


@pytest.mark.asyncio
async def test_count():
assert await streams.iter(range(5)).count() == 5


@pytest.mark.asyncio
async def test_cycle():
stream = streams.iter([1, 2]).cycle()
assert [await anext(stream) for _ in range(4)] == [1, 2, 1, 2]


@pytest.mark.asyncio
async def test_any_all():
assert await streams.iter(range(5)).any(lambda x: x == 3)
assert await streams.iter(range(5)).all(lambda x: x < 5)
assert not await streams.iter(range(5)).all(lambda x: x < 3)


@pytest.mark.asyncio
async def test_scan():
stream = streams.iter(range(1, 4)).scan(0, lambda acc, val: acc + val)
assert await stream.collect() == [1, 3, 6]


@pytest.mark.asyncio
async def test_skip_while():
async def pred(x):
await asyncio.sleep(0)
return x < 3

stream = streams.iter(range(5)).skip_while(pred)
assert await anext(stream) == 3
assert await stream.collect() == [4]


@pytest.mark.asyncio
async def test_take_while():
async def pred(x):
await asyncio.sleep(0)
return x < 3

stream = streams.iter(range(5)).take_while(pred)
assert await anext(stream) == 0
assert await stream.collect() == [1, 2]


@pytest.mark.asyncio
async def test_take_until():
async def stopper():
await asyncio.sleep(0.05)
return "done"

@streams.async_stream
async def st():
for i in range(10):
await asyncio.sleep(0.01)
yield i

# Case 1: stopper interrupts the collect()
stream = st().take_until(stopper())
assert await stream.collect() == [0, 1, 2, 3]
assert stream.take_result() == "done"

# Future already resolved
assert stream.take_future() is None

# Iterate the rest of the stream
assert await stream.collect() == [4, 5, 6, 7, 8, 9]

# Case 2: remove the stopper to iterate uninterrupted
stream = st().take_until(stopper())
assert await anext(stream) == 0
assert await anext(stream) == 1
stopper = stream.take_future()
assert await stream.collect() == [2, 3, 4, 5, 6, 7, 8, 9]
assert await stopper == "done"


@pytest.mark.asyncio
async def test_take():
stream = streams.iter(range(5)).take(3)
assert await stream.collect() == [0, 1, 2]


@pytest.mark.asyncio
async def test_skip():
stream = streams.iter(range(5)).skip(2)
assert await anext(stream) == 2
assert await stream.collect() == [3, 4]


class _ForwardSink(Sink):
def __init__(self):
self.items = []
self.flushed = False
self.closed = False

async def feed(self, item):
self.items.append(item)

async def send(self, item):
self.items.append(item)

async def flush(self):
self.flushed = True

async def close(self):
self.closed = True


@pytest.mark.asyncio
async def test_forward():
sink = _ForwardSink()
await streams.iter([1, 2, 3]).forward(sink)
assert sink.items == [1, 2, 3]
assert sink.flushed
assert sink.closed
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.