Skip to content

Commit

Permalink
1.2.0 followup (#428)
Browse files Browse the repository at this point in the history
* Use ptr::write when writing uninitialized memory

* Use smaller unsafe blocks

* Rust BindData/InitData can just use Rust types

* Use unsafe blocks inside unsafe fns generated by macro

This will prevent the macro generating a warning in edition 2024

* Fix unused import

* Better safety docs for vtab methods

Contrary to the previous docs, the instances passed to these functions are *not* initialized by the caller. Rather, the called function is responsible for writing into uninitialized memory.

* Similar safety fixes in vtab tests

* VTab::bind and init are now safe

Rather than passing a pointer to a block of uninitialized memory, which can easily lead to UB, these functions now just return Rust objects.

This improves #414 by reducing the amount of unsafe code needed from extensions.

* vtab::Free is no longer needed

BindInfo and InitInfo will be dropped in the usual way when freed by duckdb core. Any necessary destructors can be in Drop impls.

* BindData and InitData should be Send+Sync

It's not completely clear but it looks like the engine could run the table fn from multiple threads, so requiring this seems safer

* Add a safe & typed interface to get bind_data

* Also safely retrieve the init_data

* Add unsafe blocks, rm unnecessary cast

* clippy

---------

Co-authored-by: Martin Pool <[email protected]>
  • Loading branch information
Maxxen and sourcefrog authored Feb 10, 2025
1 parent 7848ebb commit 1e29fc1
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 294 deletions.
67 changes: 22 additions & 45 deletions crates/duckdb/examples/hello-ext-capi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,78 +4,55 @@ extern crate libduckdb_sys;

use duckdb::{
core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId},
vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab},
vtab::{BindInfo, FunctionInfo, InitInfo, VTab},
Connection, Result,
};
use duckdb_loadable_macros::duckdb_entrypoint_c_api;
use libduckdb_sys as ffi;
use std::{
error::Error,
ffi::{c_char, CString},
ffi::CString,
sync::atomic::{AtomicBool, Ordering},
};

#[repr(C)]
struct HelloBindData {
name: *mut c_char,
}

impl Free for HelloBindData {
fn free(&mut self) {
unsafe {
if self.name.is_null() {
return;
}
drop(CString::from_raw(self.name));
}
}
name: String,
}

#[repr(C)]
struct HelloInitData {
done: bool,
done: AtomicBool,
}

struct HelloVTab;

impl Free for HelloInitData {}

impl VTab for HelloVTab {
type InitData = HelloInitData;
type BindData = HelloBindData;

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn std::error::Error>> {
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let param = bind.get_parameter(0).to_string();
unsafe {
(*data).name = CString::new(param).unwrap().into_raw();
}
Ok(())
let name = bind.get_parameter(0).to_string();
Ok(HelloBindData { name })
}

unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
(*data).done = false;
}
Ok(())
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
Ok(HelloInitData {
done: AtomicBool::new(false),
})
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data::<HelloInitData>();
let bind_info = func.get_bind_data::<HelloBindData>();

unsafe {
if (*init_info).done {
output.set_len(0);
} else {
(*init_info).done = true;
let vector = output.flat_vector(0);
let name = CString::from_raw((*bind_info).name);
let result = CString::new(format!("Hello {}", name.to_str()?))?;
// Can't consume the CString
(*bind_info).name = CString::into_raw(name);
vector.insert(0, result);
output.set_len(1);
}
fn func(func: &FunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_data = func.get_init_data();
let bind_data = func.get_bind_data();
if init_data.done.swap(true, Ordering::Relaxed) {
output.set_len(0);
} else {
let vector = output.flat_vector(0);
let result = CString::new(format!("Hello {}", bind_data.name))?;
vector.insert(0, result);
output.set_len(1);
}
Ok(())
}
Expand Down
38 changes: 14 additions & 24 deletions crates/duckdb/examples/hello-ext/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,51 @@ extern crate libduckdb_sys;

use duckdb::{
core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId},
vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab},
vtab::{BindInfo, FunctionInfo, InitInfo, VTab},
Connection, Result,
};
use duckdb_loadable_macros::duckdb_entrypoint;
use libduckdb_sys as ffi;
use std::{
error::Error,
ffi::{c_char, c_void, CString},
ptr,
sync::atomic::{AtomicBool, Ordering},
};

struct HelloBindData {
name: String,
}

impl Free for HelloBindData {}

struct HelloInitData {
done: bool,
done: AtomicBool,
}

struct HelloVTab;

