diff --git a/aioreactive/__init__.py b/aioreactive/__init__.py index 75cf439..cec59b4 100644 --- a/aioreactive/__init__.py +++ b/aioreactive/__init__.py @@ -314,6 +314,18 @@ def merge(self, other: AsyncObservable[_TSource]) -> AsyncRx[_TSource]: AsyncRx.create, ) + def reduce( + self, accumulator: Callable[[_TResult, _TSource], _TResult], initial: _TResult + ) -> "AsyncRx[_TResult]": + return pipe(self, reduce(accumulator, initial), AsyncRx[_TResult]) + + def reduce_async( + self, + accumulator: Callable[[_TResult, _TSource], Awaitable[_TResult]], + initial: _TResult, + ) -> "AsyncRx[_TResult]": + return pipe(self, reduce_async(accumulator, initial), AsyncRx[_TResult]) + def skip(self, count: int) -> AsyncObservable[_TSource]: """Skip items from start of the stream. @@ -838,6 +850,46 @@ def of_async(workflow: Awaitable[_TSource]) -> AsyncObservable[_TSource]: return of_async(workflow) +def reduce( + accumulator: Callable[[_TResult, _TSource], _TResult], + initial: _TResult, +) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]: + """The reduce operator. + + If an error occurs either in the accumulator or the source the subscription is disposed and the error is thrown. + + Args: + accumulator: An accumulator function + initial: The initial value + + Returns: + The reduce operator function + """ + from .transform import reduce as _reduce + + return _reduce(accumulator, initial) + + +def reduce_async( + accumulator: Callable[[_TResult, _TSource], Awaitable[_TResult]], + initial: _TResult, +) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]: + """The async reduce operator. + + See the `reduce` operator. + + Args: + accumulator (Callable[[_TResult, _TSource], Awaitable[_TResult]]): An async accumulator function + initial (_TResult): The initial value + + Returns: + Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]: The operator function + """ + from .transform import reduce_async as _reduce + + return _reduce(accumulator, initial) + + def retry( retry_count: int, ) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TSource]]: diff --git a/aioreactive/transform.py b/aioreactive/transform.py index 87f60a3..8a5eeab 100644 --- a/aioreactive/transform.py +++ b/aioreactive/transform.py @@ -510,3 +510,28 @@ def scan_async( The scan operator. """ return _scan(accumulator, initial) + + +def reduce( + accumulator: Callable[[_TResult, _TSource], _TResult], + initial: _TResult, +) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]: + async def _reduce(current: _TResult, value: _TSource) -> _TResult: + return accumulator(current, value) + + def _operator(Observable: AsyncObservable[_TSource]) -> AsyncObservable[_TResult]: + return pipe(Observable, reduce_async(_reduce, initial)) + + return _operator + + +def reduce_async( + accumulator: Callable[[_TResult, _TSource], Awaitable[_TResult]], + initial: _TResult, +) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]: + def _operator(source: AsyncObservable[_TSource]) -> AsyncObservable[_TResult]: + from .filtering import take_last + + return pipe(source, scan_async(accumulator, initial), take_last(1)) + + return _operator diff --git a/tests/test_reduce.py b/tests/test_reduce.py new file mode 100644 index 0000000..068eb73 --- /dev/null +++ b/tests/test_reduce.py @@ -0,0 +1,59 @@ +import asyncio + +import pytest +from expression import pipe + +import aioreactive as rx +from aioreactive.notification import OnCompleted, OnNext +from aioreactive.testing import AsyncTestObserver, VirtualTimeEventLoop + + +class MyException(Exception): + pass + + +@pytest.fixture() +def event_loop(): + loop = VirtualTimeEventLoop() + try: + yield loop + finally: + loop.close() + + +def sync_sum(a: int, b: int) -> int: + return a + b + + +async def async_sum(a: int, b: int) -> int: + await asyncio.sleep(0.2) + return a + b + + +@pytest.mark.asyncio +async def test_reduce(): + xs = rx.from_iterable([1, 2, 3, 4]) + observer = AsyncTestObserver() + ys = pipe(xs, rx.reduce(sync_sum, 0)) + await rx.run(ys, observer) + + values = list(map(lambda t: t[1], observer.values)) + assert values == [ + OnNext(10), + OnCompleted, + ] + + +@pytest.mark.asyncio +async def test_reduce_async(): + xs = rx.from_iterable([1, 2, 3, 4]) + observer = AsyncTestObserver() + ys = pipe(xs, rx.reduce_async(async_sum, 0)) + + await rx.run(ys, observer) + + values = list(map(lambda t: t[1], observer.values)) + assert values == [ + OnNext(10), + OnCompleted, + ]