Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions guppylang/std/array/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic

from guppylang.decorator import guppy
from guppylang.std.builtins import ArrayIter, SizedIter, _array_unsafe_getitem
from guppylang.std.option import Option, nothing, some

if TYPE_CHECKING:
from guppylang.std.builtins import array, owned

n = guppy.nat_var("n")

L = guppy.type_var("L", copyable=False, droppable=False)
L2 = guppy.type_var("L2", copyable=False, droppable=False)


@guppy.struct
class ArrayZipIter(Generic[L, L2, n]):
"""Zipped iterator over arrays."""

xs: array[L, n]
ys: array[L2, n]
i: int

@guppy
def __next__(
self: ArrayZipIter[L, L2, n] @ owned,
) -> Option[tuple[tuple[L, L2], ArrayZipIter[L, L2, n]]]:
if self.i < int(n):
x = _array_unsafe_getitem(self.xs, self.i)
y = _array_unsafe_getitem(self.ys, self.i)
elem = (x, y)
return some((elem, ArrayZipIter(self.xs, self.ys, self.i + 1)))
ArrayIter(self.xs, 0)._assert_all_used()
ArrayIter(self.ys, 0)._assert_all_used()
return nothing()


@guppy
def zip(
a: array[L, n] @ owned, b: array[L2, n] @ owned
) -> SizedIter[ArrayZipIter[L, L2, n], n]:
"""Zip two arrays together into an iterator of tuples.
Args:
a: First array.
b: Second array.
Returns:
An iterator of tuples, where each tuple contains elements from the two input
arrays at the same index.
"""
return SizedIter(ArrayZipIter(a, b, 0))


@guppy
def enumerate(a: array[L, n] @ owned) -> SizedIter[ArrayZipIter[int, L, n], n]:
"""
Enumerates the elements of an array, pairing each element with its index.

Args:
a: The input array of type `L` with size `n`.

Returns:
An iterator that yields tuples,
where each tuple contains an index (int) and the corresponding element
from the input array.
"""

return zip(array(i for i in range(n)), a)
112 changes: 112 additions & 0 deletions guppylang/std/array/bool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

from typing import no_type_check

from guppylang.decorator import guppy
from guppylang.std.builtins import array

n = guppy.nat_var("n")


@guppy
@no_type_check
def array_eq(a: array[bool, n], b: array[bool, n]) -> bool:
"""Check if two boolean arrays are equal element-wise.

Args:
a: First boolean array.
b: Second boolean array.

Returns:
True if all elements are equal, False otherwise.
"""
for i in range(n):
if a[i] != b[i]:
return False
return True


@guppy
@no_type_check
def array_any(arr: array[bool, n]) -> bool:
"""Check if any element in the boolean array is True.

Args:
arr: Boolean array.

Returns:
True if at least one element is True, False otherwise.
"""
for i in range(n):
if arr[i]:
return True
return False


@guppy
@no_type_check
def array_all(arr: array[bool, n]) -> bool:
"""Check if all elements in the boolean array are True.

Args:
arr: Boolean array.

Returns:
True if all elements are True, False otherwise.
"""
for i in range(n):
if not arr[i]:
return False
return True


@guppy
@no_type_check
def parity(bits: array[bool, n]) -> bool:
"""Compute the parity of a boolean array.

Args:
bits: Boolean array.

Returns:
True if the number of True elements is odd, False otherwise.
"""
out = False
for i in range(n):
out ^= bits[i]
return out


@guppy
@no_type_check
def bitwise_xor(x: array[bool, n], y: array[bool, n]) -> array[bool, n]:
"""Perform bitwise XOR operation on two boolean arrays.

Args:
x: First boolean array.
y: Second boolean array.

Returns:
Resultant boolean array after XOR operation.
"""
return array(x[i] ^ y[i] for i in range(n))


