From c1c852a92ced6c078a9a741dbadc56fec597ad4a Mon Sep 17 00:00:00 2001 From: Marius Seritan <39998+winding-lines@users.noreply.github.com> Date: Sun, 7 Sep 2025 10:38:42 -0700 Subject: [PATCH] Improve StringArray write_to --- firebolt/arrays/base.mojo | 4 ++++ firebolt/arrays/binary.mojo | 9 ++++++++- firebolt/arrays/tests/test_binary.mojo | 5 +++++ firebolt/dtypes.mojo | 4 ++++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/firebolt/arrays/base.mojo b/firebolt/arrays/base.mojo index 2965998..b8d2d4b 100644 --- a/firebolt/arrays/base.mojo +++ b/firebolt/arrays/base.mojo @@ -84,6 +84,10 @@ struct ArrayData(Copyable, Movable, Representable, Stringable, Writable): if self.dtype.native == known_type: writer.write(self.buffers[0][].unsafe_get[known_type](index)) return + if self.dtype.is_string(): + # Should print a StringArray through the element specific write_to. + writer.write("") + return writer.write("dtype=") writer.write(self.dtype) diff --git a/firebolt/arrays/binary.mojo b/firebolt/arrays/binary.mojo index 8be054c..af73e23 100644 --- a/firebolt/arrays/binary.mojo +++ b/firebolt/arrays/binary.mojo @@ -116,7 +116,14 @@ struct StringArray(Array): writer.write("StringArray( length=") writer.write(self.data.length) - writer.write(")") + writer.write(", data= [") + for i in range(self.data.length): + writer.write('"') + writer.write(self.unsafe_get((i))) + writer.write('", ') + if i > 1: + break + writer.write(" ])") fn __str__(self) -> String: return String.write(self) diff --git a/firebolt/arrays/tests/test_binary.mojo b/firebolt/arrays/tests/test_binary.mojo index fea53ad..ac74847 100644 --- a/firebolt/arrays/tests/test_binary.mojo +++ b/firebolt/arrays/tests/test_binary.mojo @@ -21,3 +21,8 @@ def test_string_builder(): var s = a.unsafe_get(0) assert_equal(String(s), "hello") + + assert_equal( + a.__str__().strip(), + 'StringArray( length=2, data= ["hello", "world", ])', + ) diff --git a/firebolt/dtypes.mojo b/firebolt/dtypes.mojo index 94b6c14..1fac92c 100644 --- a/firebolt/dtypes.mojo +++ b/firebolt/dtypes.mojo @@ -384,6 +384,10 @@ struct DataType( fn is_numeric(self) -> Bool: return self.is_integer() or self.is_floating_point() + @always_inline + fn is_string(self) -> Bool: + return self.code == STRING + @always_inline fn is_list(self) -> Bool: return self.code == LIST