Skip to content

Commit 1b73dcb

Browse files
Feature: StructArray unsafe_get (#41)
1 parent 65b9076 commit 1b73dcb

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

firebolt/arrays/nested.mojo

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ struct StructArray(Array):
153153
offset=0,
154154
)
155155

156+
fn __init__(out self, *, var data: ArrayData):
157+
self.fields = data.dtype.fields.copy()
158+
self.capacity = data.length
159+
self.data = data^
160+
156161
fn __moveinit__(out self, deinit existing: Self):
157162
self.data = existing.data^
158163
self.fields = existing.fields^
@@ -185,6 +190,19 @@ struct StructArray(Array):
185190
writer.write(self.data.length)
186191
writer.write(")")
187192

193+
fn _index_for_field_name(self, name: StringSlice) raises -> Int:
194+
for idx, ref field in enumerate(self.data.dtype.fields):
195+
if field.name == name:
196+
return idx
197+
198+
raise Error("Field {} does not exist in this StructArray.".format(name))
199+
200+
fn unsafe_get(
201+
self, name: StringSlice
202+
) raises -> ref [self.data.children[0]] ArrayData:
203+
"""Access the field with the given name in the struct."""
204+
return self.data.children[self._index_for_field_name(name)][]
205+
188206
fn __str__(self) -> String:
189207
return String.write(self)
190208

firebolt/arrays/tests/test_nested.mojo

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,5 +175,40 @@ def test_struct_array_str_repr():
175175
assert_equal(str_repr, repr_repr)
176176

177177

178+
def build_struct() -> StructArray:
179+
var int_data_a = ArrayData.from_buffer[int32](
180+
Buffer.from_values[DType.int32](1, 2, 3, 4, 5), 5
181+
)
182+
var field_1 = Field("int_data_a", materialize[int32]())
183+
184+
var int_data_b = ArrayData.from_buffer[int32](
185+
Buffer.from_values[DType.int32](10, 20, 30), 3
186+
)
187+
var field_2 = Field("int_data_b", materialize[int32]())
188+
bitmap = Bitmap.alloc(2)
189+
bitmap.unsafe_range_set(0, 2, True)
190+
var struct_array_data = ArrayData(
191+
dtype=struct_(List(field_1^, field_2^)),
192+
length=2,
193+
bitmap=ArcPointer(bitmap^),
194+
offset=0,
195+
buffers=List[ArcPointer[Buffer]](),
196+
children=List(ArcPointer(int_data_a^), ArcPointer(int_data_b^)),
197+
)
198+
return StructArray(data=struct_array_data^)
199+
200+
201+
def test_struct_array_unsafe_get():
202+
var struct_array = build_struct()
203+
ref int_data_a = struct_array.unsafe_get("int_data_a")
204+
var int_a = Int32Array(int_data_a.copy())
205+
assert_equal(int_a.unsafe_get(0), 1)
206+
assert_equal(int_a.unsafe_get(4), 5)
207+
ref int_data_b = struct_array.unsafe_get("int_data_b")
208+
var int_b = Int32Array(int_data_b.copy())
209+
assert_equal(int_b.unsafe_get(0), 10)
210+
assert_equal(int_b.unsafe_get(2), 30)
211+
212+
178213
fn main() raises:
179214
test_list_of_list()

0 commit comments

Comments
 (0)