Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support multiple statements within one query
Browse files Browse the repository at this point in the history
NoyException committed Dec 17, 2024
1 parent 167d6de commit 5b8289b
Showing 4 changed files with 129 additions and 118 deletions.
17 changes: 7 additions & 10 deletions compatibility/pg-pytools/psycopg_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from psycopg import sql
import psycopg

rows = [
@@ -13,16 +12,14 @@
with psycopg.connect("dbname=postgres user=postgres host=127.0.0.1 port=5432", autocommit=True) as conn:
# Open a cursor to perform database operations
with conn.cursor() as cur:
cur.execute("DROP SCHEMA IF EXISTS test CASCADE")
cur.execute("CREATE SCHEMA test")

cur.execute("""
CREATE TABLE test.tb1 (
id integer PRIMARY KEY,
num integer,
data text)
""")

DROP SCHEMA IF EXISTS test CASCADE;
CREATE SCHEMA test;
CREATE TABLE test.tb1 (
id integer PRIMARY KEY,
num integer,
data text)
""")

# Pass data to fill a query placeholders and let Psycopg perform the correct conversion
cur.execute(
12 changes: 6 additions & 6 deletions pgserver/connection_data.go
Original file line number Diff line number Diff line change
@@ -48,11 +48,11 @@ const (
ReadyForQueryTransactionIndicator_FailedTransactionBlock ReadyForQueryTransactionIndicator = 'E'
)

// ConvertedQuery represents a query that has been converted from the Postgres representation to the Vitess
// representation. String may contain the string version of the converted query. AST will contain the tree
// version of the converted query, and is the recommended form to use. If AST is nil, then use the String version,
// ConvertedStatement represents a statement that has been converted from the Postgres representation to the Vitess
// representation. String may contain the string version of the converted statement. AST will contain the tree
// version of the converted statement, and is the recommended form to use. If AST is nil, then use the String version,
// otherwise always prefer to AST.
type ConvertedQuery struct {
type ConvertedStatement struct {
String string
AST tree.Statement
StatementTag string
@@ -86,7 +86,7 @@ type copyFromStdinState struct {
}

type PortalData struct {
Query ConvertedQuery
Query ConvertedStatement
IsEmptyQuery bool
Fields []pgproto3.FieldDescription
ResultFormatCodes []int16
@@ -96,7 +96,7 @@ type PortalData struct {
}

type PreparedStatementData struct {
Query ConvertedQuery
Query ConvertedStatement
ReturnFields []pgproto3.FieldDescription
BindVarTypes []uint32
Stmt *duckdb.Stmt
172 changes: 93 additions & 79 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
@@ -470,62 +470,73 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages
return true, err
}

query, err := h.convertQuery(message.String)
statements, err := h.convertQuery(message.String)
if err != nil {
return true, err
}

// A query message destroys the unnamed statement and the unnamed portal
// A statements message destroys the unnamed statement and the unnamed portal
h.deletePreparedStatement("")
h.deletePortal("")

// Certain statement types get handled directly by the handler instead of being passed to the engine
handled, endOfMessages, err = h.handleQueryOutsideEngine(query)
if handled {
return endOfMessages, err
} else if err != nil {
h.logger.Warnf("Failed to handle query %v outside engine: %v", query, err)
for _, statement := range statements {
// Certain statement types get handled directly by the handler instead of being passed to the engine
handled, endOfMessages, err = h.handleStatementOutsideEngine(statement)
if handled {
if err != nil {
h.logger.Warnf("Failed to handle statement %v outside engine: %v", statement, err)
return true, err
}
} else {
if err != nil {
h.logger.Warnf("Failed to handle statement %v outside engine: %v", statement, err)
}
endOfMessages, err = true, h.runStatement(statement)
if err != nil {
return true, err
}
}
}

return true, h.query(query)
return endOfMessages, nil
}

// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being
// handleStatementOutsideEngine handles any queries that should be handled by the handler directly, rather than being
// passed to the engine. The response parameter |handled| is true if the query was handled, |endOfMessages| is true
// if no more messages are expected for this query and server should send the client a READY FOR QUERY message,
// and any error that occurred while handling the query.
func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (handled bool, endOfMessages bool, err error) {
switch stmt := query.AST.(type) {
func (h *ConnectionHandler) handleStatementOutsideEngine(statement ConvertedStatement) (handled bool, endOfMessages bool, err error) {
switch stmt := statement.AST.(type) {
case *tree.Deallocate:
// TODO: handle ALL keyword
return true, true, h.deallocatePreparedStatement(stmt.Name.String(), h.preparedStatements, query, h.Conn())
return true, true, h.deallocatePreparedStatement(stmt.Name.String(), h.preparedStatements, statement, h.Conn())
case *tree.Discard:
return true, true, h.discardAll(query)
return true, true, h.discardAll(statement)
case *tree.CopyFrom:
// When copying data from STDIN, the data is sent to the server as CopyData messages
// We send endOfMessages=false since the server will be in COPY DATA mode and won't
// be ready for more queries util COPY DATA mode is completed.
if stmt.Stdin {
return true, false, h.handleCopyFromStdinQuery(query, stmt, "")
return true, false, h.handleCopyFromStdinQuery(statement, stmt, "")
}
case *tree.CopyTo:
return true, true, h.handleCopyToStdout(query, stmt, "" /* unused */, stmt.Options.CopyFormat, "")
return true, true, h.handleCopyToStdout(statement, stmt, "" /* unused */, stmt.Options.CopyFormat, "")
}

if query.StatementTag == "COPY" {
if target, format, options, ok := ParseCopyFrom(query.String); ok {
if statement.StatementTag == "COPY" {
if target, format, options, ok := ParseCopyFrom(statement.String); ok {
stmt, err := parser.ParseOne("COPY " + target + " FROM STDIN")
if err != nil {
return false, true, err
}
copyFrom := stmt.AST.(*tree.CopyFrom)
copyFrom.Options.CopyFormat = format
return true, false, h.handleCopyFromStdinQuery(query, copyFrom, options)
return true, false, h.handleCopyFromStdinQuery(statement, copyFrom, options)
}
if subquery, format, options, ok := ParseCopyTo(query.String); ok {
if subquery, format, options, ok := ParseCopyTo(statement.String); ok {
if strings.HasPrefix(subquery, "(") && strings.HasSuffix(subquery, ")") {
// subquery may be richer than Postgres supports, so we just pass it as a string
return true, true, h.handleCopyToStdout(query, nil, subquery, format, options)
return true, true, h.handleCopyToStdout(statement, nil, subquery, format, options)
}
// subquery is "table [(column_list)]", so we can parse it and pass the AST
stmt, err := parser.ParseOne("COPY " + subquery + " TO STDOUT")
@@ -534,11 +545,11 @@ func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (hand
}
copyTo := stmt.AST.(*tree.CopyTo)
copyTo.Options.CopyFormat = format
return true, true, h.handleCopyToStdout(query, copyTo, "", format, options)
return true, true, h.handleCopyToStdout(statement, copyTo, "", format, options)
}
}

handled, err = h.handlePgCatalogQueries(query)
handled, err = h.handlePgCatalogQueries(statement)
if handled || err != nil {
return true, true, err
}
@@ -551,11 +562,13 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error {
h.waitForSync = true

// TODO: "Named prepared statements must be explicitly closed before they can be redefined by another Parse message, but this is not required for the unnamed statement"
query, err := h.convertQuery(message.Query)
statements, err := h.convertQuery(message.Query)
if err != nil {
return err
}

// TODO(Noy): handle multiple statements
query := statements[0]
if query.AST == nil {
// special case: empty query
h.preparedStatements[message.Name] = PreparedStatementData{
@@ -730,7 +743,7 @@ func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error {
}

// Certain statement types get handled directly by the handler instead of being passed to the engine
handled, _, err := h.handleQueryOutsideEngine(query)
handled, _, err := h.handleStatementOutsideEngine(query)
if handled {
return err
}
@@ -912,7 +925,7 @@ func (h *ConnectionHandler) handleCopyFail(_ *pgproto3.CopyFail) (stop bool, end
return false, true, nil
}

func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedQuery, conn net.Conn) error {
func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedStatement, conn net.Conn) error {
_, ok := preparedStatements[name]
if !ok {
return fmt.Errorf("prepared statement %s does not exist", name)
@@ -965,27 +978,27 @@ func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []
return vars, nil
}

// query runs the given query and sends a CommandComplete message to the client
func (h *ConnectionHandler) query(query ConvertedQuery) error {
h.logger.Tracef("running query %v", query)
// runStatement runs the given query and sends a CommandComplete message to the client
func (h *ConnectionHandler) runStatement(statement ConvertedStatement) error {
h.logger.Tracef("running statement %v", statement)

// |rowsAffected| gets altered by the callback below
rowsAffected := int32(0)

// Get the accurate statement tag for the query
if !query.PgParsable && !IsWellKnownStatementTag(query.StatementTag) {
tag, err := h.duckHandler.getStatementTag(h.mysqlConn, query.String)
// Get the accurate statement tag for the statement
if !statement.PgParsable && !IsWellKnownStatementTag(statement.StatementTag) {
tag, err := h.duckHandler.getStatementTag(h.mysqlConn, statement.String)
if err != nil {
return err
}
h.logger.Tracef("getting statement tag for query %v via preparing in DuckDB: %s", query, tag)
query.StatementTag = tag
h.logger.Tracef("getting statement tag for statement %v via preparing in DuckDB: %s", statement, tag)
statement.StatementTag = tag
}

if query.SubscriptionConfig != nil {
return h.executeSubscriptionSQL(query.SubscriptionConfig)
} else if query.BackupConfig != nil {
msg, err := h.executeBackup(query.BackupConfig)
if statement.SubscriptionConfig != nil {
return h.executeSubscriptionSQL(statement.SubscriptionConfig)
} else if statement.BackupConfig != nil {
msg, err := h.executeBackup(statement.BackupConfig)
if err != nil {
return err
}
@@ -994,18 +1007,18 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error {
})
}

callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false)
callback := h.spoolRowsCallback(statement.StatementTag, &rowsAffected, false)
if err := h.duckHandler.ComQuery(
context.Background(),
h.mysqlConn,
query.String,
query.AST,
statement.String,
statement.AST,
callback,
); err != nil {
return fmt.Errorf("fallback query execution failed: %w", err)
return fmt.Errorf("fallback statement execution failed: %w", err)
}

return h.send(makeCommandComplete(query.StatementTag, rowsAffected))
return h.send(makeCommandComplete(statement.StatementTag, rowsAffected))
}

// spoolRowsCallback returns a callback function that will send RowDescription message,
@@ -1076,33 +1089,33 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error)
if err != nil {
return false, err
}
return true, h.query(query)
return true, h.runStatement(query[0])
}
// Command: \l on psql 16
if statement == "select\n d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n null as \"icu rules\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" {
query, err := h.convertQuery(`select d.datname as "Name", 'postgres' as "Owner", 'UTF8' as "Encoding", 'en_US.UTF-8' as "Collate", 'en_US.UTF-8' as "Ctype", 'en-US' as "ICU Locale", case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as "locale provider", '' as "access privileges" from pg_catalog.pg_database d order by 1;`)
if err != nil {
return false, err
}
return true, h.query(query)
return true, h.runStatement(query[0])
}
// Command: \dt
if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'table' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`,
StatementTag: "SELECT",
})
}
// Command: \d
if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`,
StatementTag: "SELECT",
})
}
// Alternate \d for psql 14
if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 's' then 'special' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`,
StatementTag: "SELECT",
})
@@ -1115,29 +1128,29 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error)
}
// Command: \dn
if statement == "select n.nspname as \"name\",\n pg_catalog.pg_get_userbyid(n.nspowner) as \"owner\"\nfrom pg_catalog.pg_namespace n\nwhere n.nspname !~ '^pg_' and n.nspname <> 'information_schema'\norder by 1;" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT 'public' AS "Name", 'pg_database_owner' AS "Owner";`,
StatementTag: "SELECT",
})
}
// Command: \df
if statement == "select n.nspname as \"schema\",\n p.proname as \"name\",\n pg_catalog.pg_get_function_result(p.oid) as \"result data type\",\n pg_catalog.pg_get_function_arguments(p.oid) as \"argument data types\",\n case p.prokind\n when 'a' then 'agg'\n when 'w' then 'window'\n when 'p' then 'proc'\n else 'func'\n end as \"type\"\nfrom pg_catalog.pg_proc p\n left join pg_catalog.pg_namespace n on n.oid = p.pronamespace\nwhere pg_catalog.pg_function_is_visible(p.oid)\n and n.nspname <> 'pg_catalog'\n and n.nspname <> 'information_schema'\norder by 1, 2, 4;" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT '' AS "Schema", '' AS "Name", '' AS "Result data type", '' AS "Argument data types", '' AS "Type" LIMIT 0;`,
StatementTag: "SELECT",
})
}
// Command: \dv
if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relkind in ('v','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'view' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'VIEW' ORDER BY 2;`,
StatementTag: "SELECT",
})
}
// Command: \du
if statement == "select r.rolname, r.rolsuper, r.rolinherit,\n r.rolcreaterole, r.rolcreatedb, r.rolcanlogin,\n r.rolconnlimit, r.rolvaliduntil,\n array(select b.rolname\n from pg_catalog.pg_auth_members m\n join pg_catalog.pg_roles b on (m.roleid = b.oid)\n where m.member = r.oid) as memberof\n, r.rolreplication\n, r.rolbypassrls\nfrom pg_catalog.pg_roles r\nwhere r.rolname !~ '^pg_'\norder by 1;" {
// We don't support users yet, so we'll just return nothing for now
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT '' FROM dual LIMIT 0;`,
StatementTag: "SELECT",
})
@@ -1149,13 +1162,13 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error)
func (h *ConnectionHandler) handledWorkbenchCommands(statement string) (bool, error) {
lower := strings.ToLower(statement)
if lower == "select * from current_schema()" || lower == "select * from current_schema();" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT search_path AS "current_schema";`,
StatementTag: "SELECT",
})
}
if lower == "select * from current_database()" || lower == "select * from current_database();" {
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: `SELECT DATABASE() AS "current_database";`,
StatementTag: "SELECT",
})
@@ -1193,8 +1206,8 @@ func (h *ConnectionHandler) sendError(err error) {
}
}

// convertQuery takes the given Postgres query, and converts it as an ast.ConvertedQuery that will work with the handler.
func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifier) (ConvertedQuery, error) {
// convertQuery takes the given Postgres query, and converts it as a list of ast.ConvertedStatement that will work with the handler.
func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifier) ([]ConvertedStatement, error) {
for _, modifier := range modifiers {
query = modifier(query)
}
@@ -1204,21 +1217,21 @@ func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifie
// Check if the query is a subscription query, and if so, parse it as a subscription query.
subscriptionConfig, err := parseSubscriptionSQL(query)
if subscriptionConfig != nil && err == nil {
return ConvertedQuery{
return []ConvertedStatement{{
String: query,
PgParsable: true,
SubscriptionConfig: subscriptionConfig,
}, nil
}}, nil
}

// Check if the query is a backup query, and if so, parse it as a backup query.
backupConfig, err := parseBackupSQL(query)
if backupConfig != nil && err == nil {
return ConvertedQuery{
return []ConvertedStatement{{
String: query,
PgParsable: true,
BackupConfig: backupConfig,
}, nil
}}, nil
}

stmts, err := parser.Parse(query)
@@ -1228,30 +1241,31 @@ func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifie
stmts, _ = parser.Parse("SELECT 'SQL syntax is incompatible with PostgreSQL' AS error")
}

if len(stmts) > 1 {
return ConvertedQuery{}, fmt.Errorf("only a single statement at a time is currently supported")
}
//if len(stmts) > 1 {
// return []ConvertedStatement{}, fmt.Errorf("only a single statement at a time is currently supported")
//}
if len(stmts) == 0 {
return ConvertedQuery{String: query}, nil
return []ConvertedStatement{{String: query}}, nil
}

var stmtTag string
if parsable {
stmtTag = stmts[0].AST.StatementTag()
} else {
stmtTag = GuessStatementTag(query)
convertedStmts := make([]ConvertedStatement, len(stmts))
logrus.Warnf("String: %s", query)
for i, stmt := range stmts {
logrus.Warnf("SQL: %s", stmt.SQL)
convertedStmts[i].String = stmt.SQL
convertedStmts[i].AST = stmt.AST
if parsable {
convertedStmts[i].StatementTag = stmt.AST.StatementTag()
} else {
convertedStmts[i].StatementTag = GuessStatementTag(stmt.SQL)
}
convertedStmts[i].PgParsable = parsable
}

return ConvertedQuery{
String: query,
AST: stmts[0].AST,
StatementTag: stmtTag,
PgParsable: parsable,
}, nil
return convertedStmts, nil
}

// discardAll handles the DISCARD ALL command
func (h *ConnectionHandler) discardAll(query ConvertedQuery) error {
func (h *ConnectionHandler) discardAll(query ConvertedStatement) error {
h.closeBackendConn()

return h.send(&pgproto3.CommandComplete{
@@ -1263,7 +1277,7 @@ func (h *ConnectionHandler) discardAll(query ConvertedQuery) error {
// COPY FROM STDIN can't be handled directly by the GMS engine, since COPY FROM STDIN relies on multiple messages sent
// over the wire.
func (h *ConnectionHandler) handleCopyFromStdinQuery(
query ConvertedQuery, copyFrom *tree.CopyFrom,
query ConvertedStatement, copyFrom *tree.CopyFrom,
rawOptions string, // For non-PG-parseable COPY FROM
) error {
sqlCtx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String)
@@ -1327,7 +1341,7 @@ func returnsRow(tag string) bool {
}
}

func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo, subquery string, format tree.CopyFormat, rawOptions string) error {
func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo *tree.CopyTo, subquery string, format tree.CopyFormat, rawOptions string) error {
ctx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String)
if err != nil {
return err
46 changes: 23 additions & 23 deletions pgserver/pg_catalog_handler.go
Original file line number Diff line number Diff line change
@@ -130,7 +130,7 @@ func (h *ConnectionHandler) handleIsInRecovery() (bool, error) {
if err != nil {
return false, err
}
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: fmt.Sprintf(`SELECT '%s' AS "pg_is_in_recovery";`, isInRecovery),
StatementTag: "SELECT",
})
@@ -142,14 +142,14 @@ func (h *ConnectionHandler) handleWALSN() (bool, error) {
if err != nil {
return false, err
}
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: fmt.Sprintf(`SELECT '%s' AS "%s";`, lsnStr, "pg_current_wal_lsn"),
StatementTag: "SELECT",
})
}

// handler for currentSetting
func (h *ConnectionHandler) handleCurrentSetting(query ConvertedQuery) (bool, error) {
func (h *ConnectionHandler) handleCurrentSetting(query ConvertedStatement) (bool, error) {
sql := RemoveComments(query.String)
matches := currentSettingRegex.FindStringSubmatch(sql)
if len(matches) != 3 {
@@ -159,38 +159,38 @@ func (h *ConnectionHandler) handleCurrentSetting(query ConvertedQuery) (bool, er
if err != nil {
return false, err
}
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: fmt.Sprintf(`SELECT '%s' AS "current_setting";`, fmt.Sprintf("%v", setting)),
StatementTag: "SELECT",
})
}

// handler for pgCatalog
func (h *ConnectionHandler) handlePgCatalog(query ConvertedQuery) (bool, error) {
func (h *ConnectionHandler) handlePgCatalog(query ConvertedStatement) (bool, error) {
sql := RemoveComments(query.String)
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__.$1"),
StatementTag: "SELECT",
})
}

type PGCatalogHandler struct {
// HandledInPlace is a function that determines if the query should be handled in place and not passed to the engine.
HandledInPlace func(ConvertedQuery) (bool, error)
Handler func(*ConnectionHandler, ConvertedQuery) (bool, error)
HandledInPlace func(ConvertedStatement) (bool, error)
Handler func(*ConnectionHandler, ConvertedStatement) (bool, error)
}

func isPgIsInRecovery(query ConvertedQuery) bool {
func isPgIsInRecovery(query ConvertedStatement) bool {
sql := RemoveComments(query.String)
return pgIsInRecoveryRegex.MatchString(sql)
}

func isPgWALSN(query ConvertedQuery) bool {
func isPgWALSN(query ConvertedStatement) bool {
sql := RemoveComments(query.String)
return pgWALLSNRegex.MatchString(sql)
}

func isPgCurrentSetting(query ConvertedQuery) bool {
func isPgCurrentSetting(query ConvertedStatement) bool {
sql := RemoveComments(query.String)
if !currentSettingRegex.MatchString(sql) {
return false
@@ -206,15 +206,15 @@ func isPgCurrentSetting(query ConvertedQuery) bool {
return true
}

func isSpecialPgCatalog(query ConvertedQuery) bool {
func isSpecialPgCatalog(query ConvertedStatement) bool {
sql := RemoveComments(query.String)
return pgCatalogRegex.MatchString(sql)
}

// The key is the statement tag of the query.
var pgCatalogHandlers = map[string]PGCatalogHandler{
"SELECT": {
HandledInPlace: func(query ConvertedQuery) (bool, error) {
HandledInPlace: func(query ConvertedStatement) (bool, error) {
// TODO(sean): Evaluate the conditions by iterating over the AST.
if isPgIsInRecovery(query) {
return true, nil
@@ -230,7 +230,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
}
return false, nil
},
Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) {
Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) {
if isPgIsInRecovery(query) {
return h.handleIsInRecovery()
}
@@ -248,14 +248,14 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
},
},
"SHOW": {
HandledInPlace: func(query ConvertedQuery) (bool, error) {
HandledInPlace: func(query ConvertedStatement) (bool, error) {
switch query.AST.(type) {
case *tree.ShowVar:
return true, nil
}
return false, nil
},
Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) {
Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) {
showVar, ok := query.AST.(*tree.ShowVar)
if !ok {
return false, nil
@@ -266,7 +266,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
if err != nil {
return false, err
}
return true, h.query(ConvertedQuery{
return true, h.runStatement(ConvertedStatement{
String: fmt.Sprintf(`SELECT '%s' AS "%s";`, fmt.Sprintf("%v", setting), key),
StatementTag: "SELECT",
})
@@ -281,7 +281,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
},
},
"SET": {
HandledInPlace: func(query ConvertedQuery) (bool, error) {
HandledInPlace: func(query ConvertedStatement) (bool, error) {
switch stmt := query.AST.(type) {
case *tree.SetVar:
key := strings.ToLower(stmt.Name)
@@ -301,7 +301,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
}
return false, nil
},
Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) {
Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) {
setVar, ok := query.AST.(*tree.SetVar)
if !ok {
return false, fmt.Errorf("error: invalid set statement: %v", query.String)
@@ -337,7 +337,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
},
},
"RESET": {
HandledInPlace: func(query ConvertedQuery) (bool, error) {
HandledInPlace: func(query ConvertedStatement) (bool, error) {
switch stmt := query.AST.(type) {
case *tree.SetVar:
if !stmt.Reset && !stmt.ResetAll {
@@ -351,7 +351,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
}
return false, nil
},
Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) {
Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) {
resetVar, ok := query.AST.(*tree.SetVar)
if !ok || (!resetVar.Reset && !resetVar.ResetAll) {
return false, fmt.Errorf("error: invalid reset statement: %v", query.String)
@@ -378,7 +378,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
// shouldQueryBeHandledInPlace determines whether a query should be handled in place, rather than being
// passed to the engine. This is useful for queries that are not supported by the engine, or that require
// special handling.
func shouldQueryBeHandledInPlace(sql ConvertedQuery) (bool, error) {
func shouldQueryBeHandledInPlace(sql ConvertedStatement) (bool, error) {
handler, ok := pgCatalogHandlers[sql.StatementTag]
if !ok {
return false, nil
@@ -392,7 +392,7 @@ func shouldQueryBeHandledInPlace(sql ConvertedQuery) (bool, error) {

// TODO(sean): This is a temporary work around for clients that query the views from schema 'pg_catalog'.
// Remove this once we add the views for 'pg_catalog'.
func (h *ConnectionHandler) handlePgCatalogQueries(sql ConvertedQuery) (bool, error) {
func (h *ConnectionHandler) handlePgCatalogQueries(sql ConvertedStatement) (bool, error) {
handler, ok := pgCatalogHandlers[sql.StatementTag]
if !ok {
return false, nil

0 comments on commit 5b8289b

Please sign in to comment.