Skip to content

Commit ba8a90a

Browse files
committed
Feature: unsafe_get for ListArray of Numeric data types.
1 parent 78eb11f commit ba8a90a

File tree

4 files changed

+60
-16
lines changed

4 files changed

+60
-16
lines changed

firebolt/arrays/base.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ struct ArrayData(Copyable, Movable, Representable, Stringable, Writable):
8484
if self.dtype.native == known_type:
8585
writer.write(self.buffers[0][].unsafe_get[known_type](index))
8686
return
87-
writer.write("Can't process data type:")
87+
writer.write("dtype=")
8888
writer.write(self.dtype)
8989

9090
fn write_to[W: Writer](self, mut writer: W):

firebolt/arrays/nested.mojo

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,35 @@ struct ListArray(Array):
2424
self.values = data.children[0]
2525
self.capacity = data.length
2626

27-
fn __init__[T: Array](out self, values: T, capacity: Int = 0):
28-
var bitmap = Bitmap.alloc(capacity)
29-
var offsets = Buffer.alloc[DType.uint32](capacity + 1)
30-
offsets.unsafe_set[DType.uint32](0, 0)
27+
fn __init__[T: Array](out self, var values: T, capacity: Int = 1):
28+
"""Initialize a list with the given values.
29+
30+
Default capacity is at least 1 to accomodate the values.
3131
32+
Args:
33+
values: Array to use as the first element in the ListArray.
34+
capacity: The capacity of the ListArray.
35+
"""
3236
var values_data = values.as_data()
37+
try:
38+
print("ListArray entering __init__ {}".format(values))
39+
except e:
40+
print(e)
3341
var list_dtype = list_(values_data.dtype)
3442

43+
var bitmap = Bitmap.alloc(capacity)
44+
bitmap.unsafe_set(0, True)
45+
var offsets = Buffer.alloc[DType.uint32](capacity + 1)
46+
offsets.unsafe_set[DType.uint32](0, 0)
47+
offsets.unsafe_set[DType.uint32](1, values_data.length)
48+
3549
self.capacity = capacity
3650
self.bitmap = ArcPointer(bitmap^)
3751
self.offsets = ArcPointer(offsets^)
3852
self.values = ArcPointer(values_data^)
3953
self.data = ArrayData(
4054
dtype=list_dtype,
41-
length=0,
55+
length=1,
4256
bitmap=self.bitmap,
4357
buffers=List(self.offsets),
4458
children=List(self.values),
@@ -68,6 +82,27 @@ struct ListArray(Array):
6882
)
6983
self.data.length += 1
7084

85+
fn unsafe_get(self, index: Int) raises -> ArrayData:
86+
"""Access the value at a given index in the list array."""
87+
var child_dtype = self.data.dtype.fields[0].dtype
88+
if not child_dtype.is_numeric():
89+
raise Error(
90+
"Only numeric dtype supported right now, got {}".format(
91+
child_dtype
92+
)
93+
)
94+
var start = Int(self.offsets[].unsafe_get[DType.int32](index))
95+
var end = Int(self.offsets[].unsafe_get[DType.int32](index + 1))
96+
ref first_child = self.data.children[0][]
97+
return ArrayData(
98+
dtype=child_dtype,
99+
bitmap=first_child.bitmap,
100+
buffers=first_child.buffers,
101+
offset=self.data.offset + start,
102+
length=end - start,
103+
children=List[ArcPointer[ArrayData]](),
104+
)
105+
71106
fn write_to[W: Writer](self, mut writer: W):
72107
"""
73108
Formats this ListArray to the provided Writer.

firebolt/arrays/tests/test_nested.mojo

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@ from firebolt.test_fixtures.bool_array import as_bool_array_scalar
77

88

99
def test_list_int_array():
10-
var ints = Int64Array()
11-
var lists = ListArray(ints)
12-
assert_equal(lists.data.dtype, list_(int64))
13-
10+
var ints = Int64Array(capacity=3)
1411
ints.append(1)
1512
ints.append(2)
1613
ints.append(3)
17-
lists.unsafe_append(True)
14+
var lists = ListArray(ints^)
15+
assert_equal(lists.data.dtype, list_(int64))
16+
1817
assert_equal(len(lists), 1)
1918

2019
var data = lists.as_data()
@@ -23,15 +22,18 @@ def test_list_int_array():
2322
var arr = data.as_list()
2423
assert_equal(len(arr), 1)
2524

25+
var first_value = lists.unsafe_get(0)
26+
assert_equal(first_value.__str__().strip(), "1 2 3")
27+
2628

2729
def test_list_bool_array():
2830
var bools = BoolArray()
29-
var lists = ListArray(bools)
3031

3132
bools.append(as_bool_array_scalar(True))
3233
bools.append(as_bool_array_scalar(False))
3334
bools.append(as_bool_array_scalar(True))
34-
lists.unsafe_append(True)
35+
36+
var lists = ListArray(bools^)
3537
assert_equal(len(lists), 1)
3638

3739

@@ -57,13 +59,13 @@ def test_struct_array():
5759

5860
def test_list_array_str_repr():
5961
var ints = Int64Array()
60-
var lists = ListArray(ints)
62+
var lists = ListArray(ints^)
6163

6264
var str_repr = lists.__str__()
6365
var repr_repr = lists.__repr__()
6466

65-
assert_equal(str_repr, "ListArray(length=0)")
66-
assert_equal(repr_repr, "ListArray(length=0)")
67+
assert_equal(str_repr, "ListArray(length=1)")
68+
assert_equal(repr_repr, "ListArray(length=1)")
6769
assert_equal(str_repr, repr_repr)
6870

6971

@@ -81,3 +83,8 @@ def test_struct_array_str_repr():
8183
assert_equal(str_repr, "StructArray(length=0)")
8284
assert_equal(repr_repr, "StructArray(length=0)")
8385
assert_equal(str_repr, repr_repr)
86+
87+
88+
fn main() raises -> None:
89+
"""Main entry point to get stdout during testing."""
90+
test_list_int_array()

firebolt/dtypes.mojo

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ struct DataType(
300300
writer.write("int32")
301301
elif self.code == INT64:
302302
writer.write("int64")
303+
elif self.code == LIST:
304+
writer.write("list")
303305
elif self.code == STRUCT:
304306
writer.write("struct")
305307
else:

0 commit comments

Comments
 (0)