diff --git a/r/sedonadb/R/000-wrappers.R b/r/sedonadb/R/000-wrappers.R index 6cd4654f..26c1f376 100644 --- a/r/sedonadb/R/000-wrappers.R +++ b/r/sedonadb/R/000-wrappers.R @@ -310,6 +310,18 @@ class(`InternalContext`) <- c( } } +`InternalDataFrame_transmute` <- function(self) { + function(`ctx`, `exprs_sexp`) { + `ctx` <- .savvy_extract_ptr(`ctx`, "sedonadb::InternalContext") + .savvy_wrap_InternalDataFrame(.Call( + savvy_InternalDataFrame_transmute__impl, + `self`, + `ctx`, + `exprs_sexp` + )) + } +} + `.savvy_wrap_InternalDataFrame` <- function(ptr) { e <- new.env(parent = emptyenv()) e$.ptr <- ptr @@ -327,6 +339,7 @@ class(`InternalContext`) <- c( e$`to_parquet` <- `InternalDataFrame_to_parquet`(ptr) e$`to_provider` <- `InternalDataFrame_to_provider`(ptr) e$`to_view` <- `InternalDataFrame_to_view`(ptr) + e$`transmute` <- `InternalDataFrame_transmute`(ptr) class(e) <- c( "sedonadb::InternalDataFrame", @@ -448,6 +461,18 @@ class(`SedonaDBExpr`) <- c("sedonadb::SedonaDBExpr__bundle", "savvy_sedonadb__se } } +`SedonaDBExprFactory_evaluate_scalar` <- function(self) { + function(`exprs_sexp`, `stream_in`, `stream_out`) { + .Call( + savvy_SedonaDBExprFactory_evaluate_scalar__impl, + `self`, + `exprs_sexp`, + `stream_in`, + `stream_out` + ) + } +} + `SedonaDBExprFactory_scalar_function` <- function(self) { function(`name`, `args`) { .savvy_wrap_SedonaDBExpr(.Call( @@ -465,6 +490,7 @@ class(`SedonaDBExpr`) <- c("sedonadb::SedonaDBExpr__bundle", "savvy_sedonadb__se e$`aggregate_function` <- `SedonaDBExprFactory_aggregate_function`(ptr) e$`binary` <- `SedonaDBExprFactory_binary`(ptr) e$`column` <- `SedonaDBExprFactory_column`(ptr) + e$`evaluate_scalar` <- `SedonaDBExprFactory_evaluate_scalar`(ptr) e$`scalar_function` <- `SedonaDBExprFactory_scalar_function`(ptr) class(e) <- c( diff --git a/r/sedonadb/R/dataframe.R b/r/sedonadb/R/dataframe.R index fefc3a3d..8a11c51a 100644 --- a/r/sedonadb/R/dataframe.R +++ b/r/sedonadb/R/dataframe.R @@ -193,6 +193,28 @@ sd_preview <- function(.data, n = NULL, ascii = NULL, width = NULL) { invisible(.data) } +sd_transmute <- function(.data, ...) { + .data <- as_sedonadb_dataframe(.data) + expr_quos <- rlang::enquos(...) + env <- parent.frame() + + expr_ctx <- sd_expr_ctx(infer_nanoarrow_schema(.data), env) + exprs <- lapply(expr_quos, rlang::quo_get_expr) + sd_exprs <- lapply(exprs, sd_eval_expr, expr_ctx = expr_ctx, env = env) + exprs_names <- names(exprs) + if (!is.null(exprs_names)) { + for (i in seq_along(sd_exprs)) { + name <- exprs_names[i] + if (!is.na(name) && name != "") { + sd_exprs[[i]] <- sd_expr_alias(sd_exprs[[i]], name, expr_ctx$factory) + } + } + } + + df <- .data$df$transmute(.data$ctx, sd_exprs) + new_sedonadb_dataframe(.data$ctx, df) +} + #' Write DataFrame to (Geo)Parquet files #' #' Write this DataFrame to one or more (Geo)Parquet files. For input that contains diff --git a/r/sedonadb/R/expression.R b/r/sedonadb/R/expression.R index cca754a2..902c151d 100644 --- a/r/sedonadb/R/expression.R +++ b/r/sedonadb/R/expression.R @@ -158,6 +158,28 @@ sd_eval_expr <- function(expr, expr_ctx = sd_expr_ctx(env = env), env = parent.f ) } +sd_eval <- function(stream, exprs, env = parent.frame()) { + stream <- nanoarrow::as_nanoarrow_array_stream( + stream, + geometry_schema = geoarrow::geoarrow_wkb() + ) + expr_ctx <- sd_expr_ctx(stream$get_schema(), env) + sd_exprs <- lapply(exprs, sd_eval_expr, expr_ctx = expr_ctx, env = env) + exprs_names <- names(exprs) + if (!is.null(exprs_names)) { + for (i in seq_along(sd_exprs)) { + name <- exprs_names[i] + if (!is.na(name) && name != "") { + sd_exprs[[i]] <- sd_expr_alias(sd_exprs[[i]], name, expr_ctx$factory) + } + } + } + + stream_out <- nanoarrow::nanoarrow_allocate_array_stream() + expr_ctx$factory$evaluate_scalar(sd_exprs, stream, stream_out) + stream_out +} + sd_eval_expr_inner <- function(expr, expr_ctx) { if (rlang::is_call(expr)) { # Extract `pkg::fun` or `fun` if this is a usual call (e.g., not diff --git a/r/sedonadb/src/init.c b/r/sedonadb/src/init.c index 0e9efae4..0da04e29 100644 --- a/r/sedonadb/src/init.c +++ b/r/sedonadb/src/init.c @@ -212,6 +212,13 @@ SEXP savvy_InternalDataFrame_to_view__impl(SEXP self__, SEXP c_arg__ctx, return handle_result(res); } +SEXP savvy_InternalDataFrame_transmute__impl(SEXP self__, SEXP c_arg__ctx, + SEXP c_arg__exprs_sexp) { + SEXP res = savvy_InternalDataFrame_transmute__ffi(self__, c_arg__ctx, + c_arg__exprs_sexp); + return handle_result(res); +} + SEXP savvy_SedonaDBExpr_alias__impl(SEXP self__, SEXP c_arg__name) { SEXP res = savvy_SedonaDBExpr_alias__ffi(self__, c_arg__name); return handle_result(res); @@ -261,6 +268,15 @@ SEXP savvy_SedonaDBExprFactory_column__impl(SEXP self__, SEXP c_arg__name, return handle_result(res); } +SEXP savvy_SedonaDBExprFactory_evaluate_scalar__impl(SEXP self__, + SEXP c_arg__exprs_sexp, + SEXP c_arg__stream_in, + SEXP c_arg__stream_out) { + SEXP res = savvy_SedonaDBExprFactory_evaluate_scalar__ffi( + self__, c_arg__exprs_sexp, c_arg__stream_in, c_arg__stream_out); + return handle_result(res); +} + SEXP savvy_SedonaDBExprFactory_literal__impl(SEXP c_arg__array_xptr, SEXP c_arg__schema_xptr) { SEXP res = savvy_SedonaDBExprFactory_literal__ffi(c_arg__array_xptr, @@ -330,6 +346,8 @@ static const R_CallMethodDef CallEntries[] = { (DL_FUNC)&savvy_InternalDataFrame_to_provider__impl, 1}, {"savvy_InternalDataFrame_to_view__impl", (DL_FUNC)&savvy_InternalDataFrame_to_view__impl, 4}, + {"savvy_InternalDataFrame_transmute__impl", + (DL_FUNC)&savvy_InternalDataFrame_transmute__impl, 3}, {"savvy_SedonaDBExpr_alias__impl", (DL_FUNC)&savvy_SedonaDBExpr_alias__impl, 2}, {"savvy_SedonaDBExpr_cast__impl", (DL_FUNC)&savvy_SedonaDBExpr_cast__impl, @@ -346,6 +364,8 @@ static const R_CallMethodDef CallEntries[] = { (DL_FUNC)&savvy_SedonaDBExprFactory_binary__impl, 4}, {"savvy_SedonaDBExprFactory_column__impl", (DL_FUNC)&savvy_SedonaDBExprFactory_column__impl, 3}, + {"savvy_SedonaDBExprFactory_evaluate_scalar__impl", + (DL_FUNC)&savvy_SedonaDBExprFactory_evaluate_scalar__impl, 4}, {"savvy_SedonaDBExprFactory_literal__impl", (DL_FUNC)&savvy_SedonaDBExprFactory_literal__impl, 2}, {"savvy_SedonaDBExprFactory_new__impl", diff --git a/r/sedonadb/src/rust/api.h b/r/sedonadb/src/rust/api.h index fac6258b..54559d13 100644 --- a/r/sedonadb/src/rust/api.h +++ b/r/sedonadb/src/rust/api.h @@ -60,6 +60,8 @@ SEXP savvy_InternalDataFrame_to_provider__ffi(SEXP self__); SEXP savvy_InternalDataFrame_to_view__ffi(SEXP self__, SEXP c_arg__ctx, SEXP c_arg__table_ref, SEXP c_arg__overwrite); +SEXP savvy_InternalDataFrame_transmute__ffi(SEXP self__, SEXP c_arg__ctx, + SEXP c_arg__exprs_sexp); // methods and associated functions for SedonaDBExpr SEXP savvy_SedonaDBExpr_alias__ffi(SEXP self__, SEXP c_arg__name); @@ -78,6 +80,10 @@ SEXP savvy_SedonaDBExprFactory_binary__ffi(SEXP self__, SEXP c_arg__op, SEXP c_arg__lhs, SEXP c_arg__rhs); SEXP savvy_SedonaDBExprFactory_column__ffi(SEXP self__, SEXP c_arg__name, SEXP c_arg__qualifier); +SEXP savvy_SedonaDBExprFactory_evaluate_scalar__ffi(SEXP self__, + SEXP c_arg__exprs_sexp, + SEXP c_arg__stream_in, + SEXP c_arg__stream_out); SEXP savvy_SedonaDBExprFactory_literal__ffi(SEXP c_arg__array_xptr, SEXP c_arg__schema_xptr); SEXP savvy_SedonaDBExprFactory_new__ffi(SEXP c_arg__ctx); diff --git a/r/sedonadb/src/rust/src/dataframe.rs b/r/sedonadb/src/rust/src/dataframe.rs index e34cee82..4df20c3f 100644 --- a/r/sedonadb/src/rust/src/dataframe.rs +++ b/r/sedonadb/src/rust/src/dataframe.rs @@ -33,6 +33,7 @@ use std::{iter::zip, ptr::swap_nonoverlapping, sync::Arc}; use tokio::runtime::Runtime; use crate::context::InternalContext; +use crate::expression::SedonaDBExprFactory; use crate::ffi::{import_schema, FFITableProviderR}; use crate::runtime::wait_for_future_captured_r; @@ -311,4 +312,17 @@ impl InternalDataFrame { let inner = self.inner.clone().select(exprs)?; Ok(new_data_frame(inner, self.runtime.clone())) } + + fn transmute( + &self, + ctx: &InternalContext, + exprs_sexp: savvy::Sexp, + ) -> savvy::Result { + let exprs = SedonaDBExprFactory::exprs(exprs_sexp)?; + + let plan = + SedonaDBExprFactory::select(self.inner.clone().into_unoptimized_plan(), exprs, vec![])?; + let inner = DataFrame::new(ctx.inner.ctx.state(), plan); + Ok(new_data_frame(inner, self.runtime.clone())) + } } diff --git a/r/sedonadb/src/rust/src/expression.rs b/r/sedonadb/src/rust/src/expression.rs index 0add4b53..23f086de 100644 --- a/r/sedonadb/src/rust/src/expression.rs +++ b/r/sedonadb/src/rust/src/expression.rs @@ -15,12 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::{iter::zip, ptr::swap_nonoverlapping, sync::Arc}; -use datafusion_common::{Column, ScalarValue}; +use arrow_array::{ + ffi_stream::FFI_ArrowArrayStream, RecordBatch, RecordBatchIterator, RecordBatchReader, +}; +use arrow_schema::{FieldRef, Schema}; +use datafusion::physical_plan::PhysicalExpr; +use datafusion_common::{ + tree_node::{Transformed, TreeNode}, + Column, DFSchema, Result, ScalarValue, +}; use datafusion_expr::{ - expr::{AggregateFunction, FieldMetadata, NullTreatment, ScalarFunction}, - BinaryExpr, Cast, Expr, Operator, + expr::{AggregateFunction, FieldMetadata, NullTreatment, ScalarFunction, WindowFunction}, + utils::{expr_as_column_expr, find_aggregate_exprs, find_column_exprs, find_window_exprs}, + BinaryExpr, Cast, ColumnarValue, Expr, LogicalPlan, LogicalPlanBuilder, + LogicalPlanBuilderOptions, Operator, WindowFunctionDefinition, }; use savvy::{savvy, savvy_err, EnvironmentSexp}; use sedona::context::SedonaContext; @@ -172,10 +182,74 @@ impl SedonaDBExprFactory { Err(savvy_err!("Aggregate UDF '{name}' not found")) } } + + fn evaluate_scalar( + &self, + exprs_sexp: savvy::Sexp, + stream_in: savvy::Sexp, + stream_out: savvy::Sexp, + ) -> savvy::Result { + let out_void = unsafe { savvy_ffi::R_ExternalPtrAddr(stream_out.0) }; + if out_void.is_null() { + return Err(savvy_err!("external pointer to null in evaluate()")); + } + + let exprs = Self::exprs(exprs_sexp)?; + let expr_names = exprs + .iter() + .map(|e| e.schema_name().to_string()) + .collect::>(); + let reader_in = crate::ffi::import_array_stream(stream_in)?; + + let physical_exprs = exprs + .into_iter() + .map(|e| { + self.ctx.ctx.create_physical_expr( + e, + &DFSchema::try_from(reader_in.schema().as_ref().clone())?, + ) + }) + .collect::>>>()?; + + let out_fields = physical_exprs + .iter() + .map(|e| e.return_field(&reader_in.schema())) + .collect::>>()?; + let out_fields_named = zip(out_fields, expr_names) + .map(|(f, name)| f.as_ref().clone().with_name(name)) + .collect::>(); + let out_schema = Arc::new(Schema::new(out_fields_named)); + + let mut out_batches = Vec::new(); + let mut size = 0; + for batch in reader_in { + let batch = batch?; + size += batch.num_rows(); + let columns = physical_exprs + .iter() + .map(|e| e.evaluate(&batch)) + .collect::>>()?; + let out_batch = RecordBatch::try_new( + out_schema.clone(), + ColumnarValue::values_to_arrays(&columns)?, + )?; + out_batches.push(out_batch); + } + + let reader = Box::new(RecordBatchIterator::new( + out_batches.into_iter().map(Ok), + out_schema, + )); + let mut ffi_stream = FFI_ArrowArrayStream::new(reader); + let ffi_out = out_void as *mut FFI_ArrowArrayStream; + unsafe { swap_nonoverlapping(&mut ffi_stream, ffi_out, 1) }; + + savvy::Sexp::try_from(size as f64) + } } impl SedonaDBExprFactory { - fn exprs(exprs_sexp: savvy::Sexp) -> savvy::Result> { + pub fn exprs(exprs_sexp: savvy::Sexp) -> savvy::Result> { savvy::ListSexp::try_from(exprs_sexp)? .iter() .map(|(_, item)| -> savvy::Result { @@ -185,6 +259,223 @@ impl SedonaDBExprFactory { }) .collect() } + + pub fn select( + base_plan: LogicalPlan, + exprs: Vec, + group_by_exprs: Vec, + ) -> datafusion_common::Result { + // Translated from DataFusion's SQL SELECT -> LogicalPlan constructor + // https://github.com/apache/datafusion/blob/102caeb2261c5ae006c201546cf74769d80ceff8/datafusion/sql/src/select.rs#L890-L1098 + + // First, find aggregates in SELECT + let aggr_exprs = find_aggregate_exprs(&exprs); + + // Determine if we should use aggregation or window functions + // If we have an explicit GROUP BY or can infer one, use aggregation + // Otherwise, treat aggregates as window functions + let use_aggregation = if !group_by_exprs.is_empty() { + true + } else if !aggr_exprs.is_empty() { + // Try to infer GROUP BY from columns outside aggregates + let all_columns = find_column_exprs(&exprs); + let agg_columns = find_column_exprs(&aggr_exprs); + let non_agg_columns: Vec<_> = all_columns + .into_iter() + .filter(|col| !agg_columns.contains(col)) + .collect(); + !non_agg_columns.is_empty() + } else { + false + }; + + // Process aggregation if appropriate + let (plan, select_exprs_post_aggr) = if use_aggregation && !aggr_exprs.is_empty() { + // We have aggregates with a valid GROUP BY, create aggregate plan + let result = Self::aggregate(&base_plan, &exprs, group_by_exprs, &aggr_exprs)?; + (result.plan, result.select_exprs) + } else if !aggr_exprs.is_empty() { + // We have aggregates but no valid GROUP BY - convert to window functions + // First resolve column references to be fully qualified + let exprs_resolved: Vec = exprs + .iter() + .map(|expr| Self::resolve_columns(expr, &base_plan)) + .collect::>>()?; + + let exprs_with_windows = Self::aggregates_to_window_functions(&exprs_resolved)?; + (base_plan, exprs_with_windows) + } else { + // No aggregation + (base_plan, exprs.clone()) + }; + + // All of the window expressions (includes aggregates converted to windows) + let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); + + // Process window functions after aggregation + let plan = if window_func_exprs.is_empty() { + plan + } else { + // Resolve columns in window expressions to be fully qualified + let window_func_exprs: Vec = window_func_exprs + .iter() + .map(|expr| Self::resolve_columns(expr, &plan)) + .collect::>>()?; + + let plan = LogicalPlanBuilder::window_plan(plan, window_func_exprs.clone())?; + + // Re-write the projection + let select_exprs_post_aggr = select_exprs_post_aggr + .iter() + .map(|expr| Self::rebase_expr(expr, &window_func_exprs, &plan)) + .collect::>>()?; + + // Final projection + LogicalPlanBuilder::from(plan) + .project(select_exprs_post_aggr)? + .build()? + }; + + // Final projection if no windows + if window_func_exprs.is_empty() { + LogicalPlanBuilder::from(plan) + .project(select_exprs_post_aggr)? + .build() + } else { + Ok(plan) + } + } + + /// Helper function to rebase expressions to reference columns from the plan. + /// Simplified version of datafusion-sql's rebase_expr (which is pub(crate)). + fn rebase_expr(expr: &Expr, base_exprs: &[Expr], plan: &LogicalPlan) -> Result { + let result = expr.clone().transform_down(|nested_expr| { + if base_exprs.contains(&nested_expr) { + Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?)) + } else { + Ok(Transformed::no(nested_expr)) + } + })?; + Ok(result.data) + } + + /// Helper function to resolve column references to fully qualified columns. + /// Simplified version of datafusion-sql's resolve_columns (which is pub(crate)). + fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { + let result = expr.clone().transform_up(|nested_expr| { + match nested_expr { + Expr::Column(col) => { + let (qualifier, field) = plan.schema().qualified_field_from_column(&col)?; + Ok(Transformed::yes(Expr::Column(Column::from(( + qualifier, field, + ))))) + } + _ => { + // keep recursing + Ok(Transformed::no(nested_expr)) + } + } + })?; + Ok(result.data) + } + + /// Convert aggregate functions to window functions with empty OVER clause + fn aggregates_to_window_functions(exprs: &[Expr]) -> Result> { + exprs + .iter() + .map(|expr| { + expr.clone() + .transform_up(|nested_expr| { + match nested_expr { + Expr::AggregateFunction(agg) => { + // Convert to window function with empty OVER () + let window_func = + Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(agg.func.clone()), + agg.params.args, + ))); + Ok(Transformed::yes(window_func)) + } + _ => Ok(Transformed::no(nested_expr)), + } + }) + .map(|t| t.data) + }) + .collect() + } + + /// Create an aggregate plan from the given input, group by, and aggregate expressions. + /// Based on DataFusion's aggregate() method. + /// https://github.com/apache/datafusion/blob/102caeb2261c5ae006c201546cf74769d80ceff8/datafusion/sql/src/select.rs#L652-L764 + fn aggregate( + input: &LogicalPlan, + select_exprs: &[Expr], + group_by_exprs: Vec, + aggr_exprs: &[Expr], + ) -> Result { + // If group_by_exprs is empty, we need to extract column references from + // select_exprs that are NOT inside aggregate functions + let group_by_exprs = if group_by_exprs.is_empty() { + // Find all columns referenced in select expressions + let all_columns = find_column_exprs(select_exprs); + + // Find columns that are inside aggregate expressions + let agg_columns = find_column_exprs(aggr_exprs); + + // Keep only columns that are NOT inside aggregates + all_columns + .into_iter() + .filter(|col| !agg_columns.contains(col)) + .collect::>() + } else { + group_by_exprs + }; + + // Create the aggregate plan + let options = LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); + let plan = LogicalPlanBuilder::from(input.clone()) + .with_options(options) + .aggregate(group_by_exprs, aggr_exprs.to_vec())? + .build()?; + + // Get the group_by_exprs and aggr_exprs from the constructed plan + // (they may have been modified by implicit group by logic) + let (group_by_exprs, aggr_exprs_from_plan) = if let LogicalPlan::Aggregate(agg) = &plan { + (&agg.group_expr, &agg.aggr_expr) + } else { + unreachable!(); + }; + + // Combine the original grouping and aggregate expressions into one list + let mut aggr_projection_exprs = vec![]; + for expr in group_by_exprs { + aggr_projection_exprs.push(expr.clone()); + } + aggr_projection_exprs.extend_from_slice(aggr_exprs_from_plan); + + // Now attempt to resolve columns and replace with fully-qualified columns + let aggr_projection_exprs = aggr_projection_exprs + .iter() + .map(|expr| Self::resolve_columns(expr, input)) + .collect::>>()?; + + // Resolve columns in select expressions too, so qualifiers match when rebasing + let select_exprs_resolved = select_exprs + .iter() + .map(|expr| Self::resolve_columns(expr, input)) + .collect::>>()?; + + // Re-write the projection + let select_exprs_post_aggr = select_exprs_resolved + .iter() + .map(|expr| Self::rebase_expr(expr, &aggr_projection_exprs, input)) + .collect::>>()?; + + Ok(AggregatePlanResult { + plan, + select_exprs: select_exprs_post_aggr, + }) + } } impl TryFrom for &SedonaDBExpr { @@ -197,3 +488,13 @@ impl TryFrom for &SedonaDBExpr { .ok_or(savvy_err!("Invalid SedonaDBExpr object.")) } } + +/// Result of the `aggregate` function, containing the aggregate plan and +/// rewritten expressions that reference the aggregate output columns. +/// https://github.com/apache/datafusion/blob/102caeb2261c5ae006c201546cf74769d80ceff8/datafusion/sql/src/select.rs#L55-L68 +struct AggregatePlanResult { + /// The aggregate logical plan + plan: LogicalPlan, + /// SELECT expressions rewritten to reference aggregate output columns + select_exprs: Vec, +}