diff --git a/.github/workflows/clients-compatibility.yml b/.github/workflows/clients-compatibility.yml index 96f03c22..2535bdf6 100644 --- a/.github/workflows/clients-compatibility.yml +++ b/.github/workflows/clients-compatibility.yml @@ -6,6 +6,7 @@ on: - main - compatibility - test + - support_flightsql pull_request: branches: [ "main" ] @@ -133,4 +134,39 @@ jobs: - name: Run the Compatibility Test for Python Data Tools run: | - bats ./compatibility/pg-pytools/test.bats \ No newline at end of file + bats ./compatibility/pg-pytools/test.bats + + test-flightsql: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Install dependencies + run: | + go get . + + pip3 install "sqlglot[rs]" + pip3 install "psycopg[binary]" pandas pyarrow polars adbc_driver_flightsql + + - name: Build + run: go build -v + + - name: Start MyDuck Server + run: | + ./myduckserver --flightsql-port 47470 & + sleep 10 + + - name: Run the Compatibility Test for FlightSQL + run: | + go test -v ./compatibility/flightsql/go/flightsql_test.go + python3 -m unittest discover ./compatibility/flightsql/python -p "flightsql_test.py" \ No newline at end of file diff --git a/.github/workflows/password-auth.yml b/.github/workflows/password-auth.yml new file mode 100644 index 00000000..0f79dcd6 --- /dev/null +++ b/.github/workflows/password-auth.yml @@ -0,0 +1,67 @@ +name: Password Auth Test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install system packages + uses: awalsh128/cache-apt-pkgs-action@latest + with: + packages: postgresql-client mysql-client + version: 1.0 + + - name: Install dependencies + run: | + go get . + + pip3 install "sqlglot[rs]" + + curl -LJO https://github.com/duckdb/duckdb/releases/latest/download/duckdb_cli-linux-amd64.zip + unzip duckdb_cli-linux-amd64.zip + chmod +x duckdb + sudo mv duckdb /usr/local/bin + + - name: Build + run: go build -v + + - name: Start MyDuck Server with password + run: | + ./myduckserver --superuser-password=testpass123 & + sleep 5 + + - name: Test PostgreSQL auth + run: | + export PGPASSWORD=testpass123 + # Basic connection test + psql -h 127.0.0.1 -U postgres -d postgres -c "SELECT 1 as test;" + # Create and query a table + psql -h 127.0.0.1 -U postgres -d postgres -c "CREATE TABLE test (id int); INSERT INTO test VALUES (42); SELECT * FROM test;" + # Test wrong password + ! PGPASSWORD=wrongpass psql -h 127.0.0.1 -U postgres -d postgres -c "SELECT 1" + + - name: Test MySQL auth + run: | + # Basic connection test + mysql -h127.0.0.1 -uroot -ptestpass123 -e "SELECT 1 as test;" + # Create and query a table + mysql -h127.0.0.1 -uroot -ptestpass123 -e "CREATE DATABASE IF NOT EXISTS test; USE test; CREATE TABLE t1 (id int); INSERT INTO t1 VALUES (42); SELECT * FROM t1;" + # Test wrong password + ! mysql -h127.0.0.1 -uroot -pwrongpass -e "SELECT 1" diff --git a/.github/workflows/psql.yml b/.github/workflows/psql.yml index 3002d260..75c08a04 100644 --- a/.github/workflows/psql.yml +++ b/.github/workflows/psql.yml @@ -35,12 +35,10 @@ jobs: pip3 install "sqlglot[rs]" pyarrow pandas - curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.3/duckdb_cli-linux-amd64.zip + curl -LJO https://github.com/duckdb/duckdb/releases/latest/download/duckdb_cli-linux-amd64.zip unzip duckdb_cli-linux-amd64.zip chmod +x duckdb sudo mv duckdb /usr/local/bin - duckdb -c 'INSTALL json from core' - duckdb -c 'SELECT extension_name, loaded, install_path FROM duckdb_extensions() where installed' - name: Build run: go build -v diff --git a/README.md b/README.md index 79e6dae0..beee1863 100644 --- a/README.md +++ b/README.md @@ -163,18 +163,34 @@ With MyDuck's powerful analytics capabilities, you can create an hybrid transact To rename the default database, pass the `DEFAULT_DB` environment variable to the Docker container: ```bash -docker run -p 13306:3306 -p 15432:5432 --env=DEFAULT_DB=mydbname apecloud/myduckserver:latest +docker run -d -p 13306:3306 -p 15432:5432 \ + --env=DEFAULT_DB=mydbname \ + apecloud/myduckserver:latest ``` + +To set the superuser password, pass the `SUPERUSER_PASSWORD` environment variable to the Docker container: + +```bash +docker run -d -p 13306:3306 -p 15432:5432 \ + --env=SUPERUSER_PASSWORD=mysecretpassword \ + apecloud/myduckserver:latest +``` + + To initialize MyDuck Server with custom SQL statements, mount your `.sql` file to either `/docker-entrypoint-initdb.d/mysql/` or `/docker-entrypoint-initdb.d/postgres/` inside the Docker container, depending on the SQL dialect you're using. For example: ```bash # Execute `init.sql` via MySQL protocol -docker run -d -p 13306:3306 --name=myduck -v ./init.sql:/docker-entrypoint-initdb.d/mysql/init.sql apecloud/myduckserver:latest +docker run -d -p 13306:3306 --name=myduck \ + -v ./init.sql:/docker-entrypoint-initdb.d/mysql/init.sql \ + apecloud/myduckserver:latest # Execute `init.sql` via PostgreSQL protocol -docker run -d -p 15432:5432 --name=myduck -v ./init.sql:/docker-entrypoint-initdb.d/postgres/init.sql apecloud/myduckserver:latest +docker run -d -p 15432:5432 --name=myduck \ + -v ./init.sql:/docker-entrypoint-initdb.d/postgres/init.sql \ + apecloud/myduckserver:latest ``` ### Query Parquet Files diff --git a/compatibility/flightsql/go/flightsql_test.go b/compatibility/flightsql/go/flightsql_test.go new file mode 100644 index 00000000..2171a4c6 --- /dev/null +++ b/compatibility/flightsql/go/flightsql_test.go @@ -0,0 +1,152 @@ +package main + +import ( + "context" + "reflect" + "testing" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/flightsql" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +// Set connection options +var options = map[string]string{ + adbc.OptionKeyURI: "grpc://localhost:47470", + flightsql.OptionSSLSkipVerify: adbc.OptionValueEnabled, +} + +// Create database connection +func createDatabaseConnection(t *testing.T) adbc.Connection { + alloc := memory.NewGoAllocator() + drv := flightsql.NewDriver(alloc) + db, err := drv.NewDatabase(options) + if err != nil { + t.Fatalf("Error creating database: %v", err) + } + + cnxn, err := db.Open(context.Background()) + if err != nil { + t.Fatalf("Error opening connection: %v", err) + } + + return cnxn +} + +// Execute SQL statement +func executeSQLStatement(cnxn adbc.Connection, query string, t *testing.T) { + stmt, err := cnxn.NewStatement() + if err != nil { + t.Fatalf("failed to create statement: %v", err) + } + defer stmt.Close() + + err = stmt.SetSqlQuery(query) + if err != nil { + t.Fatalf("failed to set SQL query: %v", err) + } + + _, err = stmt.ExecuteUpdate(context.Background()) + if err != nil { + t.Fatalf("failed to execute SQL statement: %v", err) + } +} + +// Execute query and verify results +func executeQueryAndVerify(cnxn adbc.Connection, query string, expectedResults []struct { + id int64 + name string + value int64 +}, t *testing.T) { + stmt, err := cnxn.NewStatement() + if err != nil { + t.Fatalf("failed to create statement: %v", err) + } + defer stmt.Close() + + err = stmt.SetSqlQuery(query) + if err != nil { + t.Fatalf("failed to set SQL query: %v", err) + } + + rows, _, err := stmt.ExecuteQuery(context.Background()) + if err != nil { + t.Fatalf("failed to execute query: %v", err) + } + defer rows.Release() + + var actualResults []struct { + id int64 + name string + value int64 + } + + // Read query results and verify + for rows.Next() { + record := rows.Record() + numRows := record.NumRows() + + id := record.Column(0).(*array.Int64) + name := record.Column(1).(*array.String) + value := record.Column(2).(*array.Int64) + for i := 0; i < int(numRows); i++ { + actualResults = append(actualResults, struct { + id int64 + name string + value int64 + }{ + id: id.Value(i), + name: name.Value(i), + value: value.Value(i), + }) + } + } + + // Verify query results + if len(actualResults) != len(expectedResults) { + t.Errorf("Expected %d rows, but got %d", len(expectedResults), len(actualResults)) + } + + for i, result := range actualResults { + expected := expectedResults[i] + if !reflect.DeepEqual(result, expected) { + t.Errorf("Row %d: Expected %+v, but got %+v", i, expected, result) + } + } +} + +// Go test function +func TestSQLOperations(t *testing.T) { + cnxn := createDatabaseConnection(t) + defer cnxn.Close() + + // 1. Execute DROP TABLE IF EXISTS intTable + executeSQLStatement(cnxn, "DROP TABLE IF EXISTS intTable", t) + + // 2. Execute CREATE TABLE IF NOT EXISTS intTable + executeSQLStatement(cnxn, `CREATE TABLE IF NOT EXISTS intTable ( + id INTEGER PRIMARY KEY, + name VARCHAR(50), + value INT + )`, t) + + // 3. Execute INSERT INTO intTable + executeSQLStatement(cnxn, "INSERT INTO intTable (id, name, value) VALUES (1, 'TestName', 100)", t) + executeSQLStatement(cnxn, "INSERT INTO intTable (id, name, value) VALUES (2, 'AnotherName', 200)", t) + + // 4. Query data and verify insertion was successful + expectedResults := []struct { + id int64 + name string + value int64 + }{ + {id: 1, name: "TestName", value: 100}, + {id: 2, name: "AnotherName", value: 200}, + } + query := "SELECT id, name, value FROM intTable" + executeQueryAndVerify(cnxn, query, expectedResults, t) + + // 5. Execute DROP TABLE IF EXISTS intTable + executeSQLStatement(cnxn, "DROP TABLE IF EXISTS intTable", t) +} diff --git a/compatibility/flightsql/python/flightsql_test.py b/compatibility/flightsql/python/flightsql_test.py new file mode 100644 index 00000000..4d1c7d03 --- /dev/null +++ b/compatibility/flightsql/python/flightsql_test.py @@ -0,0 +1,61 @@ +import unittest +from adbc_driver_flightsql import DatabaseOptions +from adbc_driver_flightsql.dbapi import connect + +class TestFlightSQLDatabase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + """Runs once before any tests are executed, used to set up the database connection.""" + headers = {"foo": "bar"} + cls.conn = connect( + "grpc://localhost:47470", # FlightSQL server address + db_kwargs={ + DatabaseOptions.TLS_SKIP_VERIFY.value: "true", # Skip TLS verification + **{f"{DatabaseOptions.RPC_CALL_HEADER_PREFIX.value}{k}": v for k, v in headers.items()} + } + ) + + @classmethod + def tearDownClass(cls): + """Runs once after all tests have been executed, used to close the database connection.""" + cls.conn.close() + + def setUp(self): + """Runs before each individual test to ensure a clean environment by resetting the database.""" + with self.conn.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS intTable") # Drop the table if it exists + cursor.execute(""" + CREATE TABLE IF NOT EXISTS intTable ( + id INTEGER PRIMARY KEY, + name VARCHAR(50), + value INT + ) + """) # Create the table + + def test_insert_and_select(self): + """Test inserting data and selecting it back to verify correctness.""" + with self.conn.cursor() as cursor: + # Insert sample data + cursor.execute("INSERT INTO intTable (id, name, value) VALUES (1, 'TestName', 100)") + cursor.execute("INSERT INTO intTable (id, name, value) VALUES (2, 'AnotherName', 200)") + + # Select data from the table + cursor.execute("SELECT * FROM intTable") + rows = cursor.fetchall() + + # Expected result after insertions + expected_rows = [(1, 'TestName', 100), (2, 'AnotherName', 200)] + self.assertEqual(rows, expected_rows, f"Expected rows: {expected_rows}, but got: {rows}") + + def test_drop_table(self): + """Test dropping the table to ensure the table can be deleted successfully.""" + with self.conn.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS intTable") # Drop the table + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='intTable'") # Check if the table exists + rows = cursor.fetchall() + self.assertEqual(len(rows), 0, "Table 'intTable' should be dropped and not exist in the database.") + cursor.execute("COMMIT;") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/devtools/replica-setup-postgres/replica_setup.sh b/devtools/replica-setup-postgres/replica_setup.sh index 5d48c495..644d0fe1 100644 --- a/devtools/replica-setup-postgres/replica_setup.sh +++ b/devtools/replica-setup-postgres/replica_setup.sh @@ -7,7 +7,7 @@ usage() { MYDUCK_HOST=${MYDUCK_HOST:-127.0.0.1} MYDUCK_PORT=${MYDUCK_PORT:-5432} -MYDUCK_USER=${MYDUCK_USER:-mysql} +MYDUCK_USER=${MYDUCK_USER:-postgres} MYDUCK_PASSWORD=${MYDUCK_PASSWORD:-} MYDUCK_SERVER_ID=${MYDUCK_SERVER_ID:-2} MYDUCK_IN_DOCKER=${MYDUCK_IN_DOCKER:-false} diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 3e909838..a7f1bfc5 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -127,7 +127,7 @@ run_replica_setup() { run_server_in_background() { cd "$DATA_PATH" || { echo "Error: Could not change directory to ${DATA_PATH}"; exit 1; } - nohup myduckserver $DEFAULT_DB $LOG_LEVEL $PROFILER_PORT $RESTORE_FILE $RESTORE_ENDPOINT $RESTORE_ACCESS_KEY_ID $RESTORE_SECRET_ACCESS_KEY|tee -a "${LOG_PATH}/server.log" 2>&1 & + nohup myduckserver $DEFAULT_DB $SUPERUSER_PASSWORD $LOG_LEVEL $PROFILER_PORT $RESTORE_FILE $RESTORE_ENDPOINT $RESTORE_ACCESS_KEY_ID $RESTORE_SECRET_ACCESS_KEY | tee -a "${LOG_PATH}/server.log" 2>&1 & echo "$!" > "${PID_FILE}" } @@ -207,6 +207,10 @@ setup() { export DEFAULT_DB="--default-db=$DEFAULT_DB" fi + if [ -n "$SUPERUSER_PASSWORD" ]; then + export SUPERUSER_PASSWORD="--superuser-password=$SUPERUSER_PASSWORD" + fi + if [ -n "$LOG_LEVEL" ]; then export LOG_LEVEL="--loglevel=$LOG_LEVEL" fi diff --git a/flightsqlserver/sql_batch_reader.go b/flightsqlserver/sql_batch_reader.go new file mode 100644 index 00000000..cd0aff77 --- /dev/null +++ b/flightsqlserver/sql_batch_reader.go @@ -0,0 +1,344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 +// +build go1.18 + +package flightsqlserver + +import ( + "database/sql" + "reflect" + "strconv" + "strings" + "sync/atomic" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + + // "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/arrow/memory" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func getArrowTypeFromString(dbtype string) arrow.DataType { + dbtype = strings.ToLower(dbtype) + if dbtype == "" { + // SQLite may not know the type yet. + return &arrow.NullType{} + } + if strings.HasPrefix(dbtype, "varchar") { + return arrow.BinaryTypes.String + } + + switch dbtype { + case "tinyint": + return arrow.PrimitiveTypes.Int8 + case "mediumint": + return arrow.PrimitiveTypes.Int32 + case "int", "integer", "bigint": + return arrow.PrimitiveTypes.Int64 + case "float": + return arrow.PrimitiveTypes.Float32 + case "real", "double": + return arrow.PrimitiveTypes.Float64 + case "blob": + return arrow.BinaryTypes.Binary + case "text", "date", "char", "clob": + return arrow.BinaryTypes.String + case "boolean": + return arrow.FixedWidthTypes.Boolean + default: + panic("invalid sqlite type: " + dbtype) + } +} + +var sqliteDenseUnion = arrow.DenseUnionOf([]arrow.Field{ + {Name: "int", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "float", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "string", Type: arrow.BinaryTypes.String, Nullable: true}, +}, []arrow.UnionTypeCode{0, 1, 2}) + +func getArrowType(c *sql.ColumnType) arrow.DataType { + dbtype := strings.ToLower(c.DatabaseTypeName()) + if dbtype == "" { + if c.ScanType() == nil { + return sqliteDenseUnion + } + switch c.ScanType().Kind() { + case reflect.Int8, reflect.Uint8: + return arrow.PrimitiveTypes.Int8 + case reflect.Int32, reflect.Uint32: + return arrow.PrimitiveTypes.Int32 + case reflect.Int, reflect.Int64, reflect.Uint64: + return arrow.PrimitiveTypes.Int64 + case reflect.Float32: + return arrow.PrimitiveTypes.Float32 + case reflect.Float64: + return arrow.PrimitiveTypes.Float64 + case reflect.String: + return arrow.BinaryTypes.String + } + } + return getArrowTypeFromString(dbtype) +} + +const maxBatchSize = 1024 + +type SqlBatchReader struct { + refCount int64 + + schema *arrow.Schema + rows *sql.Rows + record arrow.Record + bldr *array.RecordBuilder + err error + + rowdest []interface{} +} + +func NewSqlBatchReaderWithSchema(mem memory.Allocator, schema *arrow.Schema, rows *sql.Rows) (*SqlBatchReader, error) { + rowdest := make([]interface{}, schema.NumFields()) + for i, f := range schema.Fields() { + switch f.Type.ID() { + case arrow.DENSE_UNION, arrow.SPARSE_UNION: + rowdest[i] = new(interface{}) + case arrow.UINT8, arrow.INT8: + if f.Nullable { + rowdest[i] = &sql.NullByte{} + } else { + rowdest[i] = new(uint8) + } + case arrow.INT32: + if f.Nullable { + rowdest[i] = &sql.NullInt32{} + } else { + rowdest[i] = new(int32) + } + case arrow.INT64: + if f.Nullable { + rowdest[i] = &sql.NullInt64{} + } else { + rowdest[i] = new(int64) + } + case arrow.FLOAT32, arrow.FLOAT64: + if f.Nullable { + rowdest[i] = &sql.NullFloat64{} + } else { + rowdest[i] = new(float64) + } + case arrow.BINARY: + var b []byte + rowdest[i] = &b + case arrow.STRING: + if f.Nullable { + rowdest[i] = &sql.NullString{} + } else { + rowdest[i] = new(string) + } + } + } + + return &SqlBatchReader{ + refCount: 1, + bldr: array.NewRecordBuilder(mem, schema), + schema: schema, + rowdest: rowdest, + rows: rows}, nil +} + +func NewSqlBatchReader(mem memory.Allocator, rows *sql.Rows) (*SqlBatchReader, error) { + bldr := flightsql.NewColumnMetadataBuilder() + + cols, err := rows.ColumnTypes() + if err != nil { + rows.Close() + return nil, err + } + + rowdest := make([]interface{}, len(cols)) + fields := make([]arrow.Field, len(cols)) + for i, c := range cols { + fields[i].Name = c.Name() + if c.Name() == "?" { + fields[i].Name += ":" + strconv.Itoa(i) + } + fields[i].Nullable, _ = c.Nullable() + fields[i].Type = getArrowType(c) + fields[i].Metadata = getColumnMetadata(bldr, getSqlTypeFromTypeName(c.DatabaseTypeName()), "") + switch fields[i].Type.ID() { + case arrow.DENSE_UNION, arrow.SPARSE_UNION: + rowdest[i] = new(interface{}) + case arrow.UINT8, arrow.INT8: + if fields[i].Nullable { + rowdest[i] = &sql.NullByte{} + } else { + rowdest[i] = new(uint8) + } + case arrow.INT32: + if fields[i].Nullable { + rowdest[i] = &sql.NullInt32{} + } else { + rowdest[i] = new(int32) + } + case arrow.INT64: + if fields[i].Nullable { + rowdest[i] = &sql.NullInt64{} + } else { + rowdest[i] = new(int64) + } + case arrow.FLOAT64, arrow.FLOAT32: + if fields[i].Nullable { + rowdest[i] = &sql.NullFloat64{} + } else { + rowdest[i] = new(float64) + } + case arrow.BINARY: + var b []byte + rowdest[i] = &b + case arrow.STRING: + if fields[i].Nullable { + rowdest[i] = &sql.NullString{} + } else { + rowdest[i] = new(string) + } + } + } + + schema := arrow.NewSchema(fields, nil) + return &SqlBatchReader{ + refCount: 1, + bldr: array.NewRecordBuilder(mem, schema), + schema: schema, + rowdest: rowdest, + rows: rows}, nil +} + +func (r *SqlBatchReader) Retain() { + atomic.AddInt64(&r.refCount, 1) +} + +func (r *SqlBatchReader) Release() { + // debug.Assert(atomic.LoadInt64(&r.refCount) > 0, "too many releases") + + if atomic.AddInt64(&r.refCount, -1) == 0 { + r.rows.Close() + r.rows, r.schema, r.rowdest = nil, nil, nil + r.bldr.Release() + r.bldr = nil + if r.record != nil { + r.record.Release() + r.record = nil + } + } +} +func (r *SqlBatchReader) Schema() *arrow.Schema { return r.schema } + +func (r *SqlBatchReader) Record() arrow.Record { return r.record } + +func (r *SqlBatchReader) Err() error { return r.err } + +func (r *SqlBatchReader) Next() bool { + if r.record != nil { + r.record.Release() + r.record = nil + } + + rows := 0 + for rows < maxBatchSize && r.rows.Next() { + if err := r.rows.Scan(r.rowdest...); err != nil { + // Not really useful except for testing Flight SQL clients + detail := wrapperspb.StringValue{Value: r.schema.String()} + if st, sterr := status.New(codes.Unknown, err.Error()).WithDetails(&detail); sterr != nil { + r.err = err + } else { + r.err = st.Err() + } + return false + } + + for i, v := range r.rowdest { + fb := r.bldr.Field(i) + + switch v := v.(type) { + case *uint8: + fb.(*array.Uint8Builder).Append(*v) + case *sql.NullByte: + if !v.Valid { + fb.AppendNull() + } else { + fb.(*array.Uint8Builder).Append(v.Byte) + } + case *int64: + fb.(*array.Int64Builder).Append(*v) + case *sql.NullInt64: + if !v.Valid { + fb.AppendNull() + } else { + fb.(*array.Int64Builder).Append(v.Int64) + } + case *int32: + fb.(*array.Int32Builder).Append(*v) + case *sql.NullInt32: + if !v.Valid { + fb.AppendNull() + } else { + fb.(*array.Int32Builder).Append(v.Int32) + } + case *float64: + switch b := fb.(type) { + case *array.Float64Builder: + b.Append(*v) + case *array.Float32Builder: + b.Append(float32(*v)) + } + case *sql.NullFloat64: + if !v.Valid { + fb.AppendNull() + } else { + switch b := fb.(type) { + case *array.Float64Builder: + b.Append(v.Float64) + case *array.Float32Builder: + b.Append(float32(v.Float64)) + } + } + case *[]byte: + if v == nil { + fb.AppendNull() + } else { + fb.(*array.BinaryBuilder).Append(*v) + } + case *string: + fb.(*array.StringBuilder).Append(*v) + case *sql.NullString: + if !v.Valid { + fb.AppendNull() + } else { + fb.(*array.StringBuilder).Append(v.String) + } + } + } + + rows++ + } + + r.record = r.bldr.NewRecord() + return rows > 0 +} diff --git a/flightsqlserver/sqlite_info.go b/flightsqlserver/sqlite_info.go new file mode 100644 index 00000000..df489c61 --- /dev/null +++ b/flightsqlserver/sqlite_info.go @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 +// +build go1.18 + +package flightsqlserver + +import ( + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" +) + +func SqlInfoResultMap() flightsql.SqlInfoResultMap { + return flightsql.SqlInfoResultMap{ + uint32(flightsql.SqlInfoFlightSqlServerName): "db_name", + uint32(flightsql.SqlInfoFlightSqlServerVersion): "sqlite 3", + uint32(flightsql.SqlInfoFlightSqlServerArrowVersion): arrow.PkgVersion, + uint32(flightsql.SqlInfoFlightSqlServerReadOnly): false, + uint32(flightsql.SqlInfoDDLCatalog): false, + uint32(flightsql.SqlInfoDDLSchema): false, + uint32(flightsql.SqlInfoDDLTable): true, + uint32(flightsql.SqlInfoIdentifierCase): int64(flightsql.SqlCaseSensitivityCaseInsensitive), + uint32(flightsql.SqlInfoIdentifierQuoteChar): `"`, + uint32(flightsql.SqlInfoQuotedIdentifierCase): int64(flightsql.SqlCaseSensitivityCaseInsensitive), + uint32(flightsql.SqlInfoAllTablesAreASelectable): true, + uint32(flightsql.SqlInfoNullOrdering): int64(flightsql.SqlNullOrderingSortAtStart), + uint32(flightsql.SqlInfoFlightSqlServerTransaction): int32(flightsql.SqlTransactionTransaction), + uint32(flightsql.SqlInfoTransactionsSupported): true, + uint32(flightsql.SqlInfoKeywords): []string{"ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DESC", + "DETACH", + "DISTINCT", + "DO", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FIRST", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GENERATED", + "GLOB", + "GROUP", + "GROUPS", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LAST", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MATERIALIZED", + "NATURAL", + "NO", + "NOT", + "NOTHING", + "NOTNULL", + "NULL", + "NULLS", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OTHERS", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RETURNING", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "TABLE", + "TEMP", + "TEMPORARY", + "THEN", + "TIES", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT"}, + uint32(flightsql.SqlInfoNumericFunctions): []string{ + "ACOS", "ACOSH", "ASIN", "ASINH", "ATAN", "ATAN2", "ATANH", "CEIL", + "CEILING", "COS", "COSH", "DEGREES", "EXP", "FLOOR", "LN", "LOG", + "LOG10", "LOG2", "MOD", "PI", "POW", "POWER", "RADIANS", + "SIN", "SINH", "SQRT", "TAN", "TANH", "TRUNC"}, + uint32(flightsql.SqlInfoStringFunctions): []string{"SUBSTR", "TRIM", "LTRIM", "RTRIM", "LENGTH", + "REPLACE", "UPPER", "LOWER", "INSTR"}, + uint32(flightsql.SqlInfoSupportsConvert): map[int32][]int32{ + int32(flightsql.SqlConvertBigInt): {int32(flightsql.SqlConvertInteger)}, + }, + } +} diff --git a/flightsqlserver/sqlite_server.go b/flightsqlserver/sqlite_server.go new file mode 100644 index 00000000..4669b1a8 --- /dev/null +++ b/flightsqlserver/sqlite_server.go @@ -0,0 +1,812 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 +// +build go1.18 + +// Package example contains a FlightSQL Server implementation using +// sqlite as the backing engine. +// +// In order to ensure portability we'll use modernc.org/sqlite instead +// of github.com/mattn/go-sqlite3 because modernc is a translation of the +// SQLite source into Go, such that it doesn't require CGO to run and +// doesn't need to link against the actual libsqlite3 libraries. This way +// we don't require CGO or libsqlite3 to run this example or the tests. +// +// That said, since both implement in terms of Go's standard database/sql +// package, it's easy to swap them out if desired as the modernc.org/sqlite +// package is slower than go-sqlite3. +// +// One other important note is that modernc.org/sqlite only works +// correctly (specifically pragma_table_info) in go 1.18+ so this +// entire package is given the build constraint to only build when +// using go1.18 or higher +package flightsqlserver + +import ( + "bytes" + "context" + "database/sql" + stdsql "database/sql" + "fmt" + "math/rand" + "path/filepath" + "strings" + "sync" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/flight" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/arrow/scalar" + "github.com/marcboeker/go-duckdb" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + _ "modernc.org/sqlite" +) + +func genRandomString() []byte { + const length = 16 + max := int('z') + // don't include ':' as a valid byte to generate + // because we use it as a separator for the transactions + min := int('<') + + out := make([]byte, length) + for i := range out { + out[i] = byte(rand.Intn(max-min+1) + min) + } + return out +} + +func prepareQueryForGetTables(cmd flightsql.GetTables) string { + var b strings.Builder + b.WriteString(`SELECT 'main' AS catalog_name, '' AS schema_name, + name AS table_name, type AS table_type FROM sqlite_master WHERE 1=1`) + + if cmd.GetCatalog() != nil { + b.WriteString(" and catalog_name = '") + b.WriteString(*cmd.GetCatalog()) + b.WriteByte('\'') + } + + if cmd.GetDBSchemaFilterPattern() != nil { + b.WriteString(" and schema_name LIKE '") + b.WriteString(*cmd.GetDBSchemaFilterPattern()) + b.WriteByte('\'') + } + + if cmd.GetTableNameFilterPattern() != nil { + b.WriteString(" and table_name LIKE '") + b.WriteString(*cmd.GetTableNameFilterPattern()) + b.WriteByte('\'') + } + + if len(cmd.GetTableTypes()) > 0 { + b.WriteString(" and table_type IN (") + for i, t := range cmd.GetTableTypes() { + if i != 0 { + b.WriteByte(',') + } + fmt.Fprintf(&b, "'%s'", t) + } + b.WriteByte(')') + } + + b.WriteString(" order by table_name") + return b.String() +} + +func prepareQueryForGetKeys(filter string) string { + return `SELECT * FROM ( + SELECT + NULL AS pk_catalog_name, + NULL AS pk_schema_name, + p."table" AS pk_table_name, + p."to" AS pk_column_name, + NULL AS fk_catalog_name, + NULL AS fk_schema_name, + m.name AS fk_table_name, + p."from" AS fk_column_name, + p.seq AS key_sequence, + NULL AS pk_key_name, + NULL AS fk_key_name, + CASE + WHEN p.on_update = 'CASCADE' THEN 0 + WHEN p.on_update = 'RESTRICT' THEN 1 + WHEN p.on_update = 'SET NULL' THEN 2 + WHEN p.on_update = 'NO ACTION' THEN 3 + WHEN p.on_update = 'SET DEFAULT' THEN 4 + END AS update_rule, + CASE + WHEN p.on_delete = 'CASCADE' THEN 0 + WHEN p.on_delete = 'RESTRICT' THEN 1 + WHEN p.on_delete = 'SET NULL' THEN 2 + WHEN p.on_delete = 'NO ACTION' THEN 3 + WHEN p.on_delete = 'SET DEFAULT' THEN 4 + END AS delete_rule + FROM sqlite_master m + JOIN pragma_foreign_key_list(m.name) p ON m.name != p."table" + WHERE m.type = 'table') WHERE ` + filter + + ` ORDER BY pk_catalog_name, pk_schema_name, pk_table_name, pk_key_name, key_sequence` +} + +func CreateDB() (*sql.DB, error) { + + dbFile := "mysql.db" + dataDir := "." + + dbFile = strings.TrimSpace(dbFile) + dsn := filepath.Join(dataDir, dbFile) + + connector, err := duckdb.NewConnector(dsn, nil) + if err != nil { + return nil, err + } + + db := stdsql.OpenDB(connector) + + _, err = db.Exec(` + CREATE TABLE foreignTable ( + id INTEGER PRIMARY KEY NOT NULL, + foreignName varchar(100), + value int); + + CREATE TABLE intTable ( + id INTEGER PRIMARY KEY NOT NULL, + keyName varchar(100), + value int, + foreignId int references foreignTable(id)); + + INSERT INTO foreignTable (id, foreignName, value) VALUES (1, 'keyOne', 1); + INSERT INTO foreignTable (id, foreignName, value) VALUES (2, 'keyTwo', 0); + INSERT INTO foreignTable (id, foreignName, value) VALUES (3, 'keyThree', -1); + INSERT INTO intTable (id, keyName, value, foreignId) VALUES (1, 'one', 1, 1); + INSERT INTO intTable (id, keyName, value, foreignId) VALUES (2, 'zero', 0, 1); + INSERT INTO intTable (id, keyName, value, foreignId) VALUES (5, 'negative one', -1, 1); + `) + if err != nil { + db.Close() + return nil, err + } + + return db, nil +} + +func encodeTransactionQuery(query string, transactionID flightsql.Transaction) ([]byte, error) { + return flightsql.CreateStatementQueryTicket( + bytes.Join([][]byte{transactionID, []byte(query)}, []byte(":"))) +} + +func decodeTransactionQuery(ticket []byte) (txnID, query string, err error) { + id, queryBytes, found := bytes.Cut(ticket, []byte(":")) + if !found { + err = fmt.Errorf("%w: malformed ticket", arrow.ErrInvalid) + return + } + + txnID = string(id) + query = string(queryBytes) + return +} + +type Statement struct { + stmt *sql.Stmt + params [][]interface{} +} + +type SQLiteFlightSQLServer struct { + flightsql.BaseServer + db *sql.DB + + prepared sync.Map + openTransactions sync.Map +} + +func NewSQLiteFlightSQLServer(db *sql.DB) (*SQLiteFlightSQLServer, error) { + ret := &SQLiteFlightSQLServer{db: db} + ret.Alloc = memory.DefaultAllocator + for k, v := range SqlInfoResultMap() { + ret.RegisterSqlInfo(flightsql.SqlInfo(k), v) + } + return ret, nil +} + +func (s *SQLiteFlightSQLServer) flightInfoForCommand(desc *flight.FlightDescriptor, schema *arrow.Schema) *flight.FlightInfo { + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + Schema: flight.SerializeSchema(schema, s.Alloc), + TotalRecords: -1, + TotalBytes: -1, + } +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + query, txnid := cmd.GetQuery(), cmd.GetTransactionId() + tkt, err := encodeTransactionQuery(query, txnid) + if err != nil { + return nil, err + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (s *SQLiteFlightSQLServer) DoGetStatement(ctx context.Context, cmd flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { + txnid, query, err := decodeTransactionQuery(cmd.GetStatementHandle()) + if err != nil { + return nil, nil, err + } + + var db dbQueryCtx = s.db + if txnid != "" { + tx, loaded := s.openTransactions.Load(txnid) + if !loaded { + return nil, nil, fmt.Errorf("%w: invalid transaction id specified: %s", arrow.ErrInvalid, txnid) + } + db = tx.(*sql.Tx) + } + + return doGetQuery(ctx, s.Alloc, db, query, nil) +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoCatalogs(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.Catalogs), nil +} + +func (s *SQLiteFlightSQLServer) DoGetCatalogs(context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { + // https://www.sqlite.org/cli.html + // > The ".databases" command shows a list of all databases open + // > in the current connection. There will always be at least + // > 2. The first one is "main", the original database opened. The + // > second is "temp", the database used for temporary tables. + // For our purposes, return only "main" and ignore other databases. + + schema := schema_ref.Catalogs + + catalogs, _, err := array.FromJSON(s.Alloc, arrow.BinaryTypes.String, strings.NewReader(`["main"]`)) + if err != nil { + return nil, nil, err + } + defer catalogs.Release() + + batch := array.NewRecord(schema, []arrow.Array{catalogs}, 1) + + ch := make(chan flight.StreamChunk, 1) + ch <- flight.StreamChunk{Data: batch} + close(ch) + + return schema, ch, nil +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoSchemas(_ context.Context, cmd flightsql.GetDBSchemas, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.DBSchemas), nil +} + +func (s *SQLiteFlightSQLServer) DoGetDBSchemas(_ context.Context, cmd flightsql.GetDBSchemas) (*arrow.Schema, <-chan flight.StreamChunk, error) { + // SQLite doesn't support schemas, so pretend we have a single unnamed schema. + schema := schema_ref.DBSchemas + + ch := make(chan flight.StreamChunk, 1) + + if cmd.GetDBSchemaFilterPattern() == nil || *cmd.GetDBSchemaFilterPattern() == "" { + catalogs, _, err := array.FromJSON(s.Alloc, arrow.BinaryTypes.String, strings.NewReader(`["main"]`)) + if err != nil { + return nil, nil, err + } + defer catalogs.Release() + + dbSchemas, _, err := array.FromJSON(s.Alloc, arrow.BinaryTypes.String, strings.NewReader(`[""]`)) + if err != nil { + return nil, nil, err + } + defer dbSchemas.Release() + + batch := array.NewRecord(schema, []arrow.Array{catalogs, dbSchemas}, 1) + ch <- flight.StreamChunk{Data: batch} + } + + close(ch) + + return schema, ch, nil +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoTables(_ context.Context, cmd flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + schema := schema_ref.Tables + if cmd.GetIncludeSchema() { + schema = schema_ref.TablesWithIncludedSchema + } + return s.flightInfoForCommand(desc, schema), nil +} + +func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { + query := prepareQueryForGetTables(cmd) + + rows, err := s.db.QueryContext(ctx, query) + if err != nil { + return nil, nil, err + } + + var rdr array.RecordReader + + rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema_ref.Tables, rows) + if err != nil { + return nil, nil, err + } + + ch := make(chan flight.StreamChunk, 2) + if cmd.GetIncludeSchema() { + rdr, err = NewSqliteTablesSchemaBatchReader(ctx, s.Alloc, rdr, s.db, query) + if err != nil { + return nil, nil, err + } + } + + schema := rdr.Schema() + go flight.StreamChunksFromReader(rdr, ch) + return schema, ch, nil +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoXdbcTypeInfo(_ context.Context, _ flightsql.GetXdbcTypeInfo, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.XdbcTypeInfo), nil +} + +func (s *SQLiteFlightSQLServer) DoGetXdbcTypeInfo(_ context.Context, cmd flightsql.GetXdbcTypeInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) { + var batch arrow.Record + if cmd.GetDataType() == nil { + batch = GetTypeInfoResult(s.Alloc) + } else { + batch = GetFilteredTypeInfoResult(s.Alloc, *cmd.GetDataType()) + } + + ch := make(chan flight.StreamChunk, 1) + ch <- flight.StreamChunk{Data: batch} + close(ch) + return batch.Schema(), ch, nil +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoTableTypes(_ context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.TableTypes), nil +} + +func (s *SQLiteFlightSQLServer) DoGetTableTypes(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) { + query := "SELECT DISTINCT type AS table_type FROM sqlite_master" + return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.TableTypes) +} + +func (s *SQLiteFlightSQLServer) DoPutCommandStatementUpdate(ctx context.Context, cmd flightsql.StatementUpdate) (int64, error) { + var ( + res sql.Result + err error + ) + + if len(cmd.GetTransactionId()) > 0 { + tx, loaded := s.openTransactions.Load(string(cmd.GetTransactionId())) + if !loaded { + return -1, status.Error(codes.InvalidArgument, "invalid transaction handle provided") + } + + res, err = tx.(*sql.Tx).ExecContext(ctx, cmd.GetQuery()) + } else { + res, err = s.db.ExecContext(ctx, cmd.GetQuery()) + } + + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (s *SQLiteFlightSQLServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) { + var stmt *sql.Stmt + + if len(req.GetTransactionId()) > 0 { + tx, loaded := s.openTransactions.Load(string(req.GetTransactionId())) + if !loaded { + return result, status.Error(codes.InvalidArgument, "invalid transaction handle provided") + } + stmt, err = tx.(*sql.Tx).PrepareContext(ctx, req.GetQuery()) + } else { + stmt, err = s.db.PrepareContext(ctx, req.GetQuery()) + } + + if err != nil { + return result, err + } + + handle := genRandomString() + s.prepared.Store(string(handle), Statement{stmt: stmt}) + + result.Handle = handle + // no way to get the dataset or parameter schemas from sql.DB + return +} + +func (s *SQLiteFlightSQLServer) ClosePreparedStatement(ctx context.Context, request flightsql.ActionClosePreparedStatementRequest) error { + handle := request.GetPreparedStatementHandle() + if val, loaded := s.prepared.LoadAndDelete(string(handle)); loaded { + stmt := val.(Statement) + return stmt.stmt.Close() + } + + return status.Error(codes.InvalidArgument, "prepared statement not found") +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + _, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle())) + if !ok { + return nil, status.Error(codes.InvalidArgument, "prepared statement not found") + } + + return &flight.FlightInfo{ + Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}}, + FlightDescriptor: desc, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +type dbQueryCtx interface { + QueryContext(context.Context, string, ...any) (*sql.Rows, error) +} + +func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query string, schema *arrow.Schema, args ...interface{}) (*arrow.Schema, <-chan flight.StreamChunk, error) { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + // Not really useful except for testing Flight SQL clients + trailers := metadata.Pairs("afsql-sqlite-query", query) + grpc.SetTrailer(ctx, trailers) + return nil, nil, err + } + + var rdr *SqlBatchReader + if schema != nil { + rdr, err = NewSqlBatchReaderWithSchema(mem, schema, rows) + } else { + rdr, err = NewSqlBatchReader(mem, rows) + if err == nil { + schema = rdr.schema + } + } + + if err != nil { + return nil, nil, err + } + + ch := make(chan flight.StreamChunk) + go flight.StreamChunksFromReader(rdr, ch) + return schema, ch, nil +} + +func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) { + val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle())) + if !ok { + return nil, nil, status.Error(codes.InvalidArgument, "prepared statement not found") + } + + stmt := val.(Statement) + readers := make([]array.RecordReader, 0, len(stmt.params)) + if len(stmt.params) == 0 { + rows, err := stmt.stmt.QueryContext(ctx) + if err != nil { + return nil, nil, err + } + + rdr, err := NewSqlBatchReader(s.Alloc, rows) + if err != nil { + return nil, nil, err + } + + schema = rdr.schema + readers = append(readers, rdr) + } else { + defer func() { + if err != nil { + for _, r := range readers { + r.Release() + } + } + }() + var ( + rows *sql.Rows + rdr *SqlBatchReader + ) + // if we have multiple rows of bound params, execute the query + // multiple times and concatenate the result sets. + for _, p := range stmt.params { + rows, err = stmt.stmt.QueryContext(ctx, p...) + if err != nil { + return nil, nil, err + } + + if schema == nil { + rdr, err = NewSqlBatchReader(s.Alloc, rows) + if err != nil { + return nil, nil, err + } + schema = rdr.schema + } else { + rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema, rows) + if err != nil { + return nil, nil, err + } + } + + readers = append(readers, rdr) + } + } + + ch := make(chan flight.StreamChunk) + go flight.ConcatenateReaders(readers, ch) + out = ch + return +} + +func scalarToIFace(s scalar.Scalar) (interface{}, error) { + if !s.IsValid() { + return nil, nil + } + + switch val := s.(type) { + case *scalar.Int8: + return val.Value, nil + case *scalar.Uint8: + return val.Value, nil + case *scalar.Int32: + return val.Value, nil + case *scalar.Int64: + return val.Value, nil + case *scalar.Float32: + return val.Value, nil + case *scalar.Float64: + return val.Value, nil + case *scalar.String: + return string(val.Value.Bytes()), nil + case *scalar.Binary: + return val.Value.Bytes(), nil + case scalar.DateScalar: + return val.ToTime(), nil + case scalar.TimeScalar: + return val.ToTime(), nil + case *scalar.DenseUnion: + return scalarToIFace(val.Value) + default: + return nil, fmt.Errorf("unsupported type: %s", val) + } +} + +func getParamsForStatement(rdr flight.MessageReader) (params [][]interface{}, err error) { + params = make([][]interface{}, 0) + for rdr.Next() { + rec := rdr.Record() + + nrows := int(rec.NumRows()) + ncols := int(rec.NumCols()) + + for i := 0; i < nrows; i++ { + invokeParams := make([]interface{}, ncols) + for c := 0; c < ncols; c++ { + col := rec.Column(c) + sc, err := scalar.GetScalar(col, i) + if err != nil { + return nil, err + } + if r, ok := sc.(scalar.Releasable); ok { + r.Release() + } + + invokeParams[c], err = scalarToIFace(sc) + if err != nil { + return nil, err + } + } + params = append(params, invokeParams) + } + } + + return params, rdr.Err() +} + +func (s *SQLiteFlightSQLServer) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) ([]byte, error) { + val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle())) + if !ok { + return nil, status.Error(codes.InvalidArgument, "prepared statement not found") + } + + stmt := val.(Statement) + args, err := getParamsForStatement(rdr) + if err != nil { + return nil, status.Errorf(codes.Internal, "error gathering parameters for prepared statement query: %s", err.Error()) + } + + stmt.params = args + s.prepared.Store(string(cmd.GetPreparedStatementHandle()), stmt) + return cmd.GetPreparedStatementHandle(), nil +} + +func (s *SQLiteFlightSQLServer) DoPutPreparedStatementUpdate(ctx context.Context, cmd flightsql.PreparedStatementUpdate, rdr flight.MessageReader) (int64, error) { + val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle())) + if !ok { + return 0, status.Error(codes.InvalidArgument, "prepared statement not found") + } + + stmt := val.(Statement) + args, err := getParamsForStatement(rdr) + if err != nil { + return 0, status.Errorf(codes.Internal, "error gathering parameters for prepared statement: %s", err.Error()) + } + + if len(args) == 0 { + result, err := stmt.stmt.ExecContext(ctx) + if err != nil { + if strings.Contains(err.Error(), "no such table") { + return 0, status.Error(codes.NotFound, err.Error()) + } + return 0, err + } + + return result.RowsAffected() + } + + var totalAffected int64 + for _, p := range args { + result, err := stmt.stmt.ExecContext(ctx, p...) + if err != nil { + if strings.Contains(err.Error(), "no such table") { + return totalAffected, status.Error(codes.NotFound, err.Error()) + } + return totalAffected, err + } + + n, err := result.RowsAffected() + if err != nil { + return totalAffected, err + } + totalAffected += n + } + + return totalAffected, nil +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoPrimaryKeys(_ context.Context, cmd flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.PrimaryKeys), nil +} + +func (s *SQLiteFlightSQLServer) DoGetPrimaryKeys(ctx context.Context, cmd flightsql.TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { + // the field key_name cannot be recovered by sqlite so it is + // being set to null following the same pattern for catalog name and schema_name + var b strings.Builder + + b.WriteString(` + SELECT null AS catalog_name, null AS schema_name, table_name, name AS column_name, pk AS key_sequence, null as key_name + FROM pragma_table_info(table_name) + JOIN (SELECT null AS catalog_name, null AS schema_name, name AS table_name, type AS table_type + FROM sqlite_master) where 1=1 AND pk !=0`) + + if cmd.Catalog != nil { + fmt.Fprintf(&b, " and catalog_name LIKE '%s'", *cmd.Catalog) + } + if cmd.DBSchema != nil { + fmt.Fprintf(&b, " and schema_name LIKE '%s'", *cmd.DBSchema) + } + + fmt.Fprintf(&b, " and table_name LIKE '%s'", cmd.Table) + + return doGetQuery(ctx, s.Alloc, s.db, b.String(), schema_ref.PrimaryKeys) +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoImportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.ImportedKeys), nil +} + +func (s *SQLiteFlightSQLServer) DoGetImportedKeys(ctx context.Context, ref flightsql.TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { + filter := "fk_table_name = '" + ref.Table + "'" + if ref.Catalog != nil { + filter += " AND fk_catalog_name = '" + *ref.Catalog + "'" + } + if ref.DBSchema != nil { + filter += " AND fk_schema_name = '" + *ref.DBSchema + "'" + } + query := prepareQueryForGetKeys(filter) + return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ImportedKeys) +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoExportedKeys(_ context.Context, _ flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.ExportedKeys), nil +} + +func (s *SQLiteFlightSQLServer) DoGetExportedKeys(ctx context.Context, ref flightsql.TableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { + filter := "pk_table_name = '" + ref.Table + "'" + if ref.Catalog != nil { + filter += " AND pk_catalog_name = '" + *ref.Catalog + "'" + } + if ref.DBSchema != nil { + filter += " AND pk_schema_name = '" + *ref.DBSchema + "'" + } + query := prepareQueryForGetKeys(filter) + return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys) +} + +func (s *SQLiteFlightSQLServer) GetFlightInfoCrossReference(_ context.Context, _ flightsql.CrossTableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return s.flightInfoForCommand(desc, schema_ref.CrossReference), nil +} + +func (s *SQLiteFlightSQLServer) DoGetCrossReference(ctx context.Context, cmd flightsql.CrossTableRef) (*arrow.Schema, <-chan flight.StreamChunk, error) { + pkref := cmd.PKRef + filter := "pk_table_name = '" + pkref.Table + "'" + if pkref.Catalog != nil { + filter += " AND pk_catalog_name = '" + *pkref.Catalog + "'" + } + if pkref.DBSchema != nil { + filter += " AND pk_schema_name = '" + *pkref.DBSchema + "'" + } + + fkref := cmd.FKRef + filter += " AND fk_table_name = '" + fkref.Table + "'" + if fkref.Catalog != nil { + filter += " AND fk_catalog_name = '" + *fkref.Catalog + "'" + } + if fkref.DBSchema != nil { + filter += " AND fk_schema_name = '" + *fkref.DBSchema + "'" + } + query := prepareQueryForGetKeys(filter) + return doGetQuery(ctx, s.Alloc, s.db, query, schema_ref.ExportedKeys) +} + +func (s *SQLiteFlightSQLServer) BeginTransaction(_ context.Context, req flightsql.ActionBeginTransactionRequest) (id []byte, err error) { + tx, err := s.db.Begin() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to begin transaction: %s", err.Error()) + } + + handle := genRandomString() + s.openTransactions.Store(string(handle), tx) + return handle, nil +} + +func (s *SQLiteFlightSQLServer) EndTransaction(_ context.Context, req flightsql.ActionEndTransactionRequest) error { + if req.GetAction() == flightsql.EndTransactionUnspecified { + return status.Error(codes.InvalidArgument, "must specify Commit or Rollback to end transaction") + } + + handle := string(req.GetTransactionId()) + if tx, loaded := s.openTransactions.LoadAndDelete(handle); loaded { + txn := tx.(*sql.Tx) + switch req.GetAction() { + case flightsql.EndTransactionCommit: + if err := txn.Commit(); err != nil { + return status.Error(codes.Internal, "failed to commit transaction: "+err.Error()) + } + case flightsql.EndTransactionRollback: + if err := txn.Rollback(); err != nil { + return status.Error(codes.Internal, "failed to rollback transaction: "+err.Error()) + } + } + return nil + } + + return status.Error(codes.InvalidArgument, "transaction id not found") +} diff --git a/flightsqlserver/sqlite_tables_schema_batch_reader.go b/flightsqlserver/sqlite_tables_schema_batch_reader.go new file mode 100644 index 00000000..0f9a8010 --- /dev/null +++ b/flightsqlserver/sqlite_tables_schema_batch_reader.go @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 +// +build go1.18 + +package flightsqlserver + +import ( + "context" + "database/sql" + "strings" + "sync/atomic" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/flight" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + + // "github.com/apache/arrow-go/v18/arrow/internal/debug" + "github.com/apache/arrow-go/v18/arrow/memory" + sqlite3 "modernc.org/sqlite/lib" +) + +type SqliteTablesSchemaBatchReader struct { + refCount int64 + + mem memory.Allocator + ctx context.Context + rdr array.RecordReader + stmt *sql.Stmt + schemaBldr *array.BinaryBuilder + record arrow.Record + err error +} + +func NewSqliteTablesSchemaBatchReader(ctx context.Context, mem memory.Allocator, rdr array.RecordReader, db *sql.DB, mainQuery string) (*SqliteTablesSchemaBatchReader, error) { + schemaQuery := `SELECT table_name, name, type, [notnull] + FROM pragma_table_info(table_name) + JOIN (` + mainQuery + `) WHERE table_name = ?` + + stmt, err := db.PrepareContext(ctx, schemaQuery) + if err != nil { + rdr.Release() + return nil, err + } + + return &SqliteTablesSchemaBatchReader{ + refCount: 1, + ctx: ctx, + rdr: rdr, + stmt: stmt, + mem: mem, + schemaBldr: array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary), + }, nil +} + +func (s *SqliteTablesSchemaBatchReader) Err() error { return s.err } + +func (s *SqliteTablesSchemaBatchReader) Retain() { atomic.AddInt64(&s.refCount, 1) } + +func (s *SqliteTablesSchemaBatchReader) Release() { + // debug.Assert(atomic.LoadInt64(&s.refCount) > 0, "too many releases") + + if atomic.AddInt64(&s.refCount, -1) == 0 { + s.rdr.Release() + s.stmt.Close() + s.schemaBldr.Release() + if s.record != nil { + s.record.Release() + s.record = nil + } + } +} + +func (s *SqliteTablesSchemaBatchReader) Schema() *arrow.Schema { + fields := append(s.rdr.Schema().Fields(), + arrow.Field{Name: "table_schema", Type: arrow.BinaryTypes.Binary}) + return arrow.NewSchema(fields, nil) +} + +func (s *SqliteTablesSchemaBatchReader) Record() arrow.Record { return s.record } + +func getSqlTypeFromTypeName(sqltype string) int { + if sqltype == "" { + return sqlite3.SQLITE_NULL + } + + sqltype = strings.ToLower(sqltype) + + if strings.HasPrefix(sqltype, "varchar") || strings.HasPrefix(sqltype, "char") { + return sqlite3.SQLITE_TEXT + } + + switch sqltype { + case "int", "integer": + return sqlite3.SQLITE_INTEGER + case "real": + return sqlite3.SQLITE_FLOAT + case "blob": + return sqlite3.SQLITE_BLOB + case "text", "date": + return sqlite3.SQLITE_TEXT + default: + return sqlite3.SQLITE_NULL + } +} + +func getPrecisionFromCol(sqltype int) int { + switch sqltype { + case sqlite3.SQLITE_INTEGER: + return 10 + case sqlite3.SQLITE_FLOAT: + return 15 + } + return 0 +} + +func getColumnMetadata(bldr *flightsql.ColumnMetadataBuilder, sqltype int, table string) arrow.Metadata { + defer bldr.Clear() + + bldr.Scale(15).IsReadOnly(false).IsAutoIncrement(false) + if table != "" { + bldr.TableName(table) + } + switch sqltype { + case sqlite3.SQLITE_TEXT, sqlite3.SQLITE_BLOB: + default: + bldr.Precision(int32(getPrecisionFromCol(sqltype))) + } + + return bldr.Metadata() +} + +func (s *SqliteTablesSchemaBatchReader) Next() bool { + if s.record != nil { + s.record.Release() + s.record = nil + } + + if !s.rdr.Next() { + return false + } + + rec := s.rdr.Record() + tableNameArr := rec.Column(rec.Schema().FieldIndices("table_name")[0]).(*array.String) + + bldr := flightsql.NewColumnMetadataBuilder() + columnFields := make([]arrow.Field, 0) + for i := 0; i < tableNameArr.Len(); i++ { + table := tableNameArr.Value(i) + rows, err := s.stmt.QueryContext(s.ctx, table) + if err != nil { + s.err = err + return false + } + + var tableName, name, typ string + var nn int + for rows.Next() { + if err := rows.Scan(&tableName, &name, &typ, &nn); err != nil { + rows.Close() + s.err = err + return false + } + + columnFields = append(columnFields, arrow.Field{ + Name: name, + Type: getArrowTypeFromString(typ), + Nullable: nn == 0, + Metadata: getColumnMetadata(bldr, getSqlTypeFromTypeName(typ), tableName), + }) + } + + rows.Close() + if rows.Err() != nil { + s.err = rows.Err() + return false + } + val := flight.SerializeSchema(arrow.NewSchema(columnFields, nil), s.mem) + s.schemaBldr.Append(val) + + columnFields = columnFields[:0] + } + + schemaCol := s.schemaBldr.NewArray() + defer schemaCol.Release() + + s.record = array.NewRecord(s.Schema(), append(rec.Columns(), schemaCol), rec.NumRows()) + return true +} diff --git a/flightsqlserver/type_info.go b/flightsqlserver/type_info.go new file mode 100644 index 00000000..c512a28f --- /dev/null +++ b/flightsqlserver/type_info.go @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 +// +build go1.18 + +package flightsqlserver + +import ( + "strings" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql/schema_ref" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +func GetTypeInfoResult(mem memory.Allocator) arrow.Record { + typeNames, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, + strings.NewReader(`["bit", "tinyint", "bigint", "longvarbinary", + "varbinary", "text", "longvarchar", "char", + "integer", "smallint", "float", "double", + "numeric", "varchar", "date", "time", "timestamp"]`)) + defer typeNames.Release() + + dataType, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, + strings.NewReader(`[-7, -6, -5, -4, -3, -1, -1, 1, 4, 5, 6, 8, 8, 12, 91, 92, 93]`)) + defer dataType.Release() + + columnSize, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, + strings.NewReader(`[1, 3, 19, 65536, 255, 65536, 65536, 255, 9, 5, 7, 15, 15, 255, 10, 8, 32]`)) + defer columnSize.Release() + + literalPrefix, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, + strings.NewReader(`[null, null, null, null, null, "'", "'", "'", null, null, null, null, null, "'" ,"'", "'", "'"]`)) + defer literalPrefix.Release() + + literalSuffix, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, + strings.NewReader(`[null, null, null, null, null, "'", "'", "'", null, null, null, null, null, "'" ,"'", "'", "'"]`)) + defer literalSuffix.Release() + + createParams, _, _ := array.FromJSON(mem, arrow.ListOfField(arrow.Field{Name: "item", Type: arrow.BinaryTypes.String, Nullable: false}), + strings.NewReader(`[[], [], [], [], [], ["length"], ["length"], ["length"], [], [], [], [], [], ["length"], [], [], []]`)) + defer createParams.Release() + + nullable, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, + strings.NewReader(`[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]`)) + defer nullable.Release() + + // reference for creating a boolean() array with only zeros + zeroBoolArray, _, err := array.FromJSON(mem, arrow.FixedWidthTypes.Boolean, + strings.NewReader(`[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]`), array.WithUseNumber()) + if err != nil { + panic(err) + } + defer zeroBoolArray.Release() + caseSensitive := zeroBoolArray + + searchable, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, + strings.NewReader(`[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]`)) + defer searchable.Release() + + unsignedAttribute := zeroBoolArray + fixedPrecScale := zeroBoolArray + autoUniqueVal := zeroBoolArray + + localTypeName := typeNames + + zeroIntArray, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, + strings.NewReader(`[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]`)) + defer zeroIntArray.Release() + + minimalScale := zeroIntArray + maximumScale := zeroIntArray + sqlDataType := dataType + sqlDateTimeSub := zeroIntArray + numPrecRadix := zeroIntArray + intervalPrecision := zeroIntArray + + return array.NewRecord(schema_ref.XdbcTypeInfo, []arrow.Array{ + typeNames, dataType, columnSize, literalPrefix, literalSuffix, + createParams, nullable, caseSensitive, searchable, unsignedAttribute, + fixedPrecScale, autoUniqueVal, localTypeName, minimalScale, maximumScale, + sqlDataType, sqlDateTimeSub, numPrecRadix, intervalPrecision}, 17) +} + +func GetFilteredTypeInfoResult(mem memory.Allocator, filter int32) arrow.Record { + batch := GetTypeInfoResult(mem) + defer batch.Release() + + dataTypeVector := []int32{-7, -6, -5, -4, -3, -1, -1, 1, 4, 5, 6, 8, 8, 12, 91, 92, 93} + start, end := -1, -1 + for i, v := range dataTypeVector { + if filter == v { + if start == -1 { + start = i + } + } else if start != -1 && end == -1 { + end = i + break + } + } + + return batch.NewSlice(int64(start), int64(end)) +} diff --git a/flightsqltest/driver_test.go b/flightsqltest/driver_test.go new file mode 100644 index 00000000..a1305667 --- /dev/null +++ b/flightsqltest/driver_test.go @@ -0,0 +1,1929 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 +// +build go1.18 + +package flightsqltest + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math/rand" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/apecloud/myduckserver/flightsqlserver" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/flight" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql/driver" + "github.com/apecloud/myduckserver/catalog" + + // "github.com/apache/arrow-go/v18/arrow/flight/flightsql/example" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +const defaultTableName = "drivertest" +const dataDirectory = "." +const dbFileName = "mysql.db" + +var defaultStatements = map[string]string{ + "create table": ` +CREATE TABLE %s ( + id INTEGER PRIMARY KEY, + name varchar(100), + value int +);`, + "insert": `INSERT INTO %s (id, name, value) VALUES (%d, '%s', %d);`, + "query": `SELECT * FROM %s;`, + "constraint query": `SELECT * FROM %s WHERE name LIKE '%%%s%%'`, + "placeholder query": `SELECT * FROM %s WHERE name LIKE ?`, +} + +type SqlTestSuite struct { + suite.Suite + + Config driver.DriverConfig + TableName string + Statements map[string]string + + createServer func() (flight.Server, string, error) + startServer func(flight.Server) error + stopServer func(flight.Server) +} + +func (s *SqlTestSuite) SetupSuite() { + if s.TableName == "" { + s.TableName = defaultTableName + } + + if s.Statements == nil { + s.Statements = make(map[string]string) + } + // Fill in the statements. Keep statements already defined e.g. by the + // user or suite-generator. + for k, v := range defaultStatements { + if _, found := s.Statements[k]; !found { + s.Statements[k] = v + } + } + + s.createServer = func() (flight.Server, string, error) { + provider, err := catalog.NewDBProvider(dataDirectory, dbFileName) + if err != nil { + return nil, "", err + } + sqliteServer, err := flightsqlserver.NewSQLiteFlightSQLServer(provider.Storage()) + if err != nil { + return nil, "", err + } + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(sqliteServer)) + if err := server.Init("localhost:7878"); err != nil { + return nil, "", err + } + return server, server.Addr().String(), nil + } + + require.Contains(s.T(), s.Statements, "create table") + require.Contains(s.T(), s.Statements, "insert") + require.Contains(s.T(), s.Statements, "query") + require.Contains(s.T(), s.Statements, "constraint query") + require.Contains(s.T(), s.Statements, "placeholder query") +} + +func (s *SqlTestSuite) TestOpenClose() { + t := s.T() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + require.NoError(t, db.Close()) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestCreateTable() { + t := s.T() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.TableName)) + require.NoError(t, err) + + result, err := db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + affected, err := result.RowsAffected() + require.Equal(t, int64(0), affected) + require.NoError(t, err) + + last, err := result.LastInsertId() + require.Equal(t, int64(-1), last) + require.ErrorIs(t, err, driver.ErrNotSupported) + + require.NoError(t, db.Close()) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestInsert() { + t := s.T() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.TableName)) + require.NoError(t, err) + + // Create the table + _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + // Insert data + values := map[string]int{ + "zero": 0, + "one": 1, + "minus one": -1, + "twelve": 12, + } + var stmts []string + id := 0 + for k, v := range values { + stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, id, k, v)) + id++ + } + result, err := db.Exec(strings.Join(stmts, "\n")) + require.NoError(t, err) + + affected, err := result.RowsAffected() + require.Equal(t, int64(1), affected) + require.NoError(t, err) + + require.NoError(t, db.Close()) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestQuery() { + t := s.T() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.TableName)) + require.NoError(t, err) + + // Create the table + _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + // Insert data + expected := map[string]int{ + "zero": 0, + "one": 1, + "minus one": -1, + "twelve": 12, + } + var stmts []string + id := 0 + for k, v := range expected { + stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, id, k, v)) + id++ + } + _, err = db.Exec(strings.Join(stmts, "\n")) + require.NoError(t, err) + + rows, err := db.Query(fmt.Sprintf(s.Statements["query"], s.TableName)) + require.NoError(t, err) + + // Check result + actual := make(map[string]int, len(expected)) + for rows.Next() { + var name string + var id, value int + require.NoError(t, rows.Scan(&id, &name, &value)) + actual[name] = value + } + require.NoError(t, db.Close()) + require.EqualValues(t, expected, actual) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestQueryWithEmptyResultset() { + t := s.T() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.TableName)) + require.NoError(t, err) + + // Create the table + _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + rows, err := db.Query(fmt.Sprintf(s.Statements["query"], s.TableName)) + require.NoError(t, err) + require.False(t, rows.Next()) + + row := db.QueryRow(fmt.Sprintf(s.Statements["query"], s.TableName)) + require.NotNil(t, row) + require.NoError(t, row.Err()) + + target := make(map[string]any) + err = row.Scan(&target) + require.ErrorIs(t, err, sql.ErrNoRows) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestPreparedQuery() { + t := s.T() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.TableName)) + require.NoError(t, err) + + // Create the table + _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + // Insert data + expected := map[string]int{ + "zero": 0, + "one": 1, + "minus one": -1, + "twelve": 12, + } + var stmts []string + id := 0 + for k, v := range expected { + stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, id, k, v)) + id++ + } + _, err = db.Exec(strings.Join(stmts, "\n")) + require.NoError(t, err) + + // Do query + stmt, err := db.Prepare(fmt.Sprintf(s.Statements["query"], s.TableName)) + require.NoError(t, err) + + rows, err := stmt.Query() + require.NoError(t, err) + + // Check result + actual := make(map[string]int, len(expected)) + for rows.Next() { + var name string + var id, value int + require.NoError(t, rows.Scan(&id, &name, &value)) + actual[name] = value + } + require.NoError(t, db.Close()) + require.EqualValues(t, expected, actual) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsManualPrematureClose tests concurrent rows implementation for closing right after loading. +// Is expected that rows' internal engine update its status, preventing errors and inconsistent further operations. +func (s *SqlTestSuite) TestRowsManualPrematureClose() { + t := s.T() + t.Skip() + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table + const tableName = `TestRowsManualPrematureClose` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY , name VARCHAR(300), value INT);` + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) + require.NoError(t, err) + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount int = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + rows, err := db.QueryContext(context.TODO(), sqlSelectAll) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + // Close Rows normally + require.NoError(t, rows.Close()) + + require.False(t, rows.Next()) + + // Safe double-closing + require.NoError(t, rows.Close()) + + // Columns() should return an error after rows.Close() (sql: Rows are closed) + columns, err := rows.Columns() + require.Error(t, err) + require.Empty(t, columns) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsNormalExhaustion tests concurrent rows implementation for normal query/netx/close operation +func (s *SqlTestSuite) TestRowsNormalExhaustion() { + t := s.T() + t.Skip() + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table + const tableName = `TestRowsNormalExhaustion` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY , name VARCHAR(300), value INT);` + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) + require.NoError(t, err) + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount int = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do Query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + rows, err := db.QueryContext(ctx, sqlSelectAll) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + var ( + actualCount = 0 + xid, + xvalue int + xname string + ) + + for rows.Next() { + require.NoError(t, rows.Scan(&xid, &xname, &xvalue)) + actualCount++ + } + + require.Equal(t, rowCount, actualCount) + require.NoError(t, rows.Close()) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsPrematureCloseDuringNextLoop ensures that: +// - closing during Next() loop doesn't trigger concurrency errors. +// - the interation is properly/promptly interrupted. +func (s *SqlTestSuite) TestRowsPrematureCloseDuringNextLoop() { + t := s.T() + t.Skip() + + // Create and start the server. + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table. + const tableName = `TestRowsPrematureCloseDuringNextLoop` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY, name VARCHAR(300), value INT);` + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + time.Sleep(200 * time.Millisecond) + // Do query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + rows, err := db.QueryContext(context.TODO(), sqlSelectAll) + require.NoError(t, err) + require.NotNil(t, rows) + + const closeAfterNRows = 10 + var ( + i, + xid, + xvalue int + xname string + ) + + for rows.Next() { + err = rows.Scan(&xid, &xname, &xvalue) + require.NoError(t, err) + + i++ + if i >= closeAfterNRows { + require.NoError(t, rows.Close()) + } + } + require.NoError(t, rows.Err()) + + require.Equal(t, closeAfterNRows, i) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsInterruptionByContextManualCancellation cancels the context before it starts retrieving rows.Next(). +// it gives time for cancellation propagation, and ensures that no further data was retrieved. +func (s *SqlTestSuite) TestRowsInterruptionByContextManualCancellation() { + t := s.T() + t.Skip() + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + // Create the table + const tableName = `TestRowsInterruptionByContextManualCancellation` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY , name VARCHAR(300), value BigINT);` + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) + require.NoError(t, err) + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (id,name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`(%d, '%s', %d),`, i, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + rows, err := db.QueryContext(ctx, sqlSelectAll) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + defer rows.Close() + + go cancel() + + time.Sleep(100 * time.Millisecond) + + count := 0 + for rows.Next() { + count++ + } + + require.Zero(t, count) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsInterruptionByContextTimeout forces a timeout, and ensures no further data is retrieved after that. +func (s *SqlTestSuite) TestRowsInterruptionByContextTimeout() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table + const tableName = `TestRowsInterruptionByContextTimeout` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY , name VARCHAR(300), value Bigint);` + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) + require.NoError(t, err) + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (id, name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`(%d, '%s', %d),`, i, getRandomString(gen, randStringLen), gen.Int())) + // fmt.Println(i, getRandomString(gen, randStringLen), gen.Int()) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do query + const ( + timeout = 1500 * time.Millisecond + sqlSelectAll = `SELECT id, name, value FROM ` + tableName + ) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + rows, err := db.QueryContext(ctx, sqlSelectAll) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + defer rows.Close() + + // eventually, after time.Sleep(), the context will be cancelled. + // then, rows.Next() should return false, and <-ctx.Done() will never be tested. + for rows.Next() { + select { + case <-ctx.Done(): + t.Fatal("cancellation didn't prevent more records to be read") + default: + time.Sleep(time.Second) + } + } + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsManualPrematureCloseStmt tests concurrent rows implementation for closing right after loading. +// Is expected that rows' internal engine update its status, preventing errors and inconsistent further operations. +func (s *SqlTestSuite) TestRowsManualPrematureCloseStmt() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table + const tableName = `TestRowsManualPrematureCloseStmt` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(300), value INT);` + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount int = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + stmt, err := db.PrepareContext(ctx, sqlSelectAll) + require.NoError(t, err) + + rows, err := stmt.QueryContext(ctx) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + // Close Rows normally + require.NoError(t, rows.Close()) + + require.False(t, rows.Next()) + + // Safe double-closing + require.NoError(t, rows.Close()) + + // Columns() should return an error after rows.Close() (sql: Rows are closed) + columns, err := rows.Columns() + require.Error(t, err) + require.Empty(t, columns) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsNormalExhaustionStmt tests concurrent rows implementation for normal query/netx/close operation +func (s *SqlTestSuite) TestRowsNormalExhaustionStmt() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table + const tableName = `TestRowsNormalExhaustionStmt` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(300), value INT);` + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount int = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do Query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + stmt, err := db.PrepareContext(ctx, sqlSelectAll) + require.NoError(t, err) + + rows, err := stmt.QueryContext(ctx) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + var ( + actualCount = 0 + xid, + xvalue int + xname string + ) + + for rows.Next() { + require.NoError(t, rows.Scan(&xid, &xname, &xvalue)) + actualCount++ + } + + require.Equal(t, rowCount, actualCount) + require.NoError(t, rows.Close()) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsPrematureCloseDuringNextLoopStmt ensures that: +// - closing during Next() loop doesn't trigger concurrency errors. +// - the interation is properly/promptly interrupted. +func (s *SqlTestSuite) TestRowsPrematureCloseDuringNextLoopStmt() { + t := s.T() + t.Skip() + // Create and start the server. + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table. + const tableName = `TestRowsPrematureCloseDuringNextLoopStmt` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(300), value INT);` + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + stmt, err := db.PrepareContext(ctx, sqlSelectAll) + require.NoError(t, err) + + rows, err := stmt.QueryContext(ctx) + + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + const closeAfterNRows = 10 + var ( + i, + xid, + xvalue int + xname string + ) + + for rows.Next() { + err = rows.Scan(&xid, &xname, &xvalue) + require.NoError(t, err) + + i++ + if i >= closeAfterNRows { + require.NoError(t, rows.Close()) + } + } + + require.Equal(t, closeAfterNRows, i) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsInterruptionByContextManualCancellationStmt cancels the context before it starts retrieving rows.Next(). +// it gives time for cancellation propagation, and ensures that no further data was retrieved. +func (s *SqlTestSuite) TestRowsInterruptionByContextManualCancellationStmt() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table + const tableName = `TestRowsInterruptionByContextManualCancellationStmt` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(300), value INT);` + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do query + const sqlSelectAll = `SELECT id, name, value FROM ` + tableName + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + stmt, err := db.PrepareContext(ctx, sqlSelectAll) + require.NoError(t, err) + + rows, err := stmt.QueryContext(ctx) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + defer rows.Close() + + go cancel() + + time.Sleep(100 * time.Millisecond) + + count := 0 + for rows.Next() { + count++ + } + + require.Zero(t, count) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +// TestRowsInterruptionByContextTimeoutStmt forces a timeout, and ensures no further data is retrieved after that. +func (s *SqlTestSuite) TestRowsInterruptionByContextTimeoutStmt() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + + defer s.stopServer(server) + + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + + defer db.Close() + + // Create the table + const tableName = `TestRowsInterruptionByContextTimeoutStmt` + const ddlCreateTable = `CREATE TABLE ` + tableName + ` (id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(300), value INT);` + + _, err = db.Exec(ddlCreateTable) + require.NoError(t, err) + + // generate data enough for chunked concurrent test: + const rowCount = 6000 + const randStringLen = 250 + const sqlInsert = `INSERT INTO ` + tableName + ` (name,value) VALUES ` + + gen := rand.New(rand.NewSource(time.Now().UnixNano())) + + var sb strings.Builder + sb.WriteString(sqlInsert) + + for i := 0; i < rowCount; i++ { + sb.WriteString(fmt.Sprintf(`('%s', %d),`, getRandomString(gen, randStringLen), gen.Int())) + } + + insertQuery := strings.TrimSuffix(sb.String(), ",") + + rs, err := db.Exec(insertQuery) + require.NoError(t, err) + + insertedRows, err := rs.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(rowCount), insertedRows) + + // Do query + const ( + timeout = 1500 * time.Millisecond + sqlSelectAll = `SELECT id, name, value FROM ` + tableName + ) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + stmt, err := db.PrepareContext(ctx, sqlSelectAll) + require.NoError(t, err) + + rows, err := stmt.QueryContext(ctx) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Err()) + + defer rows.Close() + + // eventually, after time.Sleep(), the context will be cancelled. + // then, rows.Next() should return false, and <-ctx.Done() will never be tested. + for rows.Next() { + select { + case <-ctx.Done(): + t.Fatal("cancellation didn't prevent more records to be read") + default: + time.Sleep(time.Second) + } + } + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestPreparedQueryWithConstraint() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + // Create the table + _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + // Insert data + data := map[string]int{ + "zero": 0, + "one": 1, + "minus one": -1, + "twelve": 12, + } + var stmts []string + for k, v := range data { + stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, k, v)) + } + _, err = db.Exec(strings.Join(stmts, "\n")) + require.NoError(t, err) + + // Do query + stmt, err := db.Prepare(fmt.Sprintf(s.Statements["constraint query"], s.TableName, "one")) + require.NoError(t, err) + + rows, err := stmt.Query() + require.NoError(t, err) + + // Check result + expected := map[string]int{ + "one": 1, + "minus one": -1, + } + actual := make(map[string]int, len(expected)) + for rows.Next() { + var name string + var id, value int + require.NoError(t, rows.Scan(&id, &name, &value)) + actual[name] = value + } + require.NoError(t, db.Close()) + require.EqualValues(t, expected, actual) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestPreparedQueryWithPlaceholder() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + // Create the table + _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + // Insert data + data := map[string]int{ + "zero": 0, + "one": 1, + "minus one": -1, + "twelve": 12, + } + var stmts []string + for k, v := range data { + stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], s.TableName, k, v)) + } + _, err = db.Exec(strings.Join(stmts, "\n")) + require.NoError(t, err) + + // Do query + query := fmt.Sprintf(s.Statements["placeholder query"], s.TableName) + stmt, err := db.Prepare(query) + require.NoError(t, err) + + params := []interface{}{"%%one%%"} + rows, err := stmt.Query(params...) + require.NoError(t, err) + + // Check result + expected := map[string]int{ + "one": 1, + "minus one": -1, + } + actual := make(map[string]int, len(expected)) + for rows.Next() { + var name string + var id, value int + require.NoError(t, rows.Scan(&id, &name, &value)) + actual[name] = value + } + require.NoError(t, db.Close()) + require.EqualValues(t, expected, actual) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestTxRollback() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + tx, err := db.Begin() + require.NoError(t, err) + + _, err = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", s.TableName)) + require.NoError(t, err) + + // Create the table + _, err = tx.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + // Insert data + data := map[string]int{ + "zero": 0, + "one": 1, + "minus one": -1, + "twelve": 12, + } + for k, v := range data { + stmt := fmt.Sprintf(s.Statements["insert"], s.TableName, k, v) + _, err = tx.Exec(stmt) + require.NoError(t, err) + } + + // Rollback the transaction + require.NoError(t, tx.Rollback()) + + // Check result + tbls := `SELECT name FROM sqlite_schema WHERE type ='table' AND name NOT LIKE 'sqlite_%';` + rows, err := db.Query(tbls) + require.NoError(t, err) + count := 0 + for rows.Next() { + count++ + } + require.Equal(t, 0, count) + require.NoError(t, db.Close()) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +func (s *SqlTestSuite) TestTxCommit() { + t := s.T() + t.Skip() + + // Create and start the server + server, addr, err := s.createServer() + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(s.T(), s.startServer(server)) + }() + defer s.stopServer(server) + time.Sleep(100 * time.Millisecond) + + // Configure client + cfg := s.Config + cfg.Address = addr + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + tx, err := db.Begin() + require.NoError(t, err) + + // Create the table + _, err = tx.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName)) + require.NoError(t, err) + + // Insert data + data := map[string]int{ + "zero": 0, + "one": 1, + "minus one": -1, + "twelve": 12, + } + for k, v := range data { + stmt := fmt.Sprintf(s.Statements["insert"], s.TableName, k, v) + _, err = tx.Exec(stmt) + require.NoError(t, err) + } + + // Commit the transaction + require.NoError(t, tx.Commit()) + + // Check if the table exists + tbls := `SELECT name FROM sqlite_schema WHERE type ='table' AND name NOT LIKE 'sqlite_%';` + rows, err := db.Query(tbls) + require.NoError(t, err) + + var tables []string + for rows.Next() { + var name string + require.NoError(t, rows.Scan(&name)) + tables = append(tables, name) + } + require.Contains(t, tables, "drivertest") + + // Check the actual data + stmt, err := db.Prepare(fmt.Sprintf(s.Statements["query"], s.TableName)) + require.NoError(t, err) + + rows, err = stmt.Query() + require.NoError(t, err) + + // Check result + actual := make(map[string]int, len(data)) + for rows.Next() { + var name string + var id, value int + require.NoError(t, rows.Scan(&id, &name, &value)) + actual[name] = value + } + require.NoError(t, db.Close()) + require.EqualValues(t, data, actual) + + // Tear-down server + s.stopServer(server) + wg.Wait() +} + +/*** BACKEND tests ***/ + +func TestSqliteBackend(t *testing.T) { + // mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + s := &SqlTestSuite{ + Config: driver.DriverConfig{ + Timeout: 5 * time.Second, + }, + } + + // s.createServer = func() (flight.Server, string, error) { + // server := flight.NewServerWithMiddleware(nil) + + // // Setup the SQLite backend + // db, err := sql.Open("sqlite", ":memory:") + // if err != nil { + // return nil, "", err + // } + // sqliteServer, err := example.NewSQLiteFlightSQLServer(db) + // if err != nil { + // return nil, "", err + // } + // sqliteServer.Alloc = mem + + // // Connect the FlightSQL frontend to the backend + // server.RegisterFlightService(flightsql.NewFlightServer(sqliteServer)) + // if err := server.Init("localhost:0"); err != nil { + // return nil, "", err + // } + // server.SetShutdownOnSignals(os.Interrupt, os.Kill) + // return server, server.Addr().String(), nil + // } + s.startServer = func(server flight.Server) error { return server.Serve() } + s.stopServer = func(server flight.Server) { server.Shutdown() } + + suite.Run(t, s) +} + +func TestPreparedStatementSchema(t *testing.T) { + t.Skip() + // Setup the expected test + backend := &MockServer{ + PreparedStatementParameterSchema: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}, Nullable: false}}, nil), + DataSchema: arrow.NewSchema([]arrow.Field{ + {Name: "time", Type: &arrow.Time64Type{Unit: arrow.Nanosecond}, Nullable: true}, + {Name: "value", Type: &arrow.Int64Type{}, Nullable: false}, + }, nil), + Data: "[]", + } + + // Instantiate a mock server + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(backend)) + require.NoError(t, server.Init("localhost:0")) + server.SetShutdownOnSignals(os.Interrupt, os.Kill) + go server.Serve() + defer server.Shutdown() + + // Configure client + cfg := driver.DriverConfig{ + Timeout: 5 * time.Second, + Address: server.Addr().String(), + } + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + // Do query + stmt, err := db.Prepare("SELECT * FROM foo WHERE name LIKE ?") + require.NoError(t, err) + + _, err = stmt.Query() + require.ErrorContains(t, err, "expected 1 arguments, got 0") + + // Test for error issues by driver + _, err = stmt.Query(23) + require.ErrorContains(t, err, "invalid value type int64 for builder *array.StringBuilder") + + rows, err := stmt.Query("master") + require.NoError(t, err) + require.NotNil(t, rows) +} + +func TestPreparedStatementNoSchema(t *testing.T) { + t.Skip() + // Setup the expected test + backend := &MockServer{ + DataSchema: arrow.NewSchema([]arrow.Field{ + {Name: "time", Type: &arrow.Time64Type{Unit: arrow.Nanosecond}, Nullable: true}, + {Name: "value", Type: &arrow.Int64Type{}, Nullable: false}, + }, nil), + Data: "[]", + ExpectedPreparedStatementSchema: arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}, Nullable: false}}, nil), + } + + // Instantiate a mock server + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(backend)) + require.NoError(t, server.Init("localhost:0")) + server.SetShutdownOnSignals(os.Interrupt, os.Kill) + go server.Serve() + defer server.Shutdown() + + // Configure client + cfg := driver.DriverConfig{ + Timeout: 5 * time.Second, + Address: server.Addr().String(), + } + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + // Do query + stmt, err := db.Prepare("SELECT * FROM foo WHERE name LIKE ?") + require.NoError(t, err) + + _, err = stmt.Query() + require.NoError(t, err, "expected 1 arguments, got 0") + + // Test for error issued by server due to missing parameter schema + _, err = stmt.Query(23) + require.ErrorContains(t, err, "parameter schema: unexpected") + + rows, err := stmt.Query("master") + require.NoError(t, err) + require.NotNil(t, rows) +} + +func TestNoPreparedStatementImplemented(t *testing.T) { + t.Skip() + // Setup the expected test + backend := &MockServer{ + DataSchema: arrow.NewSchema([]arrow.Field{ + {Name: "time", Type: &arrow.Time64Type{Unit: arrow.Nanosecond}, Nullable: true}, + {Name: "value", Type: &arrow.Int64Type{}, Nullable: false}, + }, nil), + Data: "[]", + PreparedStatementError: "not supported", + } + + // Instantiate a mock server + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(backend)) + require.NoError(t, server.Init("localhost:0")) + server.SetShutdownOnSignals(os.Interrupt, os.Kill) + go server.Serve() + defer server.Shutdown() + + // Configure client + cfg := driver.DriverConfig{ + Timeout: 5 * time.Second, + Address: server.Addr().String(), + } + db, err := sql.Open("flightsql", cfg.DSN()) + require.NoError(t, err) + defer db.Close() + + // Do query + _, err = db.Query("SELECT * FROM foo") + require.NoError(t, err) +} + +// Mockup database server +type MockServer struct { + flightsql.BaseServer + DataSchema *arrow.Schema + PreparedStatementParameterSchema *arrow.Schema + PreparedStatementError string + Data string + + ExpectedPreparedStatementSchema *arrow.Schema +} + +func (s *MockServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (flightsql.ActionCreatePreparedStatementResult, error) { + if s.PreparedStatementError != "" { + return flightsql.ActionCreatePreparedStatementResult{}, errors.New(s.PreparedStatementError) + } + return flightsql.ActionCreatePreparedStatementResult{ + Handle: []byte("prepared"), + DatasetSchema: s.DataSchema, + ParameterSchema: s.PreparedStatementParameterSchema, + }, nil +} + +func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) ([]byte, error) { + if s.ExpectedPreparedStatementSchema != nil { + if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) { + return nil, errors.New("parameter schema: unexpected") + } + return qry.GetPreparedStatementHandle(), nil + } + + if s.PreparedStatementParameterSchema != nil && !s.PreparedStatementParameterSchema.Equal(r.Schema()) { + return nil, fmt.Errorf("parameter schema: %w", arrow.ErrInvalid) + } + + // GH-35328: it's rare, but this function can complete execution and return + // closing the reader *after* the schema is written but *before* the parameter batch + // is written (race condition based on goroutine scheduling). In that situation, + // the client call to Write the parameter record batch will return an io.EOF because + // this end of the connection will have closed before it attempted to send the batch. + // This created a flaky test situation that was difficult to reproduce (1-4 failures + // in 5000 runs). We can avoid this flakiness by simply *explicitly* draining the + // record batch messages from the reader before returning. + for r.Next() { + } + + return qry.GetPreparedStatementHandle(), nil +} + +func (s *MockServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { + record, _, err := array.RecordFromJSON(memory.DefaultAllocator, s.DataSchema, strings.NewReader(s.Data)) + if err != nil { + return nil, nil, err + } + chunk := make(chan flight.StreamChunk) + go func() { + defer close(chunk) + chunk <- flight.StreamChunk{ + Data: record, + Desc: nil, + Err: nil, + } + }() + return s.DataSchema, chunk, nil +} + +func (s *MockServer) GetFlightInfoPreparedStatement(ctx context.Context, stmt flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + handle := stmt.GetPreparedStatementHandle() + ticket, err := flightsql.CreateStatementQueryTicket(handle) + if err != nil { + return nil, err + } + return &flight.FlightInfo{ + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{ + {Ticket: &flight.Ticket{Ticket: ticket}}, + }, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +func (s *MockServer) GetFlightInfoStatement(_ context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) { + handle := query.GetTransactionId() + ticket, err := flightsql.CreateStatementQueryTicket(handle) + if err != nil { + return nil, err + } + return &flight.FlightInfo{ + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{ + {Ticket: &flight.Ticket{Ticket: ticket}}, + }, + TotalRecords: -1, + TotalBytes: -1, + }, nil +} + +const getRandomStringCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789. " + +var getRandomStringCharsetLen = len(getRandomStringCharset) + +func getRandomString(gen *rand.Rand, length int) string { + result := make([]byte, length) + + for i := range result { + result[i] = getRandomStringCharset[rand.Intn(getRandomStringCharsetLen)] + } + + return string(result) +} diff --git a/go.mod b/go.mod index 9fd48c52..74ab1115 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.23.4 require ( github.com/Shopify/toxiproxy/v2 v2.9.0 + github.com/apache/arrow-adbc/go/adbc v1.3.0 github.com/apache/arrow-go/v18 v18.0.0 github.com/aws/aws-sdk-go-v2 v1.30.4 github.com/aws/aws-sdk-go-v2/config v1.27.31 @@ -28,10 +29,13 @@ require ( github.com/prometheus/client_golang v1.20.3 github.com/rs/zerolog v1.33.0 github.com/shopspring/decimal v1.3.1 - github.com/sirupsen/logrus v1.8.1 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 golang.org/x/text v0.19.0 + google.golang.org/grpc v1.67.1 + google.golang.org/protobuf v1.35.1 gopkg.in/src-d/go-errors.v1 v1.0.0 + modernc.org/sqlite v1.33.1 vitess.io/vitess v0.21.1 ) @@ -63,6 +67,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/biogo/store v0.0.0-20201120204734-aad293a2328f // indirect github.com/blevesearch/snowballstem v0.9.0 // indirect + github.com/bluele/gcache v0.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect github.com/cockroachdb/redact v1.1.3 // indirect @@ -85,6 +90,7 @@ require ( github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -97,6 +103,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pierrre/geohash v1.0.0 // indirect @@ -107,6 +114,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.59.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rs/xid v1.5.0 // indirect github.com/sasha-s/go-deadlock v0.3.1 // indirect @@ -118,18 +126,23 @@ require ( github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/otel v1.31.0 // indirect go.opentelemetry.io/otel/trace v1.31.0 // indirect - golang.org/x/crypto v0.27.0 // indirect + golang.org/x/crypto v0.28.0 // indirect golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect golang.org/x/mod v0.21.0 // indirect + golang.org/x/net v0.30.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/tools v0.26.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect - google.golang.org/grpc v1.67.1 // indirect - google.golang.org/protobuf v1.35.1 // indirect + google.golang.org/genproto v0.0.0-20241021214115-324edc3d5d38 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/gc/v3 v3.0.0-20240801135723-a856999a2e4a // indirect + modernc.org/libc v1.60.1 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.8.0 // indirect + modernc.org/strutil v1.2.0 // indirect + modernc.org/token v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index d8bd5f8f..cae0cc4d 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/apache/arrow-adbc/go/adbc v1.3.0 h1:cdH/jmQX+3vdSVjt2CLNrlwfE7hE0Dfe3i/vnWD6OIg= +github.com/apache/arrow-adbc/go/adbc v1.3.0/go.mod h1:KJTcRJ1+Dknd/K6bNHwv1+DaEVKZnqcApqf3IMKIkuk= github.com/apache/arrow-go/v18 v18.0.0 h1:1dBDaSbH3LtulTyOVYaBCHO3yVRwjV+TZaqn3g6V7ZM= github.com/apache/arrow-go/v18 v18.0.0/go.mod h1:t6+cWRSmKgdQ6HsxisQjok+jBpKGhRDiqcf3p0p/F+A= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= @@ -113,6 +115,8 @@ github.com/biogo/store v0.0.0-20201120204734-aad293a2328f h1:+6okTAeUsUrdQr/qN7f github.com/biogo/store v0.0.0-20201120204734-aad293a2328f/go.mod h1:z52shMwD6SGwRg2iYFjjDwX5Ene4ENTw6HfXraUy/08= github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s= github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs= +github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= +github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/broady/gogeohash v0.0.0-20120525094510-7b2c40d64042 h1:iEdmkrNMLXbM7ecffOAtZJQOQUTE4iMonxrb5opUgE4= github.com/broady/gogeohash v0.0.0-20120525094510-7b2c40d64042/go.mod h1:f1L9YvXvlt9JTa+A17trQjSMM6bV40f+tHjB+Pi+Fqk= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= @@ -292,6 +296,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -330,6 +336,8 @@ github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= @@ -496,6 +504,8 @@ github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7 github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= @@ -578,6 +588,8 @@ github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+Gx github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -610,8 +622,8 @@ github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeV github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= -github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -631,6 +643,8 @@ github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3 github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -712,8 +726,8 @@ golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= @@ -797,7 +811,6 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -817,6 +830,7 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -884,12 +898,12 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 h1:BulPr26Jqjnd4eYDVe+YvyR7Yc2vJGkO5/0UxD0/jZU= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:hL97c3SYopEHblzpxRL4lSs523++l8DYxGM1FQiYmb4= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 h1:hjSy6tcFQZ171igDaN5QHOw2n6vx40juYbC/x67CEhc= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:qpvKtACPCQhAdu3PyQgV4l3LMXZEtft7y8QcarRsp9I= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20241021214115-324edc3d5d38 h1:Q3nlH8iSQSRUwOskjbcSMcF2jiYMNiQYZ0c2KEJLKKU= +google.golang.org/genproto v0.0.0-20241021214115-324edc3d5d38/go.mod h1:xBI+tzfqGGN2JBeSebfKXFSdBpWVQ7sLW40PTupVRm4= +google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 h1:fVoAXEKA4+yufmbdVYv+SE73+cPZbbbe8paLsHfkK+U= +google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53/go.mod h1:riSXTwQ4+nqmPGtobMFyW5FqVAmIs0St6VPp4Ug7CE4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= @@ -957,6 +971,32 @@ honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= +modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= +modernc.org/ccgo/v4 v4.21.0 h1:kKPI3dF7RIag8YcToh5ZwDcVMIv6VGa0ED5cvh0LMW4= +modernc.org/ccgo/v4 v4.21.0/go.mod h1:h6kt6H/A2+ew/3MW/p6KEoQmrq/i3pr0J/SiwiaF/g0= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.5.0 h1:bJ9ChznK1L1mUtAQtxi0wi5AtAs5jQuw4PrPHO5pb6M= +modernc.org/gc/v2 v2.5.0/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= +modernc.org/gc/v3 v3.0.0-20240801135723-a856999a2e4a h1:CfbpOLEo2IwNzJdMvE8aiRbPMxoTpgAJeyePh0SmO8M= +modernc.org/gc/v3 v3.0.0-20240801135723-a856999a2e4a/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= +modernc.org/libc v1.60.1 h1:at373l8IFRTkJIkAU85BIuUoBM4T1b51ds0E1ovPG2s= +modernc.org/libc v1.60.1/go.mod h1:xJuobKuNxKH3RUatS7GjR+suWj+5c2K7bi4m/S5arOY= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= +modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= +modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= +modernc.org/sqlite v1.33.1 h1:trb6Z3YYoeM9eDL1O8do81kP+0ejv+YzgyFo+Gwy0nM= +modernc.org/sqlite v1.33.1/go.mod h1:pXV2xHxhzXZsgT/RtTFAPY6JJDEvOTcTdwADQCCWD4k= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= vitess.io/vitess v0.21.1 h1:XpuyM1Jit6eKz4tPodcl1fOxAdIa86m3k2rPmQnw2co= diff --git a/main.go b/main.go index d6ada236..5cfdc511 100644 --- a/main.go +++ b/main.go @@ -18,9 +18,16 @@ import ( "context" "flag" "fmt" + "log" + "net" + "os" + "strconv" + "github.com/apache/arrow-go/v18/arrow/flight" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" "github.com/apecloud/myduckserver/backend" "github.com/apecloud/myduckserver/catalog" + "github.com/apecloud/myduckserver/flightsqlserver" "github.com/apecloud/myduckserver/myfunc" "github.com/apecloud/myduckserver/pgserver" "github.com/apecloud/myduckserver/pgserver/logrepl" @@ -52,6 +59,9 @@ var ( postgresPort = 5432 + // Shared between the MySQL and Postgres servers. + superuserPassword = "" + defaultTimeZone = "" // for Restore @@ -59,6 +69,9 @@ var ( restoreEndpoint = "" restoreAccessKeyId = "" restoreSecretAccessKey = "" + + flightsqlHost = "localhost" + flightsqlPort = -1 // Disabled by default ) func init() { @@ -71,6 +84,8 @@ func init() { flag.StringVar(&defaultDb, "default-db", defaultDb, "The default database name to use.") flag.IntVar(&logLevel, "loglevel", logLevel, "The log level to use.") + flag.StringVar(&superuserPassword, "superuser-password", superuserPassword, "The password for the superuser account.") + flag.StringVar(&replicaOptions.ReportHost, "report-host", replicaOptions.ReportHost, "The host name or IP address of the replica to be reported to the source during replica registration.") flag.IntVar(&replicaOptions.ReportPort, "report-port", replicaOptions.ReportPort, "The TCP/IP port number for connecting to the replica, to be reported to the source during replica registration.") flag.StringVar(&replicaOptions.ReportUser, "report-user", replicaOptions.ReportUser, "The account user name of the replica to be reported to the source during replica registration.") @@ -83,6 +98,9 @@ func init() { flag.StringVar(&restoreEndpoint, "restore-endpoint", restoreEndpoint, "The endpoint of object storage service to restore from.") flag.StringVar(&restoreAccessKeyId, "restore-access-key-id", restoreAccessKeyId, "The access key ID to restore from.") flag.StringVar(&restoreSecretAccessKey, "restore-secret-access-key", restoreSecretAccessKey, "The secret access key to restore from.") + + flag.StringVar(&flightsqlHost, "flightsql-host", flightsqlHost, "hostname for the Flight SQL service") + flag.IntVar(&flightsqlPort, "flightsql-port", flightsqlPort, "port number for the Flight SQL service") } func ensureSQLTranslate() { @@ -128,7 +146,7 @@ func main() { engine.Analyzer.Catalog.RegisterFunction(sql.NewContext(context.Background()), myfunc.ExtraBuiltIns...) engine.Analyzer.Catalog.MySQLDb.SetPlugins(plugin.AuthPlugins) - if err := setPersister(provider, engine); err != nil { + if err := setPersister(provider, engine, "root", superuserPassword); err != nil { logrus.Fatalln("Failed to set the persister:", err) } @@ -149,6 +167,7 @@ func main() { pgServer, err := pgserver.NewServer( provider, address, postgresPort, + superuserPassword, func() *sql.Context { session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider) return sql.NewContext(context.Background(), sql.WithSession(session)) @@ -172,6 +191,29 @@ func main() { go pgServer.Start() } + if flightsqlPort > 0 { + + db := provider.Storage() + if err != nil { + log.Fatal(err) + } + defer db.Close() + + srv, err := flightsqlserver.NewSQLiteFlightSQLServer(db) + if err != nil { + log.Fatal(err) + } + + server := flight.NewServerWithMiddleware(nil) + server.RegisterFlightService(flightsql.NewFlightServer(srv)) + server.Init(net.JoinHostPort(*&flightsqlHost, strconv.Itoa(*&flightsqlPort))) + server.SetShutdownOnSignals(os.Interrupt, os.Kill) + + fmt.Println("Starting SQLite Flight SQL Server on", server.Addr(), "...") + + go server.Serve() + } + if err = myServer.Start(); err != nil { logrus.WithError(err).Fatalln("Failed to start MySQL-protocol server") } diff --git a/persist.go b/persist.go index ccd285f8..a7fc1fc5 100644 --- a/persist.go +++ b/persist.go @@ -39,7 +39,7 @@ func (m *MySQLPersister) Persist(ctx *sql.Context, data []byte) error { } // https://github.com/dolthub/go-mysql-server/blob/main/_example/users_example.go -func setPersister(provider sql.DatabaseProvider, engine *sqle.Engine) error { +func setPersister(provider sql.DatabaseProvider, engine *sqle.Engine, superuser, password string) error { session := memory.NewSession(sql.NewBaseSession(), provider) ctx := sql.NewContext(context.Background(), sql.WithSession(session)) ctx.SetCurrentDatabase("mysql") @@ -67,16 +67,16 @@ func setPersister(provider sql.DatabaseProvider, engine *sqle.Engine) error { } } - addAccount := func(account string, address string) { + addAccount := func(account, password, address string) { ed := mysqlDb.Editor() defer ed.Close() - mysqlDb.AddSuperUser(ed, account, address, "") + mysqlDb.AddSuperUser(ed, account, address, password) } // Modify it to "%" to allow accepting connections outside when myduckserver runs in Docker // TODO should add a config to decide this or some better way to support this // addAccount("root", "localhost") - addAccount("root", "%") + addAccount(superuser, password, "%") return nil } diff --git a/pgserver/authentication_scram.go b/pgserver/authentication_scram.go index c5c1a2e5..55355dd2 100644 --- a/pgserver/authentication_scram.go +++ b/pgserver/authentication_scram.go @@ -19,7 +19,6 @@ import ( "encoding/base64" "fmt" "net" - "os" "strings" "github.com/dolthub/doltgresql/server/auth" @@ -39,26 +38,27 @@ const ( ) // EnableAuthentication handles whether authentication is enabled. If enabled, it verifies that the given user exists, -// and checks that the encrypted password is derivable from the stored encrypted password. As the feature is still in -// development, it is disabled by default. It may be enabled by supplying the environment variable -// "DOLTGRES_ENABLE_AUTHENTICATION", or by simply setting this boolean to true. -var EnableAuthentication = false - -func init() { - if _, ok := os.LookupEnv("DOLTGRES_ENABLE_AUTHENTICATION"); ok { - EnableAuthentication = true - } +// and checks that the encrypted password is derivable from the stored encrypted password. +var EnableAuthentication = true +func InitSuperuser(password string) { auth.DropRole("doltgres") + auth.DropRole("postgres") var err error - mysql := auth.CreateDefaultRole("mysql") - mysql.CanLogin = true - mysql.Password, err = auth.NewScramSha256Password("") + postgres := auth.CreateDefaultRole("postgres") + postgres.CanLogin = true + postgres.Password, err = auth.NewScramSha256Password(password) if err != nil { panic(err) } - auth.SetRole(mysql) + auth.SetRole(postgres) + + // Postgres does not allow empty passwords, + // so we disable authentication if the superuser password is empty. + if password == "" { + EnableAuthentication = false + } } // SASLBindingFlag are the flags for gs2-cbind-flag, used in SASL authentication. @@ -111,7 +111,7 @@ func (h *ConnectionHandler) handleAuthentication(startupMessage *pgproto3.Startu } } } else { - username = "doltgres" // TODO: should we use this, or the default "postgres" since programs may default to it? + username = "postgres" host = "localhost" } h.mysqlConn.User = username @@ -119,7 +119,7 @@ func (h *ConnectionHandler) handleAuthentication(startupMessage *pgproto3.Startu User: username, Host: host, } - // Since this is all still in development, we'll check if authentication is enabled. + // Currently, regression tests disable authentication, since we can't just replay the messages due to nonces. if !EnableAuthentication { return h.send(&pgproto3.AuthenticationOk{}) } diff --git a/pgtest/server.go b/pgtest/server.go index 673b0a74..a72c0c48 100644 --- a/pgtest/server.go +++ b/pgtest/server.go @@ -54,6 +54,7 @@ func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pg pgServer, err = pgserver.NewServer( provider, "127.0.0.1", port, + "", func() *sql.Context { session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider) return sql.NewContext(context.Background(), sql.WithSession(session)) @@ -79,7 +80,7 @@ func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pg } // Since we use the in-memory DuckDB storage, we need to connect to the `memory` database - dsn := fmt.Sprintf("postgres://mysql:@127.0.0.1:%d/memory", port) + dsn := fmt.Sprintf("postgres://postgres:@127.0.0.1:%d/memory", port) conn, err = pgx.Connect(ctx, dsn) if err != nil { close()