Skip to content

Commit

Permalink
feat:migrate to duckdb arrow for query execution (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
ddh-5230 authored Jan 6, 2025
1 parent 18113cf commit 0ba0b75
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 74 deletions.
28 changes: 23 additions & 5 deletions compatibility/flightsql/go/flightsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,36 @@ func executeQueryAndVerify(cnxn adbc.Connection, query string, expectedResults [
record := rows.Record()
numRows := record.NumRows()

id := record.Column(0).(*array.Int64)
name := record.Column(1).(*array.String)
value := record.Column(2).(*array.Int64)
for i := 0; i < int(numRows); i++ {
var id, value int64
switch idCol := record.Column(0).(type) {
case *array.Int32:
id = int64(idCol.Value(i))
case *array.Int64:
id = idCol.Value(i)
default:
t.Fatalf("unexpected type for id column: %T", record.Column(0))
}

name := record.Column(1).(*array.String)

switch valueCol := record.Column(2).(type) {
case *array.Int32:
value = int64(valueCol.Value(i))
case *array.Int64:
value = valueCol.Value(i)
default:
t.Fatalf("unexpected type for value column: %T", record.Column(2))
}

actualResults = append(actualResults, struct {
id int64
name string
value int64
}{
id: id.Value(i),
id: id,
name: name.Value(i),
value: value.Value(i),
value: value,
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion compatibility/flightsql/python/flightsql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def setUp(self):
value INT
)
""") # Create the table
self.conn.commit()

def test_insert_and_select(self):
"""Test inserting data and selecting it back to verify correctness."""
Expand All @@ -55,7 +56,6 @@ def test_drop_table(self):
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='intTable'") # Check if the table exists
rows = cursor.fetchall()
self.assertEqual(len(rows), 0, "Table 'intTable' should be dropped and not exist in the database.")
cursor.execute("COMMIT;")

if __name__ == "__main__":
unittest.main()
135 changes: 67 additions & 68 deletions flightsqlserver/sqlite_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ import (
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/apache/arrow-go/v18/arrow/scalar"
"github.com/marcboeker/go-duckdb"

"google.golang.org/grpc"
_ "github.com/marcboeker/go-duckdb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
_ "modernc.org/sqlite"
)
Expand Down Expand Up @@ -210,12 +208,14 @@ func decodeTransactionQuery(ticket []byte) (txnID, query string, err error) {

type Statement struct {
stmt *sql.Stmt
query string
params [][]interface{}
}

type SQLiteFlightSQLServer struct {
flightsql.BaseServer
db *sql.DB
db *sql.DB
conn *duckdb.Conn

prepared sync.Map
openTransactions sync.Map
Expand Down Expand Up @@ -260,17 +260,20 @@ func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd flightsq
if err != nil {
return nil, nil, err
}

var db dbQueryCtx = s.db
if txnid != "" {
tx, loaded := s.openTransactions.Load(txnid)
if !loaded {
return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid)
}
db = tx.(*sql.Tx)
return nil, nil, fmt.Errorf("transactions not yet supported with DuckDB")
}

return doGetQuery(ctx, s.Alloc, db, query, nil)
// var db dbQueryCtx = s.db
// if txnid != "" {
// tx, loaded := s.openTransactions.Load(txnid)
// if !loaded {
// return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid)
// }
// db = tx.(*sql.Tx)
// }

return doGetQuery(ctx, s.db, query, nil)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand Down Expand Up @@ -345,14 +348,20 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTables(_ context.Context, cmd fligh
func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
query := prepareQueryForGetTables(cmd)

rows, err := s.db.QueryContext(ctx, query)
conn, err := s.db.Conn(ctx)
var duckConn *duckdb.Conn
err = conn.Raw(func(driverConn any) error {
duckConn = driverConn.(*duckdb.Conn)
return nil
})
if err != nil {
return nil, nil, err
}

var rdr array.RecordReader

rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema_ref.Tables, rows)
arrow, err := duckdb.NewArrowFromConn(duckConn)
if err != nil {
return nil, nil, err
}
rdr, err := arrow.QueryContext(ctx, query)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -394,7 +403,7 @@ func (s *SQLiteFlightSQLServer) GetFlightInfoTableTypes(_ context.Context, desc

func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) {
query := "SELECT DISTINCT type AS table_type FROM sqlite_master"
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.TableTypes)
return doGetQuery(ctx, s.db, query, schema_ref.TableTypes)
}

func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) {
Expand Down Expand Up @@ -422,6 +431,7 @@ func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context,

func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) {
var stmt *sql.Stmt
query := req.GetQuery()

if len(req.GetTransactionId()) > 0 {
tx, loaded := s.openTransactions.Load(string(req.GetTransactionId()))
Expand All @@ -438,7 +448,10 @@ func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req
}

handle := genRandomString()
s.prepared.Store(string(handle), Statement{stmt: stmt})
s.prepared.Store(string(handle), Statement{
stmt: stmt,
query: query,
})

result.Handle = handle
// no way to get the dataset or parameter schemas from sql.DB
Expand Down Expand Up @@ -473,54 +486,58 @@ type dbQueryCtx interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
}

func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) {
rows, err := db.QueryContext(ctx, query, args...)
func doGetQuery(ctx context.Context, db *sql.DB, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) {

conn, err := db.Conn(ctx)
var duckConn *duckdb.Conn
err = conn.Raw(func(driverConn any) error {
duckConn = driverConn.(*duckdb.Conn)
return nil
})
arrow, err := duckdb.NewArrowFromConn(duckConn)
if err != nil {
// Not really useful except for testing Flight SQL clients
trailers := metadata.Pairs("afsql-sqlite-query", query)
grpc.SetTrailer(ctx, trailers)
return nil, nil, err
}

var rdr *SqlBatchReader
if schema != nil {
rdr, err = NewSqlBatchReaderWithSchema(mem, schema, rows)
} else {
rdr, err = NewSqlBatchReader(mem, rows)
if err == nil {
schema = rdr.schema
}
}

rdr, err := arrow.QueryContext(ctx, query, args...)
if err != nil {
return nil, nil, err
}

schema = rdr.Schema()
ch := make(chan flight.StreamChunk)
go flight.StreamChunksFromReader(rdr, ch)
return schema, ch, nil
}

func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle()))

if !ok {
return nil, nil, status.Error(codes.InvalidArgument, "prepared statement not found")
}

conn, err := s.db.Conn(ctx)
var duckConn *duckdb.Conn
err = conn.Raw(func(driverConn any) error {
duckConn = driverConn.(*duckdb.Conn)
return nil
})
if err != nil {
return nil, nil, err
}

stmt := val.(Statement)
arrow, err := duckdb.NewArrowFromConn(duckConn)
if err != nil {
return nil, nil, err
}

readers := make([]array.RecordReader, 0, len(stmt.params))
if len(stmt.params) == 0 {
rows, err := stmt.stmt.QueryContext(ctx)
rdr, err := arrow.QueryContext(ctx, stmt.query)
if err != nil {
return nil, nil, err
}

rdr, err := NewSqlBatchReader(s.Alloc, rows)
if err != nil {
return nil, nil, err
}

schema = rdr.schema
schema = rdr.Schema()
readers = append(readers, rdr)
} else {
defer func() {
Expand All @@ -530,35 +547,17 @@ func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd
}
}
}()
var (
rows *sql.Rows
rdr *SqlBatchReader
)
// if we have multiple rows of bound params, execute the query
// multiple times and concatenate the result sets.
for _, p := range stmt.params {
rows, err = stmt.stmt.QueryContext(ctx, p...)
rdr, err := arrow.QueryContext(ctx, stmt.query, p...)
if err != nil {
return nil, nil, err
}

if schema == nil {
rdr, err = NewSqlBatchReader(s.Alloc, rows)
if err != nil {
return nil, nil, err
}
schema = rdr.schema
} else {
rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema, rows)
if err != nil {
return nil, nil, err
}
}

schema = rdr.Schema()
readers = append(readers, rdr)
}
}

ch := make(chan flight.StreamChunk)
go flight.ConcatenateReaders(readers, ch)
out = ch
Expand Down Expand Up @@ -715,7 +714,7 @@ func (s *SQLiteFlightSQLServer) DoGetPrimaryKeys(ctx context.Context, cmd flight

fmt.Fprintf(&b, " and table_name LIKE '%s'", cmd.Table)

return doGetQuery(ctx, s.Alloc, s.db, b.String(), schema_ref.PrimaryKeys)
return doGetQuery(ctx, s.db, b.String(), schema_ref.PrimaryKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoImportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -731,7 +730,7 @@ func (s *SQLiteFlightSQLServer) DoGetImportedKeys(ctx context.Context, ref fligh
filter += " AND fk_schema_name = '" + *ref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ImportedKeys)
return doGetQuery(ctx, s.db, query, schema_ref.ImportedKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoExportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -747,7 +746,7 @@ func (s *SQLiteFlightSQLServer) DoGetExportedKeys(ctx context.Context, ref fligh
filter += " AND pk_schema_name = '" + *ref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys)
return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys)
}

func (s *SQLiteFlightSQLServer) GetFlightInfoCrossReference(_ context.Context, _ flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
Expand All @@ -773,7 +772,7 @@ func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx context.Context, cmd fli
filter += " AND fk_schema_name = '" + *fkref.DBSchema + "'"
}
query := prepareQueryForGetKeys(filter)
return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys)
return doGetQuery(ctx, s.db, query, schema_ref.ExportedKeys)
}

func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req flightsql.ActionBeginTransactionRequest) (id []byte, err error) {
Expand Down
1 change: 1 addition & 0 deletions flightsqltest/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func (s *SqlTestSuite) SetupSuite() {
if err != nil {
return nil, "", err
}

sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage())
if err != nil {
return nil, "", err
Expand Down

0 comments on commit 0ba0b75

Please sign in to comment.