@@ -29,7 +29,10 @@ use arrow::{
2929 } ,
3030 record_batch:: RecordBatch ,
3131} ;
32- use arrow_array:: { Array , Float32Array , Float64Array , UnionArray } ;
32+ use arrow_array:: {
33+ Array , BooleanArray , DictionaryArray , Float32Array , Float64Array , Int8Array ,
34+ UnionArray ,
35+ } ;
3336use arrow_buffer:: ScalarBuffer ;
3437use arrow_schema:: { ArrowError , UnionFields , UnionMode } ;
3538use datafusion_functions_aggregate:: count:: count_udaf;
@@ -2363,3 +2366,105 @@ async fn dense_union_is_null() {
23632366 ] ;
23642367 assert_batches_sorted_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
23652368}
2369+
2370+ #[ tokio:: test]
2371+ async fn boolean_dictionary_as_filter ( ) {
2372+ let values = vec ! [ Some ( true ) , Some ( false ) , None , Some ( true ) ] ;
2373+ let keys = vec ! [ 0 , 0 , 1 , 2 , 1 , 3 , 1 ] ;
2374+ let values_array = BooleanArray :: from ( values) ;
2375+ let keys_array = Int8Array :: from ( keys) ;
2376+ let array =
2377+ DictionaryArray :: new ( keys_array, Arc :: new ( values_array) as Arc < dyn Array > ) ;
2378+ let array = Arc :: new ( array) ;
2379+
2380+ let field = Field :: new (
2381+ "my_dict" ,
2382+ DataType :: Dictionary ( Box :: new ( DataType :: Int8 ) , Box :: new ( DataType :: Boolean ) ) ,
2383+ true ,
2384+ ) ;
2385+ let schema = Arc :: new ( Schema :: new ( vec ! [ field] ) ) ;
2386+
2387+ let batch = RecordBatch :: try_new ( schema, vec ! [ array. clone( ) ] ) . unwrap ( ) ;
2388+
2389+ let ctx = SessionContext :: new ( ) ;
2390+
2391+ ctx. register_batch ( "dict_batch" , batch) . unwrap ( ) ;
2392+
2393+ let df = ctx. table ( "dict_batch" ) . await . unwrap ( ) ;
2394+
2395+ // view_all
2396+ let expected = [
2397+ "+---------+" ,
2398+ "| my_dict |" ,
2399+ "+---------+" ,
2400+ "| true |" ,
2401+ "| true |" ,
2402+ "| false |" ,
2403+ "| |" ,
2404+ "| false |" ,
2405+ "| true |" ,
2406+ "| false |" ,
2407+ "+---------+" ,
2408+ ] ;
2409+ assert_batches_eq ! ( expected, & df. clone( ) . collect( ) . await . unwrap( ) ) ;
2410+
2411+ let result_df = df. clone ( ) . filter ( col ( "my_dict" ) ) . unwrap ( ) ;
2412+ let expected = [
2413+ "+---------+" ,
2414+ "| my_dict |" ,
2415+ "+---------+" ,
2416+ "| true |" ,
2417+ "| true |" ,
2418+ "| true |" ,
2419+ "+---------+" ,
2420+ ] ;
2421+ assert_batches_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2422+
2423+ // test nested dictionary
2424+ let keys = vec ! [ 0 , 2 ] ; // 0 -> true, 2 -> false
2425+ let keys_array = Int8Array :: from ( keys) ;
2426+ let nested_array = DictionaryArray :: new ( keys_array, array) ;
2427+
2428+ let field = Field :: new (
2429+ "my_nested_dict" ,
2430+ DataType :: Dictionary (
2431+ Box :: new ( DataType :: Int8 ) ,
2432+ Box :: new ( DataType :: Dictionary (
2433+ Box :: new ( DataType :: Int8 ) ,
2434+ Box :: new ( DataType :: Boolean ) ,
2435+ ) ) ,
2436+ ) ,
2437+ true ,
2438+ ) ;
2439+
2440+ let schema = Arc :: new ( Schema :: new ( vec ! [ field] ) ) ;
2441+
2442+ let batch = RecordBatch :: try_new ( schema, vec ! [ Arc :: new( nested_array) ] ) . unwrap ( ) ;
2443+
2444+ ctx. register_batch ( "nested_dict_batch" , batch) . unwrap ( ) ;
2445+
2446+ let df = ctx. table ( "nested_dict_batch" ) . await . unwrap ( ) ;
2447+
2448+ // view_all
2449+ let expected = [
2450+ "+----------------+" ,
2451+ "| my_nested_dict |" ,
2452+ "+----------------+" ,
2453+ "| true |" ,
2454+ "| false |" ,
2455+ "+----------------+" ,
2456+ ] ;
2457+
2458+ assert_batches_eq ! ( expected, & df. clone( ) . collect( ) . await . unwrap( ) ) ;
2459+
2460+ let result_df = df. clone ( ) . filter ( col ( "my_nested_dict" ) ) . unwrap ( ) ;
2461+ let expected = [
2462+ "+----------------+" ,
2463+ "| my_nested_dict |" ,
2464+ "+----------------+" ,
2465+ "| true |" ,
2466+ "+----------------+" ,
2467+ ] ;
2468+
2469+ assert_batches_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2470+ }
0 commit comments