Skip to content

Commit 5814c7e

Browse files
Do not accept null is_set for first_value/last_value (#18301)
## Which issue does this PR close? - Closes #18300 ## Rationale for this change As laid out in the issue, this improves internal checks by testing an assumed invariant, instead of silently nulling data on error. The cost is a single null check on a column with a number of entries dependent on the number of partitions, not the data itself. ## What changes are included in this PR? * Adds a null check to the second column of `merge_batch` of both `FIRST_VALUE` and `LAST_VALUE`. ## Are these changes tested? Tests are included. ## Are there any user-facing changes? Hopefully not.
1 parent bea4b68 commit 5814c7e

File tree

1 file changed

+105
-1
lines changed

1 file changed

+105
-1
lines changed

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,8 @@ impl Accumulator for TrivialFirstValueAccumulator {
817817
// Second index contains is_set flag.
818818
if !self.is_set {
819819
let flags = states[1].as_boolean();
820+
validate_is_set_flags(flags, "first_value")?;
821+
820822
let filtered_states =
821823
filter_states_according_to_is_set(&states[0..1], flags)?;
822824
if let Some(first) = filtered_states.first() {
@@ -962,6 +964,8 @@ impl Accumulator for FirstValueAccumulator {
962964
// last index contains is_set flag.
963965
let is_set_idx = states.len() - 1;
964966
let flags = states[is_set_idx].as_boolean();
967+
validate_is_set_flags(flags, "first_value")?;
968+
965969
let filtered_states =
966970
filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
967971
// 1..is_set_idx range corresponds to ordering section
@@ -1299,6 +1303,8 @@ impl Accumulator for TrivialLastValueAccumulator {
12991303
// LAST_VALUE(last1, last2, last3, ...)
13001304
// Second index contains is_set flag.
13011305
let flags = states[1].as_boolean();
1306+
validate_is_set_flags(flags, "last_value")?;
1307+
13021308
let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?;
13031309
if let Some(last) = filtered_states.last() {
13041310
if !last.is_empty() {
@@ -1444,6 +1450,8 @@ impl Accumulator for LastValueAccumulator {
14441450
// last index contains is_set flag.
14451451
let is_set_idx = states.len() - 1;
14461452
let flags = states[is_set_idx].as_boolean();
1453+
validate_is_set_flags(flags, "last_value")?;
1454+
14471455
let filtered_states =
14481456
filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
14491457
// 1..is_set_idx range corresponds to ordering section
@@ -1487,6 +1495,16 @@ impl Accumulator for LastValueAccumulator {
14871495
}
14881496
}
14891497

1498+
/// Validates that `is_set flags` do not contain NULL values.
1499+
fn validate_is_set_flags(flags: &BooleanArray, function_name: &str) -> Result<()> {
1500+
if flags.null_count() > 0 {
1501+
return Err(DataFusionError::Internal(format!(
1502+
"{function_name}: is_set flags contain nulls"
1503+
)));
1504+
}
1505+
Ok(())
1506+
}
1507+
14901508
/// Filters states according to the `is_set` flag at the last column and returns
14911509
/// the resulting states.
14921510
fn filter_states_according_to_is_set(
@@ -1515,7 +1533,7 @@ mod tests {
15151533
use std::iter::repeat_with;
15161534

15171535
use arrow::{
1518-
array::{Int64Array, ListArray},
1536+
array::{BooleanArray, Int64Array, ListArray, StringArray},
15191537
compute::SortOptions,
15201538
datatypes::Schema,
15211539
};
@@ -1928,4 +1946,90 @@ mod tests {
19281946

19291947
Ok(())
19301948
}
1949+
1950+
#[test]
1951+
fn test_first_value_merge_with_is_set_nulls() -> Result<()> {
1952+
// Test data with corrupted is_set flag
1953+
let value = Arc::new(StringArray::from(vec![Some("first_string")])) as ArrayRef;
1954+
let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1955+
1956+
// Test TrivialFirstValueAccumulator
1957+
let mut trivial_accumulator =
1958+
TrivialFirstValueAccumulator::try_new(&DataType::Utf8, false)?;
1959+
let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
1960+
let result = trivial_accumulator.merge_batch(&trivial_states);
1961+
assert!(result.is_err());
1962+
assert!(result
1963+
.unwrap_err()
1964+
.to_string()
1965+
.contains("is_set flags contain nulls"));
1966+
1967+
// Test FirstValueAccumulator (with ordering)
1968+
let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
1969+
let ordering_expr = col("ordering", &schema)?;
1970+
let mut ordered_accumulator = FirstValueAccumulator::try_new(
1971+
&DataType::Utf8,
1972+
&[DataType::Int64],
1973+
LexOrdering::new(vec![PhysicalSortExpr {
1974+
expr: ordering_expr,
1975+
options: SortOptions::default(),
1976+
}])
1977+
.unwrap(),
1978+
false,
1979+
false,
1980+
)?;
1981+
let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
1982+
let ordered_states = vec![value, ordering, corrupted_flag];
1983+
let result = ordered_accumulator.merge_batch(&ordered_states);
1984+
assert!(result.is_err());
1985+
assert!(result
1986+
.unwrap_err()
1987+
.to_string()
1988+
.contains("is_set flags contain nulls"));
1989+
1990+
Ok(())
1991+
}
1992+
1993+
#[test]
1994+
fn test_last_value_merge_with_is_set_nulls() -> Result<()> {
1995+
// Test data with corrupted is_set flag
1996+
let value = Arc::new(StringArray::from(vec![Some("last_string")])) as ArrayRef;
1997+
let corrupted_flag = Arc::new(BooleanArray::from(vec![None])) as ArrayRef;
1998+
1999+
// Test TrivialLastValueAccumulator
2000+
let mut trivial_accumulator =
2001+
TrivialLastValueAccumulator::try_new(&DataType::Utf8, false)?;
2002+
let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)];
2003+
let result = trivial_accumulator.merge_batch(&trivial_states);
2004+
assert!(result.is_err());
2005+
assert!(result
2006+
.unwrap_err()
2007+
.to_string()
2008+
.contains("is_set flags contain nulls"));
2009+
2010+
// Test LastValueAccumulator (with ordering)
2011+
let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]);
2012+
let ordering_expr = col("ordering", &schema)?;
2013+
let mut ordered_accumulator = LastValueAccumulator::try_new(
2014+
&DataType::Utf8,
2015+
&[DataType::Int64],
2016+
LexOrdering::new(vec![PhysicalSortExpr {
2017+
expr: ordering_expr,
2018+
options: SortOptions::default(),
2019+
}])
2020+
.unwrap(),
2021+
false,
2022+
false,
2023+
)?;
2024+
let ordering = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
2025+
let ordered_states = vec![value, ordering, corrupted_flag];
2026+
let result = ordered_accumulator.merge_batch(&ordered_states);
2027+
assert!(result.is_err());
2028+
assert!(result
2029+
.unwrap_err()
2030+
.to_string()
2031+
.contains("is_set flags contain nulls"));
2032+
2033+
Ok(())
2034+
}
19312035
}

0 commit comments

Comments
 (0)