@guppy
@no_type_check
def pack_bits_dlo(ar: array[bool, n]) -> int:
"""Pack bits into an integer assuming decreasing lexicographical order.

The first element of the array is considered the most significant bit.

Args:
ar: Boolean array.

Returns:
Integer representation of the packed bits.
"""
out = 0
for i in range(n):
if ar[i]:
out += 1 << (n - 1 - i)
return out
Empty file.
124 changes: 124 additions & 0 deletions tests/integration/array_lib/test_bool_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import no_type_check
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.array.bool import (
array_eq,
array_any,
array_all,
parity,
bitwise_xor,
pack_bits_dlo,
)
from guppylang.std.builtins import array


def test_bool_array_eq(validate, run_int_fn):
module = GuppyModule("test")
module.load(array_eq)

@guppy(module)
@no_type_check
def main() -> int:
yes = array_eq(array(True, False, True), array(True, False, True))
no = not array_eq(array(True, False, True), array(False, True, False))

return 2 * int(yes) + int(no)

package = module.compile()
validate(package)
run_int_fn(package, expected=3)


def test_array_any(validate, run_int_fn) -> None:
module = GuppyModule("test")
module.load(array_any)

@guppy(module)
@no_type_check
def main() -> int:
yes = array_any(array(False, False, True))
no = not array_any(array(False, False, False))

return 2 * int(yes) + int(no)

package = module.compile()
validate(package)
run_int_fn(package, expected=3)


def test_array_all(validate, run_int_fn) -> None:
module = GuppyModule("test")
module.load(array_all)

@guppy(module)
@no_type_check
def main() -> int:
yes = array_all(array(True, True, True))
no = not array_all(array(True, False, True))

return 2 * int(yes) + int(no)

package = module.compile()
validate(package)
run_int_fn(package, expected=3)


def test_parity_check(validate, run_int_fn) -> None:
module = GuppyModule("test")
module.load(parity)

@guppy(module)
@no_type_check
def main() -> int:
yes = parity(array(True, True, True))
no = not parity(array(True, False, True))

return 2 * int(yes) + int(no)

package = module.compile()
validate(package)
run_int_fn(package, expected=3)


def test_bitwise_xor(validate, run_int_fn) -> None:
module = GuppyModule("test")
module.load(bitwise_xor, array_eq)

@guppy(module)
@no_type_check
def main() -> int:
first = array_eq(
bitwise_xor(array(True, False, True), array(False, True, True)),
array(True, True, False),
)
second = array_eq(
bitwise_xor(array(True, True, True), array(True, True, True)),
array(False, False, False),
)

return 2 * int(first) + int(second)

package = module.compile()
validate(package)
run_int_fn(package, expected=3)


def test_packbits_dlo(validate, run_int_fn) -> None:
module = GuppyModule("test")
module.load(pack_bits_dlo)

@guppy(module)
@no_type_check
def main() -> int:
five = pack_bits_dlo(array(True, False, True))
four = pack_bits_dlo(array(True, False, False))
empty = pack_bits_dlo(array())
one = pack_bits_dlo(array(True))
zero = pack_bits_dlo(array(False))
two_zero = pack_bits_dlo(array(False, False))

return five + four + empty + one + zero + two_zero

package = module.compile()
validate(package)
run_int_fn(package, expected=10)
53 changes: 53 additions & 0 deletions tests/integration/array_lib/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import no_type_check
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

from guppylang.std.array import zip, enumerate
from guppylang.std.builtins import array


def test_zip(validate, run_int_fn):
module = GuppyModule("test")
module.load(zip)

@guppy(module)
@no_type_check
def main() -> int:
pyi = array(13, 2352, 358)
pyb = array(True, False, True)

total = 0
for i, b in zip(pyi, pyb):
total += i * (int(b) + 1)

return total

package = module.compile()
validate(package)

run_int_fn(package, expected=3094)


def test_enumerate(validate, run_int_fn):
module = GuppyModule("test")
module.load(enumerate)

@guppy(module)
@no_type_check
def main() -> int:
pyi = array(13, 2352, 358)

total = 0
for i, v in enumerate(pyi):
total += v * (i + 1)

return total

package = module.compile()
validate(package)

run_int_fn(package, expected=5791)


if __name__ == "__main__":
test_zip(None, None)
Loading