Skip to content

Commit 5056a2b

Browse files
authored
fix(rust): Increase precision when constructing float Series (#25323)
1 parent 49adf41 commit 5056a2b

File tree

2 files changed

+147
-2
lines changed

2 files changed

+147
-2
lines changed

crates/polars-core/src/series/any_value.rs

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::fmt::Write;
22

33
use arrow::bitmap::MutableBitmap;
44
use num_traits::AsPrimitive;
5+
use polars_compute::cast::SerPrimitive;
56

67
#[cfg(feature = "dtype-categorical")]
78
use crate::chunked_array::builder::CategoricalChunkedBuilder;
@@ -308,17 +309,119 @@ fn any_values_to_string(values: &[AnyValue], strict: bool) -> PolarsResult<Strin
308309
Ok(builder.finish())
309310
}
310311
fn any_values_to_string_nonstrict(values: &[AnyValue]) -> StringChunked {
312+
fn _write_any_value(av: &AnyValue<'_>, buffer: &mut String, float_buf: &mut Vec<u8>) {
313+
match av {
314+
AnyValue::String(s) => buffer.push_str(s),
315+
AnyValue::Float64(f) => {
316+
float_buf.clear();
317+
SerPrimitive::write(float_buf, *f);
318+
let s = std::str::from_utf8(float_buf).unwrap();
319+
buffer.push_str(s);
320+
},
321+
AnyValue::Float32(f) => {
322+
float_buf.clear();
323+
SerPrimitive::write(float_buf, *f as f64);
324+
let s = std::str::from_utf8(float_buf).unwrap();
325+
buffer.push_str(s);
326+
},
327+
#[cfg(feature = "dtype-f16")]
328+
AnyValue::Float16(f) => {
329+
float_buf.clear();
330+
SerPrimitive::write(float_buf, f64::from(*f));
331+
let s = std::str::from_utf8(float_buf).unwrap();
332+
buffer.push_str(s);
333+
},
334+
#[cfg(feature = "dtype-struct")]
335+
AnyValue::StructOwned(payload) => {
336+
buffer.push('{');
337+
let mut iter = payload.0.iter().peekable();
338+
while let Some(child) = iter.next() {
339+
_write_any_value(child, buffer, float_buf);
340+
if iter.peek().is_some() {
341+
buffer.push(',')
342+
}
343+
}
344+
buffer.push('}');
345+
},
346+
#[cfg(feature = "dtype-struct")]
347+
AnyValue::Struct(_, _, flds) => {
348+
let mut vals = Vec::with_capacity(flds.len());
349+
av._materialize_struct_av(&mut vals);
350+
351+
buffer.push('{');
352+
let mut iter = vals.iter().peekable();
353+
while let Some(child) = iter.next() {
354+
_write_any_value(child, buffer, float_buf);
355+
if iter.peek().is_some() {
356+
buffer.push(',')
357+
}
358+
}
359+
buffer.push('}');
360+
},
361+
#[cfg(feature = "dtype-array")]
362+
AnyValue::Array(vals, _) => {
363+
buffer.push('[');
364+
let mut iter = vals.iter().peekable();
365+
while let Some(child) = iter.next() {
366+
_write_any_value(&child, buffer, float_buf);
367+
if iter.peek().is_some() {
368+
buffer.push(',');
369+
}
370+
}
371+
buffer.push(']');
372+
},
373+
AnyValue::List(vals) => {
374+
buffer.push('[');
375+
let mut iter = vals.iter().peekable();
376+
while let Some(child) = iter.next() {
377+
_write_any_value(&child, buffer, float_buf);
378+
if iter.peek().is_some() {
379+
buffer.push(',');
380+
}
381+
}
382+
buffer.push(']');
383+
},
384+
av => {
385+
write!(buffer, "{av}").unwrap();
386+
},
387+
}
388+
}
389+
311390
let mut builder = StringChunkedBuilder::new(PlSmallStr::EMPTY, values.len());
312391
let mut owned = String::new(); // Amortize allocations.
392+
let mut float_buf = vec![];
313393
for av in values {
394+
owned.clear();
395+
float_buf.clear();
396+
314397
match av {
315398
AnyValue::String(s) => builder.append_value(s),
316399
AnyValue::StringOwned(s) => builder.append_value(s),
317400
AnyValue::Null => builder.append_null(),
318401
AnyValue::Binary(_) | AnyValue::BinaryOwned(_) => builder.append_null(),
402+
403+
// Explicitly convert and dump floating-point values to strings
404+
// to preserve as much precision as possible.
405+
// Using write!(..., "{av}") steps through Display formatting
406+
// which rounds to an arbitrary precision thus losing information.
407+
AnyValue::Float64(f) => {
408+
SerPrimitive::write(&mut float_buf, *f);
409+
let s = std::str::from_utf8(&float_buf).unwrap();
410+
builder.append_value(s);
411+
},
412+
AnyValue::Float32(f) => {
413+
SerPrimitive::write(&mut float_buf, *f as f64); // promote to f64 for serialization
414+
let s = std::str::from_utf8(&float_buf).unwrap();
415+
builder.append_value(s);
416+
},
417+
#[cfg(feature = "dtype-f16")]
418+
AnyValue::Float16(f) => {
419+
SerPrimitive::write(&mut float_buf, f64::from(*f));
420+
let s = std::str::from_utf8(&float_buf).unwrap();
421+
builder.append_value(s);
422+
},
319423
av => {
320-
owned.clear();
321-
write!(owned, "{av}").unwrap();
424+
_write_any_value(av, &mut owned, &mut float_buf);
322425
builder.append_value(&owned);
323426
},
324427
}

py-polars/tests/unit/constructors/test_any_value_fallbacks.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, Any
88

99
import pytest
10+
from numpy import array
1011

1112
import polars as pl
1213
from polars._plr import PySeries
@@ -408,3 +409,44 @@ def test_categorical_lit_18874() -> None:
408409
]
409410
),
410411
)
412+
413+
414+
@pytest.mark.parametrize(
415+
("values", "expected"),
416+
[
417+
# Float64 should have ~17; Float32 ~6 digits of precision preserved
418+
([0.123, 0.123456789], ["0.123", "0.123456789"]),
419+
([[0.123, 0.123456789]], ["[0.123,0.123456789]"]),
420+
([array([0.123, 0.123456789])], ["[0.123,0.123456789]"]),
421+
([{"a": 0.123, "b": 0.123456789}], ["{0.123,0.123456789}"]),
422+
([[{"a": 0.123, "b": 0.123456789}]], ["[{0.123,0.123456789}]"]),
423+
([{"x": [0.1, 0.2]}, [{"y": 0.3}]], ["{[0.1,0.2]}", "[{0.3}]"]),
424+
(
425+
[None, {"a": None, "b": 1.0}, [None, 2.0]],
426+
[None, "{null,1.0}", "[null,2.0]"],
427+
),
428+
([[], {}], ["[]", "{}"]),
429+
([[0.5]], ["[0.5]"]),
430+
([{"a": 0.5}], ["{0.5}"]),
431+
],
432+
ids=[
433+
"basic_floats",
434+
"nested_list",
435+
"nested_array",
436+
"basic_struct",
437+
"list_of_structs",
438+
"nested_mixed",
439+
"mixed_nulls",
440+
"empty_containers",
441+
"single_element_list",
442+
"single_element_struct",
443+
],
444+
)
445+
def test_float_to_string_precision_25257(
446+
values: list[Any], expected: list[Any]
447+
) -> None:
448+
# verify the conversion is decoupled from Display formatting
449+
with pl.Config(float_precision=1):
450+
s = pl.Series(values, strict=False, dtype=pl.String)
451+
452+
assert (s == pl.Series(expected)).all()

0 commit comments

Comments
 (0)