Skip to content

Commit 3d68aa9

Browse files
committed
address_review_comments_rpad
1 parent b78cb65 commit 3d68aa9

File tree

2 files changed

+36
-59
lines changed

2 files changed

+36
-59
lines changed

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 35 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ use arrow::array::types::Int32Type;
2121
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
2222
use arrow::array::{ArrayRef, OffsetSizeTrait};
2323
use arrow::datatypes::DataType;
24-
use datafusion::common::{cast::as_generic_string_array, DataFusionError, HashMap, ScalarValue};
24+
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
2525
use datafusion::physical_plan::ColumnarValue;
26-
use std::fmt::Write;
2726
use std::sync::Arc;
2827

2928
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
@@ -115,53 +114,26 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
115114
match rpad_argument {
116115
RPadArgument::ColArray(array_int) => {
117116
let int_pad_array = array_int.as_primitive::<Int32Type>();
118-
let mut str_pad_value_map = HashMap::new();
119-
for i in 0..string_array.len() {
120-
if string_array.is_null(i) || int_pad_array.is_null(i) {
121-
continue; // skip nulls
122-
}
123-
str_pad_value_map.insert(string_array.value(i), int_pad_array.value(i));
124-
}
125117

126118
let mut builder = GenericStringBuilder::<T>::with_capacity(
127-
str_pad_value_map.len(),
128-
str_pad_value_map.len() * int_pad_array.len(),
119+
string_array.len(),
120+
string_array.len() * int_pad_array.len(),
129121
);
130122

131-
for string in string_array.iter() {
123+
for (string, length) in string_array.iter().zip(int_pad_array) {
132124
match string {
133-
Some(string) => {
134-
// It looks Spark's UTF8String is closer to chars rather than graphemes
135-
// https://stackoverflow.com/a/46290728
136-
let char_len = string.chars().count();
137-
let length: usize = 0.max(*str_pad_value_map.get(string).unwrap()) as usize;
138-
let space_string = " ".repeat(length);
139-
if length <= char_len {
140-
if truncate {
141-
let idx = string
142-
.char_indices()
143-
.nth(length)
144-
.map(|(i, _)| i)
145-
.unwrap_or(string.len());
146-
builder.append_value(&string[..idx]);
147-
} else {
148-
builder.append_value(string);
149-
}
150-
} else {
151-
// write_str updates only the value buffer, not null nor offset buffer
152-
// This is convenient for concatenating str(s)
153-
builder.write_str(string)?;
154-
builder.append_value(&space_string[char_len..]);
155-
}
156-
}
125+
Some(string) => builder.append_value(add_padding_string(
126+
string.parse().unwrap(),
127+
length.unwrap() as usize,
128+
truncate,
129+
)),
157130
_ => builder.append_null(),
158131
}
159132
}
160133
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
161134
}
162135
RPadArgument::ConstLength(length) => {
163136
let length = 0.max(length) as usize;
164-
let space_string = " ".repeat(length);
165137

166138
let mut builder = GenericStringBuilder::<T>::with_capacity(
167139
string_array.len(),
@@ -170,32 +142,36 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
170142

171143
for string in string_array.iter() {
172144
match string {
173-
Some(string) => {
174-
// It looks Spark's UTF8String is closer to chars rather than graphemes
175-
// https://stackoverflow.com/a/46290728
176-
let char_len = string.chars().count();
177-
if length <= char_len {
178-
if truncate {
179-
let idx = string
180-
.char_indices()
181-
.nth(length)
182-
.map(|(i, _)| i)
183-
.unwrap_or(string.len());
184-
builder.append_value(&string[..idx]);
185-
} else {
186-
builder.append_value(string);
187-
}
188-
} else {
189-
// write_str updates only the value buffer, not null nor offset buffer
190-
// This is convenient for concatenating str(s)
191-
builder.write_str(string)?;
192-
builder.append_value(&space_string[char_len..]);
193-
}
194-
}
145+
Some(string) => builder.append_value(add_padding_string(
146+
string.parse().unwrap(),
147+
length,
148+
truncate,
149+
)),
195150
_ => builder.append_null(),
196151
}
197152
}
198153
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
199154
}
200155
}
201156
}
157+
158+
fn add_padding_string(string: String, length: usize, truncate: bool) -> String {
159+
// It looks Spark's UTF8String is closer to chars rather than graphemes
160+
// https://stackoverflow.com/a/46290728
161+
let space_string = " ".repeat(length);
162+
let char_len = string.chars().count();
163+
if length <= char_len {
164+
if truncate {
165+
let idx = string
166+
.char_indices()
167+
.nth(length)
168+
.map(|(i, _)| i)
169+
.unwrap_or(string.len());
170+
string[..idx].parse().unwrap()
171+
} else {
172+
string
173+
}
174+
} else {
175+
string + &space_string[char_len..]
176+
}
177+
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
327327
val value = "IfIWasARoadIWouldBeBent"
328328
sql("create table t1(c1 varchar(100), c2 int) using parquet")
329329
sql(s"insert into t1 values('$value', 10)")
330+
sql(s"insert into t1 values((${null}, 10))")
330331
val res = sql("select rpad(c1,c2) , rpad(c1,5) from t1 order by c1")
331332
checkSparkAnswerAndOperator(res)
332333
}

0 commit comments

Comments
 (0)