From 56d437a06dc97e9841c0dd0d28ef058ccd0bbeb5 Mon Sep 17 00:00:00 2001 From: Marius Seritan <39998+winding-lines@users.noreply.github.com> Date: Sun, 31 Aug 2025 18:11:18 -0700 Subject: [PATCH] Feature: unsafe_get for ListArray of Numeric data types. --- firebolt/arrays/base.mojo | 2 +- firebolt/arrays/nested.mojo | 41 ++++++++++++++++++++++---- firebolt/arrays/tests/test_nested.mojo | 27 ++++++++++------- firebolt/dtypes.mojo | 2 ++ 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/firebolt/arrays/base.mojo b/firebolt/arrays/base.mojo index a29a9c4..2965998 100644 --- a/firebolt/arrays/base.mojo +++ b/firebolt/arrays/base.mojo @@ -84,7 +84,7 @@ struct ArrayData(Copyable, Movable, Representable, Stringable, Writable): if self.dtype.native == known_type: writer.write(self.buffers[0][].unsafe_get[known_type](index)) return - writer.write("Can't process data type:") + writer.write("dtype=") writer.write(self.dtype) fn write_to[W: Writer](self, mut writer: W): diff --git a/firebolt/arrays/nested.mojo b/firebolt/arrays/nested.mojo index db0b9d7..be54abf 100644 --- a/firebolt/arrays/nested.mojo +++ b/firebolt/arrays/nested.mojo @@ -24,21 +24,31 @@ struct ListArray(Array): self.values = data.children[0] self.capacity = data.length - fn __init__[T: Array](out self, values: T, capacity: Int = 0): - var bitmap = Bitmap.alloc(capacity) - var offsets = Buffer.alloc[DType.uint32](capacity + 1) - offsets.unsafe_set[DType.uint32](0, 0) + fn __init__[T: Array](out self, var values: T, capacity: Int = 1): + """Initialize a list with the given values. + + Default capacity is at least 1 to accomodate the values. + Args: + values: Array to use as the first element in the ListArray. + capacity: The capacity of the ListArray. + """ var values_data = values.as_data() var list_dtype = list_(values_data.dtype) + var bitmap = Bitmap.alloc(capacity) + bitmap.unsafe_set(0, True) + var offsets = Buffer.alloc[DType.uint32](capacity + 1) + offsets.unsafe_set[DType.uint32](0, 0) + offsets.unsafe_set[DType.uint32](1, values_data.length) + self.capacity = capacity self.bitmap = ArcPointer(bitmap^) self.offsets = ArcPointer(offsets^) self.values = ArcPointer(values_data^) self.data = ArrayData( dtype=list_dtype, - length=0, + length=1, bitmap=self.bitmap, buffers=List(self.offsets), children=List(self.values), @@ -68,6 +78,27 @@ struct ListArray(Array): ) self.data.length += 1 + fn unsafe_get(self, index: Int) raises -> ArrayData: + """Access the value at a given index in the list array.""" + var child_dtype = self.data.dtype.fields[0].dtype + if not child_dtype.is_numeric(): + raise Error( + "Only numeric dtype supported right now, got {}".format( + child_dtype + ) + ) + var start = Int(self.offsets[].unsafe_get[DType.int32](index)) + var end = Int(self.offsets[].unsafe_get[DType.int32](index + 1)) + ref first_child = self.data.children[0][] + return ArrayData( + dtype=child_dtype, + bitmap=first_child.bitmap, + buffers=first_child.buffers, + offset=self.data.offset + start, + length=end - start, + children=List[ArcPointer[ArrayData]](), + ) + fn write_to[W: Writer](self, mut writer: W): """ Formats this ListArray to the provided Writer. diff --git a/firebolt/arrays/tests/test_nested.mojo b/firebolt/arrays/tests/test_nested.mojo index f39951d..5a1e5ab 100644 --- a/firebolt/arrays/tests/test_nested.mojo +++ b/firebolt/arrays/tests/test_nested.mojo @@ -7,14 +7,13 @@ from firebolt.test_fixtures.bool_array import as_bool_array_scalar def test_list_int_array(): - var ints = Int64Array() - var lists = ListArray(ints) - assert_equal(lists.data.dtype, list_(int64)) - + var ints = Int64Array(capacity=3) ints.append(1) ints.append(2) ints.append(3) - lists.unsafe_append(True) + var lists = ListArray(ints^) + assert_equal(lists.data.dtype, list_(int64)) + assert_equal(len(lists), 1) var data = lists.as_data() @@ -23,15 +22,18 @@ def test_list_int_array(): var arr = data.as_list() assert_equal(len(arr), 1) + var first_value = lists.unsafe_get(0) + assert_equal(first_value.__str__().strip(), "1 2 3") + def test_list_bool_array(): var bools = BoolArray() - var lists = ListArray(bools) bools.append(as_bool_array_scalar(True)) bools.append(as_bool_array_scalar(False)) bools.append(as_bool_array_scalar(True)) - lists.unsafe_append(True) + + var lists = ListArray(bools^) assert_equal(len(lists), 1) @@ -57,13 +59,13 @@ def test_struct_array(): def test_list_array_str_repr(): var ints = Int64Array() - var lists = ListArray(ints) + var lists = ListArray(ints^) var str_repr = lists.__str__() var repr_repr = lists.__repr__() - assert_equal(str_repr, "ListArray(length=0)") - assert_equal(repr_repr, "ListArray(length=0)") + assert_equal(str_repr, "ListArray(length=1)") + assert_equal(repr_repr, "ListArray(length=1)") assert_equal(str_repr, repr_repr) @@ -81,3 +83,8 @@ def test_struct_array_str_repr(): assert_equal(str_repr, "StructArray(length=0)") assert_equal(repr_repr, "StructArray(length=0)") assert_equal(str_repr, repr_repr) + + +fn main() raises -> None: + """Main entry point to get stdout during testing.""" + test_list_int_array() diff --git a/firebolt/dtypes.mojo b/firebolt/dtypes.mojo index 12a3700..94b6c14 100644 --- a/firebolt/dtypes.mojo +++ b/firebolt/dtypes.mojo @@ -300,6 +300,8 @@ struct DataType( writer.write("int32") elif self.code == INT64: writer.write("int64") + elif self.code == LIST: + writer.write("list") elif self.code == STRUCT: writer.write("struct") else: