From af20d75f73d8f9ca68af5da6459fd698fc6bfab1 Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Mon, 15 Jul 2024 11:06:35 -0400 Subject: [PATCH] Extract_field udf --- .../re_datafusion/examples/datafusion.rs | 15 +++ crates/store/re_datafusion/src/chunk_table.rs | 9 +- .../re_datafusion/src/field_extraction.rs | 117 ++++++++++++++++++ crates/store/re_datafusion/src/lib.rs | 7 ++ 4 files changed, 143 insertions(+), 5 deletions(-) create mode 100644 crates/store/re_datafusion/src/field_extraction.rs diff --git a/crates/store/re_datafusion/examples/datafusion.rs b/crates/store/re_datafusion/examples/datafusion.rs index 9f973244f5e9..117c12a60e87 100644 --- a/crates/store/re_datafusion/examples/datafusion.rs +++ b/crates/store/re_datafusion/examples/datafusion.rs @@ -47,7 +47,22 @@ async fn main() -> anyhow::Result<()> { let ctx = create_datafusion_context(store)?; let df = ctx.sql("SELECT * FROM custom_table").await?; + df.show().await?; + + let df = ctx + .sql("SELECT \"example.MyPoint\" FROM custom_table") + .await?; + df.show().await?; + let df = ctx + .sql("SELECT \"example.MyLabel\" FROM custom_table") + .await?; df.show().await?; + + let df = ctx + .sql("SELECT array_extract(\"example.MyPoint\", 'x') as X, array_extract(\"example.MyPoint\", 'y') as Y FROM custom_table") + .await?; + df.show().await?; + Ok(()) } diff --git a/crates/store/re_datafusion/src/chunk_table.rs b/crates/store/re_datafusion/src/chunk_table.rs index 6965dd0573fc..eb2fc5694df9 100644 --- a/crates/store/re_datafusion/src/chunk_table.rs +++ b/crates/store/re_datafusion/src/chunk_table.rs @@ -161,14 +161,13 @@ impl ExecutionPlan for CustomExec { .store .iter_chunks() .map(|chunk| { - let components = re_log_types::example_components::MyPoints::all_components(); - RecordBatch::try_new( self.projected_schema.clone(), - components + self.projected_schema + .fields() .iter() - .filter_map(|c| { - chunk.components().get(c).map(|c| { + .filter_map(|f| { + chunk.components().get(&f.name().clone().into()).map(|c| { let data = c.to_data(); let converted = GenericListArray::::from(data); diff --git a/crates/store/re_datafusion/src/field_extraction.rs b/crates/store/re_datafusion/src/field_extraction.rs new file mode 100644 index 000000000000..4e06cf0da7c5 --- /dev/null +++ b/crates/store/re_datafusion/src/field_extraction.rs @@ -0,0 +1,117 @@ +use datafusion::arrow::array::{Array, ListArray, StructArray}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::{plan_err, DataFusionError, ExprSchema, Result}; +use datafusion::logical_expr::ScalarUDFImpl; +use datafusion::logical_expr::{ColumnarValue, Expr, Signature, Volatility}; +use datafusion::scalar::ScalarValue; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub struct ExtractField { + signature: Signature, +} + +impl ExtractField { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ExtractField { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_extract" + } + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Err(DataFusionError::Internal( + "Should have dispatched to return_type_from_exprs".to_owned(), + )) + } + + fn return_type_from_exprs(&self, args: &[Expr], schema: &dyn ExprSchema) -> Result { + let Some(Expr::Column(col)) = args.first() else { + return plan_err!("rr_extract first arg must be a Column containing a List of Structs"); + }; + let dt = schema.data_type(col)?; + + let DataType::List(inner) = dt else { + return plan_err!("rr_extract first arg must be a Column containing a List of Structs"); + }; + + let DataType::Struct(fields) = inner.data_type() else { + return plan_err!("rr_extract first arg must be a Column containing a List of Structs"); + }; + + let Some(Expr::Literal(ScalarValue::Utf8(Some(field)))) = args.get(1) else { + return plan_err!( + "rr_extract second arg must be a string matching a field in the struct" + ); + }; + + let Some(final_dt) = fields.find(field) else { + return plan_err!( + "rr_extract second arg must be a string matching a field in the struct" + ); + }; + + Ok(DataType::List(Arc::new(Field::new( + "item", + final_dt.1.data_type().clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let ColumnarValue::Scalar(ScalarValue::Utf8(Some(field_name))) = &args[1] else { + return Err(DataFusionError::Internal( + "Expected second argument to be a string".to_owned(), + )); + }; + + let args = ColumnarValue::values_to_arrays(args)?; + let arg = &args[0]; + + // Downcast to list array + let Some(list_array) = arg.as_any().downcast_ref::() else { + return Err(DataFusionError::Internal( + "Expected first argument to be a ListArray".to_owned(), + )); + }; + + // Get the child values array + let child_values = list_array.values(); + + // Downcast to a struct array + let Some(struct_array) = child_values.as_any().downcast_ref::() else { + return Err(DataFusionError::Internal( + "Expected ListArray to contain StructArray".to_owned(), + )); + }; + + // Get the values of the field with the correct name + let Some(field_values) = struct_array.column_by_name(field_name) else { + return Err(DataFusionError::Internal(format!( + "Expected StructArray to contain field named '{field_name}'", + ))); + }; + + // Create a new list array with the same offsets but the child values + let new_array = ListArray::new( + Arc::new(Field::new("item", field_values.data_type().clone(), true)), + list_array.offsets().clone(), + field_values.clone(), + list_array.nulls().cloned(), + ); + + Ok(ColumnarValue::Array(Arc::new(new_array))) + } +} diff --git a/crates/store/re_datafusion/src/lib.rs b/crates/store/re_datafusion/src/lib.rs index 78b991d1a55f..b0d023465cff 100644 --- a/crates/store/re_datafusion/src/lib.rs +++ b/crates/store/re_datafusion/src/lib.rs @@ -5,17 +5,24 @@ //! mod chunk_table; +mod field_extraction; use chunk_table::CustomDataSource; use datafusion::error::Result; +use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::*; +use field_extraction::ExtractField; use re_chunk_store::ChunkStore; use std::sync::Arc; pub fn create_datafusion_context(store: ChunkStore) -> Result { + let extract_field = ScalarUDF::from(ExtractField::new()); + let ctx = SessionContext::new(); + ctx.register_udf(extract_field.clone()); + let db = CustomDataSource::new(store); ctx.register_table("custom_table", Arc::new(db))?;