diff --git a/vortex-duckdb/build.rs b/vortex-duckdb/build.rs index 18cb31ff964..2ebf0ded516 100644 --- a/vortex-duckdb/build.rs +++ b/vortex-duckdb/build.rs @@ -243,6 +243,7 @@ fn main() { .file("cpp/expr.cpp") .file("cpp/logical_type.cpp") .file("cpp/object_cache.cpp") + .file("cpp/logical_expr.cpp") .file("cpp/scalar_function.cpp") .file("cpp/table_filter.cpp") .file("cpp/table_function.cpp") diff --git a/vortex-duckdb/cpp/CMakeLists.txt b/vortex-duckdb/cpp/CMakeLists.txt index 5ec2cd9b3f3..3abc76964ae 100644 --- a/vortex-duckdb/cpp/CMakeLists.txt +++ b/vortex-duckdb/cpp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(vortex error.cpp expr.cpp object_cache.cpp + logical_expr.cpp scalar_function.cpp table_filter.cpp table_function.cpp diff --git a/vortex-duckdb/cpp/expr.cpp b/vortex-duckdb/cpp/expr.cpp index ae953a52ccd..4a459c074f9 100644 --- a/vortex-duckdb/cpp/expr.cpp +++ b/vortex-duckdb/cpp/expr.cpp @@ -23,9 +23,102 @@ extern "C" const char *duckdb_vx_expr_to_string(duckdb_vx_expr ffi_expr) { return result; } +// Get detailed debug string representation of expression +extern "C" duckdb_vx_string duckdb_vx_expr_to_debug_string(duckdb_vx_expr ffi_expr) { + try { + if (!ffi_expr) { + return nullptr; + } + + auto expr = reinterpret_cast(ffi_expr); + + // Create detailed debug string with class, type, and content information + std::string debug_str = "Expression Debug Info:\n"; + debug_str += " Class: " + ExpressionClassToString(expr->GetExpressionClass()) + "\n"; + debug_str += " Type: " + ExpressionTypeToString(expr->GetExpressionType()) + "\n"; + debug_str += " Return Type: " + expr->return_type.ToString() + "\n"; + debug_str += " ToString(): " + expr->ToString() + "\n"; + + // Add specific information based on expression class + switch (expr->GetExpressionClass()) { + case ExpressionClass::BOUND_COLUMN_REF: { + auto &col_ref = expr->Cast(); + debug_str += " Column Binding: table=" + std::to_string(col_ref.binding.table_index) + + ", column=" + std::to_string(col_ref.binding.column_index) + "\n"; + debug_str += " Depth: " + std::to_string(col_ref.depth) + "\n"; + break; + } + case ExpressionClass::BOUND_FUNCTION: { + auto &func_expr = expr->Cast(); + debug_str += " Function: " + func_expr.function.name + "\n"; + debug_str += " Arguments: " + std::to_string(func_expr.children.size()) + "\n"; + for (size_t i = 0; i < func_expr.children.size(); i++) { + debug_str += " [" + std::to_string(i) + "] " + func_expr.children[i]->ToString() + "\n"; + } + break; + } + case ExpressionClass::BOUND_CONSTANT: { + auto &const_expr = expr->Cast(); + debug_str += " Value: " + const_expr.value.ToString() + "\n"; + break; + } + case ExpressionClass::BOUND_COMPARISON: { + auto &comp_expr = expr->Cast(); + debug_str += " Left: " + comp_expr.left->ToString() + "\n"; + debug_str += " Right: " + comp_expr.right->ToString() + "\n"; + break; + } + case ExpressionClass::BOUND_CONJUNCTION: { + auto &conj_expr = expr->Cast(); + debug_str += " Children: " + std::to_string(conj_expr.children.size()) + "\n"; + for (size_t i = 0; i < conj_expr.children.size(); i++) { + debug_str += " [" + std::to_string(i) + "] " + conj_expr.children[i]->ToString() + "\n"; + } + break; + } + case ExpressionClass::BOUND_OPERATOR: { + auto &op_expr = expr->Cast(); + debug_str += " Children: " + std::to_string(op_expr.children.size()) + "\n"; + for (size_t i = 0; i < op_expr.children.size(); i++) { + debug_str += " [" + std::to_string(i) + "] " + op_expr.children[i]->ToString() + "\n"; + } + break; + } + default: + debug_str += " (No additional debug info for this expression class)\n"; + break; + } + + // Create string wrapper + return new std::string(debug_str); + } catch (...) { + return nullptr; + } +} + +// Legacy alias for backwards compatibility with optimizer_rule.h +extern "C" duckdb_vx_string duckdb_vx_expression_to_string(duckdb_vx_expr ffi_expr) { + try { + if (!ffi_expr) { + return nullptr; + } + + auto expr = reinterpret_cast(ffi_expr); + std::string str = expr->ToString(); + + // Create string wrapper + return new std::string(str); + } catch (...) { + return nullptr; + } +} + //! Create a DuckDB vortex error. extern "C" void duckdb_vx_destroy_expr(duckdb_vx_expr *ffi_expr) { - auto expr = reinterpret_cast(ffi_expr); + if (ffi_expr == nullptr) { + return; + } + auto expr = reinterpret_cast(*ffi_expr); delete expr; memset(ffi_expr, 0, sizeof(duckdb_vx_expr)); } @@ -49,6 +142,14 @@ extern "C" const char *duckdb_vx_expr_get_bound_column_ref_get_name(duckdb_vx_ex return result; } +extern "C" uint64_t duckdb_vx_expr_get_bound_column_ref_depth(duckdb_vx_expr ffi_expr) { + if (!ffi_expr) { + return 0; + } + auto &expr = reinterpret_cast(ffi_expr)->Cast(); + return expr.depth; +} + extern "C" duckdb_value duckdb_vx_expr_bound_constant_get_value(duckdb_vx_expr ffi_expr) { if (!ffi_expr) { return nullptr; diff --git a/vortex-duckdb/cpp/include/duckdb_vx.h b/vortex-duckdb/cpp/include/duckdb_vx.h index cd6275c16c5..76d61b652dc 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx.h +++ b/vortex-duckdb/cpp/include/duckdb_vx.h @@ -11,6 +11,7 @@ #include "duckdb_vx/error.h" #include "duckdb_vx/expr.h" #include "duckdb_vx/logical_type.h" +#include "duckdb_vx/logical_operator.h" #include "duckdb_vx/object_cache.h" #include "duckdb_vx/scalar_function.h" #include "duckdb_vx/table_filter.h" diff --git a/vortex-duckdb/cpp/include/duckdb_vx/expr.h b/vortex-duckdb/cpp/include/duckdb_vx/expr.h index b1344341403..4d0cdc81866 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/expr.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/expr.h @@ -10,10 +10,17 @@ extern "C" { #endif typedef struct duckdb_vx_expr_ *duckdb_vx_expr; +typedef void* duckdb_vx_string; /// Return the string representation of the expression. Must be freed with `duckdb_vx_free`. const char *duckdb_vx_expr_to_string(duckdb_vx_expr expr); +/// Return a detailed debug string representation of the expression +duckdb_vx_string duckdb_vx_expr_to_debug_string(duckdb_vx_expr expr); + +/// Legacy alias for backwards compatibility with optimizer_rule.h +duckdb_vx_string duckdb_vx_expression_to_string(duckdb_vx_expr expr); + void duckdb_vx_destroy_expr(duckdb_vx_expr *expr); // See ExpressionClass in duckdb/include/duckdb/common/enums/expression_class.hpp @@ -211,6 +218,8 @@ duckdb_vx_expr_class duckdb_vx_expr_get_class(duckdb_vx_expr expr); const char *duckdb_vx_expr_get_bound_column_ref_get_name(duckdb_vx_expr expr); +uint64_t duckdb_vx_expr_get_bound_column_ref_depth(duckdb_vx_expr expr); + duckdb_value duckdb_vx_expr_bound_constant_get_value(duckdb_vx_expr expr); typedef struct { @@ -256,6 +265,22 @@ typedef struct { void duckdb_vx_expr_get_bound_function(duckdb_vx_expr expr, duckdb_vx_expr_bound_function *out); +// ============================================== +// String Wrapper Functions +// ============================================== + +// Create a string wrapper from std::string +duckdb_vx_string duckdb_vx_create_string(const char* str); + +// Get length of wrapped string +uint64_t duckdb_vx_string_length(duckdb_vx_string str); + +// Get C string data from wrapped string +const char* duckdb_vx_string_data(duckdb_vx_string str); + +// Free wrapped string +void duckdb_vx_string_free(duckdb_vx_string str); + #ifdef __cplusplus /* End C ABI */ } #endif diff --git a/vortex-duckdb/cpp/include/duckdb_vx/logical_operator.h b/vortex-duckdb/cpp/include/duckdb_vx/logical_operator.h new file mode 100644 index 00000000000..dd7dc07f680 --- /dev/null +++ b/vortex-duckdb/cpp/include/duckdb_vx/logical_operator.h @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#pragma once + +#include "duckdb.h" +#include "expr.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================== +// Forward Declarations +// ============================================== + +// Forward declarations for DuckDB types - opaque pointers +typedef void* duckdb_vx_logical_operator; +typedef void* duckdb_vx_string; + +// ============================================== +// Type Definitions +// ============================================== + +// Logical operator types enum (subset of DuckDB's LogicalOperatorType) +typedef enum { + DUCKDB_VX_LOGICAL_GET = 0, + DUCKDB_VX_LOGICAL_PROJECTION = 1, + DUCKDB_VX_LOGICAL_FILTER = 2, + DUCKDB_VX_LOGICAL_JOIN = 3, + DUCKDB_VX_LOGICAL_AGGREGATE = 4, + DUCKDB_VX_LOGICAL_UNKNOWN = 999 +} DUCKDB_VX_LOGICAL_OPERATOR_TYPE; + +// Expression types enum (subset of DuckDB's ExpressionType) for logical plan compatibility +typedef enum { + DUCKDB_VX_BOUND_COLUMN_REF = 0, + DUCKDB_VX_BOUND_FUNCTION = 1, + DUCKDB_VX_CONSTANT = 2, + DUCKDB_VX_EXPRESSION_UNKNOWN = 999 +} DUCKDB_VX_EXPRESSION_TYPE; + +// Column binding structure for logical plans +typedef struct { + uint64_t table_index; + uint64_t column_index; +} duckdb_vx_column_binding; + +// Rust callback function type for visiting operators +typedef void (*duckdb_vx_rust_visitor_callback)(duckdb_vx_logical_operator op, void* user_data); + +// ============================================== +// Basic Logical Operator Inspection +// ============================================== + +// Get operator type +DUCKDB_VX_LOGICAL_OPERATOR_TYPE duckdb_vx_get_operator_type(duckdb_vx_logical_operator op); + +// Get string representation of operator +duckdb_vx_string duckdb_vx_logical_operator_to_string(duckdb_vx_logical_operator op); + +// Get operator children count +uint64_t duckdb_vx_get_children_count(duckdb_vx_logical_operator op); + +// Get operator child by index +duckdb_vx_logical_operator duckdb_vx_get_child(duckdb_vx_logical_operator op, uint64_t index); + +// Get operator expressions count +uint64_t duckdb_vx_get_expressions_count(duckdb_vx_logical_operator op); + +// Get operator expression by index +duckdb_vx_expr duckdb_vx_get_expression(duckdb_vx_logical_operator op, uint64_t index); + +// Set operator expression by index +void duckdb_vx_set_expression(duckdb_vx_logical_operator op, uint64_t index, duckdb_vx_expr expr); + +// ============================================== +// LogicalGet (Table Scan) Functions +// ============================================== + +// Get table function name from LogicalGet +char* duckdb_vx_get_function_name(duckdb_vx_logical_operator get_op); + +// Get column names count from LogicalGet +uint64_t duckdb_vx_get_column_names_count(duckdb_vx_logical_operator get_op); + +// Get individual column name by index from LogicalGet +char* duckdb_vx_get_column_name(duckdb_vx_logical_operator get_op, uint64_t index); + +// Get projection IDs count from LogicalGet +uint64_t duckdb_vx_get_projection_ids_count(duckdb_vx_logical_operator get_op); + +// Get individual projection ID by index from LogicalGet +uint64_t duckdb_vx_get_projection_id(duckdb_vx_logical_operator get_op, uint64_t index); + +// Update projection IDs in LogicalGet +void duckdb_vx_update_projection_ids(duckdb_vx_logical_operator get_op, + uint64_t* new_projection_ids, + uint64_t count); + +// Add column ID to LogicalGet +void duckdb_vx_add_column_id(duckdb_vx_logical_operator get_op, uint64_t column_id); + +// Clear column IDs in LogicalGet +void duckdb_vx_clear_column_ids(duckdb_vx_logical_operator get_op); + +// Get detailed string representation of LogicalGet operator +duckdb_vx_string duckdb_vx_logical_get_to_string(duckdb_vx_logical_operator get_op); + +// ============================================== +// LogicalProjection Functions +// ============================================== + +// Get detailed string representation of LogicalProjection operator +duckdb_vx_string duckdb_vx_logical_projection_to_string(duckdb_vx_logical_operator proj_op); + +// ============================================== +// Expression Functions +// ============================================== + +// Logical plan expression functions - using simplified type enum +DUCKDB_VX_EXPRESSION_TYPE duckdb_vx_get_expression_type(duckdb_vx_expr expr); +char* duckdb_vx_get_function_name_from_expr(duckdb_vx_expr expr); +uint64_t duckdb_vx_get_function_arg_count(duckdb_vx_expr expr); +duckdb_vx_expr duckdb_vx_get_function_arg(duckdb_vx_expr expr, uint64_t index); +char* duckdb_vx_get_column_alias(duckdb_vx_expr expr); +duckdb_vx_column_binding duckdb_vx_get_column_binding(duckdb_vx_expr expr); +duckdb_vx_expr duckdb_vx_create_column_ref(const char* name, duckdb_vx_column_binding binding, uint64_t depth); +void duckdb_vx_update_column_binding(duckdb_vx_expr expr, duckdb_vx_column_binding binding); + +// ============================================== +// Visitor Pattern +// ============================================== + +// Visit all operators in plan tree with Rust callback +void duckdb_vx_visit_operators(duckdb_vx_logical_operator plan, + duckdb_vx_rust_visitor_callback callback, + void* user_data); + +// ============================================== +// Optimizer Registration +// ============================================== + +// Register a Rust-based optimizer function +void duckdb_vx_register_rust_optimizer(duckdb_database db_handle, + duckdb_vx_rust_visitor_callback optimizer_func, + void* user_data); + +// ============================================== +// String Wrapper Functions +// ============================================== + +// Create a string wrapper from std::string +duckdb_vx_string duckdb_vx_create_string(const char* str); + +// Get length of wrapped string +uint64_t duckdb_vx_string_length(duckdb_vx_string str); + +// Get C string data from wrapped string +const char* duckdb_vx_string_data(duckdb_vx_string str); + +// Free wrapped string +void duckdb_vx_string_free(duckdb_vx_string str); + +// ============================================== +// Memory Management +// ============================================== + +// Memory management functions +void duckdb_vx_free_string(char* str); + +// String utility functions for C strings +uint64_t duckdb_vx_c_string_length(const char* str); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/vortex-duckdb/cpp/include/duckdb_vx/vector.h b/vortex-duckdb/cpp/include/duckdb_vx/vector.h index 4290ef4bb40..8e0efc5776b 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/vector.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/vector.h @@ -21,6 +21,10 @@ void duckdb_vx_set_dictionary_vector_length(duckdb_vector dict, unsigned int len // Add the buffer to the string vector (basically, keep it alive as long as the vector). void duckdb_vx_string_vector_add_buffer(duckdb_vector ffi_vector, duckdb_vx_data buffer); +// Get string data and size for computing virtual columns +uint32_t duckdb_vx_string_vector_get_string_size(duckdb_vector ffi_vector, idx_t index); +const char *duckdb_vx_string_vector_get_string_data(duckdb_vector ffi_vector, idx_t index); + // Converts a duckdb flat vector into a Sequence vector. void duckdb_vx_sequence_vector(duckdb_vector c_vector, int64_t start, int64_t step, idx_t capacity); diff --git a/vortex-duckdb/cpp/logical_expr.cpp b/vortex-duckdb/cpp/logical_expr.cpp new file mode 100644 index 00000000000..d5e2bb66ea0 --- /dev/null +++ b/vortex-duckdb/cpp/logical_expr.cpp @@ -0,0 +1,527 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#include "duckdb.hpp" +#include "duckdb/optimizer/optimizer_extension.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include +#include +#include +#include +#include + +#include "duckdb_vx/logical_operator.h" + +using namespace duckdb; + +namespace vortex { + +// Global variables to store Rust optimizer callback +static duckdb_vx_rust_visitor_callback g_rust_optimizer_callback = nullptr; +static void *g_rust_optimizer_user_data = nullptr; + +// C++ wrapper for Rust optimizer callback - this is the actual optimizer function +static void VortexLengthOptimizeFunction(OptimizerExtensionInput &input, duckdb::unique_ptr &plan) { + if (g_rust_optimizer_callback && plan) { + g_rust_optimizer_callback(plan.get(), g_rust_optimizer_user_data); + } +} + +class VortexLengthExtension : public OptimizerExtension { +public: + VortexLengthExtension() { + optimize_function = VortexLengthOptimizeFunction; + } + + static void Register(DatabaseInstance &db) { + try { + auto &config = DBConfig::GetConfig(db); + + // Create the extension and ensure function pointer is set + OptimizerExtension optimizer; + optimizer.optimize_function = VortexLengthOptimizeFunction; + + config.optimizer_extensions.push_back(std::move(optimizer)); + } catch (std::exception &e) { + throw e; + } + } +}; + +} // namespace vortex + +// ============================================== +// C API Implementation for Rust FFI +// ============================================== + +// Basic operator inspection functions +extern "C" DUCKDB_VX_LOGICAL_OPERATOR_TYPE duckdb_vx_get_operator_type(duckdb_vx_logical_operator op) { + if (!op) + return DUCKDB_VX_LOGICAL_UNKNOWN; + + auto &logical_op = *reinterpret_cast(op); + switch (logical_op.type) { + case LogicalOperatorType::LOGICAL_GET: + return DUCKDB_VX_LOGICAL_GET; + case LogicalOperatorType::LOGICAL_PROJECTION: + return DUCKDB_VX_LOGICAL_PROJECTION; + case LogicalOperatorType::LOGICAL_FILTER: + return DUCKDB_VX_LOGICAL_FILTER; + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + return DUCKDB_VX_LOGICAL_JOIN; + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + return DUCKDB_VX_LOGICAL_AGGREGATE; + default: + return DUCKDB_VX_LOGICAL_UNKNOWN; + } +} + +extern "C" uint64_t duckdb_vx_get_children_count(duckdb_vx_logical_operator op) { + if (!op) + return 0; + auto &logical_op = *reinterpret_cast(op); + return logical_op.children.size(); +} + +extern "C" duckdb_vx_logical_operator duckdb_vx_get_child(duckdb_vx_logical_operator op, uint64_t index) { + if (!op) + return nullptr; + auto &logical_op = *reinterpret_cast(op); + if (index >= logical_op.children.size()) + return nullptr; + return logical_op.children[index].get(); +} + +extern "C" uint64_t duckdb_vx_get_expressions_count(duckdb_vx_logical_operator op) { + if (!op) + return 0; + auto &logical_op = *reinterpret_cast(op); + return logical_op.expressions.size(); +} + +extern "C" duckdb_vx_expr duckdb_vx_get_expression(duckdb_vx_logical_operator op, uint64_t index) { + if (!op) + return nullptr; + auto &logical_op = *reinterpret_cast(op); + if (index >= logical_op.expressions.size()) + return nullptr; + return reinterpret_cast(logical_op.expressions[index].get()); +} + +extern "C" void duckdb_vx_set_expression(duckdb_vx_logical_operator op, uint64_t index, duckdb_vx_expr expr) { + if (!op || !expr) + return; + auto &logical_op = *reinterpret_cast(op); + if (index >= logical_op.expressions.size()) + return; + + // Transfer ownership of the expression + logical_op.expressions[index].reset(reinterpret_cast(expr)); +} + +// LogicalGet specific functions +extern "C" char *duckdb_vx_get_function_name(duckdb_vx_logical_operator get_op) { + if (!get_op) + return nullptr; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return nullptr; + + auto &get = logical_op.Cast(); + return strdup(get.function.name.c_str()); +} + +extern "C" uint64_t duckdb_vx_get_column_names_count(duckdb_vx_logical_operator get_op) { + if (!get_op) + return 0; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return 0; + + auto &get = logical_op.Cast(); + return get.names.size(); +} + +extern "C" char *duckdb_vx_get_column_name(duckdb_vx_logical_operator get_op, uint64_t index) { + if (!get_op) + return nullptr; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return nullptr; + + auto &get = logical_op.Cast(); + if (index >= get.names.size()) + return nullptr; + + return strdup(get.names[index].c_str()); +} + +extern "C" uint64_t duckdb_vx_get_projection_ids_count(duckdb_vx_logical_operator get_op) { + if (!get_op) + return 0; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return 0; + + auto &get = logical_op.Cast(); + return get.projection_ids.size(); +} + +extern "C" uint64_t duckdb_vx_get_projection_id(duckdb_vx_logical_operator get_op, uint64_t index) { + if (!get_op) + return 0; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return 0; + + auto &get = logical_op.Cast(); + if (index >= get.projection_ids.size()) + return 0; + + return get.projection_ids[index]; +} + +extern "C" void duckdb_vx_update_projection_ids(duckdb_vx_logical_operator get_op, + uint64_t *new_projection_ids, uint64_t count) { + if (!get_op || !new_projection_ids) + return; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return; + + auto &get = logical_op.Cast(); + get.projection_ids.clear(); + for (uint64_t i = 0; i < count; i++) { + get.projection_ids.push_back(new_projection_ids[i]); + } +} + +extern "C" void duckdb_vx_add_column_id(duckdb_vx_logical_operator get_op, uint64_t column_id) { + if (!get_op) + return; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return; + + auto &get = logical_op.Cast(); + get.AddColumnId(column_id); +} + +extern "C" void duckdb_vx_clear_column_ids(duckdb_vx_logical_operator get_op) { + if (!get_op) + return; + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) + return; + + auto &get = logical_op.Cast(); + get.ClearColumnIds(); +} + +// Get detailed string representation of LogicalGet operator +extern "C" duckdb_vx_string duckdb_vx_logical_get_to_string(duckdb_vx_logical_operator get_op) { + try { + if (!get_op) { + return nullptr; + } + + auto &logical_op = *reinterpret_cast(get_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_GET) { + return nullptr; + } + + auto &get = logical_op.Cast(); + + // Create detailed string representation + std::string str = "LogicalGet:\n"; + str += " Function: " + get.function.name + "\n"; + str += " Table Index: " + std::to_string(get.table_index) + "\n"; + str += " Columns Idx: ["; + + auto &column_ids = get.GetColumnIds(); + for (size_t i = 0; i < column_ids.size(); i++) { + if (i > 0) + str += ", "; + str += std::to_string(column_ids[i].GetPrimaryIndex()); + } + str += "]\n"; + + str += " Columns Names: ["; + + if (!get.names.empty()) { + for (size_t i = 0; i < get.names.size(); i++) { + if (i > 0) + str += ", "; + str += get.names[i]; + } + } + str += "]\n"; + + str += " Projection IDs: ["; + if (!get.projection_ids.empty()) { + for (size_t i = 0; i < get.projection_ids.size(); i++) { + if (i > 0) + str += ", "; + str += std::to_string(get.projection_ids[i]); + } + } + str += "]"; + + // Create string wrapper + return new std::string(str); + } catch (...) { + return nullptr; + } +} + +// Get detailed string representation of LogicalProjection operator +extern "C" duckdb_vx_string duckdb_vx_logical_projection_to_string(duckdb_vx_logical_operator proj_op) { + try { + if (!proj_op) { + return nullptr; + } + + auto &logical_op = *reinterpret_cast(proj_op); + if (logical_op.type != LogicalOperatorType::LOGICAL_PROJECTION) { + return nullptr; + } + + // Create detailed string representation + std::string str = "LogicalProjection:\n"; + str += " Expressions: [\n"; + + for (size_t i = 0; i < logical_op.expressions.size(); i++) { + str += " [" + std::to_string(i) + "] " + logical_op.expressions[i]->ToString() + "\n"; + } + str += " ]"; + + // Create string wrapper + return new std::string(str); + } catch (...) { + return nullptr; + } +} + +// Expression functions +extern "C" DUCKDB_VX_EXPRESSION_TYPE duckdb_vx_get_expression_type(duckdb_vx_expr expr) { + if (!expr) + return DUCKDB_VX_EXPRESSION_UNKNOWN; + + auto &expression = *reinterpret_cast(expr); + switch (expression.type) { + case ExpressionType::BOUND_COLUMN_REF: + return DUCKDB_VX_BOUND_COLUMN_REF; + case ExpressionType::BOUND_FUNCTION: + return DUCKDB_VX_BOUND_FUNCTION; + case ExpressionType::VALUE_CONSTANT: + return DUCKDB_VX_CONSTANT; + default: + return DUCKDB_VX_EXPRESSION_UNKNOWN; + } +} + +extern "C" char *duckdb_vx_get_function_name_from_expr(duckdb_vx_expr expr) { + if (!expr) + return nullptr; + auto &expression = *reinterpret_cast(expr); + + if (expression.type == ExpressionType::BOUND_FUNCTION) { + auto &func_expr = expression.Cast(); + return strdup(func_expr.function.name.c_str()); + } + return nullptr; +} + +extern "C" uint64_t duckdb_vx_get_function_arg_count(duckdb_vx_expr expr) { + if (!expr) + return 0; + auto &expression = *reinterpret_cast(expr); + + if (expression.type == ExpressionType::BOUND_FUNCTION) { + auto &func_expr = expression.Cast(); + return func_expr.children.size(); + } + return 0; +} + +extern "C" duckdb_vx_expr duckdb_vx_get_function_arg(duckdb_vx_expr expr, uint64_t index) { + if (!expr) + return nullptr; + auto &expression = *reinterpret_cast(expr); + + if (expression.type == ExpressionType::BOUND_FUNCTION) { + auto &func_expr = expression.Cast(); + if (index >= func_expr.children.size()) + return nullptr; + return reinterpret_cast(func_expr.children[index].get()); + } + return nullptr; +} + +extern "C" char *duckdb_vx_get_column_alias(duckdb_vx_expr expr) { + if (!expr) + return nullptr; + auto &expression = *reinterpret_cast(expr); + + if (expression.type == ExpressionType::BOUND_COLUMN_REF) { + auto &col_ref = expression.Cast(); + return strdup(col_ref.alias.c_str()); + } + return nullptr; +} + +extern "C" duckdb_vx_column_binding duckdb_vx_get_column_binding(duckdb_vx_expr expr) { + duckdb_vx_column_binding binding = {0, 0}; + if (!expr) + return binding; + + auto &expression = *reinterpret_cast(expr); + if (expression.type == ExpressionType::BOUND_COLUMN_REF) { + auto &col_ref = expression.Cast(); + binding.table_index = col_ref.binding.table_index; + binding.column_index = col_ref.binding.column_index; + } + return binding; +} + +extern "C" duckdb_vx_expr duckdb_vx_create_column_ref(const char *name, duckdb_vx_column_binding binding, + uint64_t depth) { + if (!name) + return nullptr; + + Expression *col_ref = new BoundColumnRefExpression( + std::string(name), + LogicalType::INTEGER, ColumnBinding(binding.table_index, binding.column_index), + depth + ); + + return reinterpret_cast(col_ref); +} + +extern "C" void duckdb_vx_update_column_binding(duckdb_vx_expr expr, duckdb_vx_column_binding binding) { + if (!expr) + return; + auto &expression = *reinterpret_cast(expr); + + if (expression.type == ExpressionType::BOUND_COLUMN_REF) { + auto &col_ref = expression.Cast(); + col_ref.binding.table_index = binding.table_index; + col_ref.binding.column_index = binding.column_index; + } +} + +// Visitor pattern implementation +extern "C" void duckdb_vx_visit_operators(duckdb_vx_logical_operator plan, + duckdb_vx_rust_visitor_callback callback, void *user_data) { + if (!plan || !callback) + return; + + auto &logical_op = *reinterpret_cast(plan); + + // Call the Rust callback on this operator + callback(plan, user_data); + + // Recursively visit children + for (auto &child : logical_op.children) { + duckdb_vx_visit_operators(child.get(), callback, user_data); + } +} + +extern "C" void duckdb_vx_register_rust_optimizer(duckdb_database db_handle, + duckdb_vx_rust_visitor_callback optimizer_func, + void *user_data) { + std::cout << "🔧 REGISTERING: Rust-based optimizer..." << std::endl; + + if (!db_handle || !optimizer_func) { + std::cout << "❌ ERROR: NULL parameters passed to Rust optimizer registration" << std::endl; + return; + } + + try { + // Store the Rust callback and user data + vortex::g_rust_optimizer_callback = optimizer_func; + vortex::g_rust_optimizer_user_data = user_data; + + // Get the DuckDB instance + struct DatabaseWrapper { + void *internal_ptr; + }; + + auto wrapper = reinterpret_cast(db_handle); + auto db = reinterpret_cast(wrapper->internal_ptr); + + // Register the optimizer using VortexLengthExtension + vortex::VortexLengthExtension::Register(*db->instance); + + std::cout << "✅ SUCCESS: Rust-based optimizer registered!" << std::endl; + } catch (std::exception &e) { + std::cout << "❌ EXCEPTION during Rust optimizer registration: " << e.what() << std::endl; + } +} + +// Memory management functions +extern "C" void duckdb_vx_free_string(char *str) { + if (str) + free(str); +} + +// String utility functions for C strings +extern "C" uint64_t duckdb_vx_c_string_length(const char *str) { + if (!str) + return 0; + return strlen(str); +} + +// ============================================== +// String Wrapper Functions +// ============================================== + +extern "C" duckdb_vx_string duckdb_vx_create_string(const char* str) { + if (!str) return nullptr; + return new std::string(str); +} + +extern "C" uint64_t duckdb_vx_string_length(duckdb_vx_string str) { + if (!str) return 0; + auto* std_str = static_cast(str); + return std_str->length(); +} + +extern "C" const char* duckdb_vx_string_data(duckdb_vx_string str) { + if (!str) return nullptr; + auto* std_str = static_cast(str); + return std_str->c_str(); +} + +extern "C" void duckdb_vx_string_free(duckdb_vx_string str) { + if (str) { + auto* std_str = static_cast(str); + delete std_str; + } +} + + + + +// Get string representation of logical operator +extern "C" duckdb_vx_string duckdb_vx_logical_operator_to_string(duckdb_vx_logical_operator op) { + try { + if (!op) { + return nullptr; + } + + auto *logical_op = reinterpret_cast(op); + std::string str = logical_op->ToString(); + + // Create string wrapper + return new std::string(str); + } catch (...) { + return nullptr; + } +} \ No newline at end of file diff --git a/vortex-duckdb/cpp/vector.cpp b/vortex-duckdb/cpp/vector.cpp index c472981212a..66fc2b66d4c 100644 --- a/vortex-duckdb/cpp/vector.cpp +++ b/vortex-duckdb/cpp/vector.cpp @@ -55,6 +55,18 @@ extern "C" void duckdb_vx_string_vector_add_buffer(duckdb_vector ffi_vector, duc StringVector::AddBuffer(*vector, ext_buffer); } +extern "C" uint32_t duckdb_vx_string_vector_get_string_size(duckdb_vector ffi_vector, idx_t index) { + auto vector = reinterpret_cast(ffi_vector); + auto string_data = FlatVector::GetData(*vector); + return string_data[index].GetSize(); +} + +extern "C" const char *duckdb_vx_string_vector_get_string_data(duckdb_vector ffi_vector, idx_t index) { + auto vector = reinterpret_cast(ffi_vector); + auto string_data = FlatVector::GetData(*vector); + return string_data[index].GetData(); +} + void duckdb_vector_flatten(duckdb_vector vector, unsigned long len) { auto dvector = reinterpret_cast(vector); dvector->Flatten(len); diff --git a/vortex-duckdb/src/convert/expr.rs b/vortex-duckdb/src/convert/expr.rs index 57bf4c7f110..d0dfdb5f1a1 100644 --- a/vortex-duckdb/src/convert/expr.rs +++ b/vortex-duckdb/src/convert/expr.rs @@ -17,6 +17,9 @@ use crate::cpp::DUCKDB_VX_EXPR_TYPE; use crate::duckdb::{Expression, ExpressionClass}; const DUCKDB_FUNCTION_NAME_CONTAINS: &str = "contains"; +const DUCKDB_FUNCTION_NAME_LENGTH: &str = "length"; +const DUCKDB_FUNCTION_NAME_LEN: &str = "len"; +const DUCKDB_FUNCTION_NAME_STRLEN: &str = "strlen"; fn like_pattern_str(value: &Expression) -> VortexResult> { match value.as_class().vortex_expect("unknown class") { @@ -146,6 +149,14 @@ pub fn try_from_bound_expression(value: &Expression) -> VortexResult { + // For now, we'll return None to indicate we can't handle this + // We'll need to pass context to properly handle this + log::debug!("length function detected on column - optimization opportunity"); + return Ok(None); + } _ => { log::debug!("bound function {}", func.scalar_function.name()); return Ok(None); diff --git a/vortex-duckdb/src/duckdb/expr.rs b/vortex-duckdb/src/duckdb/expr.rs index 585ed5a9ac3..fe605f372f3 100644 --- a/vortex-duckdb/src/duckdb/expr.rs +++ b/vortex-duckdb/src/duckdb/expr.rs @@ -1,173 +1,196 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::ffi::{CStr, c_void}; -use std::fmt::{Display, Formatter}; +use std::ffi::{CString, c_void}; +use std::fmt::{Debug, Display, Formatter}; use std::ptr; -use crate::cpp::duckdb_vx_expr_class; -use crate::duckdb::{ScalarFunction, ValueRef}; +use vortex::error::{VortexResult, vortex_bail, vortex_err}; + +use crate::cpp::{duckdb_vx_expr_class, *}; +use crate::duckdb::string::{VxString, c_string_to_rust_string}; +use crate::duckdb::{ScalarFunction, Value, ValueRef}; use crate::{cpp, duckdb, wrapper}; // TODO(joe): replace with lifetime_wrapper! -wrapper!(Expression, cpp::duckdb_vx_expr, cpp::duckdb_vx_destroy_expr); - -impl Display for Expression { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let ptr = unsafe { cpp::duckdb_vx_expr_to_string(self.as_ptr()) }; - let cstr = unsafe { CStr::from_ptr(ptr) }; - let result = write!(f, "{}", cstr.to_string_lossy()); - unsafe { cpp::duckdb_free(ptr.cast_mut().cast()) }; - result - } -} +wrapper!(Expression, duckdb_vx_expr, duckdb_vx_destroy_expr); impl Expression { pub fn as_class_id(&self) -> duckdb_vx_expr_class { - unsafe { cpp::duckdb_vx_expr_get_class(self.as_ptr()) } + unsafe { duckdb_vx_expr_get_class(self.as_ptr()) } + } + + /// Get the expression depth if this is a BoundColumnRef expression. + /// Returns None for other expression types. + /// + /// Expression depth represents how many query levels deep a column reference is. + /// Depth 0 = current query level, depth 1 = parent query (correlated), etc. + pub fn get_expression_depth(&self) -> Option { + (self.as_class_id() == DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COLUMN_REF) + .then(|| unsafe { duckdb_vx_expr_get_bound_column_ref_depth(self.as_ptr()) }) } /// Match the subclass of the expression. pub fn as_class(&self) -> Option> { - Some( - match unsafe { cpp::duckdb_vx_expr_get_class(self.as_ptr()) } { - cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COLUMN_REF => { - let ptr = - unsafe { cpp::duckdb_vx_expr_get_bound_column_ref_get_name(self.as_ptr()) }; - - ExpressionClass::BoundColumnRef(BoundColumnRef { - name: duckdb::string::String::from_ptr(ptr), - }) - } - cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_CONSTANT => { - let value = unsafe { - ValueRef::borrow(cpp::duckdb_vx_expr_bound_constant_get_value( - self.as_ptr(), - )) - }; - ExpressionClass::BoundConstant(BoundConstant { value }) - } - cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_CONJUNCTION => { - let mut out = cpp::duckdb_vx_expr_bound_conjunction { - children: ptr::null_mut(), - children_count: 0, - type_: cpp::DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_INVALID, - }; - unsafe { - cpp::duckdb_vx_expr_get_bound_conjunction(self.as_ptr(), &raw mut out) - }; - - let children = - unsafe { std::slice::from_raw_parts(out.children, out.children_count) }; - - ExpressionClass::BoundConjunction(BoundConjunction { - children, - op: out.type_, - }) - } - cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COMPARISON => { - let mut out = cpp::duckdb_vx_expr_bound_comparison { - left: ptr::null_mut(), - right: ptr::null_mut(), - type_: cpp::DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_INVALID, - }; - unsafe { - cpp::duckdb_vx_expr_get_bound_comparison(self.as_ptr(), &raw mut out) - }; - - ExpressionClass::BoundComparison(BoundComparison { - left: unsafe { Expression::borrow(out.left) }, - right: unsafe { Expression::borrow(out.right) }, - op: out.type_, - }) - } - cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_BETWEEN => { - let mut out = cpp::duckdb_vx_expr_bound_between { - input: ptr::null_mut(), - lower: ptr::null_mut(), - upper: ptr::null_mut(), - lower_inclusive: false, - upper_inclusive: false, - }; - unsafe { - cpp::duckdb_vx_expr_get_bound_between(self.as_ptr(), &raw mut out); - } - - ExpressionClass::BoundBetween(BoundBetween { - input: unsafe { Expression::borrow(out.input) }, - lower: unsafe { Expression::borrow(out.lower) }, - upper: unsafe { Expression::borrow(out.upper) }, - lower_inclusive: out.lower_inclusive, - upper_inclusive: out.upper_inclusive, - }) - } - cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_OPERATOR => { - let mut out = cpp::duckdb_vx_expr_bound_operator { - children: ptr::null_mut(), - children_count: 0, - type_: cpp::DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_INVALID, - }; - unsafe { cpp::duckdb_vx_expr_get_bound_operator(self.as_ptr(), &raw mut out) }; - - let children = - unsafe { std::slice::from_raw_parts(out.children, out.children_count) }; - - ExpressionClass::BoundOperator(BoundOperator { - children, - op: out.type_, - }) - } - cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_FUNCTION => { - let mut out = cpp::duckdb_vx_expr_bound_function { - children: ptr::null_mut(), - children_count: 0, - scalar_function: ptr::null_mut(), - bind_info: ptr::null_mut(), - }; - unsafe { cpp::duckdb_vx_expr_get_bound_function(self.as_ptr(), &raw mut out) }; - - let children = - unsafe { std::slice::from_raw_parts(out.children, out.children_count) }; - - ExpressionClass::BoundFunction(BoundFunction { - children, - scalar_function: unsafe { ScalarFunction::borrow(out.scalar_function) }, - bind_info: out.bind_info, - }) - } - _ => { - return None; + Some(match unsafe { duckdb_vx_expr_get_class(self.as_ptr()) } { + DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COLUMN_REF => { + let ptr = unsafe { duckdb_vx_expr_get_bound_column_ref_get_name(self.as_ptr()) }; + let bind_ptr = unsafe { duckdb_vx_get_column_binding(self.as_ptr()) }; + + ExpressionClass::BoundColumnRef(BoundColumnRef { + expr: self, + name: duckdb::string::String::from_ptr(ptr), + column_binding: bind_ptr.into(), + }) + } + DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_CONSTANT => { + let value = unsafe { + ValueRef::borrow(duckdb_vx_expr_bound_constant_get_value(self.as_ptr())) + }; + ExpressionClass::BoundConstant(BoundConstant { expr: self, value }) + } + DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_CONJUNCTION => { + let mut out = duckdb_vx_expr_bound_conjunction { + children: ptr::null_mut(), + children_count: 0, + type_: DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_INVALID, + }; + unsafe { duckdb_vx_expr_get_bound_conjunction(self.as_ptr(), &raw mut out) }; + + let children = + unsafe { std::slice::from_raw_parts(out.children, out.children_count) }; + + ExpressionClass::BoundConjunction(BoundConjunction { + expr: self, + children, + op: out.type_, + }) + } + DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COMPARISON => { + let mut out = duckdb_vx_expr_bound_comparison { + left: ptr::null_mut(), + right: ptr::null_mut(), + type_: DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_INVALID, + }; + unsafe { duckdb_vx_expr_get_bound_comparison(self.as_ptr(), &raw mut out) }; + + ExpressionClass::BoundComparison(BoundComparison { + expr: self, + left: unsafe { Expression::borrow(out.left) }, + right: unsafe { Expression::borrow(out.right) }, + op: out.type_, + }) + } + DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_BETWEEN => { + let mut out = duckdb_vx_expr_bound_between { + input: ptr::null_mut(), + lower: ptr::null_mut(), + upper: ptr::null_mut(), + lower_inclusive: false, + upper_inclusive: false, + }; + unsafe { + duckdb_vx_expr_get_bound_between(self.as_ptr(), &raw mut out); } - }, - ) + + ExpressionClass::BoundBetween(BoundBetween { + expr: self, + input: unsafe { Expression::borrow(out.input) }, + lower: unsafe { Expression::borrow(out.lower) }, + upper: unsafe { Expression::borrow(out.upper) }, + lower_inclusive: out.lower_inclusive, + upper_inclusive: out.upper_inclusive, + }) + } + DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_OPERATOR => { + let mut out = duckdb_vx_expr_bound_operator { + children: ptr::null_mut(), + children_count: 0, + type_: DUCKDB_VX_EXPR_TYPE::DUCKDB_VX_EXPR_TYPE_INVALID, + }; + unsafe { duckdb_vx_expr_get_bound_operator(self.as_ptr(), &raw mut out) }; + + let children = + unsafe { std::slice::from_raw_parts(out.children, out.children_count) }; + + ExpressionClass::BoundOperator(BoundOperator { + expr: self, + children, + op: out.type_, + }) + } + DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_FUNCTION => { + let mut out = duckdb_vx_expr_bound_function { + children: ptr::null_mut(), + children_count: 0, + scalar_function: ptr::null_mut(), + bind_info: ptr::null_mut(), + }; + unsafe { duckdb_vx_expr_get_bound_function(self.as_ptr(), &raw mut out) }; + + let children = + unsafe { std::slice::from_raw_parts(out.children, out.children_count) }; + + ExpressionClass::BoundFunction(BoundFunction { + expr: self, + children, + scalar_function: unsafe { ScalarFunction::borrow(out.scalar_function) }, + bind_info: out.bind_info, + }) + } + _ => { + return None; + } + }) } } pub enum ExpressionClass<'a> { - BoundColumnRef(BoundColumnRef), + BoundColumnRef(BoundColumnRef<'a>), BoundConstant(BoundConstant<'a>), - BoundComparison(BoundComparison), + BoundComparison(BoundComparison<'a>), BoundConjunction(BoundConjunction<'a>), - BoundBetween(BoundBetween), + BoundBetween(BoundBetween<'a>), BoundOperator(BoundOperator<'a>), BoundFunction(BoundFunction<'a>), } -pub struct BoundColumnRef { +pub struct BoundColumnRef<'a> { + expr: &'a Expression, pub name: duckdb::string::String, + pub column_binding: ColumnBinding, +} + +impl BoundColumnRef<'_> { + /// Get the expression depth for this BoundColumnRef. + /// + /// Expression depth in DuckDB represents how many query levels deep this column reference is. + /// A depth of 0 means the column is from the current query level, + /// depth 1 means it's from a parent query (correlated subquery), etc. + /// This is important for query optimization and determining if a subquery is correlated. + pub fn expression_depth(&self) -> u64 { + unsafe { duckdb_vx_expr_get_bound_column_ref_depth(self.expr.as_ptr()) } + } } pub struct BoundConstant<'a> { + expr: &'a Expression, pub value: ValueRef<'a>, } -pub struct BoundComparison { +impl BoundConstant<'_> { + // Specific methods for BoundConstant can be added here +} + +pub struct BoundComparison<'a> { + expr: &'a Expression, pub left: Expression, pub right: Expression, - pub op: cpp::DUCKDB_VX_EXPR_TYPE, + pub op: DUCKDB_VX_EXPR_TYPE, } -pub struct BoundBetween { +pub struct BoundBetween<'a> { + expr: &'a Expression, pub input: Expression, pub lower: Expression, pub upper: Expression, @@ -176,8 +199,9 @@ pub struct BoundBetween { } pub struct BoundConjunction<'a> { - children: &'a [cpp::duckdb_vx_expr], - pub op: cpp::DUCKDB_VX_EXPR_TYPE, + expr: &'a Expression, + children: &'a [duckdb_vx_expr], + pub op: DUCKDB_VX_EXPR_TYPE, } impl BoundConjunction<'_> { @@ -190,8 +214,9 @@ impl BoundConjunction<'_> { } pub struct BoundOperator<'a> { - children: &'a [cpp::duckdb_vx_expr], - pub op: cpp::DUCKDB_VX_EXPR_TYPE, + expr: &'a Expression, + children: &'a [duckdb_vx_expr], + pub op: DUCKDB_VX_EXPR_TYPE, } impl BoundOperator<'_> { @@ -204,7 +229,8 @@ impl BoundOperator<'_> { } pub struct BoundFunction<'a> { - children: &'a [cpp::duckdb_vx_expr], + expr: &'a Expression, + children: &'a [duckdb_vx_expr], pub scalar_function: ScalarFunction, pub bind_info: *const c_void, } @@ -216,4 +242,188 @@ impl BoundFunction<'_> { .iter() .map(|&child| unsafe { Expression::borrow(child) }) } + + pub fn function_name(&self) -> Option { + unsafe { + let name_ptr = duckdb_vx_get_function_name_from_expr(self.expr.as_ptr()); + c_string_to_rust_string(name_ptr) + } + } + + /// Get function argument count - only works if this is a BoundFunction expression + pub fn function_arg_count(&self) -> usize { + unsafe { + duckdb_vx_get_function_arg_count(self.expr.as_ptr()) + .try_into() + .unwrap_or(0) + } + // } else { + // 0 + // } + } + + /// Get function argument by index - only works if this is a BoundFunction expression + pub fn get_function_arg(&self, index: usize) -> Option { + // Check if this is a BoundFunction using the class ID directly to avoid borrowing issues + unsafe { + let arg_ptr = duckdb_vx_get_function_arg(self.expr.as_ptr(), index as u64); + if arg_ptr.is_null() { + None + } else { + Some(Expression::borrow(arg_ptr)) + } + } + } +} + +// ============================================== +// Logical Plan Expression Support +// ============================================== + +/// Represents the type of an expression in DuckDB's logical plan (simplified enum) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] +pub enum LogicalExpressionType { + BoundColumnRef = DUCKDB_VX_EXPRESSION_TYPE_DUCKDB_VX_BOUND_COLUMN_REF, + BoundFunction = DUCKDB_VX_EXPRESSION_TYPE_DUCKDB_VX_BOUND_FUNCTION, + Constant = DUCKDB_VX_EXPRESSION_TYPE_DUCKDB_VX_CONSTANT, + Unknown = DUCKDB_VX_EXPRESSION_TYPE_DUCKDB_VX_EXPRESSION_UNKNOWN, +} + +/// Column binding information for logical plans +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ColumnBinding { + pub table_index: u64, + pub column_index: u64, +} + +impl From for duckdb_vx_column_binding { + fn from(binding: ColumnBinding) -> Self { + duckdb_vx_column_binding { + table_index: binding.table_index, + column_index: binding.column_index, + } + } +} + +impl From for ColumnBinding { + fn from(binding: duckdb_vx_column_binding) -> Self { + ColumnBinding { + table_index: binding.table_index, + column_index: binding.column_index, + } + } +} + +// Add logical plan methods to the unified Expression struct +impl Expression { + /// Get the logical plan type of this expression (simplified enum) + pub fn logical_expression_type(&self) -> LogicalExpressionType { + let expr_type = unsafe { duckdb_vx_get_expression_type(self.as_ptr()) }; + match expr_type { + DUCKDB_VX_EXPRESSION_TYPE_DUCKDB_VX_BOUND_COLUMN_REF => { + LogicalExpressionType::BoundColumnRef + } + DUCKDB_VX_EXPRESSION_TYPE_DUCKDB_VX_BOUND_FUNCTION => { + LogicalExpressionType::BoundFunction + } + DUCKDB_VX_EXPRESSION_TYPE_DUCKDB_VX_CONSTANT => LogicalExpressionType::Constant, + _ => LogicalExpressionType::Unknown, + } + } + + /// Get string representation using legacy function name + pub fn to_string_legacy(&self) -> VortexResult { + unsafe { + let vx_string_ptr = duckdb_vx_expression_to_string(self.as_ptr()); + match VxString::from_raw(vx_string_ptr) { + Some(vx_string) => Ok(vx_string.to_string()), + None => vortex_bail!("Failed to convert expression to string"), + } + } + } + + /// Get column alias - only works if this is a BoundColumnRef expression + pub fn column_alias(&self) -> VortexResult> { + // Check if this is a BoundColumnRef using the class ID directly to avoid borrowing issues + // if unsafe { duckdb_vx_expr_get_class(self.as_ptr()) } + // == DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COLUMN_REF + // { + unsafe { + let alias_ptr = duckdb_vx_get_column_alias(self.as_ptr()); + Ok(c_string_to_rust_string(alias_ptr)) + } + // } else { + // Ok(None) + // } + } + + /// Get column binding - only works if this is a BoundColumnRef expression + pub fn column_binding(&self) -> Option { + // Check if this is a BoundColumnRef using the class ID directly to avoid borrowing issues + // if unsafe { duckdb_vx_expr_get_class(self.as_ptr()) } + // == DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COLUMN_REF + // { + let binding = unsafe { duckdb_vx_get_column_binding(self.as_ptr()) }; + Some(binding.into()) + // } else { + // None + // } + } + + /// Update column binding - only works if this is a BoundColumnRef expression + pub fn update_column_binding(&self, binding: ColumnBinding) { + // Check if this is a BoundColumnRef using the class ID directly to avoid borrowing issues + // if unsafe { duckdb_vx_expr_get_class(self.as_ptr()) } + // == DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COLUMN_REF + // { + unsafe { + duckdb_vx_update_column_binding(self.as_ptr(), binding.into()); + } + // } else { + // false + // } + } + + /// Create a new column reference expression + pub fn create_column_ref(name: &str, binding: ColumnBinding, depth: u64) -> VortexResult { + let c_name = CString::new(name).map_err(|e| vortex_err!("Invalid column name: {}", e))?; + unsafe { + let expr_ptr = duckdb_vx_create_column_ref(c_name.as_ptr(), binding.into(), depth); + if expr_ptr.is_null() { + vortex_bail!("Failed to create column reference expression") + } else { + Ok(Self::own(expr_ptr)) + } + } + } + + /// Get detailed debug string representation of this expression + pub fn to_debug_string(&self) -> VortexResult { + unsafe { + let vx_string_ptr = duckdb_vx_expr_to_debug_string(self.as_ptr()); + match VxString::from_raw(vx_string_ptr) { + Some(vx_string) => Ok(vx_string.to_string()), + None => vortex_bail!("Failed to convert expression to debug string"), + } + } + } +} + +impl Display for Expression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.to_string_legacy() { + Ok(s) => write!(f, "{}", s), + Err(_) => write!(f, ""), + } + } +} + +impl Debug for Expression { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.to_debug_string() { + Ok(s) => write!(f, "{}", s), + Err(_) => write!(f, ""), + } + } } diff --git a/vortex-duckdb/src/duckdb/logical_operator.rs b/vortex-duckdb/src/duckdb/logical_operator.rs new file mode 100644 index 00000000000..54bbbf0ff8f --- /dev/null +++ b/vortex-duckdb/src/duckdb/logical_operator.rs @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! DuckDB logical operator manipulation API with downcasting +//! +//! This module provides safe Rust wrappers around DuckDB's logical plan operators, +//! following a similar pattern to the Expression API with downcasting to specific +//! operator types for type-safe manipulation. + +use std::fmt::{Debug, Display, Formatter}; + +use itertools::Itertools; +use vortex::error::{VortexResult, vortex_bail}; + +use crate::cpp::*; +use crate::duckdb::expr::Expression; +use crate::duckdb::string::{c_string_to_rust_string, VxString}; +use crate::wrapper; + +wrapper!(LogicalOperator, duckdb_vx_logical_operator, |_ptr| { + // TODO: Free memory + // LogicalOperator doesn't need destruction as it's owned by DuckDB's plan tree +}); + +impl Display for LogicalOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.to_debug_string() { + Ok(s) => write!(f, "{}", s), + Err(_) => write!(f, ""), + } + } +} + +/// Represents the type of a logical operator in DuckDB's query plan +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u32)] +pub enum LogicalOperatorType { + Get = DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_GET, + Projection = DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_PROJECTION, + Filter = DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_FILTER, + Join = DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_JOIN, + Aggregate = DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_AGGREGATE, + Unknown = DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_UNKNOWN, +} + +impl LogicalOperator { + /// Get the type of this logical operator + pub fn operator_type(&self) -> LogicalOperatorType { + let op_type = unsafe { duckdb_vx_get_operator_type(self.as_ptr()) }; + match op_type { + DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_GET => LogicalOperatorType::Get, + DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_PROJECTION => { + LogicalOperatorType::Projection + } + DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_FILTER => LogicalOperatorType::Filter, + DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_JOIN => LogicalOperatorType::Join, + DUCKDB_VX_LOGICAL_OPERATOR_TYPE_DUCKDB_VX_LOGICAL_AGGREGATE => { + LogicalOperatorType::Aggregate + } + _ => LogicalOperatorType::Unknown, + } + } + + /// Downcast the operator to its specific type with specialized methods + pub fn as_class(&self) -> Option { + match self.operator_type() { + LogicalOperatorType::Get => Some(LogicalOperatorClass::Get(LogicalGet { op: self })), + LogicalOperatorType::Projection => { + Some(LogicalOperatorClass::Projection(LogicalProjection { + op: self, + })) + } + LogicalOperatorType::Filter + | LogicalOperatorType::Join + | LogicalOperatorType::Aggregate + | LogicalOperatorType::Unknown => None, + } + } + + /// Get string representation of this operator for debugging + pub fn to_debug_string(&self) -> VortexResult { + unsafe { + let vx_string_ptr = duckdb_vx_logical_operator_to_string(self.as_ptr()); + match VxString::from_raw(vx_string_ptr) { + Some(vx_string) => Ok(vx_string.to_string()), + None => vortex_bail!("Failed to convert logical operator to string") + } + } + } + + /// Get the number of child operators + pub fn children_count(&self) -> usize { + unsafe { duckdb_vx_get_children_count(self.as_ptr()) as usize } + } + + /// Get a child operator by index + pub fn get_child(&self, index: usize) -> Option { + unsafe { + let child_ptr = duckdb_vx_get_child(self.as_ptr(), index as u64); + if child_ptr.is_null() { + None + } else { + Some(LogicalOperator::borrow(child_ptr)) + } + } + } + + /// Get the number of expressions in this operator + pub fn expressions_count(&self) -> usize { + unsafe { duckdb_vx_get_expressions_count(self.as_ptr()) as usize } + } + + /// Get an expression by index + pub fn get_expression(&self, index: usize) -> Option { + unsafe { + let expr_ptr = duckdb_vx_get_expression(self.as_ptr(), index as u64); + if expr_ptr.is_null() { + None + } else { + Some(Expression::borrow(expr_ptr)) + } + } + } + + /// Set an expression by index (transfers ownership) + pub fn set_expression(&self, index: usize, expression: Expression) { + unsafe { + duckdb_vx_set_expression(self.as_ptr(), index as u64, expression.as_ptr()); + // Prevent the expression from being dropped since ownership was transferred + std::mem::forget(expression); + } + } +} + +/// Enum representing different logical operator types with specialized methods +pub enum LogicalOperatorClass<'a> { + Get(LogicalGet<'a>), + Projection(LogicalProjection<'a>), +} + +/// LogicalGet operator (table scan) with table-specific methods +pub struct LogicalGet<'a> { + op: &'a LogicalOperator, +} + +impl<'a> LogicalGet<'a> { + /// Get the table function name + pub fn function_name(&self) -> VortexResult> { + unsafe { + let name_ptr = duckdb_vx_get_function_name(self.op.as_ptr()); + Ok(c_string_to_rust_string(name_ptr)) + } + } + + /// Check if this is a vortex_scan table function + pub fn is_vortex_scan(&self) -> bool { + self.function_name() + .ok() + .flatten() + .is_some_and(|func_name| func_name == "vortex_scan") + } + + /// Get column names from the table schema + pub fn column_names(&self) -> VortexResult> { + unsafe { + let count = duckdb_vx_get_column_names_count(self.op.as_ptr()); + + if count == 0 { + return Ok(Vec::new()); + } + + let mut names = Vec::with_capacity(count as usize); + for i in 0..count { + let name_ptr = duckdb_vx_get_column_name(self.op.as_ptr(), i); + if let Some(name) = c_string_to_rust_string(name_ptr) { + names.push(name); + } + } + + Ok(names) + } + } + + /// Get the current projection IDs + pub fn get_projection_ids(&self) -> VortexResult> { + unsafe { + let count = duckdb_vx_get_projection_ids_count(self.op.as_ptr()); + Ok((0..count) + .map(|i| duckdb_vx_get_projection_id(self.op.as_ptr(), i)) + .collect_vec()) + } + } + + /// Update the projection IDs for this table scan + pub fn update_projection_ids(&self, new_projection_ids: &[u64]) -> VortexResult<()> { + unsafe { + duckdb_vx_update_projection_ids( + self.op.as_ptr(), + new_projection_ids.as_ptr() as *mut u64, + new_projection_ids.len() as u64, + ); + } + Ok(()) + } + + /// Add a column ID to the scan + pub fn add_column_id(&self, column_id: u64) { + unsafe { + duckdb_vx_add_column_id(self.op.as_ptr(), column_id); + } + } + + /// Clear all column IDs + pub fn clear_column_ids(&self) { + unsafe { + duckdb_vx_clear_column_ids(self.op.as_ptr()); + } + } + + /// Get column names (wrapper for convenience) + pub fn get_column_names(&self) -> VortexResult> { + self.column_names() + } + + /// Get detailed string representation of this LogicalGet operator + pub fn to_string(&self) -> VortexResult { + unsafe { + let vx_string_ptr = duckdb_vx_logical_get_to_string(self.op.as_ptr()); + match VxString::from_raw(vx_string_ptr) { + Some(vx_string) => Ok(vx_string.to_string()), + None => vortex_bail!("Failed to convert LogicalGet to string") + } + } + } +} + +impl Debug for LogicalGet<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.to_string() { + Ok(s) => write!(f, "{}", s), + Err(_) => write!(f, ""), + } + } +} + +/// LogicalProjection operator with projection-specific methods +pub struct LogicalProjection<'a> { + op: &'a LogicalOperator, +} + +impl<'a> LogicalProjection<'a> { + /// Get the projection expressions + pub fn projections(&self) -> impl Iterator> { + (0..self.op.expressions_count()).map(move |i| self.op.get_expression(i)) + } + + /// Set a projection expression at the given index + pub fn set_projection(&self, index: usize, expression: Expression) { + self.op.set_expression(index, expression); + } + + /// Get detailed string representation of this LogicalProjection operator + pub fn to_string(&self) -> VortexResult { + unsafe { + let vx_string_ptr = duckdb_vx_logical_projection_to_string(self.op.as_ptr()); + match VxString::from_raw(vx_string_ptr) { + Some(vx_string) => Ok(vx_string.to_string()), + None => vortex_bail!("Failed to convert LogicalProjection to string") + } + } + } +} diff --git a/vortex-duckdb/src/duckdb/mod.rs b/vortex-duckdb/src/duckdb/mod.rs index c50ef706a7a..4bbf31e4220 100644 --- a/vortex-duckdb/src/duckdb/mod.rs +++ b/vortex-duckdb/src/duckdb/mod.rs @@ -8,8 +8,9 @@ mod copy_function; mod data; mod data_chunk; mod database; -mod expr; +pub mod expr; pub mod footer_cache; +pub mod logical_operator; mod logical_type; mod macro_; mod object_cache; @@ -33,6 +34,7 @@ pub use data::*; pub use data_chunk::*; pub use database::*; pub use expr::*; +pub use logical_operator::*; pub use logical_type::*; pub use object_cache::*; pub use query_result::*; diff --git a/vortex-duckdb/src/duckdb/string.rs b/vortex-duckdb/src/duckdb/string.rs index e26a080e4dc..f42d5a92cf4 100644 --- a/vortex-duckdb/src/duckdb/string.rs +++ b/vortex-duckdb/src/duckdb/string.rs @@ -5,7 +5,7 @@ use std::ffi::{CStr, c_char}; use std::fmt::{Debug, Display, Formatter}; use std::str::Utf8Error; -use crate::cpp; +use crate::cpp::*; /// Wraps a heap allocated DuckDB string. pub struct String { @@ -40,6 +40,94 @@ impl Display for String { impl Drop for String { fn drop(&mut self) { - unsafe { cpp::duckdb_free(self.ptr.cast_mut().cast()) }; + unsafe { duckdb_free(self.ptr.cast_mut().cast()) }; + } +} + +/// Safely convert a C string pointer to a Rust String using length-based copying +/// This is more efficient than null-terminated string copying and safer than CStr +pub unsafe fn c_string_to_rust_string( + str_ptr: *mut std::os::raw::c_char, +) -> Option { + if str_ptr.is_null() { + return None; + } + + let len = unsafe { duckdb_vx_c_string_length(str_ptr) }; + if len == 0 { + unsafe { duckdb_vx_free_string(str_ptr) }; + return Some(std::string::String::new()); + } + + let slice = unsafe { std::slice::from_raw_parts(str_ptr as *const u8, len as usize) }; + let result = std::string::String::from_utf8_lossy(slice).into_owned(); + unsafe { duckdb_vx_free_string(str_ptr) }; + Some(result) +} + +/// Wrapper for duckdb_vx_string that provides safe access to C++ std::string +pub struct VxString { + ptr: duckdb_vx_string, +} + +impl VxString { + /// Create a VxString from a duckdb_vx_string pointer (takes ownership) + pub unsafe fn from_raw(ptr: duckdb_vx_string) -> Option { + if ptr.is_null() { + None + } else { + Some(VxString { ptr }) + } + } + + /// Get the length of the string + pub fn len(&self) -> usize { + unsafe { duckdb_vx_string_length(self.ptr) as usize } + } + + /// Check if the string is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get the string data as a C string pointer + pub fn as_ptr(&self) -> *const std::os::raw::c_char { + unsafe { duckdb_vx_string_data(self.ptr) } + } + + /// Convert to Rust String + pub fn to_string(&self) -> std::string::String { + if self.ptr.is_null() { + return std::string::String::new(); + } + + let len = self.len(); + if len == 0 { + return std::string::String::new(); + } + + let c_str = self.as_ptr(); + let slice = unsafe { std::slice::from_raw_parts(c_str as *const u8, len) }; + std::string::String::from_utf8_lossy(slice).into_owned() + } +} + +impl Drop for VxString { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { duckdb_vx_string_free(self.ptr) }; + } + } +} + +impl Display for VxString { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_string()) + } +} + +impl Debug for VxString { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "VxString(\"{}\")", self.to_string()) } } diff --git a/vortex-duckdb/src/e2e_test/mod.rs b/vortex-duckdb/src/e2e_test/mod.rs index 0e5ad05963d..c40331823c3 100644 --- a/vortex-duckdb/src/e2e_test/mod.rs +++ b/vortex-duckdb/src/e2e_test/mod.rs @@ -4,4 +4,6 @@ #[cfg(test)] mod object_cache_test; #[cfg(test)] +mod virtual_column_test; +#[cfg(test)] mod vortex_scan_test; diff --git a/vortex-duckdb/src/e2e_test/virtual_column_test.rs b/vortex-duckdb/src/e2e_test/virtual_column_test.rs new file mode 100644 index 00000000000..5b9f6e46d1f --- /dev/null +++ b/vortex-duckdb/src/e2e_test/virtual_column_test.rs @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tests for virtual column exposure and len() function optimization. + +use tempfile::NamedTempFile; +use vortex::IntoArray; +use vortex::arrays::{PrimitiveArray, StructArray, VarBinArray}; +use vortex::file::VortexWriteOptions; + +use crate::RUNTIME; +use crate::duckdb::{Connection, Database}; + +fn database_connection_with_optimizer() -> Connection { + let mut db = Database::open_in_memory().unwrap(); + + // Register the full extension including optimizer + crate::register_extension(&mut db).unwrap(); + + db.connect().unwrap() +} + +// Use simpler test helper from existing tests +fn database_connection() -> Connection { + let db = Database::open_in_memory().unwrap(); + let connection = db.connect().unwrap(); + crate::register_table_functions(&connection).unwrap(); + connection +} + +async fn create_test_vortex_file() -> NamedTempFile { + let temp_file = NamedTempFile::new().unwrap(); + + // Create test data with string columns + let urls = VarBinArray::from_iter( + [ + "https://example.com", + "https://test.org/path", + "https://short.co", + ] + .iter() + .map(|s| Some(s.as_bytes())), + vortex::dtype::DType::Utf8(vortex::dtype::Nullability::NonNullable), + ); + + let names = VarBinArray::from_iter( + ["Alice", "Bob", "Charlie"] + .iter() + .map(|s| Some(s.as_bytes())), + vortex::dtype::DType::Utf8(vortex::dtype::Nullability::NonNullable), + ); + + let struct_array = StructArray::try_from_iter([("url", urls), ("name", names)]).unwrap(); + + let file = tokio::fs::File::create(&temp_file).await.unwrap(); + VortexWriteOptions::default() + .write(file, struct_array.to_array_stream()) + .await + .unwrap(); + + temp_file +} + +async fn create_complex_test_vortex_file() -> NamedTempFile { + let temp_file = NamedTempFile::new().unwrap(); + + // Create test data with multiple string columns and an integer column + let titles = VarBinArray::from_iter( + [ + "Machine Learning Fundamentals", + "Advanced Database Systems", + "Web Development with Rust", + "Data Structures and Algorithms", + ] + .iter() + .map(|s| Some(s.as_bytes())), + vortex::dtype::DType::Utf8(vortex::dtype::Nullability::NonNullable), + ); + + let descriptions = VarBinArray::from_iter( + [ + "An introduction to ML concepts and techniques", + "Deep dive into modern database architectures", + "Building fast web applications using Rust", + "Core computer science fundamentals", + ] + .iter() + .map(|s| Some(s.as_bytes())), + vortex::dtype::DType::Utf8(vortex::dtype::Nullability::NonNullable), + ); + + let page_counts = PrimitiveArray::from_iter([245i32, 512i32, 398i32, 687i32]); + + let struct_array = StructArray::try_from_iter([ + ("title", titles.into_array()), + ("description", descriptions.into_array()), + ("page_count", page_counts.into_array()), + ]) + .unwrap(); + + let file = tokio::fs::File::create(&temp_file).await.unwrap(); + VortexWriteOptions::default() + .write(file, struct_array.to_array_stream()) + .await + .unwrap(); + + temp_file +} + +#[test] +fn test_virtual_columns_exposed_in_schema() { + let temp_file = RUNTIME.block_on(create_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Test that we can reference virtual columns directly in a query + // If virtual columns are properly exposed, this query should succeed + let query = format!( + "SELECT url$length, name$length FROM vortex_scan('{}')", + file_path + ); + + // The fact that this doesn't error means virtual columns are exposed + conn.query(&query).unwrap(); +} + +#[test] +fn test_len_function_works() { + let temp_file = RUNTIME.block_on(create_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Query using len() function - test that it executes without error + // Whether it uses optimization or not, it should still produce correct results + let query = format!( + "SELECT len(url), len(name) FROM vortex_scan('{}')", + file_path + ); + + conn.query(&query).unwrap(); +} + +#[test] +fn test_virtual_column_direct_access() { + let temp_file = RUNTIME.block_on(create_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Query virtual columns directly (this tests that they work without len() optimization) + let query = format!( + "SELECT url$length, name$length FROM vortex_scan('{}')", + file_path + ); + + conn.query(&query).unwrap(); +} + +#[test] +fn test_optimizer_registration() { + let temp_file = RUNTIME.block_on(create_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Execute a query that should trigger optimization + let query = format!("SELECT len(url) FROM vortex_scan('{}')", file_path); + + // The fact that this doesn't crash indicates the optimizer was registered successfully + conn.query(&query).unwrap(); +} + +#[test] +fn test_mixed_virtual_and_real_columns() { + let temp_file = RUNTIME.block_on(create_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Query mixing real columns and virtual columns + let query = format!("SELECT url, url$length FROM vortex_scan('{}')", file_path); + + conn.query(&query).unwrap(); +} + +#[test] +fn test_multiple_expr() { + let temp_file = RUNTIME.block_on(create_complex_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Test virtual columns are exposed for both string columns + let virtual_columns_query = format!( + "SELECT page_count, page_count + 1, page_count + 2 FROM vortex_scan('{}')", + file_path + ); + conn.query(&virtual_columns_query).unwrap(); +} + +#[test] +fn test_multiple_string_columns_with_len_and_integer_column() { + let temp_file = RUNTIME.block_on(create_complex_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Test len() functions work on multiple columns + let len_functions_query = format!( + "SELECT len(title), len(description) FROM vortex_scan('{}')", + file_path + ); + conn.query(&len_functions_query).unwrap(); + + // Test WHERE clause with len() function + let where_clause_query = format!( + "SELECT title FROM vortex_scan('{}') WHERE len(title) > 25", + file_path + ); + conn.query(&where_clause_query).unwrap(); +} + +#[test] +fn test_multiple_string_columns_with_len_and_integer_column_complex() { + let temp_file = RUNTIME.block_on(create_complex_test_vortex_file()); + let conn = database_connection_with_optimizer(); + let file_path = temp_file.path().to_string_lossy(); + + // Test mixing string columns, virtual columns, and integer column + let mixed_query = format!( + "SELECT title, len(title), description$length, page_count FROM vortex_scan('{}')", + file_path + ); + conn.query(&mixed_query).unwrap(); +} diff --git a/vortex-duckdb/src/exporter/mod.rs b/vortex-duckdb/src/exporter/mod.rs index 02708d48a13..8cf728b53fc 100644 --- a/vortex-duckdb/src/exporter/mod.rs +++ b/vortex-duckdb/src/exporter/mod.rs @@ -13,6 +13,7 @@ mod sequence; mod temporal; mod varbinview; +use std::collections::HashMap; use std::sync::Arc; use bitvec::prelude::Lsb0; @@ -97,6 +98,95 @@ impl ArrayExporter { }) } + pub fn try_new_with_virtual_columns( + array: &StructArray, + cache: &ConversionCache, + virtual_column_requests: &[(usize, String)], + total_columns: usize, + ) -> VortexResult { + log::debug!( + "EXPORTER: Creating exporter for {} total columns, {} virtual requests", + total_columns, + virtual_column_requests.len() + ); + log::debug!("EXPORTER: Array has {} real fields", array.fields().len()); + log::debug!( + "EXPORTER: Virtual column requests: {:?}", + virtual_column_requests + ); + + // Create a mapping from field names to indices in the struct array + let field_name_to_idx: HashMap = array + .names() + .iter() + .enumerate() + .map(|(idx, name)| (name.as_ref().to_string(), idx)) + .collect(); + + // Create exporters for exactly the columns that are projected + let mut fields: Vec> = Vec::with_capacity(total_columns); + + // Track which real columns have been used + let mut next_real_field_idx = 0; + + // The key insight: we need to place exporters at the exact projection positions + // If a position corresponds to a virtual column, use a virtual exporter + // If a position corresponds to a real column, use a real exporter + for i in 0..total_columns { + // Check if this projection position is a virtual column request + if let Some((_, source_col_name)) = virtual_column_requests + .iter() + .find(|(proj_idx, _)| *proj_idx == i) + { + // This position needs a virtual column exporter + log::debug!( + "EXPORTER: Position {} is virtual column from source '{}'", + i, + source_col_name + ); + + // Find the field by name + if let Some(&field_idx) = field_name_to_idx.get(source_col_name) { + let source_field = &array.fields()[field_idx]; + fields.push(new_virtual_length_exporter(source_field.as_ref())?); + } else { + // Fallback for missing field + log::warn!( + "EXPORTER: Could not find field '{}' in struct array", + source_col_name + ); + fields.push(new_virtual_length_exporter(array.fields()[0].as_ref())?); + } + } else { + // This position needs a real column exporter + // Use the next available real field + log::debug!( + "EXPORTER: Position {} is real column, using field {}", + i, + next_real_field_idx + ); + if next_real_field_idx < array.fields().len() { + fields.push(new_array_exporter( + array.fields()[next_real_field_idx].as_ref(), + cache, + )?); + next_real_field_idx += 1; + } else { + // Fallback if we run out of fields (shouldn't happen normally) + fields.push(new_array_exporter(array.fields()[0].as_ref(), cache)?); + } + } + } + + log::debug!("EXPORTER: Created {} exporters", fields.len()); + + Ok(Self { + fields, + array_len: array.len(), + remaining: array.len(), + }) + } + /// Export the data into the next chunk. /// /// Returns `true` if a chunk was exported, `false` if all rows have been exported. @@ -238,6 +328,59 @@ impl VectorExt for Vector { } } +/// Virtual column exporter for string length computation +struct VirtualLengthExporter { + source_array: vortex::ArrayRef, +} + +impl VirtualLengthExporter { + fn new(source_array: &dyn Array) -> VortexResult { + println!("VirtualLengthExporter: new {}", source_array.display_tree()); + Ok(Self { + source_array: source_array.to_owned(), + }) + } +} + +impl ColumnExporter for VirtualLengthExporter { + fn export(&self, offset: usize, len: usize, vector: &mut Vector) -> VortexResult<()> { + use crate::cpp::duckdb_vector_get_data; + + // Get the data pointer for the integer vector + let data_ptr = unsafe { duckdb_vector_get_data(vector.as_ptr()) }; + let data_slice: &mut [i32] = + unsafe { std::slice::from_raw_parts_mut(data_ptr as *mut i32, len) }; + + println!( + "VirtualLengthExporter: export {}", + self.source_array.display_tree() + ); + + // Convert source array to canonical VarBinView to compute lengths + let varbinview = self.source_array.to_varbinview(); + + // Compute string lengths + for (i, data) in data_slice.iter_mut().enumerate().take(len) { + let string_value = varbinview.bytes_at(offset + i); + *data = i32::try_from(string_value.len()).unwrap_or(i32::MAX); + } + + // Set vector validity - all virtual columns are non-null + unsafe { + if let Some(validity) = vector.validity_slice_mut(len) { + validity.fill(true); + } + } + + Ok(()) + } +} + +/// Create a virtual length column exporter for string arrays +fn new_virtual_length_exporter(source_array: &dyn Array) -> VortexResult> { + Ok(Box::new(VirtualLengthExporter::new(source_array)?)) +} + #[cfg(test)] mod tests { use arrow_buffer::buffer::BooleanBuffer; diff --git a/vortex-duckdb/src/lib.rs b/vortex-duckdb/src/lib.rs index 846705a1849..b622cb70513 100644 --- a/vortex-duckdb/src/lib.rs +++ b/vortex-duckdb/src/lib.rs @@ -18,6 +18,8 @@ use crate::scan::VortexTableFunction; mod convert; pub mod duckdb; pub mod exporter; +pub mod optimizer; +mod rust_optimizer; mod scan; mod utils; @@ -38,6 +40,30 @@ pub fn register_table_functions(conn: &Connection) -> VortexResult<()> { conn.register_copy_function::(c"vortex", c"vortex") } +/// Initialize the Vortex extension (table functions AND optimizer) +pub fn register_extension(db: &mut Database) -> VortexResult<()> { + println!("🚀 REGISTERING: Starting Vortex extension registration..."); + + // Register the Rust-based optimizer + println!("🚀 REGISTERING: Registering Rust optimizer..."); + if let Err(e) = optimizer::register_rust_optimizer(db) { + println!( + "âš ī¸ REGISTERING: Rust optimizer registration failed: {}. Continuing with table functions only.", + e + ); + } else { + println!("✅ REGISTERING: Rust optimizer registration succeeded!"); + } + + // Register table functions + println!("🚀 REGISTERING: Registering table functions..."); + let conn = db.connect()?; + let result = register_table_functions(&conn); + println!("✅ REGISTERING: Extension registration completed!"); + + result +} + /// Global symbol visibility in the Vortex extension: /// - Rust functions use C ABI with "_rust" suffix (e.g., vortex_init_rust) /// - C++ wrapper functions have the expected name without suffix (e.g., vortex_init) @@ -49,10 +75,28 @@ pub fn register_table_functions(conn: &Connection) -> VortexResult<()> { /// The DuckDB extension ABI initialization function. #[unsafe(no_mangle)] pub unsafe extern "C" fn vortex_init_rust(db: cpp::duckdb_database) { - let conn = unsafe { Database::borrow(db) } + println!("🚀 INIT: vortex_init_rust called - registering at DuckDB extension loading time"); + + let mut database = unsafe { Database::borrow(db) }; + + // Try registering Rust optimizer first, during extension loading + println!("🚀 INIT: Registering Rust optimizer during extension loading..."); + if let Err(e) = optimizer::register_rust_optimizer(&mut database) { + println!( + "âš ī¸ INIT: Rust optimizer registration failed: {}, continuing without optimizer", + e + ); + } else { + println!("✅ INIT: Rust optimizer registration succeeded during extension loading"); + } + + // Register table functions + println!("🚀 INIT: Registering table functions..."); + let conn = database .connect() - .vortex_expect("Failed to connect to DuckDB database"); - register_table_functions(&conn).vortex_expect("Failed to initialize Vortex extension"); + .vortex_expect("Failed to connect to database"); + register_table_functions(&conn).vortex_expect("Failed to register table functions"); + println!("✅ INIT: Extension initialization complete"); } /// The DuckDB extension ABI version function. diff --git a/vortex-duckdb/src/optimizer.rs b/vortex-duckdb/src/optimizer.rs new file mode 100644 index 00000000000..d1ede13c8f5 --- /dev/null +++ b/vortex-duckdb/src/optimizer.rs @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Optimizer extension for DuckDB to rewrite len(column) -> column$length +//! +//! This module provides both the legacy C++ optimizer and the new pure Rust +//! optimizer implementation for better maintainability and customization. + +use vortex::error::VortexResult; + +use crate::duckdb::Database; +// Re-export types for backwards compatibility +pub use crate::duckdb::expr::{ColumnBinding, Expression, LogicalExpressionType as ExpressionType}; +pub use crate::duckdb::logical_operator::{LogicalOperator, LogicalOperatorType}; +pub use crate::rust_optimizer::{LengthReplacement, RustLengthOptimizer}; + + +/// Register the Rust-based length optimizer with DuckDB +/// +/// This registers the pure Rust implementation of the length optimization +/// that automatically rewrites len(column) function calls to use virtual column references. +pub fn register_rust_optimizer(db: &mut Database) -> VortexResult<()> { + crate::rust_optimizer::register_rust_optimizer(db) +} + +/// Legacy alias for backwards compatibility +pub fn register_optimizer(db: &mut Database) -> VortexResult<()> { + register_rust_optimizer(db) +} diff --git a/vortex-duckdb/src/rust_optimizer.rs b/vortex-duckdb/src/rust_optimizer.rs new file mode 100644 index 00000000000..9acbcc1db72 --- /dev/null +++ b/vortex-duckdb/src/rust_optimizer.rs @@ -0,0 +1,650 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Pure Rust implementation of the length optimization logic +//! +//! This module implements the length optimization using the generic logical plan API. + +use std::ptr; + +use log::trace; +use vortex::error::{VortexExpect, VortexResult}; +use vortex::utils::aliases::hash_map::HashMap; +use vortex::utils::aliases::hash_set::HashSet; + +use crate::duckdb::expr::{ColumnBinding, Expression}; +use crate::duckdb::logical_operator::{LogicalOperator, LogicalOperatorClass}; +use crate::duckdb::{Database, ExpressionClass}; + +/// Length replacement information +#[derive(Debug, Clone)] +pub struct LengthReplacement { + pub original_column_binding: u64, + pub virtual_column_index: u64, + pub virtual_col_name: String, + pub new_expression_binding: u64, + pub expression_index: u64, +} + +/// Pure Rust implementation of the length optimization logic +pub struct RustLengthOptimizer { + replacements: Vec, +} + +impl RustLengthOptimizer { + pub fn new() -> Self { + Self { + replacements: Vec::new(), + } + } + + /// Apply length optimization to a logical plan + pub fn optimize_plan(&mut self, plan: &LogicalOperator) -> VortexResult<()> { + // First check if the plan contains any vortex_scan + if !Self::has_vortex_scan(plan)? { + trace!("â„šī¸ RUST OPTIMIZER: No vortex_scan found in plan, skipping"); + return Ok(()); + } + + trace!("✅ RUST OPTIMIZER: Found vortex_scan in plan!"); + + // Visit all operators and apply length optimization + self.visit_and_optimize(plan)?; + + if !self.replacements.is_empty() { + trace!( + "đŸŽ¯ RUST OPTIMIZER: Found {} len() → virtual column transformations!", + self.replacements.len() + ); + } else { + trace!("â„šī¸ RUST OPTIMIZER: No len() functions found to optimize"); + } + + trace!("✅ RUST OPTIMIZER: Length optimization completed!"); + Ok(()) + } + + /// Check if the plan contains a vortex_scan + fn has_vortex_scan(op: &LogicalOperator) -> VortexResult { + // Check this operator + if let Some(LogicalOperatorClass::Get(get_op)) = op.as_class() + && get_op.is_vortex_scan() + { + return Ok(true); + } + + // Check children recursively + for i in 0..op.children_count() { + if let Some(child) = op.get_child(i) + && Self::has_vortex_scan(&child)? + { + return Ok(true); + } + } + + Ok(false) + } + + fn visit_node(&mut self, operator: &LogicalOperator) -> Option<()> { + trace!("🔍 VISITING: Operator type: {:?}", operator.operator_type()); + + let LogicalOperatorClass::Projection(proj) = operator.as_class()? else { + trace!("🔍 Not a projection operator"); + return None; + }; + if operator.children_count() != 1 { + trace!( + "🔍 Projection operator has {} children, expected 1", + operator.children_count() + ); + return None; + } + let op_child = operator.get_child(0).unwrap(); + let LogicalOperatorClass::Get(get_op) = op_child.as_class()? else { + trace!("🔍 Child is not a Get operator"); + return None; + }; + if !get_op.is_vortex_scan() { + trace!("🔍 Get operator is not a vortex_scan"); + return None; + } + + trace!("FOUND VORTEX SCAN"); + trace!("projection operator {}", operator); + trace!("scan operator {:?}", get_op); + + // Get current state + let column_names = get_op.column_names().vortex_expect("column names"); + + // First pass: Analyze all projections to understand what's being used + let mut len_replacements: Vec<(usize, u64, String)> = Vec::new(); // (proj_idx, virtual_col_idx, virtual_col_name) + let mut original_columns_used: HashSet = HashSet::new(); + let mut projection_expressions = Vec::new(); + + // Collect all projection expressions first + for projection_expr in proj.projections() { + projection_expressions.push(projection_expr); + } + + // Analyze each projection + for (idx, projection_expr) in projection_expressions.iter().enumerate() { + let Some(projection_expr) = projection_expr else { + trace!("🔍 Projection {} is None", idx); + continue; + }; + + // Check if this is a len() function + if let Some(ExpressionClass::BoundFunction(func_)) = projection_expr.as_class() { + if let Some(function_name) = func_.function_name() { + if function_name == "len" && func_.function_arg_count() > 0 { + // This is a len() function + if let Some(arg) = func_.get_function_arg(0) + && let Some(ExpressionClass::BoundColumnRef(bound_col)) = arg.as_class() + { + let column_alias = bound_col.name; + let _original_col_idx = bound_col.column_binding.column_index; + let virtual_col_name = format!("{}$length", column_alias); + + // Find virtual column in schema + if let Some(virtual_col_idx) = + column_names.iter().position(|n| *n == virtual_col_name) + { + len_replacements.push(( + idx, + virtual_col_idx as u64, + virtual_col_name.clone(), + )); + trace!( + "Found len({}) -> {} at index {}", + column_alias, virtual_col_name, virtual_col_idx + ); + } + } + } else { + // Not a len() function - check if it uses any columns + // This helps us know if original columns are still needed + Self::find_column_refs_in_expr(projection_expr, &mut original_columns_used); + } + } + } else if let Some(ExpressionClass::BoundColumnRef(col_ref)) = + projection_expr.as_class() + { + // Direct column reference + let col_idx = col_ref.column_binding.column_index; + let col_name = &col_ref.name; + + // Check if this is a virtual column that was directly referenced + let col_name_str = col_name.to_string(); + if col_name_str.ends_with("$length") { + // This is a virtual column - find its actual index in the schema + if let Some(actual_col_idx) = + column_names.iter().position(|n| *n == col_name_str) + { + trace!( + "Virtual column reference at projection {}: {} (bound to index {} but actually at {})", + idx, col_name, col_idx, actual_col_idx + ); + // Don't add to original_columns_used since this is a virtual column + } else { + trace!("Virtual column {} not found in schema", col_name); + } + } else { + // Regular column reference + original_columns_used.insert(col_idx); + trace!( + "Direct column reference at projection {}: column {} (name: {})", + idx, col_idx, col_name + ); + } + } + } + + // Now determine the final column_ids and update projections + if !len_replacements.is_empty() { + // Step 1: Collect all unique columns needed (both regular and virtual) + let mut required_columns = HashSet::new(); + let mut projection_mappings = Vec::new(); // Maps each projection to its required column + + // Process each projection expression to understand what columns it needs + for (idx, expr) in projection_expressions.iter().enumerate() { + if let Some(expr) = expr { + // Check if this projection is a len() replacement + if let Some((_, virtual_col_idx, _)) = len_replacements + .iter() + .find(|(proj_idx, ..)| *proj_idx == idx) + { + // This projection will use the virtual column + required_columns.insert(*virtual_col_idx); + projection_mappings.push(*virtual_col_idx); + } else if let Some(ExpressionClass::BoundColumnRef(col_ref)) = expr.as_class() { + // Check if this is a virtual column reference + let col_name_str = col_ref.name.to_string(); + if col_name_str.ends_with("$length") { + // This is a direct virtual column reference - find its actual index + if let Some(actual_col_idx) = + column_names.iter().position(|n| *n == col_name_str) + { + let actual_col_idx = actual_col_idx as u64; + required_columns.insert(actual_col_idx); + projection_mappings.push(actual_col_idx); + } else { + // Fallback to bound index if not found + let col_idx = col_ref.column_binding.column_index; + required_columns.insert(col_idx); + projection_mappings.push(col_idx); + } + } else { + // Regular column reference + let col_idx = col_ref.column_binding.column_index; + required_columns.insert(col_idx); + projection_mappings.push(col_idx); + } + } else { + // For other expressions, try to find column dependencies + let mut expr_columns = HashSet::new(); + Self::find_column_refs_in_expr(expr, &mut expr_columns); + for col_id in expr_columns { + required_columns.insert(col_id); + } + // Use the first column found, or default to projection index + let first_col = projection_mappings.first().copied().unwrap_or(idx as u64); + projection_mappings.push(first_col); + } + } else { + projection_mappings.push(idx as u64); + } + } + + // Step 2: Create column_ids list in the order they're needed by projections + // Don't sort - preserve the order projections need them in + let mut new_column_ids = Vec::new(); + let mut seen_columns = HashSet::new(); + for &col_id in &projection_mappings { + if !seen_columns.contains(&col_id) { + new_column_ids.push(col_id); + seen_columns.insert(col_id); + } + } + + // Step 3: Create mapping from column_id to position in new_column_ids + let column_to_position: HashMap = new_column_ids + .iter() + .enumerate() + .map(|(pos, &col_id)| (col_id, pos)) + .collect(); + + // Step 4: Replace len() expressions with virtual column references AND fix direct virtual column bindings + for (proj_idx, virtual_col_idx, virtual_col_name) in &len_replacements { + if let Some(proj_expr) = projection_expressions[*proj_idx].as_ref() + && let Some(ExpressionClass::BoundFunction(func_)) = proj_expr.as_class() + && let Some(arg) = func_.get_function_arg(0) + && let Some(ExpressionClass::BoundColumnRef(bound_col)) = arg.as_class() + { + // Get the position of the virtual column in our new column_ids + let position_in_column_ids = column_to_position + .get(virtual_col_idx) + .copied() + .unwrap_or(0); + + let Ok(virtual_col_ref) = Expression::create_column_ref( + virtual_col_name, + ColumnBinding { + table_index: bound_col.column_binding.table_index, + column_index: position_in_column_ids as u64, + }, + 0, + ) else { + continue; + }; + + proj.set_projection(*proj_idx, virtual_col_ref); + + self.replacements.push(LengthReplacement { + original_column_binding: bound_col.column_binding.column_index, + virtual_column_index: *virtual_col_idx, + virtual_col_name: virtual_col_name.clone(), + new_expression_binding: position_in_column_ids as u64, + expression_index: *proj_idx as u64, + }); + } + } + + // Step 4.5: Fix direct virtual column references + for (idx, expr) in projection_expressions.iter().enumerate() { + if let Some(expr) = expr + && let Some(ExpressionClass::BoundColumnRef(col_ref)) = expr.as_class() + { + let col_name_str = col_ref.name.to_string(); + if col_name_str.ends_with("$length") { + // This is a direct virtual column reference that needs fixing + if let Some(actual_col_idx) = + column_names.iter().position(|n| *n == col_name_str) + { + let position_in_column_ids = column_to_position + .get(&(actual_col_idx as u64)) + .copied() + .unwrap_or(0); + + // Create corrected virtual column reference + if let Ok(corrected_col_ref) = Expression::create_column_ref( + &col_name_str, + ColumnBinding { + table_index: col_ref.column_binding.table_index, + column_index: position_in_column_ids as u64, + }, + 0, + ) { + proj.set_projection(idx, corrected_col_ref); + trace!( + "🔍 Fixed virtual column reference {}: {} -> position {}", + idx, col_name_str, position_in_column_ids + ); + } + } + } + } + } + + // Step 5: Update column_ids and projection_ids + trace!("🔍 Final column_ids: {:?}", new_column_ids); + get_op.clear_column_ids(); + for &col_id in &new_column_ids { + get_op.add_column_id(col_id); + } + + // Create projection_ids that map each projection to its position in column_ids + trace!( + "🔍 Projection mappings (column each projection needs): {:?}", + projection_mappings + ); + trace!("🔍 Column to position mapping: {:?}", column_to_position); + let projection_ids: Vec = projection_mappings + .iter() + .map(|&col_id| column_to_position.get(&col_id).copied().unwrap_or(0) as u64) + .collect(); + + let _ = get_op.update_projection_ids(&projection_ids); + trace!("🔍 Final projection_ids: {:?}", projection_ids); + + // Debug: Print final projection expressions + trace!("🔍 Final projection expressions:"); + for (i, expr) in proj.projections().enumerate() { + if let Some(expr) = expr { + trace!(" [{}]: {}", i, expr); + } else { + trace!(" [{}]: None", i); + } + } + } + + Some(()) + } + + /// Find column references in an expression + fn find_column_refs_in_expr(expr: &Expression, columns_used: &mut HashSet) { + if let Some(ExpressionClass::BoundColumnRef(col_ref)) = expr.as_class() { + columns_used.insert(col_ref.column_binding.column_index); + } else if let Some(ExpressionClass::BoundFunction(func)) = expr.as_class() { + // Check arguments of functions + for i in 0..func.function_arg_count() { + if let Some(arg) = func.get_function_arg(i) { + Self::find_column_refs_in_expr(&arg, columns_used); + } + } + } else if let Some(ExpressionClass::BoundOperator(op)) = expr.as_class() { + // Check children of operators + for child in op.children() { + Self::find_column_refs_in_expr(&child, columns_used); + } + } + } + + /// Visit all operators and apply optimizations + fn visit_and_optimize(&mut self, op: &LogicalOperator) -> VortexResult<()> { + trace!("🔍 VISITING: Operator type: {:?}", op.operator_type()); + + self.visit_node(op); + + // Visit children + for i in 0..op.children_count() { + if let Some(child) = op.get_child(i) { + self.visit_and_optimize(&child)?; + } + } + + Ok(()) + } + + /// Get the collected replacements + pub fn get_replacements(&self) -> &[LengthReplacement] { + &self.replacements + } +} + +impl Default for RustLengthOptimizer { + fn default() -> Self { + Self::new() + } +} + +/// C callback function that implements the optimization in Rust +extern "C-unwind" fn rust_optimizer_callback( + plan: crate::cpp::duckdb_vx_logical_operator, + _user_data: *mut std::ffi::c_void, +) { + if plan.is_null() { + return; + } + + let logical_op = unsafe { LogicalOperator::borrow(plan) }; + + // Create and run the optimizer + let mut optimizer = RustLengthOptimizer::new(); + match optimizer.optimize_plan(&logical_op) { + Ok(()) => { + trace!("✅ RUST OPTIMIZER: Optimization completed successfully!"); + let replacements = optimizer.get_replacements(); + if !replacements.is_empty() { + trace!( + "📊 RUST OPTIMIZER: Made {} replacements:", + replacements.len() + ); + for (i, replacement) in replacements.iter().enumerate() { + trace!( + " {}. {} → {}", + i + 1, + replacement.original_column_binding, + replacement.virtual_col_name + ); + } + } + } + Err(e) => { + trace!("❌ RUST OPTIMIZER: Optimization failed: {}", e); + } + } +} + +/// Register the Rust-based length optimizer with DuckDB +pub fn register_rust_optimizer(db: &mut Database) -> VortexResult<()> { + trace!("🔧 REGISTERING: Rust-based length optimizer..."); + + unsafe { + crate::cpp::duckdb_vx_register_rust_optimizer( + db.as_ptr(), + Some(rust_optimizer_callback), + ptr::null_mut(), + ); + } + + trace!("✅ SUCCESS: Rust-based length optimizer registered!"); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_length_replacement() { + let replacement = LengthReplacement { + original_column_binding: 1, + virtual_column_index: 5, + virtual_col_name: "title$length".to_string(), + new_expression_binding: 5, + expression_index: 0, + }; + + assert_eq!(replacement.virtual_col_name, "title$length"); + assert_eq!(replacement.virtual_column_index, 5); + } + + #[test] + fn test_optimizer_creation() { + let _optimizer = RustLengthOptimizer::new(); + assert_eq!(_optimizer.get_replacements().len(), 0); + } + + #[test] + fn test_multiple_replacements() { + let mut optimizer = RustLengthOptimizer::new(); + + // Simulate finding multiple length functions + optimizer.replacements.push(LengthReplacement { + original_column_binding: 0, + virtual_column_index: 2, + virtual_col_name: "title$length".to_string(), + new_expression_binding: 2, + expression_index: 0, + }); + + optimizer.replacements.push(LengthReplacement { + original_column_binding: 1, + virtual_column_index: 3, + virtual_col_name: "description$length".to_string(), + new_expression_binding: 3, + expression_index: 1, + }); + + assert_eq!(optimizer.get_replacements().len(), 2); + assert_eq!( + optimizer.get_replacements()[0].virtual_col_name, + "title$length" + ); + assert_eq!( + optimizer.get_replacements()[1].virtual_col_name, + "description$length" + ); + } + + #[test] + fn test_replacement_tracking() { + let mut optimizer = RustLengthOptimizer::new(); + + // Add replacements for different columns + let columns = ["title", "author", "description", "content"]; + + for (idx, col) in columns.iter().enumerate() { + optimizer.replacements.push(LengthReplacement { + original_column_binding: idx as u64, + virtual_column_index: (idx + 10) as u64, + virtual_col_name: format!("{}$length", col), + new_expression_binding: (idx + 10) as u64, + expression_index: idx as u64, + }); + } + + // Verify all replacements are tracked + assert_eq!(optimizer.get_replacements().len(), 4); + + // Verify each replacement has correct virtual column name + for (idx, replacement) in optimizer.get_replacements().iter().enumerate() { + assert_eq!( + replacement.virtual_col_name, + format!("{}$length", columns[idx]) + ); + assert_eq!(replacement.virtual_column_index, (idx + 10) as u64); + } + } + + #[test] + fn test_projection_id_updates() { + let _optimizer = RustLengthOptimizer::new(); + + // Test different projection update scenarios + + // Scenario 1: All projections are length functions + let mut proj_ids_all_length = vec![0, 1, 2]; + let virtual_cols = [10, 11, 12]; + + // Simulate what would happen in update_vortex_scan_projections + for (i, &virtual_col) in virtual_cols.iter().enumerate() { + if i < proj_ids_all_length.len() { + proj_ids_all_length[i] = virtual_col; + } + } + + assert_eq!(proj_ids_all_length, vec![10, 11, 12]); + + // Scenario 2: Mixed projections (some length, some regular) + let mut proj_ids_mixed = vec![0, 1, 2, 3]; + let replacements_at = [1, 3]; // Only replace at positions 1 and 3 + let virtual_values = [20, 21]; + + for (i, &pos) in replacements_at.iter().enumerate() { + proj_ids_mixed[pos] = virtual_values[i]; + } + + assert_eq!(proj_ids_mixed, vec![0, 20, 2, 21]); + } + + #[test] + fn test_virtual_column_name_generation() { + let test_cases = vec![ + ("title", "title$length"), + ("description", "description$length"), + ("user_name", "user_name$length"), + ("id", "id$length"), + ]; + + for (column_name, expected_virtual) in test_cases { + let virtual_name = format!("{}$length", column_name); + assert_eq!(virtual_name, expected_virtual); + } + } + + #[test] + fn test_optimizer_state_management() { + let mut optimizer1 = RustLengthOptimizer::new(); + let mut optimizer2 = RustLengthOptimizer::new(); + + // Add replacements to optimizer1 + optimizer1.replacements.push(LengthReplacement { + original_column_binding: 0, + virtual_column_index: 10, + virtual_col_name: "col1$length".to_string(), + new_expression_binding: 10, + expression_index: 0, + }); + + // Verify optimizers maintain independent state + assert_eq!(optimizer1.get_replacements().len(), 1); + assert_eq!(optimizer2.get_replacements().len(), 0); + + // Add different replacement to optimizer2 + optimizer2.replacements.push(LengthReplacement { + original_column_binding: 1, + virtual_column_index: 20, + virtual_col_name: "col2$length".to_string(), + new_expression_binding: 20, + expression_index: 1, + }); + + assert_eq!(optimizer1.get_replacements().len(), 1); + assert_eq!(optimizer2.get_replacements().len(), 1); + assert_ne!( + optimizer1.get_replacements()[0].virtual_col_name, + optimizer2.get_replacements()[0].virtual_col_name + ); + } +} diff --git a/vortex-duckdb/src/scan.rs b/vortex-duckdb/src/scan.rs index a599ed67eb2..24e8275798d 100644 --- a/vortex-duckdb/src/scan.rs +++ b/vortex-duckdb/src/scan.rs @@ -11,7 +11,6 @@ use itertools::Itertools; use num_traits::AsPrimitive; use tokio::task::block_in_place; use url::Url; -use vortex::dtype::FieldNames; use vortex::error::{VortexExpect, VortexResult, vortex_bail, vortex_err}; use vortex::expr::{ExprRef, and, and_collect, col, lit, root, select}; use vortex::file::{VortexFile, VortexOpenOptions}; @@ -19,7 +18,6 @@ use vortex::scan::{MultiScan, MultiScanIterator}; use vortex::{ArrayRef, ToCanonical}; use vortex_file::GenericVortexFile; -use crate::RUNTIME; use crate::convert::{try_from_bound_expression, try_from_table_filter}; use crate::duckdb::footer_cache::FooterCache; use crate::duckdb::{ @@ -29,6 +27,7 @@ use crate::duckdb::{ use crate::exporter::{ArrayExporter, ConversionCache}; use crate::utils::glob::expand_glob; use crate::utils::object_store::s3_store; +use crate::{RUNTIME, cpp}; pub struct VortexBindData { first_file: VortexFile, @@ -36,6 +35,7 @@ pub struct VortexBindData { file_urls: Vec, column_names: Vec, column_types: Vec, + virtual_column_mappings: Vec<(usize, usize)>, // (virtual_column_index, source_column_index) } impl Clone for VortexBindData { @@ -48,6 +48,7 @@ impl Clone for VortexBindData { file_urls: self.file_urls.clone(), column_names: self.column_names.clone(), column_types: self.column_types.clone(), + virtual_column_mappings: self.virtual_column_mappings.clone(), } } } @@ -66,6 +67,7 @@ impl Debug for VortexBindData { pub struct VortexGlobalData { scan: MultiScan<(ArrayRef, Arc)>, batch_id: AtomicU64, + virtual_column_requests: Vec<(usize, String)>, // (projection_idx, source_column_name) } pub struct VortexLocalData { @@ -102,28 +104,84 @@ fn extract_schema_from_vortex_file( } /// Creates a projection expression based on the table initialization input. -fn extract_projection_expr(init: &TableInitInput) -> ExprRef { +/// Returns both the projection expression for real columns and indices of virtual columns +fn extract_projection_expr( + init: &TableInitInput, +) -> (ExprRef, Vec<(usize, String)>) { let projection_ids = init.projection_ids().unwrap_or(&[]); let column_ids = init.column_ids(); + let bind_data = init.bind_data(); + + let mut real_columns: Vec> = Vec::new(); + let mut virtual_column_requests = Vec::new(); + #[allow(clippy::disallowed_types)] + let mut needed_source_columns = std::collections::HashSet::new(); + + // First pass: identify virtual columns and their source columns + for (proj_idx, p) in projection_ids.iter().enumerate() { + let idx: usize = p.as_(); + + // Check if this index is within the column_ids range + if idx >= column_ids.len() { + // This might be a virtual column that doesn't have a column_id assigned + println!( + "âš ī¸ Column index {} is beyond column_ids range ({}), might be virtual column", + idx, + column_ids.len() + ); + continue; + } + + let col_idx: usize = column_ids[idx].as_(); - select( - projection_ids + // Check if this is a virtual column + if let Some((_, source_idx)) = bind_data + .virtual_column_mappings .iter() - .map(|p| { - let idx: usize = p.as_(); - let val: usize = column_ids[idx].as_(); - val - }) - .map(|idx| { - init.bind_data() - .column_names - .get(idx) - .vortex_expect("prune idx in column names") - }) - .map(|s| Arc::from(s.as_str())) - .collect::(), - root(), - ) + .find(|(virt_idx, _)| *virt_idx == col_idx) + { + // This is a virtual column, track it with the source column name + let source_col_name = bind_data + .column_names + .get(*source_idx) + .vortex_expect("source column must exist"); + virtual_column_requests.push((proj_idx, source_col_name.clone())); + needed_source_columns.insert(*source_idx); + } else if col_idx < bind_data.column_names.len() - bind_data.virtual_column_mappings.len() { + // This is a real column + let col_name = bind_data + .column_names + .get(col_idx) + .vortex_expect("prune idx in column names"); + real_columns.push(Arc::from(col_name.as_str())); + } + } + + // Second pass: ensure all needed source columns are included in projection + for &source_idx in &needed_source_columns { + let source_col_name = bind_data + .column_names + .get(source_idx) + .vortex_expect("source column must exist"); + let source_col_arc = Arc::from(source_col_name.as_str()); + if !real_columns.contains(&source_col_arc) { + real_columns.push(source_col_arc); + } + } + + let projection = if real_columns.is_empty() { + // If no real columns requested, still need to project something + // Project the first column to get row count + use vortex::dtype::FieldNames; + let field_names: FieldNames = vec![Arc::from(bind_data.column_names[0].as_str())].into(); + select(field_names, root()) + } else { + use vortex::dtype::FieldNames; + let field_names: FieldNames = real_columns.into(); + select(field_names, root()) + }; + + (projection, virtual_column_requests) } /// Creates a table filter expression from the table filter set. @@ -236,19 +294,45 @@ impl TableFunction for VortexTableFunction { }) })?; - let (column_names, column_types) = extract_schema_from_vortex_file(&first_file)?; + let (mut column_names, mut column_types) = extract_schema_from_vortex_file(&first_file)?; + + // Create a list to track original column count before adding virtual columns + let original_column_count = column_names.len(); + + // Add virtual $length columns for string columns + let mut virtual_columns = Vec::new(); + let mut virtual_column_mappings = Vec::new(); + for i in 0..original_column_count { + if column_types[i].as_type_id() == cpp::DUCKDB_TYPE::DUCKDB_TYPE_VARCHAR { + let virtual_name = format!("{}$length", column_names[i]); + let virtual_index = column_names.len() + virtual_columns.len(); + virtual_columns.push(( + virtual_name.clone(), + LogicalType::new(cpp::DUCKDB_TYPE::DUCKDB_TYPE_INTEGER), + )); + virtual_column_mappings.push((virtual_index, i)); + } + } // Add result columns based on the extracted schema. for (column_name, column_type) in column_names.iter().zip(&column_types) { result.add_result_column(column_name, column_type); } + // Add virtual columns to the result + for (virtual_name, virtual_type) in &virtual_columns { + result.add_result_column(virtual_name, virtual_type); + column_names.push(virtual_name.clone()); + column_types.push(virtual_type.clone()); + } + Ok(VortexBindData { file_urls, first_file, filter_exprs: vec![], column_names, column_types, + virtual_column_mappings, }) } @@ -267,9 +351,11 @@ impl TableFunction for VortexTableFunction { let (array_result, conversion_cache) = result?; - local_state.exporter = Some(ArrayExporter::try_new( + local_state.exporter = Some(ArrayExporter::try_new_with_virtual_columns( &array_result.to_struct(), &conversion_cache, + &global_state.virtual_column_requests, + chunk.column_count(), )?); // Relaxed since there is no intra-instruction ordering required. local_state.batch_id = Some(global_state.batch_id.fetch_add(1, Ordering::Relaxed)); @@ -298,20 +384,22 @@ impl TableFunction for VortexTableFunction { fn init_global(init_input: &TableInitInput) -> VortexResult { let bind_data = init_input.bind_data(); - let projection_expr = extract_projection_expr(init_input); + let (projection_expr, virtual_column_requests) = extract_projection_expr(init_input); let filter_expr = extract_table_filter_expr(init_input, init_input.column_ids())?; log::trace!( - "Global init Vortex scan SELECT {} WHERE {}", + "Global init Vortex scan SELECT {} WHERE {}, virtual columns: {:?}", &projection_expr, filter_expr .as_ref() - .map_or("true".to_string(), |f| f.to_string()) + .map_or("true".to_string(), |f| f.to_string()), + virtual_column_requests ); let client_context = init_input.client_context()?; let object_cache = client_context.object_cache(); + let closures = bind_data .file_urls @@ -360,6 +448,7 @@ impl TableFunction for VortexTableFunction { Ok(VortexGlobalData { scan: MultiScan::new(closures), batch_id: AtomicU64::new(0), + virtual_column_requests, }) } @@ -423,8 +512,12 @@ impl TableFunction for VortexTableFunction { let mut filters = bind_data.filter_exprs.iter().map(|f| format!("{}", f)); result.push(("Filters".to_string(), filters.join(" /\\\n"))); } + // NOTE: Projection is already printed by the planner. Some(result) } } + +#[cfg(test)] +mod tests; diff --git a/vortex-duckdb/src/scan/tests.rs b/vortex-duckdb/src/scan/tests.rs new file mode 100644 index 00000000000..2cd98b49c7a --- /dev/null +++ b/vortex-duckdb/src/scan/tests.rs @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Unit tests for scan.rs functionality + +#[cfg(test)] +#[allow(clippy::module_inception)] +mod tests { + use crate::cpp::DUCKDB_TYPE; + use crate::duckdb::LogicalType; + + #[test] + fn test_varchar_detection() { + // Test that VARCHAR type is correctly identified for virtual column generation + let varchar_type = LogicalType::new(DUCKDB_TYPE::DUCKDB_TYPE_VARCHAR); + assert_eq!(varchar_type.as_type_id(), DUCKDB_TYPE::DUCKDB_TYPE_VARCHAR); + + let int_type = LogicalType::new(DUCKDB_TYPE::DUCKDB_TYPE_INTEGER); + assert_ne!(int_type.as_type_id(), DUCKDB_TYPE::DUCKDB_TYPE_VARCHAR); + } + + #[test] + fn test_virtual_column_name_format() { + // Test that virtual column names are formatted correctly + let base_name = "my_column"; + let virtual_name = format!("{}$length", base_name); + assert_eq!(virtual_name, "my_column$length"); + } +}