Skip to content

Commit

Permalink
finish documenting all functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rudymatela committed Oct 4, 2024
1 parent d65a20a commit 0e3638f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
3 changes: 0 additions & 3 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
TODO for leancheck.py
=====================

* add documentation to module and to all functions
- reorder functions if necessary

* simplify code

later
Expand Down
30 changes: 29 additions & 1 deletion src/leancheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,22 @@ class Enumerator:
which can be registered using the `Enumerator.register()` method.
This class supports computing sums and products of enumerations:
>>> print(Enumerator[int] + Enumerator[bool])
[0, False, True, 1, 2, 3, ...]
Use `*` to take the product of two enumerations:
>>> print(Enumerator[int] * Enumerator[bool])
[(0, False), (0, True), (1, False), (1, True), (2, False), (2, True), ...]
"""

tiers: typing.Callable[[], typing.Generator]
"""
Generate tiers of values.
>>> list(Enumerator[bool].tiers())
[[False, True]]
"""

def __init__(self, tiers):
Expand Down Expand Up @@ -308,6 +324,9 @@ def __add__(self, other):
>>> print(Enumerator[int] + Enumerator[bool])
[0, False, True, 1, 2, 3, ...]
>>> Enumerator[int] + Enumerator[bool]
Enumerator(lambda: (xs for xs in [[0, False, True], [1], [2], [3], [4], [5], ...]))
"""
return Enumerator(lambda: _zippend(self.tiers(), other.tiers()))

Expand All @@ -317,6 +336,9 @@ def __mul__(self, other):
>>> print(Enumerator[int] * Enumerator[bool])
[(0, False), (0, True), (1, False), (1, True), (2, False), (2, True), ...]
>>> Enumerator[int] * Enumerator[bool]
Enumerator(lambda: (xs for xs in [[(0, False), (0, True)], [(1, False), (1, True)], [(2, False), (2, True)], [(3, False), (3, True)], [(4, False), (4, True)], [(5, False), (5, True)], ...]))
"""
return Enumerator(lambda: _pproduct(self.tiers(), other.tiers()))

Expand All @@ -335,6 +357,12 @@ def __str__(self):
return "[" + ', '.join(xs) + "]"

def map(self, f):
"""
Applies a function to all values in the enumeration.
>>> Enumerator[int].map(lambda x: x*2)
Enumerator(lambda: (xs for xs in [[0], [2], [4], [6], [8], [10], ...]))
"""
return Enumerator(lambda: _mmap(f, self.tiers()))

@classmethod
Expand Down Expand Up @@ -462,7 +490,7 @@ def _intercalate(generator1, generator2):


def _zippend(*iiterables):
return itertools.starmap(itertools.chain,itertools.zip_longest(*iiterables, fillvalue=[]))
return map(list,itertools.starmap(itertools.chain,itertools.zip_longest(*iiterables, fillvalue=[])))


def _pproduct(xss, yss, with_f=None):
Expand Down

0 comments on commit 0e3638f

Please sign in to comment.