impl Free for HelloInitData {}

impl VTab for HelloVTab {
type InitData = HelloInitData;
type BindData = HelloBindData;

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn std::error::Error>> {
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let name = bind.get_parameter(0).to_string();
unsafe {
ptr::write(data, HelloBindData { name });
}
Ok(())
Ok(HelloBindData { name })
}

unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
ptr::write(data, HelloInitData { done: false });
}
Ok(())
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
Ok(HelloInitData {
done: AtomicBool::new(false),
})
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = unsafe { func.get_init_data::<HelloInitData>().as_mut().unwrap() };
let bind_info = unsafe { func.get_bind_data::<HelloBindData>().as_mut().unwrap() };

if init_info.done {
fn func(func: &FunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_data = func.get_init_data();
let bind_data = func.get_bind_data();
if init_data.done.swap(true, Ordering::Relaxed) {
output.set_len(0);
} else {
init_info.done = true;
let vector = output.flat_vector(0);
let result = CString::new(format!("Hello {}", bind_info.name))?;
let result = CString::new(format!("Hello {}", bind_data.name))?;
vector.insert(0, result);
output.set_len(1);
}
Expand Down
62 changes: 23 additions & 39 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{BindInfo, DataChunkHandle, Free, FunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab};
use std::ptr::null_mut;
use super::{BindInfo, DataChunkHandle, FunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab};
use std::sync::{atomic::AtomicBool, Mutex};

use crate::core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, Vector};
use arrow::{
Expand All @@ -24,28 +24,15 @@ use num::{cast::AsPrimitive, ToPrimitive};
/// A pointer to the Arrow record batch for the table function.
#[repr(C)]
pub struct ArrowBindData {
rb: *mut RecordBatch,
}

impl Free for ArrowBindData {
fn free(&mut self) {
unsafe {
if self.rb.is_null() {
return;
}
drop(Box::from_raw(self.rb));
}
}
rb: Mutex<RecordBatch>,
}

/// Keeps track of whether the Arrow record batch has been consumed.
#[repr(C)]
pub struct ArrowInitData {
done: bool,
done: AtomicBool,
}

impl Free for ArrowInitData {}

/// The Arrow table function.
pub struct ArrowVTab;

Expand Down Expand Up @@ -76,14 +63,14 @@ impl VTab for ArrowVTab {
type BindData = ArrowBindData;
type InitData = ArrowInitData;

unsafe fn bind(bind: &BindInfo, data: *mut ArrowBindData) -> Result<(), Box<dyn std::error::Error>> {
(*data).rb = null_mut();
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>> {
let param_count = bind.get_parameter_count();
if param_count != 2 {
return Err(format!("Bad param count: {param_count}, expected 2").into());
}
let array = bind.get_parameter(0).to_int64();
let schema = bind.get_parameter(1).to_int64();

unsafe {
let rb = address_to_arrow_record_batch(array as usize, schema as usize);
for f in rb.schema().fields() {
Expand All @@ -92,32 +79,29 @@ impl VTab for ArrowVTab {
let logical_type = to_duckdb_logical_type(data_type)?;
bind.add_result_column(name, logical_type);
}
(*data).rb = Box::into_raw(Box::new(rb));

Ok(ArrowBindData { rb: Mutex::new(rb) })
}
Ok(())
}

unsafe fn init(_: &InitInfo, data: *mut ArrowInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
(*data).done = false;
}
Ok(())
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
Ok(ArrowInitData {
done: AtomicBool::new(false),
})
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data::<ArrowInitData>();
let bind_info = func.get_bind_data::<ArrowBindData>();
unsafe {
if (*init_info).done {
output.set_len(0);
} else {
let rb = Box::from_raw((*bind_info).rb);
(*bind_info).rb = null_mut(); // erase ref in case of failure in record_batch_to_duckdb_data_chunk
record_batch_to_duckdb_data_chunk(&rb, output)?;
(*bind_info).rb = Box::into_raw(rb);
(*init_info).done = true;
}
fn func(func: &FunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data();
let bind_info = func.get_bind_data();

if init_info.done.load(std::sync::atomic::Ordering::Relaxed) {
output.set_len(0);
} else {
let rb = bind_info.rb.lock().unwrap();
record_batch_to_duckdb_data_chunk(&rb, output)?;
init_info.done.store(true, std::sync::atomic::Ordering::Relaxed);
}

Ok(())
}

Expand Down
Loading

0 comments on commit 1e29fc1

Please sign in to comment.