Skip to content

Commit

Permalink
Merge pull request #333 from apecloud/expose-stmt-bind
Browse files Browse the repository at this point in the history
Expose `Bind` and `(Query|Exec)Bound` on `Stmt` for advanced usage
  • Loading branch information
taniabogatsch authored Dec 27, 2024
2 parents 0614b2d + b8a948d commit 99e7b9c
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 3 deletions.
3 changes: 3 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ var (
errPrepare = errors.New("could not prepare query")
errMissingPrepareContext = errors.New("missing context for multi-statement query: try using PrepareContext")
errEmptyQuery = errors.New("empty query")
errCouldNotBind = errors.New("could not bind parameter")
errActiveRows = errors.New("ExecContext or QueryContext with active Rows")
errNotBound = errors.New("parameters have not been bound")
errBeginTx = errors.New("could not begin transaction")
errMultipleTx = errors.New("multiple transactions")
errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported")
Expand Down
66 changes: 63 additions & 3 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type Stmt struct {
c *Conn
stmt *C.duckdb_prepared_statement
closeOnRowsClose bool
bound bool
closed bool
rows bool
}
Expand Down Expand Up @@ -131,6 +132,18 @@ func (s *Stmt) StatementType() (StmtType, error) {
return StmtType(C.duckdb_prepared_statement_type(*s.stmt)), nil
}

// Bind binds the parameters to the statement.
// WARNING: This is a low-level API and should be used with caution.
func (s *Stmt) Bind(args []driver.NamedValue) error {
if s.closed {
return errors.Join(errCouldNotBind, errClosedStmt)
}
if s.stmt == nil {
return errors.Join(errCouldNotBind, errUninitializedStmt)
}
return s.bind(args)
}

func (s *Stmt) bind(args []driver.NamedValue) error {
if s.NumInput() > len(args) {
return fmt.Errorf("incorrect argument count for command: have %d want %d", len(args), s.NumInput())
Expand Down Expand Up @@ -258,6 +271,7 @@ func (s *Stmt) bind(args []driver.NamedValue) error {
}
}

s.bound = true
return nil
}

Expand All @@ -279,6 +293,30 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv
return &result{ra}, nil
}

// ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE.
// It can only be used after Bind has been called.
// WARNING: This is a low-level API and should be used with caution.
func (s *Stmt) ExecBound(ctx context.Context) (driver.Result, error) {
if s.closed {
return nil, errClosedCon
}
if s.rows {
return nil, errActiveRows
}
if !s.bound {
return nil, errNotBound
}

res, err := s.executeBound(ctx)
if err != nil {
return nil, err
}
defer C.duckdb_destroy_result(res)

ra := int64(C.duckdb_value_int64(res, 0, 0))
return &result{ra}, nil
}

// Deprecated: Use QueryContext instead.
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), argsToNamedArgs(args))
Expand All @@ -295,6 +333,28 @@ func (s *Stmt) QueryContext(ctx context.Context, nargs []driver.NamedValue) (dri
return newRowsWithStmt(*res, s), nil
}

// QueryBound executes a bound query that may return rows, such as a SELECT.
// It can only be used after Bind has been called.
// WARNING: This is a low-level API and should be used with caution.
func (s *Stmt) QueryBound(ctx context.Context) (driver.Rows, error) {
if s.closed {
return nil, errClosedCon
}
if s.rows {
return nil, errActiveRows
}
if !s.bound {
return nil, errNotBound
}

res, err := s.executeBound(ctx)
if err != nil {
return nil, err
}
s.rows = true
return newRowsWithStmt(*res, s), nil
}

// This method executes the query in steps and checks if context is cancelled before executing each step.
// It uses Pending Result Interface C APIs to achieve this. Reference - https://duckdb.org/docs/api/c/api#pending-result-interface
func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb_result, error) {
Expand All @@ -304,11 +364,13 @@ func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb
if s.rows {
panic("database/sql/driver: misuse of duckdb driver: ExecContext or QueryContext with active Rows")
}

if err := s.bind(args); err != nil {
return nil, err
}
return s.executeBound(ctx)
}

