Skip to content

Commit 56d437a

Browse files
committed
Feature: unsafe_get for ListArray of Numeric data types.
1 parent 78eb11f commit 56d437a

File tree

4 files changed

+56
-16
lines changed

4 files changed

+56
-16
lines changed

firebolt/arrays/base.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ struct ArrayData(Copyable, Movable, Representable, Stringable, Writable):
8484
if self.dtype.native == known_type:
8585
writer.write(self.buffers[0][].unsafe_get[known_type](index))
8686
return
87-
writer.write("Can't process data type:")
87+
writer.write("dtype=")
8888
writer.write(self.dtype)
8989

9090
fn write_to[W: Writer](self, mut writer: W):

firebolt/arrays/nested.mojo

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,31 @@ struct ListArray(Array):
2424
self.values = data.children[0]
2525
self.capacity = data.length
2626

27-
fn __init__[T: Array](out self, values: T, capacity: Int = 0):
28-
var bitmap = Bitmap.alloc(capacity)
29-
var offsets = Buffer.alloc[DType.uint32](capacity + 1)
30-
offsets.unsafe_set[DType.uint32](0, 0)
27+
fn __init__[T: Array](out self, var values: T, capacity: Int = 1):
28+
"""Initialize a list with the given values.
29+
30+
Default capacity is at least 1 to accomodate the values.
3131
32+
Args:
33+
values: Array to use as the first element in the ListArray.
34+
capacity: The capacity of the ListArray.
35+
"""
3236
var values_data = values.as_data()
3337
var list_dtype = list_(values_data.dtype)
3438

39+
var bitmap = Bitmap.alloc(capacity)
40+
bitmap.unsafe_set(0, True)
41+
var offsets = Buffer.alloc[DType.uint32](capacity + 1)
42+
offsets.unsafe_set[DType.uint32](0, 0)
43+
offsets.unsafe_set[DType.uint32](1, values_data.length)
44+
3545
self.capacity = capacity
3646
self.bitmap = ArcPointer(bitmap^)
3747
self.offsets = ArcPointer(offsets^)
3848
self.values = ArcPointer(values_data^)
3949
self.data = ArrayData(
4050
dtype=list_dtype,
41-
length=0,
51+
length=1,
4252
bitmap=self.bitmap,
4353
buffers=List(self.offsets),
4454
children=List(self.values),
@@ -68,6 +78,27 @@ struct ListArray(Array):
6878
)
6979
self.data.length += 1
7080

81+
fn unsafe_get(self, index: Int) raises -> ArrayData:
82+
"""Access the value at a given index in the list array."""
83+
var child_dtype = self.data.dtype.fields[0].dtype
84+
if not child_dtype.is_numeric():
85+
raise Error(
86+
"Only numeric dtype supported right now, got {}".format(
87+
child_dtype
88+
)
89+
)
90+
var start = Int(self.offsets[].unsafe_get[DType.int32](index))
91+
var end = Int(self.offsets[].unsafe_get[DType.int32](index + 1))
92+
ref first_child = self.data.children[0][]
93+
return ArrayData(
94+
dtype=child_dtype,
95+
bitmap=first_child.bitmap,
96+
buffers=first_child.buffers,
97+
offset=self.data.offset + start,
98+
length=end - start,
99+
children=List[ArcPointer[ArrayData]](),
100+
)
101+
71102
fn write_to[W: Writer](self, mut writer: W):
72103
"""
73104
Formats this ListArray to the provided Writer.

firebolt/arrays/tests/test_nested.mojo

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@ from firebolt.test_fixtures.bool_array import as_bool_array_scalar
77

88

99
def test_list_int_array():
10-
var ints = Int64Array()
11-
var lists = ListArray(ints)
12-
assert_equal(lists.data.dtype, list_(int64))
13-
10+
var ints = Int64Array(capacity=3)
1411
ints.append(1)
1512
ints.append(2)
1613
ints.append(3)
17-
lists.unsafe_append(True)
14+
var lists = ListArray(ints^)
15+
assert_equal(lists.data.dtype, list_(int64))
16+
1817
assert_equal(len(lists), 1)
1918

2019
var data = lists.as_data()
@@ -23,15 +22,18 @@ def test_list_int_array():
2322
var arr = data.as_list()
2423
assert_equal(len(arr), 1)
2524

25+
var first_value = lists.unsafe_get(0)
26+
assert_equal(first_value.__str__().strip(), "1 2 3")
27+
2628

2729
def test_list_bool_array():
2830
var bools = BoolArray()
29-
var lists = ListArray(bools)
3031

3132
bools.append(as_bool_array_scalar(True))
3233
bools.append(as_bool_array_scalar(False))
3334
bools.append(as_bool_array_scalar(True))
34-
lists.unsafe_append(True)
35+
36+
var lists = ListArray(bools^)
3537
assert_equal(len(lists), 1)
3638

3739

@@ -57,13 +59,13 @@ def test_struct_array():
5759

5860
def test_list_array_str_repr():
5961
var ints = Int64Array()
60-
var lists = ListArray(ints)
62+
var lists = ListArray(ints^)
6163

6264
var str_repr = lists.__str__()
6365
var repr_repr = lists.__repr__()
6466

65-
assert_equal(str_repr, "ListArray(length=0)")
66-
assert_equal(repr_repr, "ListArray(length=0)")
67+
assert_equal(str_repr, "ListArray(length=1)")
68+
assert_equal(repr_repr, "ListArray(length=1)")
6769
assert_equal(str_repr, repr_repr)
6870

6971

@@ -81,3 +83,8 @@ def test_struct_array_str_repr():
8183
assert_equal(str_repr, "StructArray(length=0)")
8284
assert_equal(repr_repr, "StructArray(length=0)")
8385
assert_equal(str_repr, repr_repr)
86+
87+
88+
fn main() raises -> None:
89+
"""Main entry point to get stdout during testing."""
90+
test_list_int_array()

firebolt/dtypes.mojo

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ struct DataType(
300300
writer.write("int32")
301301
elif self.code == INT64:
302302
writer.write("int64")
303+
elif self.code == LIST:
304+
writer.write("list")
303305
elif self.code == STRUCT:
304306
writer.write("struct")
305307
else:

0 commit comments

Comments
 (0)