Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions firebolt/arrays/nested.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,23 @@ struct ListArray(Array):
)
self.data.length += 1

fn unsafe_get(self, index: Int) raises -> ArrayData:
"""Access the value at a given index in the list array."""
var child_dtype = self.data.dtype.fields[0].dtype
var start = Int(self.offsets[].unsafe_get[DType.int32](index))
var end = Int(self.offsets[].unsafe_get[DType.int32](index + 1))
fn unsafe_get(self, index: Int, out array_data: ArrayData) raises:
"""Access the value at a given index in the list array.

Use an out argument to allow the caller to re-use memory while iterating over a pyarrow structure.
"""
var start = Int(
self.offsets[].unsafe_get[DType.int32](self.data.offset + index)
)
var end = Int(
self.offsets[].unsafe_get[DType.int32](self.data.offset + index + 1)
)
ref first_child = self.data.children[0][]
return ArrayData(
dtype=child_dtype,
dtype=first_child.dtype,
bitmap=first_child.bitmap,
buffers=first_child.buffers,
offset=self.data.offset + start,
offset=start,
length=end - start,
children=first_child.children,
)
Expand Down
2 changes: 1 addition & 1 deletion firebolt/arrays/primitive.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct PrimitiveArray[T: DataType](Array):
self.bitmap = data.bitmap
self.buffer = data.buffers[0]
self.capacity = data.length
self.offset = offset
self.offset = data.offset + offset

fn __init__(out self, capacity: Int = 0, offset: Int = 0):
self.capacity = capacity
Expand Down
4 changes: 2 additions & 2 deletions firebolt/arrays/tests/test_binary.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_string_builder():
assert_equal(len(a), 2)
assert_equal(a.capacity, 2)

var s = a.unsafe_get(0)
assert_equal(String(s), "hello")
assert_equal(String(a.unsafe_get(0)), "hello")
assert_equal(String(a.unsafe_get(1)), "world")

assert_equal(
a.__str__().strip(),
Expand Down
81 changes: 81 additions & 0 deletions firebolt/arrays/tests/test_nested.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,70 @@ from firebolt.dtypes import *
from firebolt.test_fixtures.bool_array import as_bool_array_scalar


fn build_list_of_list[data_type: DataType]() raises -> ListArray:
"""Build a test ListArray.

See: https://elferherrera.github.io/arrow_guide/arrays_nested.html
"""

# Define all the values.
var bitmap = ArcPointer(Bitmap.alloc(10))
bitmap[].unsafe_range_set(0, 10, True)
var buffer = ArcPointer(Buffer.alloc[data_type.native](10))
for i in range(10):
buffer[].unsafe_set[data_type.native](i, i + 1)

var value_data = ArrayData(
dtype=data_type,
length=10,
bitmap=bitmap,
buffers=List(buffer),
children=List[ArcPointer[ArrayData]](),
offset=0,
)

# Define the PrimitiveArrays.
var value_offset = ArcPointer(Buffer.alloc(7))
value_offset[].unsafe_set[DType.int32](0, 0)
value_offset[].unsafe_set[DType.int32](1, 2)
value_offset[].unsafe_set[DType.int32](2, 4)
value_offset[].unsafe_set[DType.int32](3, 7)
value_offset[].unsafe_set[DType.int32](4, 7)
value_offset[].unsafe_set[DType.int32](5, 8)
value_offset[].unsafe_set[DType.int32](6, 10)

var list_bitmap = ArcPointer(Bitmap.alloc(6))
list_bitmap[].unsafe_range_set(0, 6, True)
list_bitmap[].unsafe_set(3, False)
var list_data = ArrayData(
dtype=list_(data_type),
length=6,
buffers=List(value_offset),
children=List(ArcPointer(value_data)),
bitmap=list_bitmap,
offset=0,
)

# Now define the master array data.
var top_offsets = Buffer.alloc(4)
top_offsets.unsafe_set[DType.int32](0, 0)
top_offsets.unsafe_set[DType.int32](1, 2)
top_offsets.unsafe_set[DType.int32](2, 5)
top_offsets.unsafe_set[DType.int32](3, 6)
var top_bitmap = ArcPointer(Bitmap.alloc(4))
top_bitmap[].unsafe_range_set(0, 4, True)
return ListArray(
ArrayData(
dtype=list_(list_(data_type)),
length=4,
buffers=List(ArcPointer(top_offsets^)),
children=List(ArcPointer(list_data)),
bitmap=top_bitmap,
offset=0,
)
)


def test_list_int_array():
var ints = Int64Array(capacity=3)
ints.append(1)
Expand Down Expand Up @@ -57,6 +121,19 @@ def test_list_str():
assert_equal(first_value.unsafe_get(1), "world")


def test_list_of_list():
list2 = build_list_of_list[int64]()
top = ListArray(list2.unsafe_get(0))
middle_0 = top.unsafe_get(0)
bottom = Int64Array(middle_0)
assert_equal(bottom.unsafe_get(1), 2)
assert_equal(bottom.unsafe_get(0), 1)
middle_1 = top.unsafe_get(1)
bottom = Int64Array(middle_1)
assert_equal(bottom.unsafe_get(0), 3)
assert_equal(bottom.unsafe_get(1), 4)


def test_struct_array():
var fields = List[Field](
Field("id", int64),
Expand Down Expand Up @@ -103,3 +180,7 @@ def test_struct_array_str_repr():
assert_equal(str_repr, "StructArray(length=0)")
assert_equal(repr_repr, "StructArray(length=0)")
assert_equal(str_repr, repr_repr)


fn main() raises:
test_list_of_list()