func (s *Stmt) executeBound(ctx context.Context) (*C.duckdb_result, error) {
var pendingRes C.duckdb_pending_result
if state := C.duckdb_pending_prepared(*s.stmt, &pendingRes); state == C.DuckDBError {
dbErr := getDuckDBError(C.GoString(C.duckdb_pending_error(pendingRes)))
Expand Down Expand Up @@ -360,5 +422,3 @@ func argsToNamedArgs(values []driver.Value) []driver.NamedValue {
}
return args
}

var errCouldNotBind = errors.New("could not bind parameter")
63 changes: 63 additions & 0 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package duckdb
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"testing"

Expand Down Expand Up @@ -56,6 +57,27 @@ func TestPrepareQuery(t *testing.T) {
require.ErrorContains(t, err, paramIndexErrMsg)
require.Equal(t, TYPE_INVALID, paramType)

rows, err := stmt.QueryBound(context.Background())
require.Nil(t, rows)
require.ErrorIs(t, err, errNotBound)

err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}})
require.NoError(t, err)

rows, err = stmt.QueryBound(context.Background())
require.NoError(t, err)
require.NotNil(t, rows)

badRows, err := stmt.QueryBound(context.Background())
require.ErrorIs(t, err, errActiveRows)
require.Nil(t, badRows)

badResults, err := stmt.ExecBound(context.Background())
require.ErrorIs(t, err, errActiveRows)
require.Nil(t, badResults)

require.NoError(t, rows.Close())

require.NoError(t, stmt.Close())

stmtType, err = stmt.StatementType()
Expand All @@ -66,6 +88,10 @@ func TestPrepareQuery(t *testing.T) {
require.ErrorIs(t, err, errClosedStmt)
require.Equal(t, TYPE_INVALID, paramType)

err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}})
require.ErrorIs(t, err, errCouldNotBind)
require.ErrorIs(t, err, errClosedStmt)

return nil
})
require.NoError(t, err)
Expand Down Expand Up @@ -146,6 +172,17 @@ func TestPrepareQueryPositional(t *testing.T) {
require.ErrorContains(t, err, paramIndexErrMsg)
require.Equal(t, TYPE_INVALID, paramType)

result, err := stmt.ExecBound(context.Background())
require.Nil(t, result)
require.ErrorIs(t, err, errNotBound)

err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}, {Ordinal: 2, Value: "hello"}})
require.NoError(t, err)

result, err = stmt.ExecBound(context.Background())
require.NoError(t, err)
require.NotNil(t, result)

require.NoError(t, stmt.Close())

stmtType, err = stmt.StatementType()
Expand All @@ -160,6 +197,10 @@ func TestPrepareQueryPositional(t *testing.T) {
require.ErrorIs(t, err, errClosedStmt)
require.Equal(t, TYPE_INVALID, paramType)

err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}, {Ordinal: 2, Value: "hello"}})
require.ErrorIs(t, err, errCouldNotBind)
require.ErrorIs(t, err, errClosedStmt)

return nil
})
require.NoError(t, err)
Expand Down Expand Up @@ -245,6 +286,17 @@ func TestPrepareQueryNamed(t *testing.T) {
require.ErrorContains(t, err, paramIndexErrMsg)
require.Equal(t, TYPE_INVALID, paramType)

result, err := stmt.ExecBound(context.Background())
require.Nil(t, result)
require.ErrorIs(t, err, errNotBound)

err = stmt.Bind([]driver.NamedValue{{Name: "bar", Value: "hello"}, {Name: "baz", Value: 0}})
require.NoError(t, err)

result, err = stmt.ExecBound(context.Background())
require.NoError(t, err)
require.NotNil(t, result)

require.NoError(t, stmt.Close())

stmtType, err = stmt.StatementType()
Expand All @@ -259,6 +311,10 @@ func TestPrepareQueryNamed(t *testing.T) {
require.ErrorIs(t, err, errClosedStmt)
require.Equal(t, TYPE_INVALID, paramType)

err = stmt.Bind([]driver.NamedValue{{Name: "bar", Value: "hello"}, {Name: "baz", Value: 0}})
require.ErrorIs(t, err, errCouldNotBind)
require.ErrorIs(t, err, errClosedStmt)

return nil
})
require.NoError(t, err)
Expand All @@ -280,6 +336,13 @@ func TestUninitializedStmt(t *testing.T) {
paramName, err := stmt.ParamName(1)
require.ErrorIs(t, err, errUninitializedStmt)
require.Equal(t, "", paramName)

err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}})
require.ErrorIs(t, err, errCouldNotBind)
require.ErrorIs(t, err, errUninitializedStmt)

_, err = stmt.ExecBound(context.Background())
require.ErrorIs(t, err, errNotBound)
}

func TestPrepareWithError(t *testing.T) {
Expand Down

0 comments on commit 99e7b9c

Please sign in to comment.