Skip to content

Commit 78eb11f

Browse files
Fix write_to for ArrayData (#31)
We need to dispatch on the dynamic dtype.
1 parent 264cd68 commit 78eb11f

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

firebolt/arrays/base.mojo

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,29 @@ struct ArrayData(Copyable, Movable, Representable, Stringable, Writable):
6464
fn as_list(self) raises -> ListArray:
6565
return ListArray(self)
6666

67+
fn _dynamic_write[W: Writer](self, index: Int, mut writer: W):
68+
"""Write to the given stream dispatching on the dtype."""
69+
70+
@parameter
71+
for known_type in [
72+
DType.bool,
73+
DType.int16,
74+
DType.int32,
75+
DType.int64,
76+
DType.int8,
77+
DType.float32,
78+
DType.float64,
79+
DType.uint16,
80+
DType.uint32,
81+
DType.uint64,
82+
DType.uint8,
83+
]:
84+
if self.dtype.native == known_type:
85+
writer.write(self.buffers[0][].unsafe_get[known_type](index))
86+
return
87+
writer.write("Can't process data type:")
88+
writer.write(self.dtype)
89+
6790
fn write_to[W: Writer](self, mut writer: W):
6891
"""
6992
Formats this ArrayData to the provided Writer.
@@ -77,7 +100,8 @@ struct ArrayData(Copyable, Movable, Representable, Stringable, Writable):
77100

78101
for i in range(self.length):
79102
if self.is_valid(i):
80-
writer.write(self.buffers[0][].unsafe_get(i + self.offset))
103+
var real_index = i + self.offset
104+
self._dynamic_write(real_index, writer)
81105
else:
82106
writer.write("-")
83107
writer.write(" ")

firebolt/arrays/tests/test_base.mojo

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from testing import assert_true, assert_false, assert_equal
33
from memory import UnsafePointer, ArcPointer
44
from firebolt.arrays.base import ArrayData
55
from firebolt.buffers import Buffer, Bitmap
6-
from firebolt.dtypes import DType, int8, uint8
6+
from firebolt.dtypes import DType, int8, uint8, int64
77
from firebolt.test_fixtures.arrays import build_array_data, assert_bitmap_set
88

99

@@ -67,26 +67,28 @@ def test_array_data_write_to_with_offset():
6767
var bitmap = ArcPointer(Bitmap.alloc(10))
6868
var buffer = ArcPointer(Buffer.alloc[DType.uint8](10))
6969

70-
# Set up data with values at positions 1,2,3
71-
buffer[].unsafe_set[DType.uint8](1, 10)
72-
buffer[].unsafe_set[DType.uint8](2, 11)
73-
buffer[].unsafe_set[DType.uint8](3, 12)
74-
75-
# Set validity for positions 1,2,3
76-
bitmap[].unsafe_set(1, True)
77-
bitmap[].unsafe_set(2, True)
78-
bitmap[].unsafe_set(3, True)
79-
80-
# Create ArrayData with offset=1, so logical indices 0,1,2 map to physical indices 1,2,3
81-
var array_data = ArrayData(
82-
dtype=uint8,
83-
length=3,
84-
bitmap=bitmap,
85-
buffers=List(buffer),
86-
children=List[ArcPointer[ArrayData]](),
87-
offset=1,
88-
)
89-
90-
var writer = String()
91-
writer.write(array_data)
92-
assert_equal(writer.strip(), "10 11 12")
70+
@parameter
71+
for dtype in [uint8, int64]:
72+
# Set up data with values at positions 1,2,3
73+
buffer[].unsafe_set[dtype.native](1, 10)
74+
buffer[].unsafe_set[dtype.native](2, 11)
75+
buffer[].unsafe_set[dtype.native](3, 12)
76+
77+
# Set validity for positions 1,2,3
78+
bitmap[].unsafe_set(1, True)
79+
bitmap[].unsafe_set(2, True)
80+
bitmap[].unsafe_set(3, True)
81+
82+
# Create ArrayData with offset=1, so logical indices 0,1,2 map to physical indices 1,2,3
83+
var array_data = ArrayData(
84+
dtype=dtype,
85+
length=3,
86+
bitmap=bitmap,
87+
buffers=List(buffer),
88+
children=List[ArcPointer[ArrayData]](),
89+
offset=1,
90+
)
91+
92+
var writer = String()
93+
writer.write(array_data)
94+
assert_equal(writer.strip(), "10 11 12")

0 commit comments

Comments
 (0)