Skip to content

Commit b78cb65

Browse files
committed
rpad_bug_fix
1 parent 3f84dbc commit b78cb65

File tree

2 files changed

+111
-50
lines changed

2 files changed

+111
-50
lines changed

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 102 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
use arrow::array::builder::GenericStringBuilder;
1919
use arrow::array::cast::as_dictionary_array;
2020
use arrow::array::types::Int32Type;
21-
use arrow::array::{make_array, Array, DictionaryArray};
21+
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
2222
use arrow::array::{ArrayRef, OffsetSizeTrait};
2323
use arrow::datatypes::DataType;
24-
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
24+
use datafusion::common::{cast::as_generic_string_array, DataFusionError, HashMap, ScalarValue};
2525
use datafusion::physical_plan::ColumnarValue;
2626
use std::fmt::Write;
2727
use std::sync::Arc;
@@ -42,18 +42,21 @@ fn spark_read_side_padding2(
4242
) -> Result<ColumnarValue, DataFusionError> {
4343
match args {
4444
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
45+
let rpad_arg = RPadArgument::ConstLength(*length);
4546
match array.data_type() {
46-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate),
47+
DataType::Utf8 => {
48+
spark_read_side_padding_internal::<i32>(array, truncate, rpad_arg)
49+
}
4750
DataType::LargeUtf8 => {
48-
spark_read_side_padding_internal::<i64>(array, *length, truncate)
51+
spark_read_side_padding_internal::<i64>(array, truncate, rpad_arg)
4952
}
5053
// Dictionary support required for SPARK-48498
5154
DataType::Dictionary(_, value_type) => {
5255
let dict = as_dictionary_array::<Int32Type>(array);
5356
let col = if value_type.as_ref() == &DataType::Utf8 {
54-
spark_read_side_padding_internal::<i32>(dict.values(), *length, truncate)?
57+
spark_read_side_padding_internal::<i32>(dict.values(), truncate, rpad_arg)?
5558
} else {
56-
spark_read_side_padding_internal::<i64>(dict.values(), *length, truncate)?
59+
spark_read_side_padding_internal::<i64>(dict.values(), truncate, rpad_arg)?
5760
};
5861
// col consists of an array, so arg of to_array() is not used. Can be anything
5962
let values = col.to_array(0)?;
@@ -65,20 +68,22 @@ fn spark_read_side_padding2(
6568
))),
6669
}
6770
}
68-
[ColumnarValue::Array(array), ColumnarValue::Array(arrayInt)] => {
69-
let lengthToPad = arrayInt.len() as i32;
71+
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => {
72+
let rpad_arg = RPadArgument::ColArray(Arc::clone(array_int));
7073
match array.data_type() {
71-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, lengthToPad, truncate),
74+
DataType::Utf8 => {
75+
spark_read_side_padding_internal::<i32>(array, truncate, rpad_arg)
76+
}
7277
DataType::LargeUtf8 => {
73-
spark_read_side_padding_internal::<i64>(array, lengthToPad, truncate)
78+
spark_read_side_padding_internal::<i64>(array, truncate, rpad_arg)
7479
}
7580
// Dictionary support required for SPARK-48498
7681
DataType::Dictionary(_, value_type) => {
7782
let dict = as_dictionary_array::<Int32Type>(array);
7883
let col = if value_type.as_ref() == &DataType::Utf8 {
79-
spark_read_side_padding_internal::<i32>(dict.values(), lengthToPad, truncate)?
84+
spark_read_side_padding_internal::<i32>(dict.values(), truncate, rpad_arg)?
8085
} else {
81-
spark_read_side_padding_internal::<i64>(dict.values(), lengthToPad, truncate)?
86+
spark_read_side_padding_internal::<i64>(dict.values(), truncate, rpad_arg)?
8287
};
8388
// col consists of an array, so arg of to_array() is not used. Can be anything
8489
let values = col.to_array(0)?;
@@ -96,44 +101,101 @@ fn spark_read_side_padding2(
96101
}
97102
}
98103

104+
enum RPadArgument {
105+
ConstLength(i32),
106+
ColArray(ArrayRef),
107+
}
108+
99109
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
100110
array: &ArrayRef,
101-
length: i32,
102111
truncate: bool,
112+
rpad_argument: RPadArgument,
103113
) -> Result<ColumnarValue, DataFusionError> {
104114
let string_array = as_generic_string_array::<T>(array)?;
105-
let length = 0.max(length) as usize;
106-
let space_string = " ".repeat(length);
115+
match rpad_argument {
116+
RPadArgument::ColArray(array_int) => {
117+
let int_pad_array = array_int.as_primitive::<Int32Type>();
118+
let mut str_pad_value_map = HashMap::new();
119+
for i in 0..string_array.len() {
120+
if string_array.is_null(i) || int_pad_array.is_null(i) {
121+
continue; // skip nulls
122+
}
123+
str_pad_value_map.insert(string_array.value(i), int_pad_array.value(i));
124+
}
107125

108-
let mut builder =
109-
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
126+
let mut builder = GenericStringBuilder::<T>::with_capacity(
127+
str_pad_value_map.len(),
128+
str_pad_value_map.len() * int_pad_array.len(),
129+
);
110130

111-
for string in string_array.iter() {
112-
match string {
113-
Some(string) => {
114-
// It looks Spark's UTF8String is closer to chars rather than graphemes
115-
// https://stackoverflow.com/a/46290728
116-
let char_len = string.chars().count();
117-
if length <= char_len {
118-
if truncate {
119-
let idx = string
120-
.char_indices()
121-
.nth(length)
122-
.map(|(i, _)| i)
123-
.unwrap_or(string.len());
124-
builder.append_value(&string[..idx]);
125-
} else {
126-
builder.append_value(string);
131+
for string in string_array.iter() {
132+
match string {
133+
Some(string) => {
134+
// It looks Spark's UTF8String is closer to chars rather than graphemes
135+
// https://stackoverflow.com/a/46290728
136+
let char_len = string.chars().count();
137+
let length: usize = 0.max(*str_pad_value_map.get(string).unwrap()) as usize;
138+
let space_string = " ".repeat(length);
139+
if length <= char_len {
140+
if truncate {
141+
let idx = string
142+
.char_indices()
143+
.nth(length)
144+
.map(|(i, _)| i)
145+
.unwrap_or(string.len());
146+
builder.append_value(&string[..idx]);
147+
} else {
148+
builder.append_value(string);
149+
}
150+
} else {
151+
// write_str updates only the value buffer, not null nor offset buffer
152+
// This is convenient for concatenating str(s)
153+
builder.write_str(string)?;
154+
builder.append_value(&space_string[char_len..]);
155+
}
156+
}
157+
_ => builder.append_null(),
158+
}
159+
}
160+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
161+
}
162+
RPadArgument::ConstLength(length) => {
163+
let length = 0.max(length) as usize;
164+
let space_string = " ".repeat(length);
165+
166+
let mut builder = GenericStringBuilder::<T>::with_capacity(
167+
string_array.len(),
168+
string_array.len() * length,
169+
);
170+
171+
for string in string_array.iter() {
172+
match string {
173+
Some(string) => {
174+
// It looks Spark's UTF8String is closer to chars rather than graphemes
175+
// https://stackoverflow.com/a/46290728
176+
let char_len = string.chars().count();
177+
if length <= char_len {
178+
if truncate {
179+
let idx = string
180+
.char_indices()
181+
.nth(length)
182+
.map(|(i, _)| i)
183+
.unwrap_or(string.len());
184+
builder.append_value(&string[..idx]);
185+
} else {
186+
builder.append_value(string);
187+
}
188+
} else {
189+
// write_str updates only the value buffer, not null nor offset buffer
190+
// This is convenient for concatenating str(s)
191+
builder.write_str(string)?;
192+
builder.append_value(&space_string[char_len..]);
193+
}
127194
}
128-
} else {
129-
// write_str updates only the value buffer, not null nor offset buffer
130-
// This is convenient for concatenating str(s)
131-
builder.write_str(string)?;
132-
builder.append_value(&space_string[char_len..]);
195+
_ => builder.append_null(),
133196
}
134197
}
135-
_ => builder.append_null(),
198+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
136199
}
137200
}
138-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
139201
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -322,16 +322,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
322322
checkSparkAnswer("SELECT try_add(_1, _2) FROM tbl")
323323
}
324324
}
325-
test("fix_rpad") {
326-
withTable("t1") {
327-
val value = "IfIWasARoadIWouldBeBent"
328-
sql("create table t1(c1 varchar(100), c2 int) using parquet")
329-
sql(s"insert into t1 values('$value', 10)")
330-
val res = sql("select rpad(c1,c2) from t1 order by c1")
331-
res.show(10, false)
332-
checkSparkAnswerAndOperator(res)
333-
}
334-
}
325+
test("fix_rpad") {
326+
withTable("t1") {
327+
val value = "IfIWasARoadIWouldBeBent"
328+
sql("create table t1(c1 varchar(100), c2 int) using parquet")
329+
sql(s"insert into t1 values('$value', 10)")
330+
val res = sql("select rpad(c1,c2) , rpad(c1,5) from t1 order by c1")
331+
checkSparkAnswerAndOperator(res)
332+
}
333+
}
335334

336335
test("dictionary arithmetic") {
337336
// TODO: test ANSI mode

0 commit comments

Comments
 (0)