diff --git a/README.md b/README.md index a09ffbc..d364ec3 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,25 @@ Simple functions can be passed in sequence to compose more complex filters ``` +### `Stream.fold` +```python +Stream.fold(self, initial: 'T', fn: 'Callable[[T], U]', *, workers: 'int' = 1, use_threads: 'bool' = False) -> 'U' +``` +Fold the results into a single value. `fold` triggers an action so will incur a `collect`. + +```python +>>> Stream.from_iterable([1, 2, 3, 4]).fold(0, lambda a, b: a + b) == 10 +>>> Stream.from_iterable([[1], [2], [3], [4]]).fold([0], lambda a, b: a + b) == [0, 1, 2, 3, 4] +>>> Stream.from_iterable([1, 2, 3, 4]).fold(1, lambda a, b: a * b) == 24 +``` + +As `fold` triggers an action, the parameters will be forwarded to the `par_collect` call if the `workers` are greater than 1. +This will only effect the `collect` that is used to create the iterable to reduce, not the `fold` operation itself. +```python +>>> Stream.from_iterable([1, 2, 3, 4]).map(some_expensive_fn).fold(0, add, workers=4, use_threads=False) +``` + + ### `Stream.from_iterable` ```python Stream.from_iterable(it: 'Iterable') -> 'Self' @@ -217,7 +236,7 @@ Each partition is independently replayable. >>> part2.collect() == (1, 3) ``` -As `partition` triggers an action, the parameters will be forwarded to the `collect` call if the `workers` are greater than 1. +As `partition` triggers an action, the parameters will be forwarded to the `par_collect` call if the `workers` are greater than 1. ```python >>> Stream.from_iterable(range(10)).map(add_one, add_one).partition(divisible_by_3, workers=4) >>> part1.map(add_one).par_collect() == (4, 7, 10) diff --git a/coverage.svg b/coverage.svg index 16c3628..11e2b54 100644 --- a/coverage.svg +++ b/coverage.svg @@ -1,16 +1,16 @@ - - - - - - - - - - coverage - 100.0% - + + + + + + + + + + coverage + 100.0% + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4445401..2da537a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "danom" -version = "0.5.0" +version = "0.6.0" description = "Functional streams and monads" readme = "README.md" license = "MIT" diff --git a/src/danom/_stream.py b/src/danom/_stream.py index 12ab713..ac586e3 100644 --- a/src/danom/_stream.py +++ b/src/danom/_stream.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Iterable from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from enum import Enum, auto, unique +from functools import reduce from typing import Self import attrs @@ -15,6 +16,10 @@ class _BaseStream(ABC): seq: Iterable = attrs.field(validator=attrs.validators.instance_of(Iterable), repr=False) ops: tuple = attrs.field(default=(), validator=attrs.validators.instance_of(tuple), repr=False) + @classmethod + @abstractmethod + def from_iterable(cls, it: Iterable) -> Self: ... + @abstractmethod def map[T, U](self, *fns: Callable[[T], U]) -> Self: ... @@ -24,9 +29,17 @@ def filter[T](self, *fns: Callable[[T], bool]) -> Self: ... @abstractmethod def partition[T](self, fn: Callable[[T], bool]) -> tuple[Self, Self]: ... + @abstractmethod + def fold[T, U]( + self, initial: T, fn: Callable[[T], U], *, workers: int = 1, use_threads: bool = False + ) -> U: ... + @abstractmethod def collect(self) -> tuple: ... + @abstractmethod + def par_collect(self) -> tuple: ... + @attrs.define(frozen=True) class Stream(_BaseStream): @@ -97,7 +110,7 @@ def partition[T]( >>> part2.collect() == (1, 3) ``` - As `partition` triggers an action, the parameters will be forwarded to the `collect` call if the `workers` are greater than 1. + As `partition` triggers an action, the parameters will be forwarded to the `par_collect` call if the `workers` are greater than 1. ```python >>> Stream.from_iterable(range(10)).map(add_one, add_one).partition(divisible_by_3, workers=4) >>> part1.map(add_one).par_collect() == (4, 7, 10) @@ -114,6 +127,27 @@ def partition[T]( Stream(seq=(x for x in seq_tuple if not fn(x))), ) + def fold[T, U]( + self, initial: T, fn: Callable[[T], U], *, workers: int = 1, use_threads: bool = False + ) -> U: + """Fold the results into a single value. `fold` triggers an action so will incur a `collect`. + + ```python + >>> Stream.from_iterable([1, 2, 3, 4]).fold(0, lambda a, b: a + b) == 10 + >>> Stream.from_iterable([[1], [2], [3], [4]]).fold([0], lambda a, b: a + b) == [0, 1, 2, 3, 4] + >>> Stream.from_iterable([1, 2, 3, 4]).fold(1, lambda a, b: a * b) == 24 + ``` + + As `fold` triggers an action, the parameters will be forwarded to the `par_collect` call if the `workers` are greater than 1. + This will only effect the `collect` that is used to create the iterable to reduce, not the `fold` operation itself. + ```python + >>> Stream.from_iterable([1, 2, 3, 4]).map(some_expensive_fn).fold(0, add, workers=4, use_threads=False) + ``` + """ + if workers > 1: + return reduce(fn, self.par_collect(workers=workers, use_threads=use_threads), initial) + return reduce(fn, self.collect(), initial) + def collect(self) -> tuple: """Materialise the sequence from the `Stream`. diff --git a/tests/conftest.py b/tests/conftest.py index 46086f8..5e36d26 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,10 @@ from src.danom._result import Result +def add[T](a: T, b: T) -> T: + return a + b + + def has_len(value: str) -> bool: return len(value) > 0 diff --git a/tests/test_stream.py b/tests/test_stream.py index 8a5f82a..3a2f514 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,7 +1,7 @@ import pytest from src.danom import Stream -from tests.conftest import add_one, divisible_by_3, divisible_by_5 +from tests.conftest import add, add_one, divisible_by_3, divisible_by_5 @pytest.mark.parametrize( @@ -50,3 +50,14 @@ def test_stream_to_par_stream(): ) assert part1.map(add_one).collect() == (4, 7, 10) assert part2.collect() == (2, 4, 5, 7, 8, 10, 11) + + +@pytest.mark.parametrize( + ("starting", "initial", "fn", "workers", "expected_result"), + [ + pytest.param(range(10), 0, add, 1, 45), + pytest.param(range(10), 0, add, 4, 45), + ], +) +def test_fold(starting, initial, fn, workers, expected_result): + assert Stream.from_iterable(starting).fold(initial, fn, workers=workers) == expected_result diff --git a/uv.lock b/uv.lock index e49466f..45d5f0b 100644 --- a/uv.lock +++ b/uv.lock @@ -189,7 +189,7 @@ wheels = [ [[package]] name = "danom" -version = "0.5.0" +version = "0.6.0" source = { editable = "." } dependencies = [ { name = "attrs" },