From 101b3386e754e6e8e5138738ce7327a93d6baae3 Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Fri, 3 Oct 2025 13:53:00 -0400 Subject: [PATCH 01/14] lint --- statement.go | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/statement.go b/statement.go index 54c413a4..c68fa0c0 100644 --- a/statement.go +++ b/statement.go @@ -470,6 +470,82 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv return &result{ra}, nil } +func (s *Stmt) ColumnCount() (int, error) { + if s.closed { + return 0, errClosedStmt + } + if s.preparedStmt == nil { + return 0, errUninitializedStmt + } + + count := mapping.PreparedStatementColumnCount(*s.preparedStmt) + return int(count), nil +} + +func (s *Stmt) GetColumnTypes() ([]Type, error) { + n, err := s.ColumnCount() + if err != nil { + return nil, err + } + + types := make([]Type, n) + for i := range n { + t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(i+1)) + types[i] = t + } + return types, nil +} + +func (s *Stmt) GetColumnNames() ([]string, error) { + if s.closed { + return nil, errClosedStmt + } + if s.preparedStmt == nil { + return nil, errUninitializedStmt + } + + n, err := s.ColumnCount() + if err != nil { + return nil, err + } + + names := make([]string, n) + for i := range n { + name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(i+1)) + names[i] = name + } + return names, nil +} + +func (s *Stmt) GetParameterTypes() ([]Type, error) { + if s.closed { + return nil, errClosedStmt + } + if s.preparedStmt == nil { + return nil, errUninitializedStmt + } + + count := mapping.NParams(*s.preparedStmt) + types := make([]Type, count) + for i := mapping.IdxT(0); i < count; i++ { + t := mapping.ParamType(*s.preparedStmt, i+1) + types[i] = t + } + return types, nil +} + +func (s *Stmt) NParams() (int, error) { + if s.closed { + return 0, errClosedStmt + } + if s.preparedStmt == nil { + return 0, errUninitializedStmt + } + + count := mapping.NParams(*s.preparedStmt) + return int(count), nil +} + // ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE. // It can only be used after Bind has been called. // WARNING: This is a low-level API and should be used with caution. From 290bc976a12c9929a3ab4b9f297191fc35627798 Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Tue, 14 Oct 2025 15:10:12 -0400 Subject: [PATCH 02/14] fix indexing off by 1 :( --- statement.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/statement.go b/statement.go index c68fa0c0..24116fa0 100644 --- a/statement.go +++ b/statement.go @@ -490,7 +490,7 @@ func (s *Stmt) GetColumnTypes() ([]Type, error) { types := make([]Type, n) for i := range n { - t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(i+1)) + t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(i)) types[i] = t } return types, nil @@ -511,7 +511,7 @@ func (s *Stmt) GetColumnNames() ([]string, error) { names := make([]string, n) for i := range n { - name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(i+1)) + name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(i)) names[i] = name } return names, nil From ec3b96a48ec137fc58700f6cc8a7e07f18f70add Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Wed, 29 Oct 2025 17:54:36 -0400 Subject: [PATCH 03/14] add get column logical types --- statement.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/statement.go b/statement.go index 24116fa0..64cf9f3f 100644 --- a/statement.go +++ b/statement.go @@ -496,6 +496,20 @@ func (s *Stmt) GetColumnTypes() ([]Type, error) { return types, nil } +func (s *Stmt) GetColumnLogicalTypes() ([]mapping.LogicalType, error) { + n, err := s.ColumnCount() + if err != nil { + return nil, err + } + + types := make([]mapping.LogicalType, n) + for i := range n { + t := mapping.PreparedStatementColumnLogicalType(*s.preparedStmt, mapping.IdxT(i)) + types[i] = t + } + return types, nil +} + func (s *Stmt) GetColumnNames() ([]string, error) { if s.closed { return nil, errClosedStmt From a0468f0ed9b37436588fa6ee94926bf59b94451a Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Wed, 29 Oct 2025 18:28:07 -0400 Subject: [PATCH 04/14] add param logical types as well --- statement.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/statement.go b/statement.go index 64cf9f3f..d2c00081 100644 --- a/statement.go +++ b/statement.go @@ -548,6 +548,23 @@ func (s *Stmt) GetParameterTypes() ([]Type, error) { return types, nil } +func (s *Stmt) GetParameterLogicalTypes() ([]mapping.LogicalType, error) { + if s.closed { + return nil, errClosedStmt + } + if s.preparedStmt == nil { + return nil, errUninitializedStmt + } + + count := mapping.NParams(*s.preparedStmt) + types := make([]mapping.LogicalType, count) + for i := mapping.IdxT(0); i < count; i++ { + t := mapping.ParamLogicalType(*s.preparedStmt, i+1) + types[i] = t + } + return types, nil +} + func (s *Stmt) NParams() (int, error) { if s.closed { return 0, errClosedStmt From 8187ed3ca578b1f19897e3b09a5bde209b54806e Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Tue, 4 Nov 2025 17:00:57 -0500 Subject: [PATCH 05/14] clean up --- statement.go | 67 ++++++++++++++-------------------------------------- 1 file changed, 18 insertions(+), 49 deletions(-) diff --git a/statement.go b/statement.go index d2c00081..98abd046 100644 --- a/statement.go +++ b/statement.go @@ -470,6 +470,9 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv return &result{ra}, nil } +// ColumnCount returns the number of columns that will be returned by executing the prepared statement. +// If any of the column types is invalid, the result will be 1. +// Returns an error if the statement is closed or uninitialized. func (s *Stmt) ColumnCount() (int, error) { if s.closed { return 0, errClosedStmt @@ -482,7 +485,9 @@ func (s *Stmt) ColumnCount() (int, error) { return int(count), nil } -func (s *Stmt) GetColumnTypes() ([]Type, error) { +// ColumnTypes returns the types of all columns in the result set of the prepared statement. +// Returns an error if the statement is closed or uninitialized. +func (s *Stmt) ColumnTypes() ([]Type, error) { n, err := s.ColumnCount() if err != nil { return nil, err @@ -496,7 +501,13 @@ func (s *Stmt) GetColumnTypes() ([]Type, error) { return types, nil } -func (s *Stmt) GetColumnLogicalTypes() ([]mapping.LogicalType, error) { +// ColumnLogicalTypes returns the logical types of all columns in the result set of the prepared statement. +// Logical types provide more detailed type information than the basic Type, including custom types, +// nested structures, and type modifiers. The returned slice has one LogicalType entry for each column. +// Note: The caller is responsible for calling mapping.DestroyLogicalType on each returned LogicalType +// during cleanup. +// Returns an error if the statement is closed or uninitialized. +func (s *Stmt) ColumnLogicalTypes() ([]mapping.LogicalType, error) { n, err := s.ColumnCount() if err != nil { return nil, err @@ -510,7 +521,11 @@ func (s *Stmt) GetColumnLogicalTypes() ([]mapping.LogicalType, error) { return types, nil } -func (s *Stmt) GetColumnNames() ([]string, error) { +// ColumnNames returns the names of all columns in the result set of the prepared statement. +// The returned slice has one string entry for each column, indexed starting from 0. +// For statements that do not return a result set, this returns an empty slice. +// Returns an error if the statement is closed or uninitialized. +func (s *Stmt) ColumnNames() ([]string, error) { if s.closed { return nil, errClosedStmt } @@ -531,52 +546,6 @@ func (s *Stmt) GetColumnNames() ([]string, error) { return names, nil } -func (s *Stmt) GetParameterTypes() ([]Type, error) { - if s.closed { - return nil, errClosedStmt - } - if s.preparedStmt == nil { - return nil, errUninitializedStmt - } - - count := mapping.NParams(*s.preparedStmt) - types := make([]Type, count) - for i := mapping.IdxT(0); i < count; i++ { - t := mapping.ParamType(*s.preparedStmt, i+1) - types[i] = t - } - return types, nil -} - -func (s *Stmt) GetParameterLogicalTypes() ([]mapping.LogicalType, error) { - if s.closed { - return nil, errClosedStmt - } - if s.preparedStmt == nil { - return nil, errUninitializedStmt - } - - count := mapping.NParams(*s.preparedStmt) - types := make([]mapping.LogicalType, count) - for i := mapping.IdxT(0); i < count; i++ { - t := mapping.ParamLogicalType(*s.preparedStmt, i+1) - types[i] = t - } - return types, nil -} - -func (s *Stmt) NParams() (int, error) { - if s.closed { - return 0, errClosedStmt - } - if s.preparedStmt == nil { - return 0, errUninitializedStmt - } - - count := mapping.NParams(*s.preparedStmt) - return int(count), nil -} - // ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE. // It can only be used after Bind has been called. // WARNING: This is a low-level API and should be used with caution. From 69b9fdf09eb66b5f248c17c992e3ae65efb8b0f6 Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Tue, 4 Nov 2025 17:03:03 -0500 Subject: [PATCH 06/14] clarification --- statement.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/statement.go b/statement.go index 98abd046..deee6c08 100644 --- a/statement.go +++ b/statement.go @@ -471,7 +471,7 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv } // ColumnCount returns the number of columns that will be returned by executing the prepared statement. -// If any of the column types is invalid, the result will be 1. +// If any of the column types is invalid (which can happen when the type is ambiguous), the result will be 1. // Returns an error if the statement is closed or uninitialized. func (s *Stmt) ColumnCount() (int, error) { if s.closed { From 085390388708be109909d1d7f2b8824c34982298 Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Wed, 5 Nov 2025 23:37:02 -0500 Subject: [PATCH 07/14] update the prepared statement metadata apis, augment typeinfo and use typeinfo --- statement.go | 75 +++---- statement_test.go | 542 ++++++++++++++++++++++++++++++++++++++++++++++ type_info.go | 268 +++++++++++++++++++++++ type_info_test.go | 449 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 1292 insertions(+), 42 deletions(-) diff --git a/statement.go b/statement.go index deee6c08..4e3f181b 100644 --- a/statement.go +++ b/statement.go @@ -485,47 +485,26 @@ func (s *Stmt) ColumnCount() (int, error) { return int(count), nil } -// ColumnTypes returns the types of all columns in the result set of the prepared statement. +// ColumnType returns the type of the column at the given index (0-based). +// Returns TYPE_INVALID if the column is out of range. // Returns an error if the statement is closed or uninitialized. -func (s *Stmt) ColumnTypes() ([]Type, error) { - n, err := s.ColumnCount() - if err != nil { - return nil, err - } - - types := make([]Type, n) - for i := range n { - t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(i)) - types[i] = t +func (s *Stmt) ColumnType(idx int) (Type, error) { + if s.closed { + return TYPE_INVALID, errClosedStmt } - return types, nil -} - -// ColumnLogicalTypes returns the logical types of all columns in the result set of the prepared statement. -// Logical types provide more detailed type information than the basic Type, including custom types, -// nested structures, and type modifiers. The returned slice has one LogicalType entry for each column. -// Note: The caller is responsible for calling mapping.DestroyLogicalType on each returned LogicalType -// during cleanup. -// Returns an error if the statement is closed or uninitialized. -func (s *Stmt) ColumnLogicalTypes() ([]mapping.LogicalType, error) { - n, err := s.ColumnCount() - if err != nil { - return nil, err + if s.preparedStmt == nil { + return TYPE_INVALID, errUninitializedStmt } - types := make([]mapping.LogicalType, n) - for i := range n { - t := mapping.PreparedStatementColumnLogicalType(*s.preparedStmt, mapping.IdxT(i)) - types[i] = t - } - return types, nil + t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(idx)) + return t, nil } -// ColumnNames returns the names of all columns in the result set of the prepared statement. -// The returned slice has one string entry for each column, indexed starting from 0. -// For statements that do not return a result set, this returns an empty slice. +// ColumnTypeInfo returns the TypeInfo of the column at the given index (0-based). +// TypeInfo provides detailed type information including nested structures, DECIMAL precision, +// ENUM values, etc. For out-of-range indices, returns an error indicating TYPE_INVALID. // Returns an error if the statement is closed or uninitialized. -func (s *Stmt) ColumnNames() ([]string, error) { +func (s *Stmt) ColumnTypeInfo(idx int) (TypeInfo, error) { if s.closed { return nil, errClosedStmt } @@ -533,17 +512,29 @@ func (s *Stmt) ColumnNames() ([]string, error) { return nil, errUninitializedStmt } - n, err := s.ColumnCount() - if err != nil { - return nil, err + lt := mapping.PreparedStatementColumnLogicalType(*s.preparedStmt, mapping.IdxT(idx)) + defer mapping.DestroyLogicalType(<) + + return NewTypeInfoFromLogicalType(lt) +} + +// ColumnName returns the name of the column at the given index (0-based). +// Returns an empty string if the column is out of range. +// Returns an error if the statement is closed or uninitialized. +func (s *Stmt) ColumnName(idx int) (string, error) { + if s.closed { + return "", errClosedStmt + } + if s.preparedStmt == nil { + return "", errUninitializedStmt } - names := make([]string, n) - for i := range n { - name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(i)) - names[i] = name + name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(idx)) + // C API returns nullptr for out-of-range indices + if name == "" { + return "", nil } - return names, nil + return name, nil } // ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE. diff --git a/statement_test.go b/statement_test.go index 6cdf3795..75e67fca 100644 --- a/statement_test.go +++ b/statement_test.go @@ -61,6 +61,39 @@ func TestPrepareQuery(t *testing.T) { require.ErrorContains(t, innerErr, paramIndexErrMsg) require.Equal(t, TYPE_INVALID, paramType) + // Test column methods for SELECT * FROM foo + columnCount, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 2, columnCount) // bar and baz columns + + // Test column names + colName, innerErr := stmt.ColumnName(0) + require.NoError(t, innerErr) + require.Equal(t, "bar", colName) + + colName, innerErr = stmt.ColumnName(1) + require.NoError(t, innerErr) + require.Equal(t, "baz", colName) + + // Test out of bounds - should return empty string + colName, innerErr = stmt.ColumnName(2) + require.NoError(t, innerErr) + require.Equal(t, "", colName) + + // Test column types + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_VARCHAR, colType) + + colType, innerErr = stmt.ColumnType(1) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INTEGER, colType) + + // Test out of bounds - should return TYPE_INVALID + colType, innerErr = stmt.ColumnType(2) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + r, innerErr := stmt.QueryBound(context.Background()) require.Nil(t, r) require.ErrorIs(t, innerErr, errNotBound) @@ -92,6 +125,16 @@ func TestPrepareQuery(t *testing.T) { require.ErrorIs(t, innerErr, errClosedStmt) require.Equal(t, TYPE_INVALID, paramType) + // Test column methods on closed statement + _, innerErr = stmt.ColumnCount() + require.ErrorIs(t, innerErr, errClosedStmt) + + _, innerErr = stmt.ColumnName(0) + require.ErrorIs(t, innerErr, errClosedStmt) + + _, innerErr = stmt.ColumnType(0) + require.ErrorIs(t, innerErr, errClosedStmt) + innerErr = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}}) require.ErrorIs(t, innerErr, errCouldNotBind) require.ErrorIs(t, innerErr, errClosedStmt) @@ -128,6 +171,41 @@ func TestPrepareQueryPositional(t *testing.T) { closeRowsWrapper(t, res) closePreparedWrapper(t, prepared) + // Test column methods for positional SELECT + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT $1 AS first_param, $2 AS second_param`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + // Test column count + columnCount, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 2, columnCount) + + // Test column names + colName, innerErr := stmt.ColumnName(0) + require.NoError(t, innerErr) + require.Equal(t, "first_param", colName) + + colName, innerErr = stmt.ColumnName(1) + require.NoError(t, innerErr) + require.Equal(t, "second_param", colName) + + // Test column types - should be TYPE_INVALID for unresolved parameter types + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) // Type cannot be resolved from parameters + + colType, innerErr = stmt.ColumnType(1) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) // Type cannot be resolved from parameters + + return nil + }) + require.NoError(t, err) + // Access the raw connection and statement. err = conn.Raw(func(driverConn any) error { innerConn := driverConn.(*Conn) @@ -171,6 +249,20 @@ func TestPrepareQueryPositional(t *testing.T) { require.ErrorContains(t, innerErr, paramIndexErrMsg) require.Equal(t, TYPE_INVALID, paramType) + // Test column methods for UPDATE statement (should have no columns) + columnCount, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 0, columnCount) // UPDATE doesn't return columns + + // Test out of bounds access - should return empty/invalid for UPDATE with no columns + colName, innerErr := stmt.ColumnName(0) + require.NoError(t, innerErr) + require.Equal(t, "", colName) + + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + r, innerErr := stmt.ExecBound(context.Background()) require.Nil(t, r) require.ErrorIs(t, innerErr, errNotBound) @@ -330,6 +422,19 @@ func TestUninitializedStmt(t *testing.T) { require.ErrorIs(t, err, errCouldNotBind) require.ErrorIs(t, err, errUninitializedStmt) + // Test column methods on uninitialized statement + _, err = stmt.ColumnCount() + require.ErrorIs(t, err, errUninitializedStmt) + + _, err = stmt.ColumnName(0) + require.ErrorIs(t, err, errUninitializedStmt) + + _, err = stmt.ColumnType(0) + require.ErrorIs(t, err, errUninitializedStmt) + + _, err = stmt.ColumnTypeInfo(0) + require.ErrorIs(t, err, errUninitializedStmt) + _, err = stmt.ExecBound(context.Background()) require.ErrorIs(t, err, errNotBound) } @@ -1029,3 +1134,440 @@ func TestInterrupt(t *testing.T) { cancel() }() } + +func TestPreparedStatementColumnMethods(t *testing.T) { + db := openDbWrapper(t, ``) + defer closeDbWrapper(t, db) + + createTable(t, db, `CREATE TABLE test_columns (id INTEGER, name VARCHAR, value DOUBLE, created_at TIMESTAMP)`) + + // Prepare a SELECT statement + conn := openConnWrapper(t, db, context.Background()) + defer closeConnWrapper(t, conn) + + err := conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT id, name, value, created_at FROM test_columns`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + // Test ColumnCount + count, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 4, count) + + // Test ColumnName + name, innerErr := stmt.ColumnName(0) + require.NoError(t, innerErr) + require.Equal(t, "id", name) + + name, innerErr = stmt.ColumnName(1) + require.NoError(t, innerErr) + require.Equal(t, "name", name) + + name, innerErr = stmt.ColumnName(2) + require.NoError(t, innerErr) + require.Equal(t, "value", name) + + name, innerErr = stmt.ColumnName(3) + require.NoError(t, innerErr) + require.Equal(t, "created_at", name) + + // Test out of bounds - should return empty string + name, innerErr = stmt.ColumnName(-1) + require.NoError(t, innerErr) + require.Equal(t, "", name) + + name, innerErr = stmt.ColumnName(4) + require.NoError(t, innerErr) + require.Equal(t, "", name) + + // Test ColumnType + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INTEGER, colType) + + colType, innerErr = stmt.ColumnType(1) + require.NoError(t, innerErr) + require.Equal(t, TYPE_VARCHAR, colType) + + colType, innerErr = stmt.ColumnType(2) + require.NoError(t, innerErr) + require.Equal(t, TYPE_DOUBLE, colType) + + colType, innerErr = stmt.ColumnType(3) + require.NoError(t, innerErr) + require.Equal(t, TYPE_TIMESTAMP, colType) + + // Test out of bounds - should return TYPE_INVALID + colType, innerErr = stmt.ColumnType(-1) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + + colType, innerErr = stmt.ColumnType(4) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + + // Test ColumnTypeInfo - should return TypeInfo for each column + typeInfo, innerErr := stmt.ColumnTypeInfo(0) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_INTEGER, typeInfo.InternalType()) + + typeInfo, innerErr = stmt.ColumnTypeInfo(1) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_VARCHAR, typeInfo.InternalType()) + + typeInfo, innerErr = stmt.ColumnTypeInfo(2) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_DOUBLE, typeInfo.InternalType()) + + typeInfo, innerErr = stmt.ColumnTypeInfo(3) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_TIMESTAMP, typeInfo.InternalType()) + + // Test out of bounds - should return error with TYPE_INVALID + typeInfo, innerErr = stmt.ColumnTypeInfo(4) + require.Error(t, innerErr) + require.Nil(t, typeInfo) + require.Contains(t, innerErr.Error(), "cannot create TypeInfo from TYPE_INVALID") + + return nil + }) + require.NoError(t, err) + + // Test with closed statement + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT * FROM test_columns`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + + // Close the statement + require.NoError(t, stmt.Close()) + + // Test methods on closed statement + _, innerErr = stmt.ColumnCount() + require.ErrorIs(t, innerErr, errClosedStmt) + + _, innerErr = stmt.ColumnName(0) + require.ErrorIs(t, innerErr, errClosedStmt) + + _, innerErr = stmt.ColumnType(0) + require.ErrorIs(t, innerErr, errClosedStmt) + + _, innerErr = stmt.ColumnTypeInfo(0) + require.ErrorIs(t, innerErr, errClosedStmt) + + return nil + }) + require.NoError(t, err) +} + +func TestPreparedStatementColumnTypeInfo(t *testing.T) { + db := openDbWrapper(t, ``) + defer closeDbWrapper(t, db) + + conn := openConnWrapper(t, db, context.Background()) + defer closeConnWrapper(t, conn) + + // Test with complex types + err := conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + // Create a query with ARRAY, LIST, and STRUCT types + s, innerErr := innerConn.PrepareContext(context.Background(), + `SELECT [1, 2, 3]::INTEGER[3] AS arr_col, + [4, 5, 6] AS list_col, + {'x': 10, 'y': 20} AS struct_col`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + // Test ARRAY column + typeInfo, innerErr := stmt.ColumnTypeInfo(0) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_ARRAY, typeInfo.InternalType()) + + // Assert ARRAY details + details := typeInfo.Details() + require.NotNil(t, details) + arrayDetails, ok := details.(*ArrayDetails) + require.True(t, ok, "Expected ArrayDetails") + require.Equal(t, TYPE_INTEGER, arrayDetails.Child.InternalType()) + require.Equal(t, uint64(3), arrayDetails.Size) + + // Test LIST column + typeInfo, innerErr = stmt.ColumnTypeInfo(1) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_LIST, typeInfo.InternalType()) + + // Assert LIST details + details = typeInfo.Details() + require.NotNil(t, details) + listDetails, ok := details.(*ListDetails) + require.True(t, ok, "Expected ListDetails") + require.Equal(t, TYPE_INTEGER, listDetails.Child.InternalType()) + + // Test STRUCT column + typeInfo, innerErr = stmt.ColumnTypeInfo(2) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_STRUCT, typeInfo.InternalType()) + + // Assert STRUCT details + details = typeInfo.Details() + require.NotNil(t, details) + structDetails, ok := details.(*StructDetails) + require.True(t, ok, "Expected StructDetails") + require.Equal(t, 2, len(structDetails.Entries)) + + // Check first field 'x' + require.Equal(t, "x", structDetails.Entries[0].Name()) + require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) + + // Check second field 'y' + require.Equal(t, "y", structDetails.Entries[1].Name()) + require.Equal(t, TYPE_INTEGER, structDetails.Entries[1].Info().InternalType()) + + return nil + }) + require.NoError(t, err) + + // Test with DECIMAL type + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT 123.45::DECIMAL(10,2) AS dec_col`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + typeInfo, innerErr := stmt.ColumnTypeInfo(0) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_DECIMAL, typeInfo.InternalType()) + + // Assert DECIMAL details + details := typeInfo.Details() + require.NotNil(t, details) + decimalDetails, ok := details.(*DecimalDetails) + require.True(t, ok, "Expected DecimalDetails") + require.Equal(t, uint8(10), decimalDetails.Width) + require.Equal(t, uint8(2), decimalDetails.Scale) + + return nil + }) + require.NoError(t, err) + + // Test with ENUM type + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + // Create an ENUM type first + _, innerErr := innerConn.ExecContext(context.Background(), + `CREATE TYPE mood AS ENUM ('happy', 'sad', 'neutral')`, nil) + require.NoError(t, innerErr) + + s, innerErr := innerConn.PrepareContext(context.Background(), + `SELECT 'happy'::mood AS mood_col`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + typeInfo, innerErr := stmt.ColumnTypeInfo(0) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_ENUM, typeInfo.InternalType()) + + // Assert ENUM details + details := typeInfo.Details() + require.NotNil(t, details) + enumDetails, ok := details.(*EnumDetails) + require.True(t, ok, "Expected EnumDetails") + require.Equal(t, []string{"happy", "sad", "neutral"}, enumDetails.Values) + + return nil + }) + require.NoError(t, err) + + // Test with MAP type + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + s, innerErr := innerConn.PrepareContext(context.Background(), + `SELECT MAP([1, 2], ['a', 'b']) AS map_col`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + typeInfo, innerErr := stmt.ColumnTypeInfo(0) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_MAP, typeInfo.InternalType()) + + // Assert MAP details + details := typeInfo.Details() + require.NotNil(t, details) + mapDetails, ok := details.(*MapDetails) + require.True(t, ok, "Expected MapDetails") + require.Equal(t, TYPE_INTEGER, mapDetails.Key.InternalType()) + require.Equal(t, TYPE_VARCHAR, mapDetails.Value.InternalType()) + + return nil + }) + require.NoError(t, err) + + // Test with nested types: LIST of STRUCTs + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + s, innerErr := innerConn.PrepareContext(context.Background(), + `SELECT [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}] AS list_struct_col`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + typeInfo, innerErr := stmt.ColumnTypeInfo(0) + require.NoError(t, innerErr) + require.NotNil(t, typeInfo) + require.Equal(t, TYPE_LIST, typeInfo.InternalType()) + + // Assert LIST details + details := typeInfo.Details() + require.NotNil(t, details) + listDetails, ok := details.(*ListDetails) + require.True(t, ok, "Expected ListDetails") + require.Equal(t, TYPE_STRUCT, listDetails.Child.InternalType()) + + // Assert nested STRUCT details + structDetails, ok := listDetails.Child.Details().(*StructDetails) + require.True(t, ok, "Expected StructDetails for nested type") + require.Equal(t, 2, len(structDetails.Entries)) + require.Equal(t, "id", structDetails.Entries[0].Name()) + require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) + require.Equal(t, "name", structDetails.Entries[1].Name()) + require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) + + return nil + }) + require.NoError(t, err) +} + +func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { + db := openDbWrapper(t, ``) + defer closeDbWrapper(t, db) + + conn := openConnWrapper(t, db, context.Background()) + defer closeConnWrapper(t, conn) + + // Test cases where column types cannot be resolved + err := conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + // Test 1: VALUES clause without type casting - ambiguous types + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT * FROM (VALUES (?, ?)) t(a, b)`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + // When columns have ambiguous types, count becomes 1 + count, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 1, count) + + // Column type should be INVALID + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + + // Out of bounds access + colType, innerErr = stmt.ColumnType(1) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + + return nil + }) + require.NoError(t, err) + + // Test 2: Direct parameter selection - all ambiguous + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT ?, ?, ? + ?`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + // When columns have ambiguous types, count becomes 1 + count, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 1, count) + + // Column type should be INVALID + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + + return nil + }) + require.NoError(t, err) + + // Test 3: Mixed known and unknown types + createTable(t, db, `CREATE TABLE test_mixed (id INTEGER, value VARCHAR)`) + + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT id, value, ? AS param_col FROM test_mixed`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + // When any column has ambiguous type, count becomes 1 + count, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 1, count) + + // All column types become INVALID + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, colType) + + return nil + }) + require.NoError(t, err) + + // Test 4: Statement with no ambiguous types should work normally + err = conn.Raw(func(driverConn any) error { + innerConn := driverConn.(*Conn) + + s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT id, value FROM test_mixed`) + require.NoError(t, innerErr) + stmt := s.(*Stmt) + defer stmt.Close() + + // Normal count when no ambiguous types + count, innerErr := stmt.ColumnCount() + require.NoError(t, innerErr) + require.Equal(t, 2, count) + + // Column types are resolved + colType, innerErr := stmt.ColumnType(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_INTEGER, colType) + + colType, innerErr = stmt.ColumnType(1) + require.NoError(t, innerErr) + require.Equal(t, TYPE_VARCHAR, colType) + + return nil + }) + require.NoError(t, err) +} diff --git a/type_info.go b/type_info.go index 39a9d670..76658dd9 100644 --- a/type_info.go +++ b/type_info.go @@ -45,6 +45,70 @@ func (entry *structEntry) Name() string { return entry.name } +// TypeDetails is an interface for type-specific details. +// Use type assertion to access specific detail types. +type TypeDetails interface { + isTypeDetails() +} + +// DecimalDetails provides DECIMAL type information. +type DecimalDetails struct { + Width uint8 + Scale uint8 +} + +func (d *DecimalDetails) isTypeDetails() {} + +// EnumDetails provides ENUM type information. +type EnumDetails struct { + Values []string +} + +func (e *EnumDetails) isTypeDetails() {} + +// ListDetails provides LIST type information. +type ListDetails struct { + Child TypeInfo +} + +func (l *ListDetails) isTypeDetails() {} + +// ArrayDetails provides ARRAY type information. +type ArrayDetails struct { + Child TypeInfo + Size uint64 +} + +func (a *ArrayDetails) isTypeDetails() {} + +// MapDetails provides MAP type information. +type MapDetails struct { + Key TypeInfo + Value TypeInfo +} + +func (m *MapDetails) isTypeDetails() {} + +// StructDetails provides STRUCT type information. +type StructDetails struct { + Entries []StructEntry +} + +func (s *StructDetails) isTypeDetails() {} + +// UnionMember represents a UNION member with its name and type. +type UnionMember struct { + Name string + Type TypeInfo +} + +// UnionDetails provides UNION type information. +type UnionDetails struct { + Members []UnionMember +} + +func (u *UnionDetails) isTypeDetails() {} + type baseTypeInfo struct { Type @@ -76,6 +140,10 @@ type typeInfo struct { type TypeInfo interface { // InternalType returns the Type. InternalType() Type + // Details returns type-specific details for complex types. + // Returns nil for simple/primitive types. + // Use type assertion to access specific detail types. + Details() TypeDetails logicalType() mapping.LogicalType } @@ -83,6 +151,60 @@ func (info *typeInfo) InternalType() Type { return info.Type } +// Details returns type-specific details for complex types. +// Returns nil for simple/primitive types. +func (info *typeInfo) Details() TypeDetails { + switch info.Type { + case TYPE_DECIMAL: + return &DecimalDetails{ + Width: info.decimalWidth, + Scale: info.decimalScale, + } + case TYPE_ENUM: + // Make a copy of the slice to avoid exposing internal state + values := make([]string, len(info.names)) + copy(values, info.names) + return &EnumDetails{ + Values: values, + } + case TYPE_LIST: + return &ListDetails{ + Child: info.types[0], + } + case TYPE_ARRAY: + return &ArrayDetails{ + Child: info.types[0], + Size: uint64(info.arrayLength), + } + case TYPE_MAP: + return &MapDetails{ + Key: info.types[0], + Value: info.types[1], + } + case TYPE_STRUCT: + // Make a copy of the slice to avoid exposing internal state + entries := make([]StructEntry, len(info.structEntries)) + copy(entries, info.structEntries) + return &StructDetails{ + Entries: entries, + } + case TYPE_UNION: + // Build UnionMembers from types and names + members := make([]UnionMember, len(info.types)) + for i := range info.types { + members[i] = UnionMember{ + Name: info.names[i], + Type: info.types[i], + } + } + return &UnionDetails{ + Members: members, + } + default: + return nil + } +} + // NewTypeInfo returns type information for DuckDB's primitive types. // It returns the TypeInfo, if the Type parameter is a valid primitive type. // Else, it returns nil, and an error. @@ -357,6 +479,152 @@ func (info *typeInfo) logicalUnionType() mapping.LogicalType { return mapping.CreateUnionType(types, info.names) } +// NewTypeInfoFromLogicalType converts a mapping.LogicalType to TypeInfo. +// This allows inspecting types returned from prepared statements. +// The LogicalType must remain valid for the duration of this call. +// The returned TypeInfo does not hold a reference to the LogicalType. +func NewTypeInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + t := mapping.GetTypeId(lt) + + switch t { + case TYPE_DECIMAL: + return newDecimalInfoFromLogicalType(lt) + case TYPE_ENUM: + return newEnumInfoFromLogicalType(lt) + case TYPE_LIST: + return newListInfoFromLogicalType(lt) + case TYPE_ARRAY: + return newArrayInfoFromLogicalType(lt) + case TYPE_MAP: + return newMapInfoFromLogicalType(lt) + case TYPE_STRUCT: + return newStructInfoFromLogicalType(lt) + case TYPE_UNION: + return newUnionInfoFromLogicalType(lt) + case TYPE_INVALID: + return nil, getError(errAPI, errors.New("cannot create TypeInfo from TYPE_INVALID")) + default: + // Simple/primitive type + return NewTypeInfo(t) + } +} + +func newDecimalInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + width := mapping.DecimalWidth(lt) + scale := mapping.DecimalScale(lt) + return NewDecimalInfo(width, scale) +} + +func newEnumInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + size := mapping.EnumDictionarySize(lt) + if size == 0 { + return nil, getError(errAPI, errors.New("ENUM type must have at least one value")) + } + + values := make([]string, size) + for i := uint32(0); i < size; i++ { + values[i] = mapping.EnumDictionaryValue(lt, mapping.IdxT(i)) + } + + return NewEnumInfo(values[0], values[1:]...) +} + +func newListInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + childLT := mapping.ListTypeChildType(lt) + defer mapping.DestroyLogicalType(&childLT) + + childInfo, err := NewTypeInfoFromLogicalType(childLT) + if err != nil { + return nil, err + } + + return NewListInfo(childInfo) +} + +func newArrayInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + childLT := mapping.ArrayTypeChildType(lt) + defer mapping.DestroyLogicalType(&childLT) + + childInfo, err := NewTypeInfoFromLogicalType(childLT) + if err != nil { + return nil, err + } + + size := mapping.ArrayTypeArraySize(lt) + return NewArrayInfo(childInfo, uint64(size)) +} + +func newMapInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + keyLT := mapping.MapTypeKeyType(lt) + defer mapping.DestroyLogicalType(&keyLT) + + valueLT := mapping.MapTypeValueType(lt) + defer mapping.DestroyLogicalType(&valueLT) + + keyInfo, err := NewTypeInfoFromLogicalType(keyLT) + if err != nil { + return nil, err + } + + valueInfo, err := NewTypeInfoFromLogicalType(valueLT) + if err != nil { + return nil, err + } + + return NewMapInfo(keyInfo, valueInfo) +} + +func newStructInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + count := mapping.StructTypeChildCount(lt) + if count == 0 { + return nil, getError(errAPI, errors.New("STRUCT type must have at least one field")) + } + + entries := make([]StructEntry, count) + for i := mapping.IdxT(0); i < count; i++ { + name := mapping.StructTypeChildName(lt, i) + childLT := mapping.StructTypeChildType(lt, i) + + childInfo, err := NewTypeInfoFromLogicalType(childLT) + mapping.DestroyLogicalType(&childLT) + if err != nil { + return nil, err + } + + entry, err := NewStructEntry(childInfo, name) + if err != nil { + return nil, err + } + entries[i] = entry + } + + return NewStructInfo(entries[0], entries[1:]...) +} + +func newUnionInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { + count := mapping.UnionTypeMemberCount(lt) + if count == 0 { + return nil, getError(errAPI, errors.New("UNION type must have at least one member")) + } + + memberTypes := make([]TypeInfo, count) + memberNames := make([]string, count) + + for i := mapping.IdxT(0); i < count; i++ { + memberNames[i] = mapping.UnionTypeMemberName(lt, i) + memberLT := mapping.UnionTypeMemberType(lt, i) + + memberInfo, err := NewTypeInfoFromLogicalType(memberLT) + mapping.DestroyLogicalType(&memberLT) + if err != nil { + return nil, err + } + memberTypes[i] = memberInfo + } + + return NewUnionInfo(memberTypes, memberNames) +} + func funcName(i any) string { return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } diff --git a/type_info_test.go b/type_info_test.go index ffecd253..5592260f 100644 --- a/type_info_test.go +++ b/type_info_test.go @@ -3,6 +3,7 @@ package duckdb import ( "testing" + "github.com/duckdb/duckdb-go/mapping" "github.com/stretchr/testify/require" ) @@ -343,3 +344,451 @@ func TestErrTypeInfo(t *testing.T) { ) testError(t, err, errAPI.Error(), duplicateNameErrMsg) } + +func TestNewTypeInfoFromLogicalType(t *testing.T) { + // Test primitive types + primitiveTests := []Type{ + TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INTEGER, TYPE_BIGINT, + TYPE_UTINYINT, TYPE_USMALLINT, TYPE_UINTEGER, TYPE_UBIGINT, + TYPE_FLOAT, TYPE_DOUBLE, TYPE_TIMESTAMP, TYPE_DATE, TYPE_TIME, + TYPE_INTERVAL, TYPE_HUGEINT, TYPE_VARCHAR, TYPE_BLOB, + TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, TYPE_TIMESTAMP_NS, + TYPE_UUID, TYPE_TIME_TZ, TYPE_TIMESTAMP_TZ, + } + + for _, primitiveType := range primitiveTests { + t.Run(typeToStringMap[primitiveType], func(t *testing.T) { + // Create TypeInfo and convert to LogicalType + originalInfo, err := NewTypeInfo(primitiveType) + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + // Convert back to TypeInfo + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, primitiveType, reconstructedInfo.InternalType()) + }) + } + + // Test DECIMAL type + t.Run("DECIMAL", func(t *testing.T) { + originalInfo, err := NewDecimalInfo(10, 3) + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_DECIMAL, reconstructedInfo.InternalType()) + + // Verify we can convert back and get the same logical type + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, uint8(10), mapping.DecimalWidth(reconstructedLT)) + require.Equal(t, uint8(3), mapping.DecimalScale(reconstructedLT)) + }) + + // Test ENUM type + t.Run("ENUM", func(t *testing.T) { + originalInfo, err := NewEnumInfo("red", "green", "blue") + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_ENUM, reconstructedInfo.InternalType()) + + // Verify enum values + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, uint32(3), mapping.EnumDictionarySize(reconstructedLT)) + require.Equal(t, "red", mapping.EnumDictionaryValue(reconstructedLT, 0)) + require.Equal(t, "green", mapping.EnumDictionaryValue(reconstructedLT, 1)) + require.Equal(t, "blue", mapping.EnumDictionaryValue(reconstructedLT, 2)) + }) + + // Test LIST type + t.Run("LIST", func(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + originalInfo, err := NewListInfo(intInfo) + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_LIST, reconstructedInfo.InternalType()) + + // Verify child type + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + childLT := mapping.ListTypeChildType(reconstructedLT) + defer mapping.DestroyLogicalType(&childLT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(childLT)) + }) + + // Test ARRAY type + t.Run("ARRAY", func(t *testing.T) { + varcharInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + originalInfo, err := NewArrayInfo(varcharInfo, 5) + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_ARRAY, reconstructedInfo.InternalType()) + + // Verify child type and size + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + childLT := mapping.ArrayTypeChildType(reconstructedLT) + defer mapping.DestroyLogicalType(&childLT) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(childLT)) + require.Equal(t, mapping.IdxT(5), mapping.ArrayTypeArraySize(reconstructedLT)) + }) + + // Test MAP type + t.Run("MAP", func(t *testing.T) { + keyInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + valueInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + originalInfo, err := NewMapInfo(keyInfo, valueInfo) + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_MAP, reconstructedInfo.InternalType()) + + // Verify key and value types + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + keyLT := mapping.MapTypeKeyType(reconstructedLT) + defer mapping.DestroyLogicalType(&keyLT) + valueLT := mapping.MapTypeValueType(reconstructedLT) + defer mapping.DestroyLogicalType(&valueLT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(keyLT)) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(valueLT)) + }) + + // Test STRUCT type + t.Run("STRUCT", func(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + entry1, err := NewStructEntry(intInfo, "id") + require.NoError(t, err) + entry2, err := NewStructEntry(strInfo, "name") + require.NoError(t, err) + + originalInfo, err := NewStructInfo(entry1, entry2) + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_STRUCT, reconstructedInfo.InternalType()) + + // Verify struct fields + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, mapping.IdxT(2), mapping.StructTypeChildCount(reconstructedLT)) + require.Equal(t, "id", mapping.StructTypeChildName(reconstructedLT, 0)) + require.Equal(t, "name", mapping.StructTypeChildName(reconstructedLT, 1)) + + child0LT := mapping.StructTypeChildType(reconstructedLT, 0) + defer mapping.DestroyLogicalType(&child0LT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(child0LT)) + + child1LT := mapping.StructTypeChildType(reconstructedLT, 1) + defer mapping.DestroyLogicalType(&child1LT) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(child1LT)) + }) + + // Test UNION type + t.Run("UNION", func(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + originalInfo, err := NewUnionInfo( + []TypeInfo{intInfo, strInfo}, + []string{"num", "text"}, + ) + require.NoError(t, err) + + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_UNION, reconstructedInfo.InternalType()) + + // Verify union members + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, mapping.IdxT(2), mapping.UnionTypeMemberCount(reconstructedLT)) + require.Equal(t, "num", mapping.UnionTypeMemberName(reconstructedLT, 0)) + require.Equal(t, "text", mapping.UnionTypeMemberName(reconstructedLT, 1)) + + member0LT := mapping.UnionTypeMemberType(reconstructedLT, 0) + defer mapping.DestroyLogicalType(&member0LT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(member0LT)) + + member1LT := mapping.UnionTypeMemberType(reconstructedLT, 1) + defer mapping.DestroyLogicalType(&member1LT) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(member1LT)) + }) + + // Test nested complex types + t.Run("NestedTypes", func(t *testing.T) { + // Create LIST of STRUCTs + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + entry1, err := NewStructEntry(intInfo, "id") + require.NoError(t, err) + entry2, err := NewStructEntry(strInfo, "name") + require.NoError(t, err) + + structInfo, err := NewStructInfo(entry1, entry2) + require.NoError(t, err) + + listInfo, err := NewListInfo(structInfo) + require.NoError(t, err) + + lt := listInfo.logicalType() + defer mapping.DestroyLogicalType(<) + + reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_LIST, reconstructedInfo.InternalType()) + + details := reconstructedInfo.Details() + listDetails, ok := details.(*ListDetails) + require.True(t, ok) + require.Equal(t, TYPE_STRUCT, listDetails.Child.InternalType()) + + structDetails, ok := listDetails.Child.Details().(*StructDetails) + require.True(t, ok) + require.Equal(t, 2, len(structDetails.Entries)) + require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) + require.Equal(t, "id", structDetails.Entries[0].Name()) + require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) + require.Equal(t, "name", structDetails.Entries[1].Name()) + + }) + +} + +func TestTypeInfoDetails(t *testing.T) { + // Test primitive types return nil + t.Run("PrimitiveTypes", func(t *testing.T) { + primitiveTypes := []Type{ + TYPE_BOOLEAN, TYPE_INTEGER, TYPE_VARCHAR, TYPE_TIMESTAMP, TYPE_DATE, + } + + for _, primitiveType := range primitiveTypes { + info, err := NewTypeInfo(primitiveType) + require.NoError(t, err) + require.Nil(t, info.Details()) + } + }) + + // Test DECIMAL details + t.Run("DecimalDetails", func(t *testing.T) { + info, err := NewDecimalInfo(10, 3) + require.NoError(t, err) + + details := info.Details() + require.NotNil(t, details) + + decimalDetails, ok := details.(*DecimalDetails) + require.True(t, ok) + require.Equal(t, uint8(10), decimalDetails.Width) + require.Equal(t, uint8(3), decimalDetails.Scale) + }) + + // Test ENUM details + t.Run("EnumDetails", func(t *testing.T) { + info, err := NewEnumInfo("red", "green", "blue") + require.NoError(t, err) + + details := info.Details() + require.NotNil(t, details) + + enumDetails, ok := details.(*EnumDetails) + require.True(t, ok) + require.Equal(t, []string{"red", "green", "blue"}, enumDetails.Values) + + // Test that modifying the returned slice doesn't affect the original + enumDetails.Values[0] = "modified" + details2 := info.Details() + enumDetails2, ok := details2.(*EnumDetails) + require.True(t, ok) + require.Equal(t, "red", enumDetails2.Values[0]) + }) + + // Test LIST details + t.Run("ListDetails", func(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + info, err := NewListInfo(intInfo) + require.NoError(t, err) + + details := info.Details() + require.NotNil(t, details) + + listDetails, ok := details.(*ListDetails) + require.True(t, ok) + require.Equal(t, TYPE_INTEGER, listDetails.Child.InternalType()) + }) + + // Test ARRAY details + t.Run("ArrayDetails", func(t *testing.T) { + varcharInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + info, err := NewArrayInfo(varcharInfo, 5) + require.NoError(t, err) + + details := info.Details() + require.NotNil(t, details) + + arrayDetails, ok := details.(*ArrayDetails) + require.True(t, ok) + require.Equal(t, TYPE_VARCHAR, arrayDetails.Child.InternalType()) + require.Equal(t, uint64(5), arrayDetails.Size) + }) + + // Test MAP details + t.Run("MapDetails", func(t *testing.T) { + keyInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + valueInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + info, err := NewMapInfo(keyInfo, valueInfo) + require.NoError(t, err) + + details := info.Details() + require.NotNil(t, details) + + mapDetails, ok := details.(*MapDetails) + require.True(t, ok) + require.Equal(t, TYPE_INTEGER, mapDetails.Key.InternalType()) + require.Equal(t, TYPE_VARCHAR, mapDetails.Value.InternalType()) + }) + + // Test STRUCT details + t.Run("StructDetails", func(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + entry1, err := NewStructEntry(intInfo, "id") + require.NoError(t, err) + entry2, err := NewStructEntry(strInfo, "name") + require.NoError(t, err) + + info, err := NewStructInfo(entry1, entry2) + require.NoError(t, err) + + details := info.Details() + require.NotNil(t, details) + + structDetails, ok := details.(*StructDetails) + require.True(t, ok) + require.Equal(t, 2, len(structDetails.Entries)) + require.Equal(t, "id", structDetails.Entries[0].Name()) + require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) + require.Equal(t, "name", structDetails.Entries[1].Name()) + require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) + }) + + // Test UNION details + t.Run("UnionDetails", func(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + info, err := NewUnionInfo( + []TypeInfo{intInfo, strInfo}, + []string{"num", "text"}, + ) + require.NoError(t, err) + + details := info.Details() + require.NotNil(t, details) + + unionDetails, ok := details.(*UnionDetails) + require.True(t, ok) + require.Equal(t, 2, len(unionDetails.Members)) + require.Equal(t, "num", unionDetails.Members[0].Name) + require.Equal(t, TYPE_INTEGER, unionDetails.Members[0].Type.InternalType()) + require.Equal(t, "text", unionDetails.Members[1].Name) + require.Equal(t, TYPE_VARCHAR, unionDetails.Members[1].Type.InternalType()) + }) + + // Test nested type details + t.Run("NestedTypeDetails", func(t *testing.T) { + // Create a LIST of STRUCTs + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + entry1, err := NewStructEntry(intInfo, "id") + require.NoError(t, err) + entry2, err := NewStructEntry(strInfo, "name") + require.NoError(t, err) + + structInfo, err := NewStructInfo(entry1, entry2) + require.NoError(t, err) + + listInfo, err := NewListInfo(structInfo) + require.NoError(t, err) + + // Get list details + details := listInfo.Details() + require.NotNil(t, details) + + listDetails, ok := details.(*ListDetails) + require.True(t, ok) + require.Equal(t, TYPE_STRUCT, listDetails.Child.InternalType()) + + // Get struct details from the child + structDetails := listDetails.Child.Details() + require.NotNil(t, structDetails) + + structDetailsTyped, ok := structDetails.(*StructDetails) + require.True(t, ok) + require.Equal(t, 2, len(structDetailsTyped.Entries)) + }) +} From e5b9c414bda8abf2b48f265ac2b471b6d301b4de Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Thu, 6 Nov 2025 17:26:30 -0500 Subject: [PATCH 08/14] update --- statement.go | 5 +++-- statement_test.go | 23 ++++++++++++----------- type_info.go | 2 +- type_info_test.go | 11 ++++++++++- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/statement.go b/statement.go index 4e3f181b..cb9a5b12 100644 --- a/statement.go +++ b/statement.go @@ -502,7 +502,8 @@ func (s *Stmt) ColumnType(idx int) (Type, error) { // ColumnTypeInfo returns the TypeInfo of the column at the given index (0-based). // TypeInfo provides detailed type information including nested structures, DECIMAL precision, -// ENUM values, etc. For out-of-range indices, returns an error indicating TYPE_INVALID. +// ENUM values, etc. +// Returns a TypeInfo with internalType TYPE_INVALID if the column is out of range. // Returns an error if the statement is closed or uninitialized. func (s *Stmt) ColumnTypeInfo(idx int) (TypeInfo, error) { if s.closed { @@ -519,7 +520,7 @@ func (s *Stmt) ColumnTypeInfo(idx int) (TypeInfo, error) { } // ColumnName returns the name of the column at the given index (0-based). -// Returns an empty string if the column is out of range. +// Returns "" if the column is out of range. // Returns an error if the statement is closed or uninitialized. func (s *Stmt) ColumnName(idx int) (string, error) { if s.closed { diff --git a/statement_test.go b/statement_test.go index 75e67fca..75dcb3e1 100644 --- a/statement_test.go +++ b/statement_test.go @@ -182,25 +182,27 @@ func TestPrepareQueryPositional(t *testing.T) { // Test column count columnCount, innerErr := stmt.ColumnCount() require.NoError(t, innerErr) - require.Equal(t, 2, columnCount) + require.Equal(t, 1, columnCount) // Test column names colName, innerErr := stmt.ColumnName(0) require.NoError(t, innerErr) - require.Equal(t, "first_param", colName) + require.Equal(t, "unknown", colName) + // Out of range colName, innerErr = stmt.ColumnName(1) require.NoError(t, innerErr) - require.Equal(t, "second_param", colName) + require.Equal(t, "", colName) // Test column types - should be TYPE_INVALID for unresolved parameter types colType, innerErr := stmt.ColumnType(0) require.NoError(t, innerErr) - require.Equal(t, TYPE_INVALID, colType) // Type cannot be resolved from parameters + require.Equal(t, TYPE_INVALID, colType) + // Out of range also returns TYPE_INVALID colType, innerErr = stmt.ColumnType(1) require.NoError(t, innerErr) - require.Equal(t, TYPE_INVALID, colType) // Type cannot be resolved from parameters + require.Equal(t, TYPE_INVALID, colType) return nil }) @@ -252,16 +254,16 @@ func TestPrepareQueryPositional(t *testing.T) { // Test column methods for UPDATE statement (should have no columns) columnCount, innerErr := stmt.ColumnCount() require.NoError(t, innerErr) - require.Equal(t, 0, columnCount) // UPDATE doesn't return columns + require.Equal(t, 1, columnCount) // Test out of bounds access - should return empty/invalid for UPDATE with no columns colName, innerErr := stmt.ColumnName(0) require.NoError(t, innerErr) - require.Equal(t, "", colName) + require.Equal(t, "Count", colName) colType, innerErr := stmt.ColumnType(0) require.NoError(t, innerErr) - require.Equal(t, TYPE_INVALID, colType) + require.Equal(t, TYPE_BIGINT, colType) r, innerErr := stmt.ExecBound(context.Background()) require.Nil(t, r) @@ -1232,9 +1234,8 @@ func TestPreparedStatementColumnMethods(t *testing.T) { // Test out of bounds - should return error with TYPE_INVALID typeInfo, innerErr = stmt.ColumnTypeInfo(4) - require.Error(t, innerErr) - require.Nil(t, typeInfo) - require.Contains(t, innerErr.Error(), "cannot create TypeInfo from TYPE_INVALID") + require.NoError(t, innerErr) + require.Equal(t, TYPE_INVALID, typeInfo.InternalType()) return nil }) diff --git a/type_info.go b/type_info.go index 76658dd9..e3f712a5 100644 --- a/type_info.go +++ b/type_info.go @@ -502,7 +502,7 @@ func NewTypeInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { case TYPE_UNION: return newUnionInfoFromLogicalType(lt) case TYPE_INVALID: - return nil, getError(errAPI, errors.New("cannot create TypeInfo from TYPE_INVALID")) + return &typeInfo{baseTypeInfo: baseTypeInfo{Type: TYPE_INVALID}}, nil default: // Simple/primitive type return NewTypeInfo(t) diff --git a/type_info_test.go b/type_info_test.go index 5592260f..53f33358 100644 --- a/type_info_test.go +++ b/type_info_test.go @@ -644,7 +644,7 @@ func TestTypeInfoDetails(t *testing.T) { require.True(t, ok) require.Equal(t, []string{"red", "green", "blue"}, enumDetails.Values) - // Test that modifying the returned slice doesn't affect the original + // Test that modifying the returned slice doesn't affect the original TypeInfo enumDetails.Values[0] = "modified" details2 := info.Details() enumDetails2, ok := details2.(*EnumDetails) @@ -754,6 +754,15 @@ func TestTypeInfoDetails(t *testing.T) { require.Equal(t, TYPE_INTEGER, unionDetails.Members[0].Type.InternalType()) require.Equal(t, "text", unionDetails.Members[1].Name) require.Equal(t, TYPE_VARCHAR, unionDetails.Members[1].Type.InternalType()) + + // Test that modifying the returned details doesn't affect the original TypeInfo + unionDetails.Members[0].Name = "new_name" + + details2 := info.Details() + require.NotNil(t, details2) + unionDetails2, ok := details2.(*UnionDetails) + require.True(t, ok) + require.Equal(t, "num", unionDetails2.Members[0].Name) }) // Test nested type details From 81a60f249cf3fe768538177f902eb3236e413d01 Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Thu, 6 Nov 2025 18:10:06 -0500 Subject: [PATCH 09/14] apply range check; abstract out checkState --- errors.go | 5 +++ statement.go | 94 ++++++++++++++++++++++++----------------------- statement_test.go | 53 ++++++++++++++++---------- type_info.go | 2 - 4 files changed, 87 insertions(+), 67 deletions(-) diff --git a/errors.go b/errors.go index 3aaf71eb..dffa24d7 100644 --- a/errors.go +++ b/errors.go @@ -39,6 +39,10 @@ func paramIndexError(idx int, max uint64) error { return fmt.Errorf("%s: %d is out of range [1, %d]", paramIndexErrMsg, idx, max) } +func columnIndexError(idx int, max uint64) error { + return fmt.Errorf("%s: %d is out of range [0, %d)", columnIndexErrMsg, idx, max) +} + func unsupportedTypeError(name string) error { return fmt.Errorf("%s: %s", unsupportedTypeErrMsg, name) } @@ -81,6 +85,7 @@ const ( interfaceIsNilErrMsg = "interface is nil" duplicateNameErrMsg = "duplicate name" paramIndexErrMsg = "invalid parameter index" + columnIndexErrMsg = "invalid column index" ) var ( diff --git a/statement.go b/statement.go index cb9a5b12..22f93b9c 100644 --- a/statement.go +++ b/statement.go @@ -54,6 +54,17 @@ type Stmt struct { rows bool } +// checkState checks if the statement is closed or uninitialized. +func (s *Stmt) checkState() error { + if s.closed { + return errClosedStmt + } + if s.preparedStmt == nil { + return errUninitializedStmt + } + return nil +} + // Close the statement. // Implements the driver.Stmt interface. func (s *Stmt) Close() error { @@ -81,11 +92,8 @@ func (s *Stmt) NumInput() int { // ParamName returns the name of the parameter at the given index (1-based). func (s *Stmt) ParamName(n int) (string, error) { - if s.closed { - return "", errClosedStmt - } - if s.preparedStmt == nil { - return "", errUninitializedStmt + if err := s.checkState(); err != nil { + return "", err } count := mapping.NParams(*s.preparedStmt) @@ -99,11 +107,8 @@ func (s *Stmt) ParamName(n int) (string, error) { // ParamType returns the expected type of the parameter at the given index (1-based). func (s *Stmt) ParamType(n int) (Type, error) { - if s.closed { - return TYPE_INVALID, errClosedStmt - } - if s.preparedStmt == nil { - return TYPE_INVALID, errUninitializedStmt + if err := s.checkState(); err != nil { + return TYPE_INVALID, err } count := mapping.NParams(*s.preparedStmt) @@ -117,11 +122,8 @@ func (s *Stmt) ParamType(n int) (Type, error) { func (s *Stmt) paramLogicalType(n int) (mapping.LogicalType, error) { var lt mapping.LogicalType - if s.closed { - return lt, errClosedStmt - } - if s.preparedStmt == nil { - return lt, errUninitializedStmt + if err := s.checkState(); err != nil { + return lt, err } count := mapping.NParams(*s.preparedStmt) @@ -134,11 +136,8 @@ func (s *Stmt) paramLogicalType(n int) (mapping.LogicalType, error) { // StatementType returns the type of the statement. func (s *Stmt) StatementType() (StmtType, error) { - if s.closed { - return STATEMENT_TYPE_INVALID, errClosedStmt - } - if s.preparedStmt == nil { - return STATEMENT_TYPE_INVALID, errUninitializedStmt + if err := s.checkState(); err != nil { + return STATEMENT_TYPE_INVALID, err } t := mapping.PreparedStatementType(*s.preparedStmt) @@ -474,11 +473,8 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv // If any of the column types is invalid (which can happen when the type is ambiguous), the result will be 1. // Returns an error if the statement is closed or uninitialized. func (s *Stmt) ColumnCount() (int, error) { - if s.closed { - return 0, errClosedStmt - } - if s.preparedStmt == nil { - return 0, errUninitializedStmt + if err := s.checkState(); err != nil { + return 0, err } count := mapping.PreparedStatementColumnCount(*s.preparedStmt) @@ -486,51 +482,57 @@ func (s *Stmt) ColumnCount() (int, error) { } // ColumnType returns the type of the column at the given index (0-based). -// Returns TYPE_INVALID if the column is out of range. +// Returns TYPE_INVALID and a columnIndexError if the column is out of range. // Returns an error if the statement is closed or uninitialized. -func (s *Stmt) ColumnType(idx int) (Type, error) { - if s.closed { - return TYPE_INVALID, errClosedStmt +func (s *Stmt) ColumnType(n int) (Type, error) { + if err := s.checkState(); err != nil { + return TYPE_INVALID, err } - if s.preparedStmt == nil { - return TYPE_INVALID, errUninitializedStmt + + count := mapping.PreparedStatementColumnCount(*s.preparedStmt) + if n < 0 || n >= int(count) { + return TYPE_INVALID, getError(errAPI, columnIndexError(n, uint64(count))) } - t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(idx)) + t := mapping.PreparedStatementColumnType(*s.preparedStmt, mapping.IdxT(n)) return t, nil } // ColumnTypeInfo returns the TypeInfo of the column at the given index (0-based). // TypeInfo provides detailed type information including nested structures, DECIMAL precision, // ENUM values, etc. -// Returns a TypeInfo with internalType TYPE_INVALID if the column is out of range. +// Returns a TypeInfo with internalType TYPE_INVALID and a columnIndexError if the column is out of range. // Returns an error if the statement is closed or uninitialized. -func (s *Stmt) ColumnTypeInfo(idx int) (TypeInfo, error) { - if s.closed { - return nil, errClosedStmt +func (s *Stmt) ColumnTypeInfo(n int) (TypeInfo, error) { + if err := s.checkState(); err != nil { + return nil, err } - if s.preparedStmt == nil { - return nil, errUninitializedStmt + + count := mapping.PreparedStatementColumnCount(*s.preparedStmt) + if n < 0 || n >= int(count) { + return nil, getError(errAPI, columnIndexError(n, uint64(count))) } - lt := mapping.PreparedStatementColumnLogicalType(*s.preparedStmt, mapping.IdxT(idx)) + lt := mapping.PreparedStatementColumnLogicalType(*s.preparedStmt, mapping.IdxT(n)) defer mapping.DestroyLogicalType(<) return NewTypeInfoFromLogicalType(lt) } // ColumnName returns the name of the column at the given index (0-based). -// Returns "" if the column is out of range. +// Returns "" and a columnIndexError if the column is out of range. // Returns an error if the statement is closed or uninitialized. -func (s *Stmt) ColumnName(idx int) (string, error) { - if s.closed { - return "", errClosedStmt +func (s *Stmt) ColumnName(n int) (string, error) { + if err := s.checkState(); err != nil { + return "", err } - if s.preparedStmt == nil { - return "", errUninitializedStmt + + count := mapping.PreparedStatementColumnCount(*s.preparedStmt) + if n < 0 || n >= int(count) { + return "", getError(errAPI, columnIndexError(n, uint64(count))) } - name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(idx)) + name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(n)) // C API returns nullptr for out-of-range indices if name == "" { return "", nil diff --git a/statement_test.go b/statement_test.go index 75dcb3e1..6923ab21 100644 --- a/statement_test.go +++ b/statement_test.go @@ -75,9 +75,10 @@ func TestPrepareQuery(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, "baz", colName) - // Test out of bounds - should return empty string + // Test out of bounds - should return error colName, innerErr = stmt.ColumnName(2) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, "", colName) // Test column types @@ -89,9 +90,10 @@ func TestPrepareQuery(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_INTEGER, colType) - // Test out of bounds - should return TYPE_INVALID + // Test out of bounds - should return error colType, innerErr = stmt.ColumnType(2) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, TYPE_INVALID, colType) r, innerErr := stmt.QueryBound(context.Background()) @@ -189,9 +191,10 @@ func TestPrepareQueryPositional(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, "unknown", colName) - // Out of range + // Out of range - should return error colName, innerErr = stmt.ColumnName(1) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, "", colName) // Test column types - should be TYPE_INVALID for unresolved parameter types @@ -199,9 +202,10 @@ func TestPrepareQueryPositional(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_INVALID, colType) - // Out of range also returns TYPE_INVALID + // Out of range - should return error colType, innerErr = stmt.ColumnType(1) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, TYPE_INVALID, colType) return nil @@ -1176,13 +1180,15 @@ func TestPreparedStatementColumnMethods(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, "created_at", name) - // Test out of bounds - should return empty string + // Test out of bounds - should return error name, innerErr = stmt.ColumnName(-1) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, "", name) name, innerErr = stmt.ColumnName(4) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, "", name) // Test ColumnType @@ -1202,13 +1208,15 @@ func TestPreparedStatementColumnMethods(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_TIMESTAMP, colType) - // Test out of bounds - should return TYPE_INVALID + // Test out of bounds - should return error colType, innerErr = stmt.ColumnType(-1) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, TYPE_INVALID, colType) colType, innerErr = stmt.ColumnType(4) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, TYPE_INVALID, colType) // Test ColumnTypeInfo - should return TypeInfo for each column @@ -1232,10 +1240,16 @@ func TestPreparedStatementColumnMethods(t *testing.T) { require.NotNil(t, typeInfo) require.Equal(t, TYPE_TIMESTAMP, typeInfo.InternalType()) - // Test out of bounds - should return error with TYPE_INVALID + // Test out of bounds - should return error typeInfo, innerErr = stmt.ColumnTypeInfo(4) - require.NoError(t, innerErr) - require.Equal(t, TYPE_INVALID, typeInfo.InternalType()) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) + require.Nil(t, typeInfo) + + typeInfo, innerErr = stmt.ColumnTypeInfo(-1) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) + require.Nil(t, typeInfo) return nil }) @@ -1488,9 +1502,10 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_INVALID, colType) - // Out of bounds access + // Out of bounds access - should return error colType, innerErr = stmt.ColumnType(1) - require.NoError(t, innerErr) + require.Error(t, innerErr) + require.ErrorIs(t, innerErr, errAPI) require.Equal(t, TYPE_INVALID, colType) return nil diff --git a/type_info.go b/type_info.go index e3f712a5..502cabbd 100644 --- a/type_info.go +++ b/type_info.go @@ -501,8 +501,6 @@ func NewTypeInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { return newStructInfoFromLogicalType(lt) case TYPE_UNION: return newUnionInfoFromLogicalType(lt) - case TYPE_INVALID: - return &typeInfo{baseTypeInfo: baseTypeInfo{Type: TYPE_INVALID}}, nil default: // Simple/primitive type return NewTypeInfo(t) From ddf336814ffd2153f4431752042cecea114c067c Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Thu, 6 Nov 2025 20:06:21 -0500 Subject: [PATCH 10/14] lint --- statement_test.go | 34 +++++++++++++++++----------------- type_info.go | 2 +- type_info_test.go | 13 ++++++------- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/statement_test.go b/statement_test.go index 6923ab21..88a3695e 100644 --- a/statement_test.go +++ b/statement_test.go @@ -79,7 +79,7 @@ func TestPrepareQuery(t *testing.T) { colName, innerErr = stmt.ColumnName(2) require.Error(t, innerErr) require.ErrorIs(t, innerErr, errAPI) - require.Equal(t, "", colName) + require.Empty(t, colName) // Test column types colType, innerErr := stmt.ColumnType(0) @@ -179,7 +179,6 @@ func TestPrepareQueryPositional(t *testing.T) { s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT $1 AS first_param, $2 AS second_param`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() // Test column count columnCount, innerErr := stmt.ColumnCount() @@ -195,7 +194,7 @@ func TestPrepareQueryPositional(t *testing.T) { colName, innerErr = stmt.ColumnName(1) require.Error(t, innerErr) require.ErrorIs(t, innerErr, errAPI) - require.Equal(t, "", colName) + require.Empty(t, colName) // Test column types - should be TYPE_INVALID for unresolved parameter types colType, innerErr := stmt.ColumnType(0) @@ -208,6 +207,7 @@ func TestPrepareQueryPositional(t *testing.T) { require.ErrorIs(t, innerErr, errAPI) require.Equal(t, TYPE_INVALID, colType) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1156,7 +1156,6 @@ func TestPreparedStatementColumnMethods(t *testing.T) { s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT id, name, value, created_at FROM test_columns`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() // Test ColumnCount count, innerErr := stmt.ColumnCount() @@ -1184,12 +1183,12 @@ func TestPreparedStatementColumnMethods(t *testing.T) { name, innerErr = stmt.ColumnName(-1) require.Error(t, innerErr) require.ErrorIs(t, innerErr, errAPI) - require.Equal(t, "", name) + require.Empty(t, name) name, innerErr = stmt.ColumnName(4) require.Error(t, innerErr) require.ErrorIs(t, innerErr, errAPI) - require.Equal(t, "", name) + require.Empty(t, name) // Test ColumnType colType, innerErr := stmt.ColumnType(0) @@ -1251,6 +1250,7 @@ func TestPreparedStatementColumnMethods(t *testing.T) { require.ErrorIs(t, innerErr, errAPI) require.Nil(t, typeInfo) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1301,7 +1301,6 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { {'x': 10, 'y': 20} AS struct_col`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() // Test ARRAY column typeInfo, innerErr := stmt.ColumnTypeInfo(0) @@ -1341,7 +1340,7 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { require.NotNil(t, details) structDetails, ok := details.(*StructDetails) require.True(t, ok, "Expected StructDetails") - require.Equal(t, 2, len(structDetails.Entries)) + require.Len(t, structDetails.Entries, 2) // Check first field 'x' require.Equal(t, "x", structDetails.Entries[0].Name()) @@ -1351,6 +1350,7 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { require.Equal(t, "y", structDetails.Entries[1].Name()) require.Equal(t, TYPE_INTEGER, structDetails.Entries[1].Info().InternalType()) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1362,7 +1362,6 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT 123.45::DECIMAL(10,2) AS dec_col`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() typeInfo, innerErr := stmt.ColumnTypeInfo(0) require.NoError(t, innerErr) @@ -1377,6 +1376,7 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { require.Equal(t, uint8(10), decimalDetails.Width) require.Equal(t, uint8(2), decimalDetails.Scale) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1394,7 +1394,6 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { `SELECT 'happy'::mood AS mood_col`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() typeInfo, innerErr := stmt.ColumnTypeInfo(0) require.NoError(t, innerErr) @@ -1408,6 +1407,7 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { require.True(t, ok, "Expected EnumDetails") require.Equal(t, []string{"happy", "sad", "neutral"}, enumDetails.Values) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1420,7 +1420,6 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { `SELECT MAP([1, 2], ['a', 'b']) AS map_col`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() typeInfo, innerErr := stmt.ColumnTypeInfo(0) require.NoError(t, innerErr) @@ -1435,6 +1434,7 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { require.Equal(t, TYPE_INTEGER, mapDetails.Key.InternalType()) require.Equal(t, TYPE_VARCHAR, mapDetails.Value.InternalType()) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1447,7 +1447,6 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { `SELECT [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}] AS list_struct_col`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() typeInfo, innerErr := stmt.ColumnTypeInfo(0) require.NoError(t, innerErr) @@ -1464,12 +1463,13 @@ func TestPreparedStatementColumnTypeInfo(t *testing.T) { // Assert nested STRUCT details structDetails, ok := listDetails.Child.Details().(*StructDetails) require.True(t, ok, "Expected StructDetails for nested type") - require.Equal(t, 2, len(structDetails.Entries)) + require.Len(t, structDetails.Entries, 2) require.Equal(t, "id", structDetails.Entries[0].Name()) require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) require.Equal(t, "name", structDetails.Entries[1].Name()) require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1490,7 +1490,6 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT * FROM (VALUES (?, ?)) t(a, b)`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() // When columns have ambiguous types, count becomes 1 count, innerErr := stmt.ColumnCount() @@ -1508,6 +1507,7 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { require.ErrorIs(t, innerErr, errAPI) require.Equal(t, TYPE_INVALID, colType) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1519,7 +1519,6 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT ?, ?, ? + ?`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() // When columns have ambiguous types, count becomes 1 count, innerErr := stmt.ColumnCount() @@ -1531,6 +1530,7 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_INVALID, colType) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1544,7 +1544,6 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT id, value, ? AS param_col FROM test_mixed`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() // When any column has ambiguous type, count becomes 1 count, innerErr := stmt.ColumnCount() @@ -1556,6 +1555,7 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_INVALID, colType) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) @@ -1567,7 +1567,6 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT id, value FROM test_mixed`) require.NoError(t, innerErr) stmt := s.(*Stmt) - defer stmt.Close() // Normal count when no ambiguous types count, innerErr := stmt.ColumnCount() @@ -1583,6 +1582,7 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_VARCHAR, colType) + require.NoError(t, stmt.Close()) return nil }) require.NoError(t, err) diff --git a/type_info.go b/type_info.go index 502cabbd..b9fecc92 100644 --- a/type_info.go +++ b/type_info.go @@ -520,7 +520,7 @@ func newEnumInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { } values := make([]string, size) - for i := uint32(0); i < size; i++ { + for i := range size { values[i] = mapping.EnumDictionaryValue(lt, mapping.IdxT(i)) } diff --git a/type_info_test.go b/type_info_test.go index 53f33358..1b6d8bb7 100644 --- a/type_info_test.go +++ b/type_info_test.go @@ -3,8 +3,9 @@ package duckdb import ( "testing" - "github.com/duckdb/duckdb-go/mapping" "github.com/stretchr/testify/require" + + "github.com/duckdb/duckdb-go/mapping" ) type testTypeValues struct { @@ -594,14 +595,12 @@ func TestNewTypeInfoFromLogicalType(t *testing.T) { structDetails, ok := listDetails.Child.Details().(*StructDetails) require.True(t, ok) - require.Equal(t, 2, len(structDetails.Entries)) + require.Len(t, structDetails.Entries, 2) require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) require.Equal(t, "id", structDetails.Entries[0].Name()) require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) require.Equal(t, "name", structDetails.Entries[1].Name()) - }) - } func TestTypeInfoDetails(t *testing.T) { @@ -724,7 +723,7 @@ func TestTypeInfoDetails(t *testing.T) { structDetails, ok := details.(*StructDetails) require.True(t, ok) - require.Equal(t, 2, len(structDetails.Entries)) + require.Len(t, structDetails.Entries, 2) require.Equal(t, "id", structDetails.Entries[0].Name()) require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) require.Equal(t, "name", structDetails.Entries[1].Name()) @@ -749,7 +748,7 @@ func TestTypeInfoDetails(t *testing.T) { unionDetails, ok := details.(*UnionDetails) require.True(t, ok) - require.Equal(t, 2, len(unionDetails.Members)) + require.Len(t, unionDetails.Members, 2) require.Equal(t, "num", unionDetails.Members[0].Name) require.Equal(t, TYPE_INTEGER, unionDetails.Members[0].Type.InternalType()) require.Equal(t, "text", unionDetails.Members[1].Name) @@ -798,6 +797,6 @@ func TestTypeInfoDetails(t *testing.T) { structDetailsTyped, ok := structDetails.(*StructDetails) require.True(t, ok) - require.Equal(t, 2, len(structDetailsTyped.Entries)) + require.Len(t, len(structDetailsTyped.Entries), 2) }) } From a49e3b4ba9dd4c179cad206cee49e98a04675491 Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Thu, 6 Nov 2025 20:41:45 -0500 Subject: [PATCH 11/14] fix test lol --- type_info_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/type_info_test.go b/type_info_test.go index 1b6d8bb7..5ee6fabe 100644 --- a/type_info_test.go +++ b/type_info_test.go @@ -797,6 +797,6 @@ func TestTypeInfoDetails(t *testing.T) { structDetailsTyped, ok := structDetails.(*StructDetails) require.True(t, ok) - require.Len(t, len(structDetailsTyped.Entries), 2) + require.Len(t, structDetailsTyped.Entries, 2) }) } From 92051f14bc360b0189445fa7b629a23d95f96566 Mon Sep 17 00:00:00 2001 From: Louisa Huang Date: Mon, 8 Dec 2025 14:53:57 +0800 Subject: [PATCH 12/14] address comments --- statement.go | 6 +- statement_test.go | 63 +------- type_info.go | 16 +- type_info_test.go | 392 ++++++++++++++++++++++------------------------ 4 files changed, 203 insertions(+), 274 deletions(-) diff --git a/statement.go b/statement.go index b9d5ef30..cb914455 100644 --- a/statement.go +++ b/statement.go @@ -519,7 +519,7 @@ func (s *Stmt) ColumnTypeInfo(n int) (TypeInfo, error) { lt := mapping.PreparedStatementColumnLogicalType(*s.preparedStmt, mapping.IdxT(n)) defer mapping.DestroyLogicalType(<) - return NewTypeInfoFromLogicalType(lt) + return newTypeInfoFromLogicalType(lt) } // ColumnName returns the name of the column at the given index (0-based). @@ -536,10 +536,6 @@ func (s *Stmt) ColumnName(n int) (string, error) { } name := mapping.PreparedStatementColumnName(*s.preparedStmt, mapping.IdxT(n)) - // C API returns nullptr for out-of-range indices - if name == "" { - return "", nil - } return name, nil } diff --git a/statement_test.go b/statement_test.go index 88a3695e..88dc6184 100644 --- a/statement_test.go +++ b/statement_test.go @@ -134,7 +134,7 @@ func TestPrepareQuery(t *testing.T) { _, innerErr = stmt.ColumnName(0) require.ErrorIs(t, innerErr, errClosedStmt) - _, innerErr = stmt.ColumnType(0) + _, innerErr = stmt.ColumnTypeInfo(0) require.ErrorIs(t, innerErr, errClosedStmt) innerErr = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}}) @@ -255,12 +255,11 @@ func TestPrepareQueryPositional(t *testing.T) { require.ErrorContains(t, innerErr, paramIndexErrMsg) require.Equal(t, TYPE_INVALID, paramType) - // Test column methods for UPDATE statement (should have no columns) + // Test column methods for UPDATE statement (should have a single Count column) columnCount, innerErr := stmt.ColumnCount() require.NoError(t, innerErr) require.Equal(t, 1, columnCount) - // Test out of bounds access - should return empty/invalid for UPDATE with no columns colName, innerErr := stmt.ColumnName(0) require.NoError(t, innerErr) require.Equal(t, "Count", colName) @@ -269,6 +268,10 @@ func TestPrepareQueryPositional(t *testing.T) { require.NoError(t, innerErr) require.Equal(t, TYPE_BIGINT, colType) + colTypeInfo, innerErr := stmt.ColumnTypeInfo(0) + require.NoError(t, innerErr) + require.Equal(t, TYPE_BIGINT, colTypeInfo.InternalType()) + r, innerErr := stmt.ExecBound(context.Background()) require.Nil(t, r) require.ErrorIs(t, innerErr, errNotBound) @@ -1254,33 +1257,6 @@ func TestPreparedStatementColumnMethods(t *testing.T) { return nil }) require.NoError(t, err) - - // Test with closed statement - err = conn.Raw(func(driverConn any) error { - innerConn := driverConn.(*Conn) - s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT * FROM test_columns`) - require.NoError(t, innerErr) - stmt := s.(*Stmt) - - // Close the statement - require.NoError(t, stmt.Close()) - - // Test methods on closed statement - _, innerErr = stmt.ColumnCount() - require.ErrorIs(t, innerErr, errClosedStmt) - - _, innerErr = stmt.ColumnName(0) - require.ErrorIs(t, innerErr, errClosedStmt) - - _, innerErr = stmt.ColumnType(0) - require.ErrorIs(t, innerErr, errClosedStmt) - - _, innerErr = stmt.ColumnTypeInfo(0) - require.ErrorIs(t, innerErr, errClosedStmt) - - return nil - }) - require.NoError(t, err) } func TestPreparedStatementColumnTypeInfo(t *testing.T) { @@ -1559,31 +1535,4 @@ func TestPreparedStatementAmbiguousColumnTypes(t *testing.T) { return nil }) require.NoError(t, err) - - // Test 4: Statement with no ambiguous types should work normally - err = conn.Raw(func(driverConn any) error { - innerConn := driverConn.(*Conn) - - s, innerErr := innerConn.PrepareContext(context.Background(), `SELECT id, value FROM test_mixed`) - require.NoError(t, innerErr) - stmt := s.(*Stmt) - - // Normal count when no ambiguous types - count, innerErr := stmt.ColumnCount() - require.NoError(t, innerErr) - require.Equal(t, 2, count) - - // Column types are resolved - colType, innerErr := stmt.ColumnType(0) - require.NoError(t, innerErr) - require.Equal(t, TYPE_INTEGER, colType) - - colType, innerErr = stmt.ColumnType(1) - require.NoError(t, innerErr) - require.Equal(t, TYPE_VARCHAR, colType) - - require.NoError(t, stmt.Close()) - return nil - }) - require.NoError(t, err) } diff --git a/type_info.go b/type_info.go index b9fecc92..2aaadc10 100644 --- a/type_info.go +++ b/type_info.go @@ -479,11 +479,11 @@ func (info *typeInfo) logicalUnionType() mapping.LogicalType { return mapping.CreateUnionType(types, info.names) } -// NewTypeInfoFromLogicalType converts a mapping.LogicalType to TypeInfo. +// newTypeInfoFromLogicalType converts a mapping.LogicalType to TypeInfo. // This allows inspecting types returned from prepared statements. // The LogicalType must remain valid for the duration of this call. // The returned TypeInfo does not hold a reference to the LogicalType. -func NewTypeInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { +func newTypeInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { t := mapping.GetTypeId(lt) switch t { @@ -531,7 +531,7 @@ func newListInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { childLT := mapping.ListTypeChildType(lt) defer mapping.DestroyLogicalType(&childLT) - childInfo, err := NewTypeInfoFromLogicalType(childLT) + childInfo, err := newTypeInfoFromLogicalType(childLT) if err != nil { return nil, err } @@ -543,7 +543,7 @@ func newArrayInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { childLT := mapping.ArrayTypeChildType(lt) defer mapping.DestroyLogicalType(&childLT) - childInfo, err := NewTypeInfoFromLogicalType(childLT) + childInfo, err := newTypeInfoFromLogicalType(childLT) if err != nil { return nil, err } @@ -559,12 +559,12 @@ func newMapInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { valueLT := mapping.MapTypeValueType(lt) defer mapping.DestroyLogicalType(&valueLT) - keyInfo, err := NewTypeInfoFromLogicalType(keyLT) + keyInfo, err := newTypeInfoFromLogicalType(keyLT) if err != nil { return nil, err } - valueInfo, err := NewTypeInfoFromLogicalType(valueLT) + valueInfo, err := newTypeInfoFromLogicalType(valueLT) if err != nil { return nil, err } @@ -583,7 +583,7 @@ func newStructInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { name := mapping.StructTypeChildName(lt, i) childLT := mapping.StructTypeChildType(lt, i) - childInfo, err := NewTypeInfoFromLogicalType(childLT) + childInfo, err := newTypeInfoFromLogicalType(childLT) mapping.DestroyLogicalType(&childLT) if err != nil { return nil, err @@ -612,7 +612,7 @@ func newUnionInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { memberNames[i] = mapping.UnionTypeMemberName(lt, i) memberLT := mapping.UnionTypeMemberType(lt, i) - memberInfo, err := NewTypeInfoFromLogicalType(memberLT) + memberInfo, err := newTypeInfoFromLogicalType(memberLT) mapping.DestroyLogicalType(&memberLT) if err != nil { return nil, err diff --git a/type_info_test.go b/type_info_test.go index 5ee6fabe..cf1f8176 100644 --- a/type_info_test.go +++ b/type_info_test.go @@ -347,7 +347,6 @@ func TestErrTypeInfo(t *testing.T) { } func TestNewTypeInfoFromLogicalType(t *testing.T) { - // Test primitive types primitiveTests := []Type{ TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INTEGER, TYPE_BIGINT, TYPE_UTINYINT, TYPE_USMALLINT, TYPE_UINTEGER, TYPE_UBIGINT, @@ -367,240 +366,232 @@ func TestNewTypeInfoFromLogicalType(t *testing.T) { defer mapping.DestroyLogicalType(<) // Convert back to TypeInfo - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) require.NoError(t, err) require.Equal(t, primitiveType, reconstructedInfo.InternalType()) }) } +} - // Test DECIMAL type - t.Run("DECIMAL", func(t *testing.T) { - originalInfo, err := NewDecimalInfo(10, 3) - require.NoError(t, err) - - lt := originalInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(<) - - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_DECIMAL, reconstructedInfo.InternalType()) +func TestNewTypeInfoFromLogicalTypeDecimal(t *testing.T) { + originalInfo, err := NewDecimalInfo(10, 3) + require.NoError(t, err) - // Verify we can convert back and get the same logical type - reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(&reconstructedLT) - require.Equal(t, uint8(10), mapping.DecimalWidth(reconstructedLT)) - require.Equal(t, uint8(3), mapping.DecimalScale(reconstructedLT)) - }) + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) - // Test ENUM type - t.Run("ENUM", func(t *testing.T) { - originalInfo, err := NewEnumInfo("red", "green", "blue") - require.NoError(t, err) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_DECIMAL, reconstructedInfo.InternalType()) - lt := originalInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(<) + // Verify we can convert back and get the same logical type + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, uint8(10), mapping.DecimalWidth(reconstructedLT)) + require.Equal(t, uint8(3), mapping.DecimalScale(reconstructedLT)) +} - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_ENUM, reconstructedInfo.InternalType()) - - // Verify enum values - reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(&reconstructedLT) - require.Equal(t, uint32(3), mapping.EnumDictionarySize(reconstructedLT)) - require.Equal(t, "red", mapping.EnumDictionaryValue(reconstructedLT, 0)) - require.Equal(t, "green", mapping.EnumDictionaryValue(reconstructedLT, 1)) - require.Equal(t, "blue", mapping.EnumDictionaryValue(reconstructedLT, 2)) - }) +func TestNewTypeInfoFromLogicalTypeEnum(t *testing.T) { + originalInfo, err := NewEnumInfo("red", "green", "blue") + require.NoError(t, err) - // Test LIST type - t.Run("LIST", func(t *testing.T) { - intInfo, err := NewTypeInfo(TYPE_INTEGER) - require.NoError(t, err) + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) - originalInfo, err := NewListInfo(intInfo) - require.NoError(t, err) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_ENUM, reconstructedInfo.InternalType()) + + // Verify enum values + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, uint32(3), mapping.EnumDictionarySize(reconstructedLT)) + require.Equal(t, "red", mapping.EnumDictionaryValue(reconstructedLT, 0)) + require.Equal(t, "green", mapping.EnumDictionaryValue(reconstructedLT, 1)) + require.Equal(t, "blue", mapping.EnumDictionaryValue(reconstructedLT, 2)) +} - lt := originalInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(<) +func TestNewTypeInfoFromLogicalTypeList(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_LIST, reconstructedInfo.InternalType()) - - // Verify child type - reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(&reconstructedLT) - childLT := mapping.ListTypeChildType(reconstructedLT) - defer mapping.DestroyLogicalType(&childLT) - require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(childLT)) - }) + originalInfo, err := NewListInfo(intInfo) + require.NoError(t, err) - // Test ARRAY type - t.Run("ARRAY", func(t *testing.T) { - varcharInfo, err := NewTypeInfo(TYPE_VARCHAR) - require.NoError(t, err) + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) - originalInfo, err := NewArrayInfo(varcharInfo, 5) - require.NoError(t, err) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_LIST, reconstructedInfo.InternalType()) + + // Verify child type + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + childLT := mapping.ListTypeChildType(reconstructedLT) + defer mapping.DestroyLogicalType(&childLT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(childLT)) +} - lt := originalInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(<) +func TestNewTypeInfoFromLogicalTypeArray(t *testing.T) { + varcharInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_ARRAY, reconstructedInfo.InternalType()) - - // Verify child type and size - reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(&reconstructedLT) - childLT := mapping.ArrayTypeChildType(reconstructedLT) - defer mapping.DestroyLogicalType(&childLT) - require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(childLT)) - require.Equal(t, mapping.IdxT(5), mapping.ArrayTypeArraySize(reconstructedLT)) - }) + originalInfo, err := NewArrayInfo(varcharInfo, 5) + require.NoError(t, err) - // Test MAP type - t.Run("MAP", func(t *testing.T) { - keyInfo, err := NewTypeInfo(TYPE_INTEGER) - require.NoError(t, err) - valueInfo, err := NewTypeInfo(TYPE_VARCHAR) - require.NoError(t, err) + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) - originalInfo, err := NewMapInfo(keyInfo, valueInfo) - require.NoError(t, err) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_ARRAY, reconstructedInfo.InternalType()) + + // Verify child type and size + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + childLT := mapping.ArrayTypeChildType(reconstructedLT) + defer mapping.DestroyLogicalType(&childLT) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(childLT)) + require.Equal(t, mapping.IdxT(5), mapping.ArrayTypeArraySize(reconstructedLT)) +} - lt := originalInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(<) +func TestNewTypeInfoFromLogicalTypeMap(t *testing.T) { + keyInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + valueInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_MAP, reconstructedInfo.InternalType()) - - // Verify key and value types - reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(&reconstructedLT) - keyLT := mapping.MapTypeKeyType(reconstructedLT) - defer mapping.DestroyLogicalType(&keyLT) - valueLT := mapping.MapTypeValueType(reconstructedLT) - defer mapping.DestroyLogicalType(&valueLT) - require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(keyLT)) - require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(valueLT)) - }) + originalInfo, err := NewMapInfo(keyInfo, valueInfo) + require.NoError(t, err) - // Test STRUCT type - t.Run("STRUCT", func(t *testing.T) { - intInfo, err := NewTypeInfo(TYPE_INTEGER) - require.NoError(t, err) - strInfo, err := NewTypeInfo(TYPE_VARCHAR) - require.NoError(t, err) + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) - entry1, err := NewStructEntry(intInfo, "id") - require.NoError(t, err) - entry2, err := NewStructEntry(strInfo, "name") - require.NoError(t, err) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_MAP, reconstructedInfo.InternalType()) + + // Verify key and value types + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + keyLT := mapping.MapTypeKeyType(reconstructedLT) + defer mapping.DestroyLogicalType(&keyLT) + valueLT := mapping.MapTypeValueType(reconstructedLT) + defer mapping.DestroyLogicalType(&valueLT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(keyLT)) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(valueLT)) +} - originalInfo, err := NewStructInfo(entry1, entry2) - require.NoError(t, err) +func TestNewTypeInfoFromLogicalTypeStruct(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) - lt := originalInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(<) + entry1, err := NewStructEntry(intInfo, "id") + require.NoError(t, err) + entry2, err := NewStructEntry(strInfo, "name") + require.NoError(t, err) - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_STRUCT, reconstructedInfo.InternalType()) - - // Verify struct fields - reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(&reconstructedLT) - require.Equal(t, mapping.IdxT(2), mapping.StructTypeChildCount(reconstructedLT)) - require.Equal(t, "id", mapping.StructTypeChildName(reconstructedLT, 0)) - require.Equal(t, "name", mapping.StructTypeChildName(reconstructedLT, 1)) - - child0LT := mapping.StructTypeChildType(reconstructedLT, 0) - defer mapping.DestroyLogicalType(&child0LT) - require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(child0LT)) - - child1LT := mapping.StructTypeChildType(reconstructedLT, 1) - defer mapping.DestroyLogicalType(&child1LT) - require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(child1LT)) - }) + originalInfo, err := NewStructInfo(entry1, entry2) + require.NoError(t, err) - // Test UNION type - t.Run("UNION", func(t *testing.T) { - intInfo, err := NewTypeInfo(TYPE_INTEGER) - require.NoError(t, err) - strInfo, err := NewTypeInfo(TYPE_VARCHAR) - require.NoError(t, err) + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) - originalInfo, err := NewUnionInfo( - []TypeInfo{intInfo, strInfo}, - []string{"num", "text"}, - ) - require.NoError(t, err) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_STRUCT, reconstructedInfo.InternalType()) + + // Verify struct fields + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, mapping.IdxT(2), mapping.StructTypeChildCount(reconstructedLT)) + require.Equal(t, "id", mapping.StructTypeChildName(reconstructedLT, 0)) + require.Equal(t, "name", mapping.StructTypeChildName(reconstructedLT, 1)) + + child0LT := mapping.StructTypeChildType(reconstructedLT, 0) + defer mapping.DestroyLogicalType(&child0LT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(child0LT)) + + child1LT := mapping.StructTypeChildType(reconstructedLT, 1) + defer mapping.DestroyLogicalType(&child1LT) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(child1LT)) +} - lt := originalInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(<) +func TestNewTypeInfoFromLogicalTypeUnion(t *testing.T) { + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_UNION, reconstructedInfo.InternalType()) - - // Verify union members - reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() - defer mapping.DestroyLogicalType(&reconstructedLT) - require.Equal(t, mapping.IdxT(2), mapping.UnionTypeMemberCount(reconstructedLT)) - require.Equal(t, "num", mapping.UnionTypeMemberName(reconstructedLT, 0)) - require.Equal(t, "text", mapping.UnionTypeMemberName(reconstructedLT, 1)) - - member0LT := mapping.UnionTypeMemberType(reconstructedLT, 0) - defer mapping.DestroyLogicalType(&member0LT) - require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(member0LT)) - - member1LT := mapping.UnionTypeMemberType(reconstructedLT, 1) - defer mapping.DestroyLogicalType(&member1LT) - require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(member1LT)) - }) + originalInfo, err := NewUnionInfo( + []TypeInfo{intInfo, strInfo}, + []string{"num", "text"}, + ) + require.NoError(t, err) - // Test nested complex types - t.Run("NestedTypes", func(t *testing.T) { - // Create LIST of STRUCTs - intInfo, err := NewTypeInfo(TYPE_INTEGER) - require.NoError(t, err) - strInfo, err := NewTypeInfo(TYPE_VARCHAR) - require.NoError(t, err) + lt := originalInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(<) - entry1, err := NewStructEntry(intInfo, "id") - require.NoError(t, err) - entry2, err := NewStructEntry(strInfo, "name") - require.NoError(t, err) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_UNION, reconstructedInfo.InternalType()) + + // Verify union members + reconstructedLT := reconstructedInfo.(*typeInfo).logicalType() + defer mapping.DestroyLogicalType(&reconstructedLT) + require.Equal(t, mapping.IdxT(2), mapping.UnionTypeMemberCount(reconstructedLT)) + require.Equal(t, "num", mapping.UnionTypeMemberName(reconstructedLT, 0)) + require.Equal(t, "text", mapping.UnionTypeMemberName(reconstructedLT, 1)) + + member0LT := mapping.UnionTypeMemberType(reconstructedLT, 0) + defer mapping.DestroyLogicalType(&member0LT) + require.Equal(t, TYPE_INTEGER, mapping.GetTypeId(member0LT)) + + member1LT := mapping.UnionTypeMemberType(reconstructedLT, 1) + defer mapping.DestroyLogicalType(&member1LT) + require.Equal(t, TYPE_VARCHAR, mapping.GetTypeId(member1LT)) +} - structInfo, err := NewStructInfo(entry1, entry2) - require.NoError(t, err) +func TestNewTypeInfoFromLogicalTypeNested(t *testing.T) { + // Create LIST of STRUCTs + intInfo, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + strInfo, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) - listInfo, err := NewListInfo(structInfo) - require.NoError(t, err) + entry1, err := NewStructEntry(intInfo, "id") + require.NoError(t, err) + entry2, err := NewStructEntry(strInfo, "name") + require.NoError(t, err) - lt := listInfo.logicalType() - defer mapping.DestroyLogicalType(<) + structInfo, err := NewStructInfo(entry1, entry2) + require.NoError(t, err) - reconstructedInfo, err := NewTypeInfoFromLogicalType(lt) - require.NoError(t, err) - require.Equal(t, TYPE_LIST, reconstructedInfo.InternalType()) + listInfo, err := NewListInfo(structInfo) + require.NoError(t, err) - details := reconstructedInfo.Details() - listDetails, ok := details.(*ListDetails) - require.True(t, ok) - require.Equal(t, TYPE_STRUCT, listDetails.Child.InternalType()) + lt := listInfo.logicalType() + defer mapping.DestroyLogicalType(<) - structDetails, ok := listDetails.Child.Details().(*StructDetails) - require.True(t, ok) - require.Len(t, structDetails.Entries, 2) - require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) - require.Equal(t, "id", structDetails.Entries[0].Name()) - require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) - require.Equal(t, "name", structDetails.Entries[1].Name()) - }) + reconstructedInfo, err := newTypeInfoFromLogicalType(lt) + require.NoError(t, err) + require.Equal(t, TYPE_LIST, reconstructedInfo.InternalType()) + + details := reconstructedInfo.Details() + listDetails, ok := details.(*ListDetails) + require.True(t, ok) + require.Equal(t, TYPE_STRUCT, listDetails.Child.InternalType()) + + structDetails, ok := listDetails.Child.Details().(*StructDetails) + require.True(t, ok) + require.Len(t, structDetails.Entries, 2) + require.Equal(t, TYPE_INTEGER, structDetails.Entries[0].Info().InternalType()) + require.Equal(t, "id", structDetails.Entries[0].Name()) + require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) + require.Equal(t, "name", structDetails.Entries[1].Name()) } func TestTypeInfoDetails(t *testing.T) { @@ -617,7 +608,6 @@ func TestTypeInfoDetails(t *testing.T) { } }) - // Test DECIMAL details t.Run("DecimalDetails", func(t *testing.T) { info, err := NewDecimalInfo(10, 3) require.NoError(t, err) @@ -631,7 +621,6 @@ func TestTypeInfoDetails(t *testing.T) { require.Equal(t, uint8(3), decimalDetails.Scale) }) - // Test ENUM details t.Run("EnumDetails", func(t *testing.T) { info, err := NewEnumInfo("red", "green", "blue") require.NoError(t, err) @@ -651,7 +640,6 @@ func TestTypeInfoDetails(t *testing.T) { require.Equal(t, "red", enumDetails2.Values[0]) }) - // Test LIST details t.Run("ListDetails", func(t *testing.T) { intInfo, err := NewTypeInfo(TYPE_INTEGER) require.NoError(t, err) @@ -667,7 +655,6 @@ func TestTypeInfoDetails(t *testing.T) { require.Equal(t, TYPE_INTEGER, listDetails.Child.InternalType()) }) - // Test ARRAY details t.Run("ArrayDetails", func(t *testing.T) { varcharInfo, err := NewTypeInfo(TYPE_VARCHAR) require.NoError(t, err) @@ -703,7 +690,6 @@ func TestTypeInfoDetails(t *testing.T) { require.Equal(t, TYPE_VARCHAR, mapDetails.Value.InternalType()) }) - // Test STRUCT details t.Run("StructDetails", func(t *testing.T) { intInfo, err := NewTypeInfo(TYPE_INTEGER) require.NoError(t, err) @@ -730,7 +716,6 @@ func TestTypeInfoDetails(t *testing.T) { require.Equal(t, TYPE_VARCHAR, structDetails.Entries[1].Info().InternalType()) }) - // Test UNION details t.Run("UnionDetails", func(t *testing.T) { intInfo, err := NewTypeInfo(TYPE_INTEGER) require.NoError(t, err) @@ -764,7 +749,6 @@ func TestTypeInfoDetails(t *testing.T) { require.Equal(t, "num", unionDetails2.Members[0].Name) }) - // Test nested type details t.Run("NestedTypeDetails", func(t *testing.T) { // Create a LIST of STRUCTs intInfo, err := NewTypeInfo(TYPE_INTEGER) From 55625e477c44073e0828aeab56fb55063a5bd156 Mon Sep 17 00:00:00 2001 From: taniabogatsch <44262898+taniabogatsch@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:13:37 +0100 Subject: [PATCH 13/14] linter --- type_info.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/type_info.go b/type_info.go index 2aaadc10..3200f8e2 100644 --- a/type_info.go +++ b/type_info.go @@ -579,9 +579,9 @@ func newStructInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { } entries := make([]StructEntry, count) - for i := mapping.IdxT(0); i < count; i++ { - name := mapping.StructTypeChildName(lt, i) - childLT := mapping.StructTypeChildType(lt, i) + for i := range int(count) { + name := mapping.StructTypeChildName(lt, mapping.IdxT(i)) + childLT := mapping.StructTypeChildType(lt, mapping.IdxT(i)) childInfo, err := newTypeInfoFromLogicalType(childLT) mapping.DestroyLogicalType(&childLT) @@ -608,9 +608,9 @@ func newUnionInfoFromLogicalType(lt mapping.LogicalType) (TypeInfo, error) { memberTypes := make([]TypeInfo, count) memberNames := make([]string, count) - for i := mapping.IdxT(0); i < count; i++ { - memberNames[i] = mapping.UnionTypeMemberName(lt, i) - memberLT := mapping.UnionTypeMemberType(lt, i) + for i := range int(count) { + memberNames[i] = mapping.UnionTypeMemberName(lt, mapping.IdxT(i)) + memberLT := mapping.UnionTypeMemberType(lt, mapping.IdxT(i)) memberInfo, err := newTypeInfoFromLogicalType(memberLT) mapping.DestroyLogicalType(&memberLT) From 249ff552a8e48df814e586fd0b30dc888a6b8bf0 Mon Sep 17 00:00:00 2001 From: taniabogatsch <44262898+taniabogatsch@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:54:30 +0100 Subject: [PATCH 14/14] merge main formatting --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index f5a00a6d..3d43d00d 100644 --- a/errors.go +++ b/errors.go @@ -86,7 +86,7 @@ const ( interfaceIsNilErrMsg = "interface is nil" duplicateNameErrMsg = "duplicate name" paramIndexErrMsg = "invalid parameter index" - columnIndexErrMsg = "invalid column index" + columnIndexErrMsg = "invalid column index" ) var (