From 0ba0b75d652a0fba7001504117583775c33c9d83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BA=E8=BE=BE?= <130376214+ddh-5230@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:43:19 +0800 Subject: [PATCH] feat:migrate to duckdb arrow for query execution (#334) --- compatibility/flightsql/go/flightsql_test.go | 28 +++- .../flightsql/python/flightsql_test.py | 2 +- flightsqlserver/sqlite_server.go | 135 +++++++++--------- flightsqltest/driver_test.go | 1 + 4 files changed, 92 insertions(+), 74 deletions(-) diff --git a/compatibility/flightsql/go/flightsql_test.go b/compatibility/flightsql/go/flightsql_test.go index 2171a4c..e41a145 100644 --- a/compatibility/flightsql/go/flightsql_test.go +++ b/compatibility/flightsql/go/flightsql_test.go @@ -87,18 +87,36 @@ func executeQueryAndVerify(cnxn adbc.Connection, query string, expectedResults [ record := rows.Record() numRows := record.NumRows() - id := record.Column(0).(*array.Int64) - name := record.Column(1).(*array.String) - value := record.Column(2).(*array.Int64) for i := 0; i < int(numRows); i++ { + var id, value int64 + switch idCol := record.Column(0).(type) { + case *array.Int32: + id = int64(idCol.Value(i)) + case *array.Int64: + id = idCol.Value(i) + default: + t.Fatalf("unexpected type for id column: %T", record.Column(0)) + } + + name := record.Column(1).(*array.String) + + switch valueCol := record.Column(2).(type) { + case *array.Int32: + value = int64(valueCol.Value(i)) + case *array.Int64: + value = valueCol.Value(i) + default: + t.Fatalf("unexpected type for value column: %T", record.Column(2)) + } + actualResults = append(actualResults, struct { id int64 name string value int64 }{ - id: id.Value(i), + id: id, name: name.Value(i), - value: value.Value(i), + value: value, }) } } diff --git a/compatibility/flightsql/python/flightsql_test.py b/compatibility/flightsql/python/flightsql_test.py index 4d1c7d0..576f486 100644 --- a/compatibility/flightsql/python/flightsql_test.py +++ b/compatibility/flightsql/python/flightsql_test.py @@ -32,6 +32,7 @@ def setUp(self): value INT ) """) # Create the table + self.conn.commit() def test_insert_and_select(self): """Test inserting data and selecting it back to verify correctness.""" @@ -55,7 +56,6 @@ def test_drop_table(self): cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='intTable'") # Check if the table exists rows = cursor.fetchall() self.assertEqual(len(rows), 0, "Table 'intTable' should be dropped and not exist in the database.") - cursor.execute("COMMIT;") if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/flightsqlserver/sqlite_server.go b/flightsqlserver/sqlite_server.go index 4669b1a..927dcd1 100644 --- a/flightsqlserver/sqlite_server.go +++ b/flightsqlserver/sqlite_server.go @@ -55,10 +55,8 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/arrow/scalar" "github.com/marcboeker/go-duckdb" - - "google.golang.org/grpc" + _ "github.com/marcboeker/go-duckdb" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" _ "modernc.org/sqlite" ) @@ -210,12 +208,14 @@ func decodeTransactionQuery(ticket []byte) (txnID, query string, err error) { type Statement struct { stmt *sql.Stmt + query string params [][]interface{} } type SQLiteFlightSQLServer struct { flightsql.BaseServer - db *sql.DB + db *sql.DB + conn *duckdb.Conn prepared sync.Map openTransactions sync.Map @@ -260,17 +260,20 @@ func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd flightsq if err != nil { return nil, nil, err } - - var db dbQueryCtx = s.db if txnid != "" { - tx, loaded := s.openTransactions.Load(txnid) - if !loaded { - return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid) - } - db = tx.(*sql.Tx) + return nil, nil, fmt.Errorf("transactions not yet supported with DuckDB") } - return doGetQuery(ctx, s.Alloc, db, query, nil) + // var db dbQueryCtx = s.db + // if txnid != "" { + // tx, loaded := s.openTransactions.Load(txnid) + // if !loaded { + // return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid) + // } + // db = tx.(*sql.Tx) + // } + + return doGetQuery(ctx, s.db, query, nil) } func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -345,14 +348,20 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTables(_ context.Context, cmd fligh func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { query := prepareQueryForGetTables(cmd) - rows, err := s.db.QueryContext(ctx, query) + conn, err := s.db.Conn(ctx) + var duckConn *duckdb.Conn + err = conn.Raw(func(driverConn any) error { + duckConn = driverConn.(*duckdb.Conn) + return nil + }) if err != nil { return nil, nil, err } - - var rdr array.RecordReader - - rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema_ref.Tables, rows) + arrow, err := duckdb.NewArrowFromConn(duckConn) + if err != nil { + return nil, nil, err + } + rdr, err := arrow.QueryContext(ctx, query) if err != nil { return nil, nil, err } @@ -394,7 +403,7 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTableTypes(_ context.Context, desc func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { query := "SELECT DISTINCT type AS table_type FROM sqlite_master" - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.TableTypes) + return doGetQuery(ctx, s.db, query, schema_ref.TableTypes) } func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) { @@ -422,6 +431,7 @@ func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) { var stmt *sql.Stmt + query := req.GetQuery() if len(req.GetTransactionId()) > 0 { tx, loaded := s.openTransactions.Load(string(req.GetTransactionId())) @@ -438,7 +448,10 @@ func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req } handle := genRandomString() - s.prepared.Store(string(handle), Statement{stmt: stmt}) + s.prepared.Store(string(handle), Statement{ + stmt: stmt, + query: query, + }) result.Handle = handle // no way to get the dataset or parameter schemas from sql.DB @@ -473,29 +486,23 @@ type dbQueryCtx interface { QueryContext(context.Context, string, ...any) (*sql.Rows, error) } -func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) { - rows, err := db.QueryContext(ctx, query, args...) +func doGetQuery(ctx context.Context, db *sql.DB, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) { + + conn, err := db.Conn(ctx) + var duckConn *duckdb.Conn + err = conn.Raw(func(driverConn any) error { + duckConn = driverConn.(*duckdb.Conn) + return nil + }) + arrow, err := duckdb.NewArrowFromConn(duckConn) if err != nil { - // Not really useful except for testing Flight SQL clients - trailers := metadata.Pairs("afsql-sqlite-query", query) - grpc.SetTrailer(ctx, trailers) return nil, nil, err } - - var rdr *SqlBatchReader - if schema != nil { - rdr, err = NewSqlBatchReaderWithSchema(mem, schema, rows) - } else { - rdr, err = NewSqlBatchReader(mem, rows) - if err == nil { - schema = rdr.schema - } - } - + rdr, err := arrow.QueryContext(ctx, query, args...) if err != nil { return nil, nil, err } - + schema = rdr.Schema() ch := make(chan flight.StreamChunk) go flight.StreamChunksFromReader(rdr, ch) return schema, ch, nil @@ -503,24 +510,34 @@ func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle())) + if !ok { return nil, nil, status.Error(codes.InvalidArgument, "prepared statement not found") } + conn, err := s.db.Conn(ctx) + var duckConn *duckdb.Conn + err = conn.Raw(func(driverConn any) error { + duckConn = driverConn.(*duckdb.Conn) + return nil + }) + if err != nil { + return nil, nil, err + } + stmt := val.(Statement) + arrow, err := duckdb.NewArrowFromConn(duckConn) + if err != nil { + return nil, nil, err + } + readers := make([]array.RecordReader, 0, len(stmt.params)) if len(stmt.params) == 0 { - rows, err := stmt.stmt.QueryContext(ctx) + rdr, err := arrow.QueryContext(ctx, stmt.query) if err != nil { return nil, nil, err } - - rdr, err := NewSqlBatchReader(s.Alloc, rows) - if err != nil { - return nil, nil, err - } - - schema = rdr.schema + schema = rdr.Schema() readers = append(readers, rdr) } else { defer func() { @@ -530,35 +547,17 @@ func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd } } }() - var ( - rows *sql.Rows - rdr *SqlBatchReader - ) // if we have multiple rows of bound params, execute the query // multiple times and concatenate the result sets. for _, p := range stmt.params { - rows, err = stmt.stmt.QueryContext(ctx, p...) + rdr, err := arrow.QueryContext(ctx, stmt.query, p...) if err != nil { return nil, nil, err } - - if schema == nil { - rdr, err = NewSqlBatchReader(s.Alloc, rows) - if err != nil { - return nil, nil, err - } - schema = rdr.schema - } else { - rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema, rows) - if err != nil { - return nil, nil, err - } - } - + schema = rdr.Schema() readers = append(readers, rdr) } } - ch := make(chan flight.StreamChunk) go flight.ConcatenateReaders(readers, ch) out = ch @@ -715,7 +714,7 @@ func (s *SQLiteFlightSQLServer) DoGetPrimaryKeys(ctx context.Context, cmd flight fmt.Fprintf(&b, " and table_name LIKE '%s'", cmd.Table) - return doGetQuery(ctx, s.Alloc, s.db, b.String(), schema_ref.PrimaryKeys) + return doGetQuery(ctx, s.db, b.String(), schema_ref.PrimaryKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoImportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -731,7 +730,7 @@ func (s *SQLiteFlightSQLServer) DoGetImportedKeys(ctx context.Context, ref fligh filter += " AND fk_schema_name = '" + *ref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ImportedKeys) + return doGetQuery(ctx, s.db, query, schema_ref.ImportedKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoExportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -747,7 +746,7 @@ func (s *SQLiteFlightSQLServer) DoGetExportedKeys(ctx context.Context, ref fligh filter += " AND pk_schema_name = '" + *ref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys) + return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys) } func (s *SQLiteFlightSQLServer) GetFlightInfoCrossReference(_ context.Context, _ flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { @@ -773,7 +772,7 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx context.Context, cmd fli filter += " AND fk_schema_name = '" + *fkref.DBSchema + "'" } query := prepareQueryForGetKeys(filter) - return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys) + return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys) } func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req flightsql.ActionBeginTransactionRequest) (id []byte, err error) { diff --git a/flightsqltest/driver_test.go b/flightsqltest/driver_test.go index 8a79504..78ccfea 100644 --- a/flightsqltest/driver_test.go +++ b/flightsqltest/driver_test.go @@ -96,6 +96,7 @@ func (s *SqlTestSuite) SetupSuite() { if err != nil { return nil, "", err } + sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage()) if err != nil { return nil, "", err