Skip to content

Commit

Permalink
Merge pull request #37 from francipvb/features/reduce-operator
Browse files Browse the repository at this point in the history
Implemented the `reduce` and `reduce_async` operators
  • Loading branch information
dbrattli authored Jan 14, 2024
2 parents 5e22a31 + 92de0d8 commit 51608bd
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
52 changes: 52 additions & 0 deletions aioreactive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,18 @@ def merge(self, other: AsyncObservable[_TSource]) -> AsyncRx[_TSource]:
AsyncRx.create,
)

def reduce(
self, accumulator: Callable[[_TResult, _TSource], _TResult], initial: _TResult
) -> "AsyncRx[_TResult]":
return pipe(self, reduce(accumulator, initial), AsyncRx[_TResult])

def reduce_async(
self,
accumulator: Callable[[_TResult, _TSource], Awaitable[_TResult]],
initial: _TResult,
) -> "AsyncRx[_TResult]":
return pipe(self, reduce_async(accumulator, initial), AsyncRx[_TResult])

def skip(self, count: int) -> AsyncObservable[_TSource]:
"""Skip items from start of the stream.
Expand Down Expand Up @@ -838,6 +850,46 @@ def of_async(workflow: Awaitable[_TSource]) -> AsyncObservable[_TSource]:
return of_async(workflow)


def reduce(
accumulator: Callable[[_TResult, _TSource], _TResult],
initial: _TResult,
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]:
"""The reduce operator.
If an error occurs either in the accumulator or the source the subscription is disposed and the error is thrown.
Args:
accumulator: An accumulator function
initial: The initial value
Returns:
The reduce operator function
"""
from .transform import reduce as _reduce

return _reduce(accumulator, initial)


def reduce_async(
accumulator: Callable[[_TResult, _TSource], Awaitable[_TResult]],
initial: _TResult,
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]:
"""The async reduce operator.
See the `reduce` operator.
Args:
accumulator (Callable[[_TResult, _TSource], Awaitable[_TResult]]): An async accumulator function
initial (_TResult): The initial value
Returns:
Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]: The operator function
"""
from .transform import reduce_async as _reduce

return _reduce(accumulator, initial)


def retry(
retry_count: int,
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TSource]]:
Expand Down
25 changes: 25 additions & 0 deletions aioreactive/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,28 @@ def scan_async(
The scan operator.
"""
return _scan(accumulator, initial)


def reduce(
accumulator: Callable[[_TResult, _TSource], _TResult],
initial: _TResult,
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]:
async def _reduce(current: _TResult, value: _TSource) -> _TResult:
return accumulator(current, value)

def _operator(Observable: AsyncObservable[_TSource]) -> AsyncObservable[_TResult]:
return pipe(Observable, reduce_async(_reduce, initial))

return _operator


def reduce_async(
accumulator: Callable[[_TResult, _TSource], Awaitable[_TResult]],
initial: _TResult,
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]:
def _operator(source: AsyncObservable[_TSource]) -> AsyncObservable[_TResult]:
from .filtering import take_last

return pipe(source, scan_async(accumulator, initial), take_last(1))

return _operator
59 changes: 59 additions & 0 deletions tests/test_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio

import pytest
from expression import pipe

import aioreactive as rx
from aioreactive.notification import OnCompleted, OnNext
from aioreactive.testing import AsyncTestObserver, VirtualTimeEventLoop


class MyException(Exception):
pass


@pytest.fixture()
def event_loop():
loop = VirtualTimeEventLoop()
try:
yield loop
finally:
loop.close()


def sync_sum(a: int, b: int) -> int:
return a + b


async def async_sum(a: int, b: int) -> int:
await asyncio.sleep(0.2)
return a + b


@pytest.mark.asyncio
async def test_reduce():
xs = rx.from_iterable([1, 2, 3, 4])
observer = AsyncTestObserver()
ys = pipe(xs, rx.reduce(sync_sum, 0))
await rx.run(ys, observer)

values = list(map(lambda t: t[1], observer.values))
assert values == [
OnNext(10),
OnCompleted,
]


@pytest.mark.asyncio
async def test_reduce_async():
xs = rx.from_iterable([1, 2, 3, 4])
observer = AsyncTestObserver()
ys = pipe(xs, rx.reduce_async(async_sum, 0))

await rx.run(ys, observer)

values = list(map(lambda t: t[1], observer.values))
assert values == [
OnNext(10),
OnCompleted,
]

0 comments on commit 51608bd

Please sign in to comment.