-
Notifications
You must be signed in to change notification settings - Fork 236
feat: rpad support column for second arg instead of just literal #2099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
32d66c7
fb19f95
1a4b082
76ab555
7e8bf61
0fc9f93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,12 +18,11 @@ | |
use arrow::array::builder::GenericStringBuilder; | ||
use arrow::array::cast::as_dictionary_array; | ||
use arrow::array::types::Int32Type; | ||
use arrow::array::{make_array, Array, DictionaryArray}; | ||
use arrow::array::{make_array, Array, AsArray, DictionaryArray}; | ||
use arrow::array::{ArrayRef, OffsetSizeTrait}; | ||
use arrow::datatypes::DataType; | ||
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; | ||
use datafusion::physical_plan::ColumnarValue; | ||
use std::fmt::Write; | ||
use std::sync::Arc; | ||
|
||
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length | ||
|
@@ -43,17 +42,31 @@ fn spark_read_side_padding2( | |
match args { | ||
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { | ||
match array.data_type() { | ||
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate), | ||
DataType::LargeUtf8 => { | ||
spark_read_side_padding_internal::<i64>(array, *length, truncate) | ||
} | ||
DataType::Utf8 => spark_read_side_padding_internal::<i32>( | ||
array, | ||
truncate, | ||
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), | ||
), | ||
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>( | ||
array, | ||
truncate, | ||
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), | ||
), | ||
// Dictionary support required for SPARK-48498 | ||
DataType::Dictionary(_, value_type) => { | ||
let dict = as_dictionary_array::<Int32Type>(array); | ||
let col = if value_type.as_ref() == &DataType::Utf8 { | ||
spark_read_side_padding_internal::<i32>(dict.values(), *length, truncate)? | ||
spark_read_side_padding_internal::<i32>( | ||
dict.values(), | ||
truncate, | ||
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), | ||
)? | ||
} else { | ||
spark_read_side_padding_internal::<i64>(dict.values(), *length, truncate)? | ||
spark_read_side_padding_internal::<i64>( | ||
dict.values(), | ||
truncate, | ||
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), | ||
)? | ||
}; | ||
// col consists of an array, so arg of to_array() is not used. Can be anything | ||
let values = col.to_array(0)?; | ||
|
@@ -65,6 +78,21 @@ fn spark_read_side_padding2( | |
))), | ||
} | ||
} | ||
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => match array.data_type() { | ||
DataType::Utf8 => spark_read_side_padding_internal::<i32>( | ||
array, | ||
truncate, | ||
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)), | ||
), | ||
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>( | ||
array, | ||
truncate, | ||
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)), | ||
), | ||
other => Err(DataFusionError::Internal(format!( | ||
"Unsupported data type {other:?} for function rpad/read_side_padding", | ||
))), | ||
}, | ||
other => Err(DataFusionError::Internal(format!( | ||
"Unsupported arguments {other:?} for function rpad/read_side_padding", | ||
))), | ||
|
@@ -73,42 +101,71 @@ fn spark_read_side_padding2( | |
|
||
fn spark_read_side_padding_internal<T: OffsetSizeTrait>( | ||
array: &ArrayRef, | ||
length: i32, | ||
truncate: bool, | ||
pad_type: ColumnarValue, | ||
) -> Result<ColumnarValue, DataFusionError> { | ||
let string_array = as_generic_string_array::<T>(array)?; | ||
let length = 0.max(length) as usize; | ||
let space_string = " ".repeat(length); | ||
match pad_type { | ||
ColumnarValue::Array(array_int) => { | ||
let int_pad_array = array_int.as_primitive::<Int32Type>(); | ||
|
||
let mut builder = | ||
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length); | ||
let mut builder = GenericStringBuilder::<T>::with_capacity( | ||
string_array.len(), | ||
string_array.len() * int_pad_array.len(), | ||
); | ||
|
||
for string in string_array.iter() { | ||
match string { | ||
Some(string) => { | ||
// It looks Spark's UTF8String is closer to chars rather than graphemes | ||
// https://stackoverflow.com/a/46290728 | ||
let char_len = string.chars().count(); | ||
if length <= char_len { | ||
if truncate { | ||
let idx = string | ||
.char_indices() | ||
.nth(length) | ||
.map(|(i, _)| i) | ||
.unwrap_or(string.len()); | ||
builder.append_value(&string[..idx]); | ||
} else { | ||
builder.append_value(string); | ||
} | ||
} else { | ||
// write_str updates only the value buffer, not null nor offset buffer | ||
// This is convenient for concatenating str(s) | ||
builder.write_str(string)?; | ||
builder.append_value(&space_string[char_len..]); | ||
for (string, length) in string_array.iter().zip(int_pad_array) { | ||
match string { | ||
Some(string) => builder.append_value(add_padding_string( | ||
string.parse().unwrap(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
length.unwrap() as usize, | ||
truncate, | ||
)), | ||
_ => builder.append_null(), | ||
} | ||
} | ||
_ => builder.append_null(), | ||
Ok(ColumnarValue::Array(Arc::new(builder.finish()))) | ||
} | ||
ColumnarValue::Scalar(const_pad_length) => { | ||
let length = 0.max(i32::try_from(const_pad_length)?) as usize; | ||
|
||
let mut builder = GenericStringBuilder::<T>::with_capacity( | ||
string_array.len(), | ||
string_array.len() * length, | ||
); | ||
|
||
for string in string_array.iter() { | ||
match string { | ||
Some(string) => builder.append_value(add_padding_string( | ||
string.parse().unwrap(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its good to avoid unwraps and return Err instead |
||
length, | ||
truncate, | ||
)), | ||
_ => builder.append_null(), | ||
} | ||
} | ||
Ok(ColumnarValue::Array(Arc::new(builder.finish()))) | ||
} | ||
} | ||
} | ||
|
||
fn add_padding_string(string: String, length: usize, truncate: bool) -> String { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps we can think of impl like
so we don't allocate spaces if its not needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refering string by index, is it unicode safe? 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a great suggestion. My goal for now was to keep the original implementation intact and not introduce changes which directly doesn't solve the issue |
||
// It looks Spark's UTF8String is closer to chars rather than graphemes | ||
// https://stackoverflow.com/a/46290728 | ||
let space_string = " ".repeat(length); | ||
let char_len = string.chars().count(); | ||
if length <= char_len { | ||
if truncate { | ||
let idx = string | ||
.char_indices() | ||
.nth(length) | ||
.map(|(i, _)| i) | ||
.unwrap_or(string.len()); | ||
string[..idx].parse().unwrap() | ||
} else { | ||
string | ||
} | ||
} else { | ||
string + &space_string[char_len..] | ||
} | ||
Ok(ColumnarValue::Array(Arc::new(builder.finish()))) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apache/spark#46832
This seems related to padding. How does this affect dictionary encoded columns?