Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
use arrow::array::builder::GenericStringBuilder;
use arrow::array::cast::as_dictionary_array;
use arrow::array::types::Int32Type;
use arrow::array::{make_array, Array, DictionaryArray};
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
use datafusion::physical_plan::ColumnarValue;
use std::fmt::Write;
use std::sync::Arc;

/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
Expand All @@ -43,17 +42,31 @@ fn spark_read_side_padding2(
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate),
DataType::LargeUtf8 => {
spark_read_side_padding_internal::<i64>(array, *length, truncate)
}
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
),
// Dictionary support required for SPARK-48498
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apache/spark#46832

This seems related to padding. How does this affect dictionary encoded columns?

DataType::Dictionary(_, value_type) => {
let dict = as_dictionary_array::<Int32Type>(array);
let col = if value_type.as_ref() == &DataType::Utf8 {
spark_read_side_padding_internal::<i32>(dict.values(), *length, truncate)?
spark_read_side_padding_internal::<i32>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
)?
} else {
spark_read_side_padding_internal::<i64>(dict.values(), *length, truncate)?
spark_read_side_padding_internal::<i64>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
)?
};
// col consists of an array, so arg of to_array() is not used. Can be anything
let values = col.to_array(0)?;
Expand All @@ -65,6 +78,21 @@ fn spark_read_side_padding2(
))),
}
}
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
},
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function rpad/read_side_padding",
))),
Expand All @@ -73,42 +101,71 @@ fn spark_read_side_padding2(

fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
truncate: bool,
pad_type: ColumnarValue,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let space_string = " ".repeat(length);
match pad_type {
ColumnarValue::Array(array_int) => {
let int_pad_array = array_int.as_primitive::<Int32Type>();

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.len() * int_pad_array.len(),
);

for string in string_array.iter() {
match string {
Some(string) => {
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
let char_len = string.chars().count();
if length <= char_len {
if truncate {
let idx = string
.char_indices()
.nth(length)
.map(|(i, _)| i)
.unwrap_or(string.len());
builder.append_value(&string[..idx]);
} else {
builder.append_value(string);
}
} else {
// write_str updates only the value buffer, not null nor offset buffer
// This is convenient for concatenating str(s)
builder.write_str(string)?;
builder.append_value(&space_string[char_len..]);
for (string, length) in string_array.iter().zip(int_pad_array) {
match string {
Some(string) => builder.append_value(add_padding_string(
string.parse().unwrap(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

length.unwrap() as usize,
truncate,
)),
_ => builder.append_null(),
}
}
_ => builder.append_null(),
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
ColumnarValue::Scalar(const_pad_length) => {
let length = 0.max(i32::try_from(const_pad_length)?) as usize;

let mut builder = GenericStringBuilder::<T>::with_capacity(
string_array.len(),
string_array.len() * length,
);

for string in string_array.iter() {
match string {
Some(string) => builder.append_value(add_padding_string(
string.parse().unwrap(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its good to avoid unwraps and return Err instead

length,
truncate,
)),
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
}

fn add_padding_string(string: String, length: usize, truncate: bool) -> String {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we can think of impl like

fn add_padding_string(input: String, length: usize, truncate: bool) -> String {
    let char_len = input.chars().count();

    if char_len >= length {
        if truncate {
            // Take the first `length` chars safely
            input.chars().take(length).collect()
        } else {
            input
        }
    } else {
        // Pad with only the needed spaces
        let padding = " ".repeat(length - char_len);
        input + &padding
    }
}

so we don't allocate spaces if its not needed
no unwrap

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refering string by index, is it unicode safe? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great suggestion. My goal for now was to keep the original implementation intact and not introduce changes which directly doesn't solve the issue

// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
let space_string = " ".repeat(length);
let char_len = string.chars().count();
if length <= char_len {
if truncate {
let idx = string
.char_indices()
.nth(length)
.map(|(i, _)| i)
.unwrap_or(string.len());
string[..idx].parse().unwrap()
} else {
string
}
} else {
string + &space_string[char_len..]
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
10 changes: 10 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}
test("Verify rpad expr support for second arg instead of just literal") {
withTable("t1") {
val value = "IfIWasARoadIWouldBeBent"
sql("create table t1(c1 varchar(100), c2 int) using parquet")
sql(s"insert into t1 values('$value', 10)")
sql(s"insert into t1 values((${null}, 10))")
val res = sql("select rpad(c1,c2) , rpad(c1,5) from t1 order by c1")
checkSparkAnswerAndOperator(res)
}
}

test("dictionary arithmetic") {
// TODO: test ANSI mode
Expand Down
Loading