diff --git a/kioto/channels/api.py b/kioto/channels/api.py index c251af1..96f2696 100644 --- a/kioto/channels/api.py +++ b/kioto/channels/api.py @@ -1,30 +1,35 @@ +from __future__ import annotations + +from typing import Awaitable, TypeVar + from kioto.channels import impl -from typing import Any + +T = TypeVar("T") -def channel(capacity: int) -> tuple[impl.Sender, impl.Receiver]: - channel = impl.Channel(capacity) +def channel(capacity: int) -> tuple[impl.Sender[T], impl.Receiver[T]]: + channel: impl.Channel[T] = impl.Channel(capacity) sender = impl.Sender(channel) receiver = impl.Receiver(channel) return sender, receiver -def channel_unbounded() -> tuple[impl.Sender, impl.Receiver]: - channel = impl.Channel(None) +def channel_unbounded() -> tuple[impl.Sender[T], impl.Receiver[T]]: + channel: impl.Channel[T] = impl.Channel(None) sender = impl.Sender(channel) receiver = impl.Receiver(channel) return sender, receiver -def oneshot_channel(): - channel = impl.OneShotChannel() +def oneshot_channel() -> tuple[impl.OneShotSender[T], Awaitable[T]]: + channel: impl.OneShotChannel[T] = impl.OneShotChannel() sender = impl.OneShotSender(channel) receiver = impl.OneShotReceiver(channel) return sender, receiver() -def watch(initial_value: Any) -> tuple[impl.WatchSender, impl.WatchReceiver]: - channel = impl.WatchChannel(initial_value) +def watch(initial_value: T) -> tuple[impl.WatchSender[T], impl.WatchReceiver[T]]: + channel: impl.WatchChannel[T] = impl.WatchChannel(initial_value) sender = impl.WatchSender(channel) receiver = impl.WatchReceiver(channel) return sender, receiver diff --git a/kioto/channels/impl.py b/kioto/channels/impl.py index bcd58ec..91d5280 100644 --- a/kioto/channels/impl.py +++ b/kioto/channels/impl.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import threading import weakref from collections import deque -from typing import Any, Callable +from typing import Callable, Deque, Generic, TypeVar, AsyncIterator, Awaitable from kioto.streams import Stream from kioto.sink import Sink @@ -12,6 +14,9 @@ from . import error +T = TypeVar("T") + + def notify_one(waiters): if waiters: tx = waiters.pop() @@ -25,11 +30,11 @@ def notify_all(waiters): tx.send(()) -def wait_for_notice(waiters): +def wait_for_notice(waiters: Deque[OneShotSender[None]]) -> Awaitable[None]: # Create a oneshot channel - channel = OneShotChannel() - sender = OneShotSender(channel) - receiver = OneShotReceiver(channel) + channel: OneShotChannel[None] = OneShotChannel() + sender: OneShotSender[None] = OneShotSender(channel) + receiver: OneShotReceiver[None] = OneShotReceiver(channel) # register the tx side for notification waiters.append(sender) @@ -433,13 +438,13 @@ async def __anext__(self) -> memoryview: raise StopAsyncIteration -class Channel: +class Channel(Generic[T]): """ Internal Channel class managing the asyncio.Queue and tracking senders and receivers. """ def __init__(self, maxsize: int | None): - self.sync_queue = deque([], maxlen=maxsize) + self.sync_queue: Deque[T] = deque([], maxlen=maxsize) self._senders = set() self._receivers = set() @@ -459,10 +464,10 @@ def capacity(self): def full(self): return self.size() == self.capacity() - def register_sender(self, sender: "Sender"): + def register_sender(self, sender: "Sender[T]"): self._senders.add(weakref.ref(sender, self.sender_dropped)) - def register_receiver(self, receiver: "Receiver"): + def register_receiver(self, receiver: "Receiver[T]"): self._receivers.add(weakref.ref(receiver, self.receiver_dropped)) def has_receivers(self) -> bool: @@ -494,16 +499,16 @@ def notify_receiver(self): notify_one(self._recv_waiters) -class Sender: +class Sender(Generic[T]): """ Sender class providing synchronous and asynchronous send methods. """ - def __init__(self, channel: Channel): + def __init__(self, channel: Channel[T]): self._channel = channel self._channel.register_sender(self) - async def send_async(self, item: Any): + async def send_async(self, item: T): """ Asynchronously send an item to the channel and wait until it's processed. @@ -525,7 +530,7 @@ async def send_async(self, item: Any): # TODO: wait for receiver notification await self._channel.wait_for_receiver() - def send(self, item: Any): + def send(self, item: T): """ Synchronously send an item to the channel. @@ -545,7 +550,7 @@ def send(self, item: Any): self._channel.sync_queue.append(item) self._channel.notify_receiver() - def into_sink(self) -> "SenderSink": + def into_sink(self) -> "SenderSink[T]": """ Convert this Sender into a SenderSink. @@ -561,16 +566,16 @@ def __deepcopy__(self, memo): raise TypeError("Sender instances cannot be deep copied.") -class Receiver: +class Receiver(Generic[T]): """ Receiver class providing synchronous and asynchronous recv methods. """ - def __init__(self, channel: Channel): + def __init__(self, channel: Channel[T]): self._channel = channel self._channel.register_receiver(self) - async def recv(self) -> Any: + async def recv(self) -> T: """ Asynchronously receive an item from the channel. @@ -592,7 +597,7 @@ async def recv(self) -> Any: await self._channel.wait_for_sender() - def into_stream(self) -> "ReceiverStream": + def into_stream(self) -> "ReceiverStream[T]": """ Convert this Receiver into a ReceiverStream. @@ -608,22 +613,22 @@ def __deepcopy__(self, memo): raise TypeError("Receiver instances cannot be deep copied.") -class SenderSink(Sink): +class SenderSink(Sink, Generic[T]): """ Sink implementation that wraps a Sender, allowing integration with Sink interfaces. """ - def __init__(self, sender: Sender): + def __init__(self, sender: Sender[T]): self._sender = sender self._channel = sender._channel self._closed = False - async def feed(self, item: Any): + async def feed(self, item: T): if self._closed: raise error.SenderSinkClosed await self._sender.send_async(item) - async def send(self, item: Any): + async def send(self, item: T): if self._closed: raise error.SenderSinkClosed await self._sender.send_async(item) @@ -638,34 +643,34 @@ async def close(self): self._closed = True -class ReceiverStream(Stream): +class ReceiverStream(Stream[T]): """ Stream implementation that wraps a Receiver, allowing integration with Stream interfaces. """ - def __init__(self, receiver: Receiver): + def __init__(self, receiver: Receiver[T]): self._receiver = receiver - async def __anext__(self): + async def __anext__(self) -> T: try: return await self._receiver.recv() except error.SendersDisconnected: raise StopAsyncIteration -class OneShotChannel(asyncio.Future): +class OneShotChannel(asyncio.Future[T]): def sender_dropped(self): if not self.done(): exception = error.SendersDisconnected self.set_exception(exception) -class OneShotSender: - def __init__(self, channel): +class OneShotSender(Generic[T]): + def __init__(self, channel: OneShotChannel[T]): self._channel = channel weakref.finalize(self, channel.sender_dropped) - def send(self, value): + def send(self, value: T): if self._channel.done(): raise error.SenderExhausted("Value has already been sent on channel") @@ -679,35 +684,35 @@ def setter(): loop.call_soon_threadsafe(setter) -class OneShotReceiver: - def __init__(self, channel): +class OneShotReceiver(Generic[T]): + def __init__(self, channel: OneShotChannel[T]): self._channel = channel - async def __call__(self): + async def __call__(self) -> T: return await self._channel -class WatchChannel: - def __init__(self, initial_value: Any): +class WatchChannel(Generic[T]): + def __init__(self, initial_value: T): # Tracks the version of the current value self._version = 0 # Deque with maxlen=1 to store the current value - self._queue = deque([initial_value], maxlen=1) + self._queue: Deque[T] = deque([initial_value], maxlen=1) self._lock = threading.Lock() - self._waiters = deque() + self._waiters: Deque[OneShotSender[None]] = deque() - self._senders = weakref.WeakSet() - self._receivers = weakref.WeakSet() + self._senders: weakref.WeakSet[WatchSender[T]] = weakref.WeakSet() + self._receivers: weakref.WeakSet[WatchReceiver[T]] = weakref.WeakSet() - def register_sender(self, sender: "WatchSender"): + def register_sender(self, sender: "WatchSender[T]"): """ Register a new sender to the channel. """ self._senders.add(sender) - def register_receiver(self, receiver: "WatchReceiver"): + def register_receiver(self, receiver: "WatchReceiver[T]"): """ Register a new receiver to the channel. """ @@ -725,7 +730,7 @@ def has_receivers(self) -> bool: """ return len(self._receivers) > 0 - def get_current_value(self) -> Any: + def get_current_value(self) -> T: """ Retrieve the current value from the channel. """ @@ -737,7 +742,7 @@ def notify(self): """ notify_all(self._waiters) - async def wait(self): + async def wait(self) -> None: # Create a oneshot channel channel = OneShotChannel() sender = OneShotSender(channel) @@ -749,7 +754,7 @@ async def wait(self): # wait for notification await receiver() - def set_value(self, value: Any): + def set_value(self, value: T): """ Set a new value in the channel and increment the version. """ @@ -759,16 +764,16 @@ def set_value(self, value: Any): self.notify() -class WatchSender: +class WatchSender(Generic[T]): """ Sender class providing methods to send and modify values in the watch channel. """ - def __init__(self, channel: WatchChannel): + def __init__(self, channel: WatchChannel[T]): self._channel = channel self._channel.register_sender(self) - def subscribe(self) -> "WatchReceiver": + def subscribe(self) -> "WatchReceiver[T]": """ Create a new receiver who is subscribed to this sender """ @@ -780,7 +785,7 @@ def receiver_count(self) -> int: """ return len(self._channel._receivers) - def send(self, value: Any): + def send(self, value: T): """ Asynchronously send a new value to the channel. @@ -795,7 +800,7 @@ def send(self, value: Any): self._channel.set_value(value) - def send_modify(self, func: Callable[[Any], Any]): + def send_modify(self, func: Callable[[T], T]): """ Modify the current value using a provided function and send the updated value. @@ -812,7 +817,7 @@ def send_modify(self, func: Callable[[Any], Any]): new_value = func(current) self._channel.set_value(new_value) - def send_if_modified(self, func: Callable[[Any], Any]): + def send_if_modified(self, func: Callable[[T], T]): """ Modify the current value using a provided function and send the updated value only if it has changed. @@ -830,7 +835,7 @@ def send_if_modified(self, func: Callable[[Any], Any]): if new_value != current: self._channel.set_value(new_value) - def borrow(self) -> Any: + def borrow(self) -> T: """ Borrow the current value without marking it as seen. @@ -840,17 +845,17 @@ def borrow(self) -> Any: return self._channel.get_current_value() -class WatchReceiver: +class WatchReceiver(Generic[T]): """ Receiver class providing methods to access and await changes in the watch channel. """ - def __init__(self, channel: WatchChannel): + def __init__(self, channel: WatchChannel[T]): self._channel = channel self._last_version = channel._version # Initialize with the current version self._channel.register_receiver(self) - def borrow(self) -> Any: + def borrow(self) -> T: """ Borrow the current value without marking it as seen. @@ -859,7 +864,7 @@ def borrow(self) -> Any: """ return self._channel.get_current_value() - def borrow_and_update(self) -> Any: + def borrow_and_update(self) -> T: """ Borrow the current value and mark it as seen. @@ -892,7 +897,7 @@ async def changed(self): # as senders would not be able to gain access to the underlying channel. await self._channel.wait() - def into_stream(self) -> "WatchReceiverStream": + def into_stream(self) -> "WatchReceiverStream[T]": """ Convert this WatchReceiver into a WatchReceiverStream. @@ -902,7 +907,7 @@ def into_stream(self) -> "WatchReceiverStream": return WatchReceiverStream(self) -async def _watch_stream(receiver): +async def _watch_stream(receiver: WatchReceiver[T]) -> AsyncIterator[T]: # Return the initial value in the watch yield receiver.borrow() @@ -915,16 +920,16 @@ async def _watch_stream(receiver): break -class WatchReceiverStream(Stream): +class WatchReceiverStream(Stream[T]): """ Stream implementation that wraps a WatchReceiver, allowing integration with Stream interfaces. """ - def __init__(self, receiver: Receiver): + def __init__(self, receiver: WatchReceiver[T]): self._stream = _watch_stream(receiver) - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[T]: return self._stream - async def __anext__(self): + async def __anext__(self) -> T: return await anext(self._stream) diff --git a/kioto/streams/api.py b/kioto/streams/api.py index 5b73436..c3f0512 100644 --- a/kioto/streams/api.py +++ b/kioto/streams/api.py @@ -1,52 +1,58 @@ +from __future__ import annotations + import functools +from typing import Any, Callable, Iterable, TypeVar, AsyncIterator from kioto import futures from kioto.streams import impl +T = TypeVar("T") +U = TypeVar("U") + # This is the python equivalent to tokio stream::iter(iterable) -def iter(iterable) -> impl.Stream: +def iter(iterable: Iterable[T]) -> impl.Stream[T]: """ Create a stream that yields values from the input iterable """ return impl.Iter(iterable) -def once(value) -> impl.Stream: +def once(value: T) -> impl.Stream[T]: """ Create a stream that yields a single value """ return impl.Once(value) -def pending() -> impl.Stream: +def pending() -> impl.Stream[Any]: """ Create that never yields a value """ return impl.Pending() -def repeat(val: any) -> impl.Stream: +def repeat(val: T) -> impl.Stream[T]: """ Create a stream which produces the same item repeatedly. """ return impl.Repeat(val) -def repeat_with(fn: callable) -> impl.Stream: +def repeat_with(fn: Callable[[], T]) -> impl.Stream[T]: """ Create a stream with produces values by repeatedly calling the input fn """ return impl.RepeatWith(fn) -def async_stream(f): +def async_stream(f: Callable[..., AsyncIterator[T]]): """ Decorator that converts an async generator function into a Stream object """ @functools.wraps(f) - def stream(*args, **kwargs) -> impl.Stream: + def stream(*args, **kwargs) -> impl.Stream[T]: # Take an async generator function and return a Stream object # that inherits all of the stream methods return impl.Stream.from_generator(f(*args, **kwargs)) @@ -59,7 +65,7 @@ def stream(*args, **kwargs) -> impl.Stream: @async_stream -async def select(**streams): +async def select(**streams: impl.Stream[Any]) -> AsyncIterator[tuple[str, Any]]: group = impl.StreamSet(streams) while group.task_set(): try: diff --git a/kioto/streams/impl.py b/kioto/streams/impl.py index 24dfad3..5342c8a 100644 --- a/kioto/streams/impl.py +++ b/kioto/streams/impl.py @@ -1,10 +1,26 @@ +from __future__ import annotations + import asyncio import builtins +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterable, + Optional, + TypeVar, + Generic, +) from kioto.futures import pending, select, task_set from kioto.internal.queue import SlotQueue +T = TypeVar("T") +U = TypeVar("U") + + class _Sentinel: """Special marker object used to signal the end of the stream.""" @@ -12,101 +28,101 @@ def __repr__(self): return "<_Sentinel>" -class Stream: - def __aiter__(self): +class Stream(Generic[T]): + def __aiter__(self) -> AsyncIterator[T]: return self @staticmethod - def from_generator(gen): + def from_generator(gen: Iterable[T] | AsyncIterator[T]) -> "Stream[T]": return _GenStream(gen) - def map(self, fn): + def map(self, fn: Callable[[T], U]) -> "Stream[U]": return Map(self, fn) - def then(self, coro): + def then(self, coro: Callable[[T], Awaitable[U]]) -> "Stream[U]": return Then(self, coro) - def filter(self, predicate): + def filter(self, predicate: Callable[[T], bool]) -> "Stream[T]": return Filter(self, predicate) - def buffered(self, n): + def buffered(self, n: int) -> "Stream[T]": return Buffered(self, n) - def buffered_unordered(self, n): + def buffered_unordered(self, n: int) -> "Stream[T]": return BufferedUnordered(self, n) - def flatten(self): + def flatten(self: "Stream[Stream[U]]") -> "Stream[U]": return Flatten(self) - def flat_map(self, fn): + def flat_map(self, fn: Callable[[T], "Stream[U]"]) -> "Stream[U]": return FlatMap(self, fn) - def chunks(self, n): + def chunks(self, n: int) -> "Stream[list[T]]": return Chunks(self, n) - def ready_chunks(self, n): + def ready_chunks(self, n: int) -> "Stream[list[T]]": return ReadyChunks(self, n) - def filter_map(self, fn): + def filter_map(self, fn: Callable[[T], Optional[U]]) -> "Stream[U]": return FilterMap(self, fn) - def chain(self, stream): + def chain(self, stream: "Stream[T]") -> "Stream[T]": return Chain(self, stream) - def zip(self, stream): + def zip(self, stream: "Stream[U]") -> "Stream[tuple[T, U]]": return Zip(self, stream) - def switch(self, coro): + def switch(self, coro: Callable[[T], Awaitable["Stream[U]"]]) -> "Stream[U]": return Switch(self, coro) - def debounce(self, duration): + def debounce(self, duration: float) -> "Stream[T]": return Debounce(self, duration) - async def fold(self, fn, acc): + async def fold(self, fn: Callable[[U, T], U], acc: U) -> U: async for val in self: acc = fn(acc, val) return acc - async def collect(self): - return [i async for i in aiter(self)] + async def collect(self) -> list[T]: + return [i async for i in builtins.aiter(self)] -class Iter(Stream): - def __init__(self, iterable): +class Iter(Stream[T]): + def __init__(self, iterable: Iterable[T]): self.iterable = builtins.iter(iterable) - async def __anext__(self): + async def __anext__(self) -> T: try: return next(self.iterable) except StopIteration: raise StopAsyncIteration -class Map(Stream): - def __init__(self, stream, fn): +class Map(Stream[U]): + def __init__(self, stream: Stream[T], fn: Callable[[T], U]): self.fn = fn self.stream = stream - async def __anext__(self): + async def __anext__(self) -> U: return self.fn(await anext(self.stream)) -class Then(Stream): - def __init__(self, stream, fn): +class Then(Stream[U]): + def __init__(self, stream: Stream[T], fn: Callable[[T], Awaitable[U]]): self.fn = fn self.stream = stream - async def __anext__(self): + async def __anext__(self) -> U: arg = await anext(self.stream) return await self.fn(arg) -class Filter(Stream): - def __init__(self, stream, predicate): +class Filter(Stream[T]): + def __init__(self, stream: Stream[T], predicate: Callable[[T], bool]): self.predicate = predicate self.stream = stream - async def __anext__(self): + async def __anext__(self) -> T: while True: val = await anext(self.stream) if self.predicate(val): @@ -162,17 +178,17 @@ async def push_tasks(): await spawner_task -class Buffered(Stream): +class Buffered(Stream[T]): """ Buffered stream that spawns tasks from an underlying stream with a specified buffer size. Results are yielded as soon as individual tasks complete. """ - def __init__(self, stream, buffer_size: int): + def __init__(self, stream: Stream[Awaitable[T]], buffer_size: int): self.stream = _buffered(stream, buffer_size) - async def __anext__(self): + async def __anext__(self) -> T: return await anext(self.stream) @@ -237,7 +253,7 @@ async def spawn_later(spawned_task): yield result -class BufferedUnordered(Stream): +class BufferedUnordered(Stream[T]): """ Stream implementation that yields results from tasks in an unordered fashion. @@ -245,48 +261,48 @@ class BufferedUnordered(Stream): As soon as any task completes, its result is yielded and its slot is freed for reuse. """ - def __init__(self, stream, buffer_size: int): + def __init__(self, stream: Stream[Awaitable[T]], buffer_size: int): self.stream = _buffered_unordered(stream, buffer_size) - async def __anext__(self): + async def __anext__(self) -> T: return await anext(self.stream) -async def _flatten(nested_st): +async def _flatten(nested_st: Stream[Stream[T]]) -> AsyncIterator[T]: async for stream in nested_st: async for val in stream: yield val -class Flatten(Stream): - def __init__(self, stream): +class Flatten(Stream[T]): + def __init__(self, stream: Stream[Stream[T]]): self.stream = _flatten(stream) - async def __anext__(self): + async def __anext__(self) -> T: return await anext(self.stream) -async def _flat_map(stream, fn): +async def _flat_map(stream: Stream[T], fn: Callable[[T], Stream[U]]) -> AsyncIterator[U]: async for stream in stream.map(fn): async for val in stream: yield val -class FlatMap(Stream): - def __init__(self, stream, fn): +class FlatMap(Stream[U]): + def __init__(self, stream: Stream[T], fn: Callable[[T], Stream[U]]): self.stream = _flat_map(stream, fn) - async def __anext__(self): + async def __anext__(self) -> U: return await anext(self.stream) -class Chunks(Stream): - def __init__(self, stream, n): +class Chunks(Stream[list[T]]): + def __init__(self, stream: Stream[T], n: int): self.stream = stream self.n = n - async def __anext__(self): - chunk = [] + async def __anext__(self) -> list[T]: + chunk: list[T] = [] for _ in range(self.n): try: chunk.append(await anext(self.stream)) @@ -301,19 +317,19 @@ async def spawn(n): queue = asyncio.Queue(maxsize=n) -class ReadyChunks(Stream): - def __init__(self, stream, n): +class ReadyChunks(Stream[list[T]]): + def __init__(self, stream: Stream[T], n: int): self.n = n self.stream = stream - self.pending = None - self.buffer = asyncio.Queue(maxsize=n) + self.pending: asyncio.Task | None = None + self.buffer: asyncio.Queue[T] = asyncio.Queue(maxsize=n) - async def push_anext(self): + async def push_anext(self) -> None: elem = await anext(self.stream) await self.buffer.put(elem) - async def __anext__(self): - chunk = [] + async def __anext__(self) -> list[T]: + chunk: list[T] = [] # Guarantee that we have at least one element in the buffer if self.pending: await self.pending @@ -334,12 +350,12 @@ async def __anext__(self): return chunk -class FilterMap(Stream): - def __init__(self, stream, fn): +class FilterMap(Stream[U]): + def __init__(self, stream: Stream[T], fn: Callable[[T], Optional[U]]): self.stream = stream self.fn = fn - async def __anext__(self): + async def __anext__(self) -> U: while True: match self.fn(await anext(self.stream)): case None: @@ -348,36 +364,36 @@ async def __anext__(self): return result -async def _chain(left, right): +async def _chain(left: Stream[T], right: Stream[T]) -> AsyncIterator[T]: async for val in left: yield val async for val in right: yield val -class Chain(Stream): - def __init__(self, left, right): +class Chain(Stream[T]): + def __init__(self, left: Stream[T], right: Stream[T]): self.stream = _chain(left, right) - async def __anext__(self): + async def __anext__(self) -> T: return await anext(self.stream) -class Zip(Stream): - def __init__(self, left, right): +class Zip(Stream[tuple[T, U]]): + def __init__(self, left: Stream[T], right: Stream[U]): self.left = left self.right = right - async def __anext__(self): + async def __anext__(self) -> tuple[T, U]: return (await anext(self.left), await anext(self.right)) -class Switch(Stream): - def __init__(self, stream, coro): +class Switch(Stream[U]): + def __init__(self, stream: Stream[T], coro: Callable[[T], Awaitable[Stream[U]]]): self.coro = coro self.stream = stream - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[U]: # Initialize a task set, with a coroutine to fetch the next item off the stream. tasks = task_set(anext=anext(self.stream)) @@ -403,14 +419,14 @@ async def __aiter__(self): yield result -class Debounce(Stream): - def __init__(self, stream, duration): +class Debounce(Stream[T]): + def __init__(self, stream: Stream[T], duration: float): self.stream = stream self.duration = duration - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[T]: # Initialize a task set with tasks to get the next elem and a delay - pending = None + pending: T | None = None tasks = task_set(anext=anext(self.stream), delay=asyncio.sleep(self.duration)) while tasks: @@ -437,47 +453,47 @@ async def __aiter__(self): yield elem -async def _once(value): +async def _once(value: T) -> AsyncIterator[T]: yield value -class Once(Stream): - def __init__(self, value): +class Once(Stream[T]): + def __init__(self, value: T): self.stream = _once(value) - async def __anext__(self): + async def __anext__(self) -> T: return await anext(self.stream) -class Pending(Stream): - async def __anext__(self): +class Pending(Stream[Any]): + async def __anext__(self) -> Any: return await pending() -class Repeat(Stream): - def __init__(self, value): +class Repeat(Stream[T]): + def __init__(self, value: T): self.value = value - async def __anext__(self): + async def __anext__(self) -> T: return self.value -class RepeatWith(Stream): - def __init__(self, fn): +class RepeatWith(Stream[T]): + def __init__(self, fn: Callable[[], T]): self.fn = fn - async def __anext__(self): + async def __anext__(self) -> T: return self.fn() -class _GenStream(Stream): - def __init__(self, gen): +class _GenStream(Stream[T]): + def __init__(self, gen: Iterable[T] | AsyncIterator[T]): if hasattr(gen, "__aiter__"): self.gen = gen else: self.gen = Iter(gen) - async def __anext__(self): + async def __anext__(self) -> T: return await anext(self.gen)