@@ -817,6 +817,8 @@ impl Accumulator for TrivialFirstValueAccumulator {
817817 // Second index contains is_set flag.
818818 if !self . is_set {
819819 let flags = states[ 1 ] . as_boolean ( ) ;
820+ validate_is_set_flags ( flags, "first_value" ) ?;
821+
820822 let filtered_states =
821823 filter_states_according_to_is_set ( & states[ 0 ..1 ] , flags) ?;
822824 if let Some ( first) = filtered_states. first ( ) {
@@ -962,6 +964,8 @@ impl Accumulator for FirstValueAccumulator {
962964 // last index contains is_set flag.
963965 let is_set_idx = states. len ( ) - 1 ;
964966 let flags = states[ is_set_idx] . as_boolean ( ) ;
967+ validate_is_set_flags ( flags, "first_value" ) ?;
968+
965969 let filtered_states =
966970 filter_states_according_to_is_set ( & states[ 0 ..is_set_idx] , flags) ?;
967971 // 1..is_set_idx range corresponds to ordering section
@@ -1299,6 +1303,8 @@ impl Accumulator for TrivialLastValueAccumulator {
12991303 // LAST_VALUE(last1, last2, last3, ...)
13001304 // Second index contains is_set flag.
13011305 let flags = states[ 1 ] . as_boolean ( ) ;
1306+ validate_is_set_flags ( flags, "last_value" ) ?;
1307+
13021308 let filtered_states = filter_states_according_to_is_set ( & states[ 0 ..1 ] , flags) ?;
13031309 if let Some ( last) = filtered_states. last ( ) {
13041310 if !last. is_empty ( ) {
@@ -1444,6 +1450,8 @@ impl Accumulator for LastValueAccumulator {
14441450 // last index contains is_set flag.
14451451 let is_set_idx = states. len ( ) - 1 ;
14461452 let flags = states[ is_set_idx] . as_boolean ( ) ;
1453+ validate_is_set_flags ( flags, "last_value" ) ?;
1454+
14471455 let filtered_states =
14481456 filter_states_according_to_is_set ( & states[ 0 ..is_set_idx] , flags) ?;
14491457 // 1..is_set_idx range corresponds to ordering section
@@ -1487,6 +1495,16 @@ impl Accumulator for LastValueAccumulator {
14871495 }
14881496}
14891497
1498+ /// Validates that `is_set flags` do not contain NULL values.
1499+ fn validate_is_set_flags ( flags : & BooleanArray , function_name : & str ) -> Result < ( ) > {
1500+ if flags. null_count ( ) > 0 {
1501+ return Err ( DataFusionError :: Internal ( format ! (
1502+ "{function_name}: is_set flags contain nulls"
1503+ ) ) ) ;
1504+ }
1505+ Ok ( ( ) )
1506+ }
1507+
14901508/// Filters states according to the `is_set` flag at the last column and returns
14911509/// the resulting states.
14921510fn filter_states_according_to_is_set (
@@ -1515,7 +1533,7 @@ mod tests {
15151533 use std:: iter:: repeat_with;
15161534
15171535 use arrow:: {
1518- array:: { Int64Array , ListArray } ,
1536+ array:: { BooleanArray , Int64Array , ListArray , StringArray } ,
15191537 compute:: SortOptions ,
15201538 datatypes:: Schema ,
15211539 } ;
@@ -1928,4 +1946,90 @@ mod tests {
19281946
19291947 Ok ( ( ) )
19301948 }
1949+
1950+ #[ test]
1951+ fn test_first_value_merge_with_is_set_nulls ( ) -> Result < ( ) > {
1952+ // Test data with corrupted is_set flag
1953+ let value = Arc :: new ( StringArray :: from ( vec ! [ Some ( "first_string" ) ] ) ) as ArrayRef ;
1954+ let corrupted_flag = Arc :: new ( BooleanArray :: from ( vec ! [ None ] ) ) as ArrayRef ;
1955+
1956+ // Test TrivialFirstValueAccumulator
1957+ let mut trivial_accumulator =
1958+ TrivialFirstValueAccumulator :: try_new ( & DataType :: Utf8 , false ) ?;
1959+ let trivial_states = vec ! [ Arc :: clone( & value) , Arc :: clone( & corrupted_flag) ] ;
1960+ let result = trivial_accumulator. merge_batch ( & trivial_states) ;
1961+ assert ! ( result. is_err( ) ) ;
1962+ assert ! ( result
1963+ . unwrap_err( )
1964+ . to_string( )
1965+ . contains( "is_set flags contain nulls" ) ) ;
1966+
1967+ // Test FirstValueAccumulator (with ordering)
1968+ let schema = Schema :: new ( vec ! [ Field :: new( "ordering" , DataType :: Int64 , false ) ] ) ;
1969+ let ordering_expr = col ( "ordering" , & schema) ?;
1970+ let mut ordered_accumulator = FirstValueAccumulator :: try_new (
1971+ & DataType :: Utf8 ,
1972+ & [ DataType :: Int64 ] ,
1973+ LexOrdering :: new ( vec ! [ PhysicalSortExpr {
1974+ expr: ordering_expr,
1975+ options: SortOptions :: default ( ) ,
1976+ } ] )
1977+ . unwrap ( ) ,
1978+ false ,
1979+ false ,
1980+ ) ?;
1981+ let ordering = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) ] ) ) as ArrayRef ;
1982+ let ordered_states = vec ! [ value, ordering, corrupted_flag] ;
1983+ let result = ordered_accumulator. merge_batch ( & ordered_states) ;
1984+ assert ! ( result. is_err( ) ) ;
1985+ assert ! ( result
1986+ . unwrap_err( )
1987+ . to_string( )
1988+ . contains( "is_set flags contain nulls" ) ) ;
1989+
1990+ Ok ( ( ) )
1991+ }
1992+
1993+ #[ test]
1994+ fn test_last_value_merge_with_is_set_nulls ( ) -> Result < ( ) > {
1995+ // Test data with corrupted is_set flag
1996+ let value = Arc :: new ( StringArray :: from ( vec ! [ Some ( "last_string" ) ] ) ) as ArrayRef ;
1997+ let corrupted_flag = Arc :: new ( BooleanArray :: from ( vec ! [ None ] ) ) as ArrayRef ;
1998+
1999+ // Test TrivialLastValueAccumulator
2000+ let mut trivial_accumulator =
2001+ TrivialLastValueAccumulator :: try_new ( & DataType :: Utf8 , false ) ?;
2002+ let trivial_states = vec ! [ Arc :: clone( & value) , Arc :: clone( & corrupted_flag) ] ;
2003+ let result = trivial_accumulator. merge_batch ( & trivial_states) ;
2004+ assert ! ( result. is_err( ) ) ;
2005+ assert ! ( result
2006+ . unwrap_err( )
2007+ . to_string( )
2008+ . contains( "is_set flags contain nulls" ) ) ;
2009+
2010+ // Test LastValueAccumulator (with ordering)
2011+ let schema = Schema :: new ( vec ! [ Field :: new( "ordering" , DataType :: Int64 , false ) ] ) ;
2012+ let ordering_expr = col ( "ordering" , & schema) ?;
2013+ let mut ordered_accumulator = LastValueAccumulator :: try_new (
2014+ & DataType :: Utf8 ,
2015+ & [ DataType :: Int64 ] ,
2016+ LexOrdering :: new ( vec ! [ PhysicalSortExpr {
2017+ expr: ordering_expr,
2018+ options: SortOptions :: default ( ) ,
2019+ } ] )
2020+ . unwrap ( ) ,
2021+ false ,
2022+ false ,
2023+ ) ?;
2024+ let ordering = Arc :: new ( Int64Array :: from ( vec ! [ Some ( 1 ) ] ) ) as ArrayRef ;
2025+ let ordered_states = vec ! [ value, ordering, corrupted_flag] ;
2026+ let result = ordered_accumulator. merge_batch ( & ordered_states) ;
2027+ assert ! ( result. is_err( ) ) ;
2028+ assert ! ( result
2029+ . unwrap_err( )
2030+ . to_string( )
2031+ . contains( "is_set flags contain nulls" ) ) ;
2032+
2033+ Ok ( ( ) )
2034+ }
19312035}
0 commit comments