Skip to content

Commit 2e2f827

Browse files
Fix bug in ListArray unsafe_get (#36)
1 parent 3a47eaa commit 2e2f827

File tree

4 files changed

+97
-10
lines changed

4 files changed

+97
-10
lines changed

firebolt/arrays/nested.mojo

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,23 @@ struct ListArray(Array):
8080
)
8181
self.data.length += 1
8282

83-
fn unsafe_get(self, index: Int) raises -> ArrayData:
84-
"""Access the value at a given index in the list array."""
85-
var child_dtype = self.data.dtype.fields[0].dtype
86-
var start = Int(self.offsets[].unsafe_get[DType.int32](index))
87-
var end = Int(self.offsets[].unsafe_get[DType.int32](index + 1))
83+
fn unsafe_get(self, index: Int, out array_data: ArrayData) raises:
84+
"""Access the value at a given index in the list array.
85+
86+
Use an out argument to allow the caller to re-use memory while iterating over a pyarrow structure.
87+
"""
88+
var start = Int(
89+
self.offsets[].unsafe_get[DType.int32](self.data.offset + index)
90+
)
91+
var end = Int(
92+
self.offsets[].unsafe_get[DType.int32](self.data.offset + index + 1)
93+
)
8894
ref first_child = self.data.children[0][]
8995
return ArrayData(
90-
dtype=child_dtype,
96+
dtype=first_child.dtype,
9197
bitmap=first_child.bitmap,
9298
buffers=first_child.buffers,
93-
offset=self.data.offset + start,
99+
offset=start,
94100
length=end - start,
95101
children=first_child.children,
96102
)

firebolt/arrays/primitive.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ struct PrimitiveArray[T: DataType](Array):
8585
self.bitmap = data.bitmap
8686
self.buffer = data.buffers[0]
8787
self.capacity = data.length
88-
self.offset = offset
88+
self.offset = data.offset + offset
8989

9090
fn __init__(out self, capacity: Int = 0, offset: Int = 0):
9191
self.capacity = capacity

firebolt/arrays/tests/test_binary.mojo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def test_string_builder():
1919
assert_equal(len(a), 2)
2020
assert_equal(a.capacity, 2)
2121

22-
var s = a.unsafe_get(0)
23-
assert_equal(String(s), "hello")
22+
assert_equal(String(a.unsafe_get(0)), "hello")
23+
assert_equal(String(a.unsafe_get(1)), "world")
2424

2525
assert_equal(
2626
a.__str__().strip(),

firebolt/arrays/tests/test_nested.mojo

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,70 @@ from firebolt.dtypes import *
66
from firebolt.test_fixtures.bool_array import as_bool_array_scalar
77

88

9+
fn build_list_of_list[data_type: DataType]() raises -> ListArray:
10+
"""Build a test ListArray.
11+
12+
See: https://elferherrera.github.io/arrow_guide/arrays_nested.html
13+
"""
14+
15+
# Define all the values.
16+
var bitmap = ArcPointer(Bitmap.alloc(10))
17+
bitmap[].unsafe_range_set(0, 10, True)
18+
var buffer = ArcPointer(Buffer.alloc[data_type.native](10))
19+
for i in range(10):
20+
buffer[].unsafe_set[data_type.native](i, i + 1)
21+
22+
var value_data = ArrayData(
23+
dtype=data_type,
24+
length=10,
25+
bitmap=bitmap,
26+
buffers=List(buffer),
27+
children=List[ArcPointer[ArrayData]](),
28+
offset=0,
29+
)
30+
31+
# Define the PrimitiveArrays.
32+
var value_offset = ArcPointer(Buffer.alloc(7))
33+
value_offset[].unsafe_set[DType.int32](0, 0)
34+
value_offset[].unsafe_set[DType.int32](1, 2)
35+
value_offset[].unsafe_set[DType.int32](2, 4)
36+
value_offset[].unsafe_set[DType.int32](3, 7)
37+
value_offset[].unsafe_set[DType.int32](4, 7)
38+
value_offset[].unsafe_set[DType.int32](5, 8)
39+
value_offset[].unsafe_set[DType.int32](6, 10)
40+
41+
var list_bitmap = ArcPointer(Bitmap.alloc(6))
42+
list_bitmap[].unsafe_range_set(0, 6, True)
43+
list_bitmap[].unsafe_set(3, False)
44+
var list_data = ArrayData(
45+
dtype=list_(data_type),
46+
length=6,
47+
buffers=List(value_offset),
48+
children=List(ArcPointer(value_data)),
49+
bitmap=list_bitmap,
50+
offset=0,
51+
)
52+
53+
# Now define the master array data.
54+
var top_offsets = Buffer.alloc(4)
55+
top_offsets.unsafe_set[DType.int32](0, 0)
56+
top_offsets.unsafe_set[DType.int32](1, 2)
57+
top_offsets.unsafe_set[DType.int32](2, 5)
58+
top_offsets.unsafe_set[DType.int32](3, 6)
59+
var top_bitmap = ArcPointer(Bitmap.alloc(4))
60+
top_bitmap[].unsafe_range_set(0, 4, True)
61+
return ListArray(
62+
ArrayData(
63+
dtype=list_(list_(data_type)),
64+
length=4,
65+
buffers=List(ArcPointer(top_offsets^)),
66+
children=List(ArcPointer(list_data)),
67+
bitmap=top_bitmap,
68+
offset=0,
69+
)
70+
)
71+
72+
973
def test_list_int_array():
1074
var ints = Int64Array(capacity=3)
1175
ints.append(1)
@@ -57,6 +121,19 @@ def test_list_str():
57121
assert_equal(first_value.unsafe_get(1), "world")
58122

59123

124+
def test_list_of_list():
125+
list2 = build_list_of_list[int64]()
126+
top = ListArray(list2.unsafe_get(0))
127+
middle_0 = top.unsafe_get(0)
128+
bottom = Int64Array(middle_0)
129+
assert_equal(bottom.unsafe_get(1), 2)
130+
assert_equal(bottom.unsafe_get(0), 1)
131+
middle_1 = top.unsafe_get(1)
132+
bottom = Int64Array(middle_1)
133+
assert_equal(bottom.unsafe_get(0), 3)
134+
assert_equal(bottom.unsafe_get(1), 4)
135+
136+
60137
def test_struct_array():
61138
var fields = List[Field](
62139
Field("id", int64),
@@ -103,3 +180,7 @@ def test_struct_array_str_repr():
103180
assert_equal(str_repr, "StructArray(length=0)")
104181
assert_equal(repr_repr, "StructArray(length=0)")
105182
assert_equal(str_repr, repr_repr)
183+
184+
185+
fn main() raises:
186+
test_list_of_list()

0 commit comments

Comments
 (0)