Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func paramIndexError(idx int, max uint64) error {
return fmt.Errorf("%s: %d is out of range [1, %d]", paramIndexErrMsg, idx, max)
}

func columnIndexError(idx int, max uint64) error {
return fmt.Errorf("%s: %d is out of range [0, %d)", columnIndexErrMsg, idx, max)
}

func unsupportedTypeError(name string) error {
return fmt.Errorf("%s: %s", unsupportedTypeErrMsg, name)
}
Expand Down Expand Up @@ -81,6 +85,7 @@ const (
interfaceIsNilErrMsg = "interface is nil"
duplicateNameErrMsg = "duplicate name"
paramIndexErrMsg = "invalid parameter index"
columnIndexErrMsg = "invalid column index"
)

var (
Expand Down
110 changes: 90 additions & 20 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ type Stmt struct {
rows bool
}

// checkState checks if the statement is closed or uninitialized.
func (s *Stmt) checkState() error {
if s.closed {
return errClosedStmt
}
if s.preparedStmt == nil {
return errUninitializedStmt
}
return nil
}

// Close the statement.
// Implements the driver.Stmt interface.
func (s *Stmt) Close() error {
Expand Down Expand Up @@ -81,11 +92,8 @@ func (s *Stmt) NumInput() int {

// ParamName returns the name of the parameter at the given index (1-based).
func (s *Stmt) ParamName(n int) (string, error) {
if s.closed {
return "", errClosedStmt
}
if s.preparedStmt == nil {
return "", errUninitializedStmt
if err := s.checkState(); err != nil {
return "", err
}

count := mapping.NParams(*s.preparedStmt)
Expand All @@ -99,11 +107,8 @@ func (s *Stmt) ParamName(n int) (string, error) {

// ParamType returns the expected type of the parameter at the given index (1-based).
func (s *Stmt) ParamType(n int) (Type, error) {
if s.closed {
return TYPE_INVALID, errClosedStmt
}
if s.preparedStmt == nil {
return TYPE_INVALID, errUninitializedStmt
if err := s.checkState(); err != nil {
return TYPE_INVALID, err
}

count := mapping.NParams(*s.preparedStmt)
Expand All @@ -117,11 +122,8 @@ func (s *Stmt) ParamType(n int) (Type, error) {

func (s *Stmt) paramLogicalType(n int) (mapping.LogicalType, error) {
var lt mapping.LogicalType
if s.closed {
return lt, errClosedStmt
}
if s.preparedStmt == nil {
return lt, errUninitializedStmt
if err := s.checkState(); err != nil {
return lt, err
}

count := mapping.NParams(*s.preparedStmt)
Expand All @@ -134,11 +136,8 @@ func (s *Stmt) paramLogicalType(n int) (mapping.LogicalType, error) {

// StatementType returns the type of the statement.
func (s *Stmt) StatementType() (StmtType, error) {
if s.closed {
return STATEMENT_TYPE_INVALID, errClosedStmt
}
if s.preparedStmt == nil {
return STATEMENT_TYPE_INVALID, errUninitializedStmt
if err := s.checkState(); err != nil {
return STATEMENT_TYPE_INVALID, err
}

t := mapping.PreparedStatementType(*s.preparedStmt)
Expand Down Expand Up @@ -470,6 +469,77 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv
return &result{ra}, nil
}

// ColumnCount returns the number of columns that will be returned by executing the prepared statement.
// If any of the column types is invalid (which can happen when the type is ambiguous), the result will be 1.
// Returns an error if the statement is closed or uninitialized.
func (s *Stmt) ColumnCount() (int, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests here? For both error paths, and the happy path?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return uint here? Then, we can also take uint in the below functions and can skip the n < 0 check?

if err := s.checkState(); err != nil {
return 0, err
}

count := mapping.PreparedStatementColumnCount(*s.preparedStmt)
return int(count), nil
}

// ColumnType returns the type of the column at the given index (0-based).
// Returns TYPE_INVALID and a columnIndexError if the column is out of range.
// Returns an error if the statement is closed or uninitialized.
func (s *Stmt) ColumnType(n int) (Type, error) {
if err := s.checkState(); err != nil {
return TYPE_INVALID, err
}

count := mapping.PreparedStatementColumnCount(*s.preparedStmt)
if n < 0 || n >= int(count) {
return TYPE_INVALID, getError(errAPI, columnIndexError(n, uint64(count)))
}

t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(n))
return t, nil
}

// ColumnTypeInfo returns the TypeInfo of the column at the given index (0-based).
// TypeInfo provides detailed type information including nested structures, DECIMAL precision,
// ENUM values, etc.
// Returns a TypeInfo with internalType TYPE_INVALID and a columnIndexError if the column is out of range.
// Returns an error if the statement is closed or uninitialized.
func (s *Stmt) ColumnTypeInfo(n int) (TypeInfo, error) {
if err := s.checkState(); err != nil {
return nil, err
}

count := mapping.PreparedStatementColumnCount(*s.preparedStmt)
if n < 0 || n >= int(count) {
return nil, getError(errAPI, columnIndexError(n, uint64(count)))
}

lt := mapping.PreparedStatementColumnLogicalType(*s.preparedStmt, mapping.IdxT(n))
defer mapping.DestroyLogicalType(&lt)

return NewTypeInfoFromLogicalType(lt)
}

// ColumnName returns the name of the column at the given index (0-based).
// Returns "" and a columnIndexError if the column is out of range.
// Returns an error if the statement is closed or uninitialized.
func (s *Stmt) ColumnName(n int) (string, error) {
if err := s.checkState(); err != nil {
return "", err
}

count := mapping.PreparedStatementColumnCount(*s.preparedStmt)
if n < 0 || n >= int(count) {
return "", getError(errAPI, columnIndexError(n, uint64(count)))
}

name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(n))
// C API returns nullptr for out-of-range indices
if name == "" {
return "", nil
}
return name, nil
Comment on lines +537 to +540
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these paths the same?

}

// ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE.
// It can only be used after Bind has been called.
// WARNING: This is a low-level API and should be used with caution.
Expand Down
Loading
Loading