Skip to content

Commit

Permalink
nits and revert status change
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Jan 3, 2025
1 parent 34e60b9 commit 0bd3ee5
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 84 deletions.
10 changes: 5 additions & 5 deletions appender.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appende
var duckdbAppender C.duckdb_appender
state := C.duckdb_appender_create(con.duckdbCon, cSchema, cTable, &duckdbAppender)

if returnState(state) == stateError {
if state == C.DuckDBError {
// We destroy the error message when destroying the appender.
err := duckdbError(C.duckdb_appender_error(duckdbAppender))
C.duckdb_appender_destroy(&duckdbAppender)
Expand Down Expand Up @@ -95,7 +95,7 @@ func (a *Appender) Flush() error {
}

state := C.duckdb_appender_flush(a.duckdbAppender)
if returnState(state) == stateError {
if state == C.DuckDBError {
err := duckdbError(C.duckdb_appender_error(a.duckdbAppender))
return getError(errAppenderFlush, invalidatedAppenderError(err))
}
Expand All @@ -116,15 +116,15 @@ func (a *Appender) Close() error {
// We flush before closing to get a meaningful error message.
var errFlush error
state := C.duckdb_appender_flush(a.duckdbAppender)
if returnState(state) == stateError {
if state == C.DuckDBError {
errFlush = duckdbError(C.duckdb_appender_error(a.duckdbAppender))
}

// Destroy all appender data and the appender.
destroyTypeSlice(a.ptr, a.types)
var errClose error
state = C.duckdb_appender_destroy(&a.duckdbAppender)
if returnState(state) == stateError {
if state == C.DuckDBError {
errClose = errAppenderClose
}

Expand Down Expand Up @@ -199,7 +199,7 @@ func (a *Appender) appendDataChunks() error {
}

state = C.duckdb_append_data_chunk(a.duckdbAppender, chunk.data)
if returnState(state) == stateError {
if state == C.DuckDBError {
err = duckdbError(C.duckdb_appender_error(a.duckdbAppender))
break
}
Expand Down
8 changes: 4 additions & 4 deletions arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (a *Arrow) queryArrowSchema(res *C.duckdb_arrow) (*arrow.Schema, error) {
if state := C.duckdb_query_arrow_schema(
*res,
(*C.duckdb_arrow_schema)(unsafe.Pointer(&schema)),
); returnState(state) == stateError {
); state == C.DuckDBError {
return nil, errors.New("duckdb_query_arrow_schema")
}

Expand All @@ -204,7 +204,7 @@ func (a *Arrow) queryArrowArray(res *C.duckdb_arrow, sc *arrow.Schema) (arrow.Re
if state := C.duckdb_query_arrow_array(
*res,
(*C.duckdb_arrow_array)(unsafe.Pointer(&arr)),
); returnState(state) == stateError {
); state == C.DuckDBError {
return nil, errors.New("duckdb_query_arrow_array")
}

Expand All @@ -226,7 +226,7 @@ func (a *Arrow) execute(s *Stmt, args []driver.NamedValue) (*C.duckdb_arrow, err
}

var res C.duckdb_arrow
if state := C.duckdb_execute_prepared_arrow(*s.stmt, &res); returnState(state) == stateError {
if state := C.duckdb_execute_prepared_arrow(*s.stmt, &res); state == C.DuckDBError {
dbErr := C.GoString(C.duckdb_query_arrow_error(res))
C.duckdb_destroy_arrow(&res)
return nil, fmt.Errorf("duckdb_execute_prepared_arrow: %v", dbErr)
Expand Down Expand Up @@ -270,7 +270,7 @@ func (a *Arrow) RegisterView(reader array.RecordReader, name string) (release fu
a.c.duckdbCon,
cName,
(C.duckdb_arrow_stream)(stream),
); returnState(state) == stateError {
); state == C.DuckDBError {
release()
return nil, errors.New("duckdb_arrow_scan")
}
Expand Down
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (c *Conn) prepareExtractedStmt(stmts C.duckdb_extracted_statements, i C.idx
var s C.duckdb_prepared_statement
state := C.duckdb_prepare_extracted_statement(c.duckdbCon, stmts, i, &s)

if returnState(state) == stateError {
if state == C.DuckDBError {
err := getDuckDBError(C.GoString(C.duckdb_prepare_error(s)))
C.duckdb_destroy_prepare(&s)
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error
var outError *C.char
defer C.duckdb_free(unsafe.Pointer(outError))

if state := C.duckdb_open_ext(connStr, &db, config, &outError); returnState(state) == stateError {
if state := C.duckdb_open_ext(connStr, &db, config, &outError); state == C.DuckDBError {
return nil, getError(errConnect, duckdbError(outError))
}

Expand All @@ -83,7 +83,7 @@ func (*Connector) Driver() driver.Driver {

func (c *Connector) Connect(context.Context) (driver.Conn, error) {
var duckdbCon C.duckdb_connection
if state := C.duckdb_connect(c.db, &duckdbCon); returnState(state) == stateError {
if state := C.duckdb_connect(c.db, &duckdbCon); state == C.DuckDBError {
return nil, getError(errConnect, nil)
}

Expand Down Expand Up @@ -114,7 +114,7 @@ func getConnString(dsn string) string {

func prepareConfig(parsedDSN *url.URL) (C.duckdb_config, error) {
var config C.duckdb_config
if state := C.duckdb_create_config(&config); returnState(state) == stateError {
if state := C.duckdb_create_config(&config); state == C.DuckDBError {
C.duckdb_destroy_config(&config)
return nil, getError(errCreateConfig, nil)
}
Expand Down Expand Up @@ -148,7 +148,7 @@ func setConfigOption(config C.duckdb_config, name string, option string) error {
defer C.duckdb_free(unsafe.Pointer(cOption))

state := C.duckdb_set_config(config, cName, cOption)
if returnState(state) == stateError {
if state == C.DuckDBError {
C.duckdb_destroy_config(&config)
return getError(errSetConfig, fmt.Errorf("%s=%s", name, option))
}
Expand Down
11 changes: 0 additions & 11 deletions errors.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
package duckdb

/*
#include <duckdb.h>
*/
import "C"

import (
"errors"
"fmt"
Expand Down Expand Up @@ -276,10 +272,3 @@ func getDuckDBError(errMsg string) error {
Msg: errMsg,
}
}

type returnState C.duckdb_state

const (
stateSuccess returnState = C.DuckDBSuccess
stateError returnState = C.DuckDBError
)
6 changes: 3 additions & 3 deletions scalarUDF.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
con := driverConn.(*Conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, function)
C.duckdb_destroy_scalar_function(&function)
if returnState(state) == stateError {
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreate)
}
return nil
Expand Down Expand Up @@ -103,7 +103,7 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err

state := C.duckdb_add_scalar_function_to_set(set, function)
C.duckdb_destroy_scalar_function(&function)
if returnState(state) == stateError {
if state == C.DuckDBError {
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i))
}
Expand All @@ -114,7 +114,7 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err
con := driverConn.(*Conn)
state := C.duckdb_register_scalar_function_set(con.duckdbCon, set)
C.duckdb_destroy_scalar_function_set(&set)
if returnState(state) == stateError {
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreateSet)
}
return nil
Expand Down
97 changes: 42 additions & 55 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,85 +143,85 @@ func (s *Stmt) Bind(args []driver.NamedValue) error {
return s.bind(args)
}

func (s *Stmt) bindHugeint(val *big.Int, n int) (returnState, error) {
func (s *Stmt) bindHugeint(val *big.Int, n int) (C.duckdb_state, error) {
hugeint, err := hugeIntFromNative(val)
if err != nil {
return stateError, err
return C.DuckDBError, err
}
state := C.duckdb_bind_hugeint(*s.stmt, C.idx_t(n+1), hugeint)
return returnState(state), nil
return state, nil
}

func (s *Stmt) bindString(val string, n int) (returnState, error) {
func (s *Stmt) bindString(val string, n int) (C.duckdb_state, error) {
v := C.CString(val)
state := C.duckdb_bind_varchar(*s.stmt, C.idx_t(n+1), v)
C.duckdb_free(unsafe.Pointer(v))
return returnState(state), nil
return state, nil
}

func (s *Stmt) bindBlob(val []byte, n int) (returnState, error) {
func (s *Stmt) bindBlob(val []byte, n int) (C.duckdb_state, error) {
v := C.CBytes(val)
state := C.duckdb_bind_blob(*s.stmt, C.idx_t(n+1), v, C.uint64_t(len(val)))
C.duckdb_free(unsafe.Pointer(v))
return returnState(state), nil
return state, nil
}

func (s *Stmt) bindInterval(val Interval, n int) (returnState, error) {
func (s *Stmt) bindInterval(val Interval, n int) (C.duckdb_state, error) {
v := C.duckdb_interval{
months: C.int32_t(val.Months),
days: C.int32_t(val.Days),
micros: C.int64_t(val.Micros),
}
state := C.duckdb_bind_interval(*s.stmt, C.idx_t(n+1), v)
return returnState(state), nil
return state, nil
}

func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (returnState, error) {
func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (C.duckdb_state, error) {
ts, err := getCTimestamp(t, val.Value)
if err != nil {
return stateError, err
return C.DuckDBError, err
}
state := C.duckdb_bind_timestamp(*s.stmt, C.idx_t(n+1), ts)
return returnState(state), nil
return state, nil
}

func (s *Stmt) bindDate(val driver.NamedValue, n int) (returnState, error) {
func (s *Stmt) bindDate(val driver.NamedValue, n int) (C.duckdb_state, error) {
date, err := getCDate(val.Value)
if err != nil {
return stateError, err
return C.DuckDBError, err
}
state := C.duckdb_bind_date(*s.stmt, C.idx_t(n+1), date)
return returnState(state), nil
return state, nil
}

func (s *Stmt) bindTime(val driver.NamedValue, t Type, n int) (returnState, error) {
func (s *Stmt) bindTime(val driver.NamedValue, t Type, n int) (C.duckdb_state, error) {
ticks, err := getTimeTicks(val.Value)
if err != nil {
return stateError, err
return C.DuckDBError, err
}

if t == TYPE_TIME {
var ti C.duckdb_time
ti.micros = C.int64_t(ticks)
state := C.duckdb_bind_time(*s.stmt, C.idx_t(n+1), ti)
return returnState(state), nil
return state, nil
}

// TYPE_TIME_TZ: The UTC offset is 0.
ti := C.duckdb_create_time_tz(C.int64_t(ticks), 0)
v := C.duckdb_create_time_tz_value(ti)
state := C.duckdb_bind_value(*s.stmt, C.idx_t(n+1), v)
C.duckdb_destroy_value(&v)
return returnState(state), nil
return state, nil
}

func (s *Stmt) bindComplexValue(val driver.NamedValue, n int) (returnState, error) {
func (s *Stmt) bindComplexValue(val driver.NamedValue, n int) (C.duckdb_state, error) {
t, err := s.ParamType(n + 1)
if err != nil {
return stateError, err
return C.DuckDBError, err
}
if name, ok := unsupportedTypeToStringMap[t]; ok {
return stateError, addIndexToError(unsupportedTypeError(name), n+1)
return C.DuckDBError, addIndexToError(unsupportedTypeError(name), n+1)
}

switch t {
Expand All @@ -237,64 +237,51 @@ func (s *Stmt) bindComplexValue(val driver.NamedValue, n int) (returnState, erro
// FIXME: for other types: duckdb_param_logical_type once available, then create duckdb_value + duckdb_bind_value
// FIXME: for other types: implement NamedValueChecker to support custom data types.
name := typeToStringMap[t]
return stateError, addIndexToError(unsupportedTypeError(name), n+1)
return C.DuckDBError, addIndexToError(unsupportedTypeError(name), n+1)
}
return stateError, addIndexToError(unsupportedTypeError(unknownTypeErrMsg), n+1)
return C.DuckDBError, addIndexToError(unsupportedTypeError(unknownTypeErrMsg), n+1)
}

func (s *Stmt) bindValue(val driver.NamedValue, n int) (returnState, error) {
func (s *Stmt) bindValue(val driver.NamedValue, n int) (C.duckdb_state, error) {
switch v := val.Value.(type) {
case bool:
state := C.duckdb_bind_boolean(*s.stmt, C.idx_t(n+1), C.bool(v))
return returnState(state), nil
return C.duckdb_bind_boolean(*s.stmt, C.idx_t(n+1), C.bool(v)), nil
case int8:
state := C.duckdb_bind_int8(*s.stmt, C.idx_t(n+1), C.int8_t(v))
return returnState(state), nil
return C.duckdb_bind_int8(*s.stmt, C.idx_t(n+1), C.int8_t(v)), nil
case int16:
state := C.duckdb_bind_int16(*s.stmt, C.idx_t(n+1), C.int16_t(v))
return returnState(state), nil
return C.duckdb_bind_int16(*s.stmt, C.idx_t(n+1), C.int16_t(v)), nil
case int32:
state := C.duckdb_bind_int32(*s.stmt, C.idx_t(n+1), C.int32_t(v))
return returnState(state), nil
return C.duckdb_bind_int32(*s.stmt, C.idx_t(n+1), C.int32_t(v)), nil
case int64:
state := C.duckdb_bind_int64(*s.stmt, C.idx_t(n+1), C.int64_t(v))
return returnState(state), nil
return C.duckdb_bind_int64(*s.stmt, C.idx_t(n+1), C.int64_t(v)), nil
case int:
state := C.duckdb_bind_int64(*s.stmt, C.idx_t(n+1), C.int64_t(v))
return returnState(state), nil
return C.duckdb_bind_int64(*s.stmt, C.idx_t(n+1), C.int64_t(v)), nil
case *big.Int:
return s.bindHugeint(v, n)
case Decimal:
// FIXME: implement NamedValueChecker to support custom data types.
name := typeToStringMap[TYPE_DECIMAL]
return stateError, addIndexToError(unsupportedTypeError(name), n+1)
return C.DuckDBError, addIndexToError(unsupportedTypeError(name), n+1)
case uint8:
state := C.duckdb_bind_uint8(*s.stmt, C.idx_t(n+1), C.uint8_t(v))
return returnState(state), nil
return C.duckdb_bind_uint8(*s.stmt, C.idx_t(n+1), C.uint8_t(v)), nil
case uint16:
state := C.duckdb_bind_uint16(*s.stmt, C.idx_t(n+1), C.uint16_t(v))
return returnState(state), nil
return C.duckdb_bind_uint16(*s.stmt, C.idx_t(n+1), C.uint16_t(v)), nil
case uint32:
state := C.duckdb_bind_uint32(*s.stmt, C.idx_t(n+1), C.uint32_t(v))
return returnState(state), nil
return C.duckdb_bind_uint32(*s.stmt, C.idx_t(n+1), C.uint32_t(v)), nil
case uint64:
state := C.duckdb_bind_uint64(*s.stmt, C.idx_t(n+1), C.uint64_t(v))
return returnState(state), nil
return C.duckdb_bind_uint64(*s.stmt, C.idx_t(n+1), C.uint64_t(v)), nil
case float32:
state := C.duckdb_bind_float(*s.stmt, C.idx_t(n+1), C.float(v))
return returnState(state), nil
return C.duckdb_bind_float(*s.stmt, C.idx_t(n+1), C.float(v)), nil
case float64:
state := C.duckdb_bind_double(*s.stmt, C.idx_t(n+1), C.double(v))
return returnState(state), nil
return C.duckdb_bind_double(*s.stmt, C.idx_t(n+1), C.double(v)), nil
case string:
return s.bindString(v, n)
case []byte:
return s.bindBlob(v, n)
case Interval:
return s.bindInterval(v, n)
case nil:
state := C.duckdb_bind_null(*s.stmt, C.idx_t(n+1))
return returnState(state), nil
return C.duckdb_bind_null(*s.stmt, C.idx_t(n+1)), nil
}
return s.bindComplexValue(val, n)
}
Expand Down Expand Up @@ -328,7 +315,7 @@ func (s *Stmt) bind(args []driver.NamedValue) error {
}

state, err := s.bindValue(arg, i)
if state == stateError {
if state == C.DuckDBError {
// TODO: more info might be interesting, do we set an error in the statement?
return errors.Join(errCouldNotBind, err)
}
Expand Down Expand Up @@ -435,7 +422,7 @@ func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb

func (s *Stmt) executeBound(ctx context.Context) (*C.duckdb_result, error) {
var pendingRes C.duckdb_pending_result
if state := C.duckdb_pending_prepared(*s.stmt, &pendingRes); returnState(state) == stateError {
if state := C.duckdb_pending_prepared(*s.stmt, &pendingRes); state == C.DuckDBError {
dbErr := getDuckDBError(C.GoString(C.duckdb_pending_error(pendingRes)))
C.duckdb_destroy_pending(&pendingRes)
return nil, dbErr
Expand Down Expand Up @@ -463,7 +450,7 @@ func (s *Stmt) executeBound(ctx context.Context) (*C.duckdb_result, error) {
// sometimes the bg goroutine is not scheduled immediately and by that time if another query is running on this connection
// it can cancel that query so need to wait for it to finish as well
<-bgDoneCh
if returnState(state) == stateError {
if state == C.DuckDBError {
if ctx.Err() != nil {
C.duckdb_destroy_result(&res)
return nil, ctx.Err()
Expand Down
Loading

0 comments on commit 0bd3ee5

Please sign in to comment.