Skip to content

Commit

Permalink
Extract_field udf
Browse files Browse the repository at this point in the history
  • Loading branch information
jleibs committed Aug 2, 2024
1 parent 3ba95cb commit af20d75
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 5 deletions.
15 changes: 15 additions & 0 deletions crates/store/re_datafusion/examples/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,22 @@ async fn main() -> anyhow::Result<()> {
let ctx = create_datafusion_context(store)?;

let df = ctx.sql("SELECT * FROM custom_table").await?;
df.show().await?;

let df = ctx
.sql("SELECT \"example.MyPoint\" FROM custom_table")
.await?;
df.show().await?;

let df = ctx
.sql("SELECT \"example.MyLabel\" FROM custom_table")
.await?;
df.show().await?;

let df = ctx
.sql("SELECT array_extract(\"example.MyPoint\", 'x') as X, array_extract(\"example.MyPoint\", 'y') as Y FROM custom_table")
.await?;
df.show().await?;

Ok(())
}
9 changes: 4 additions & 5 deletions crates/store/re_datafusion/src/chunk_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,13 @@ impl ExecutionPlan for CustomExec {
.store
.iter_chunks()
.map(|chunk| {
let components = re_log_types::example_components::MyPoints::all_components();

RecordBatch::try_new(
self.projected_schema.clone(),
components
self.projected_schema
.fields()
.iter()
.filter_map(|c| {
chunk.components().get(c).map(|c| {
.filter_map(|f| {
chunk.components().get(&f.name().clone().into()).map(|c| {
let data = c.to_data();
let converted = GenericListArray::<i32>::from(data);

Expand Down
117 changes: 117 additions & 0 deletions crates/store/re_datafusion/src/field_extraction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use datafusion::arrow::array::{Array, ListArray, StructArray};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{plan_err, DataFusionError, ExprSchema, Result};
use datafusion::logical_expr::ScalarUDFImpl;
use datafusion::logical_expr::{ColumnarValue, Expr, Signature, Volatility};
use datafusion::scalar::ScalarValue;
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
pub struct ExtractField {
signature: Signature,
}

impl ExtractField {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for ExtractField {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_extract"
}
fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
Err(DataFusionError::Internal(
"Should have dispatched to return_type_from_exprs".to_owned(),
))
}

fn return_type_from_exprs(&self, args: &[Expr], schema: &dyn ExprSchema) -> Result<DataType> {
let Some(Expr::Column(col)) = args.first() else {
return plan_err!("rr_extract first arg must be a Column containing a List of Structs");
};
let dt = schema.data_type(col)?;

let DataType::List(inner) = dt else {
return plan_err!("rr_extract first arg must be a Column containing a List of Structs");
};

let DataType::Struct(fields) = inner.data_type() else {
return plan_err!("rr_extract first arg must be a Column containing a List of Structs");
};

let Some(Expr::Literal(ScalarValue::Utf8(Some(field)))) = args.get(1) else {
return plan_err!(
"rr_extract second arg must be a string matching a field in the struct"
);
};

let Some(final_dt) = fields.find(field) else {
return plan_err!(
"rr_extract second arg must be a string matching a field in the struct"
);
};

Ok(DataType::List(Arc::new(Field::new(
"item",
final_dt.1.data_type().clone(),
true,
))))
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let ColumnarValue::Scalar(ScalarValue::Utf8(Some(field_name))) = &args[1] else {
return Err(DataFusionError::Internal(
"Expected second argument to be a string".to_owned(),
));
};

let args = ColumnarValue::values_to_arrays(args)?;
let arg = &args[0];

// Downcast to list array
let Some(list_array) = arg.as_any().downcast_ref::<ListArray>() else {
return Err(DataFusionError::Internal(
"Expected first argument to be a ListArray".to_owned(),
));
};

// Get the child values array
let child_values = list_array.values();

// Downcast to a struct array
let Some(struct_array) = child_values.as_any().downcast_ref::<StructArray>() else {
return Err(DataFusionError::Internal(
"Expected ListArray to contain StructArray".to_owned(),
));
};

// Get the values of the field with the correct name
let Some(field_values) = struct_array.column_by_name(field_name) else {
return Err(DataFusionError::Internal(format!(
"Expected StructArray to contain field named '{field_name}'",
)));
};

// Create a new list array with the same offsets but the child values
let new_array = ListArray::new(
Arc::new(Field::new("item", field_values.data_type().clone(), true)),
list_array.offsets().clone(),
field_values.clone(),
list_array.nulls().cloned(),
);

Ok(ColumnarValue::Array(Arc::new(new_array)))
}
}
7 changes: 7 additions & 0 deletions crates/store/re_datafusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
//!
mod chunk_table;
mod field_extraction;

use chunk_table::CustomDataSource;
use datafusion::error::Result;
use datafusion::logical_expr::ScalarUDF;
use datafusion::prelude::*;
use field_extraction::ExtractField;
use re_chunk_store::ChunkStore;

use std::sync::Arc;

pub fn create_datafusion_context(store: ChunkStore) -> Result<SessionContext> {
let extract_field = ScalarUDF::from(ExtractField::new());

let ctx = SessionContext::new();

ctx.register_udf(extract_field.clone());

let db = CustomDataSource::new(store);

ctx.register_table("custom_table", Arc::new(db))?;
Expand Down

0 comments on commit af20d75

Please sign in to comment.