diff --git a/compatibility/pg-pytools/psycopg_test.py b/compatibility/pg-pytools/psycopg_test.py index 0578c912..b08b41d1 100644 --- a/compatibility/pg-pytools/psycopg_test.py +++ b/compatibility/pg-pytools/psycopg_test.py @@ -1,4 +1,3 @@ -from psycopg import sql import psycopg rows = [ @@ -13,16 +12,14 @@ with psycopg.connect("dbname=postgres user=postgres host=127.0.0.1 port=5432", autocommit=True) as conn: # Open a cursor to perform database operations with conn.cursor() as cur: - cur.execute("DROP SCHEMA IF EXISTS test CASCADE") - cur.execute("CREATE SCHEMA test") - cur.execute(""" - CREATE TABLE test.tb1 ( - id integer PRIMARY KEY, - num integer, - data text) - """) - + DROP SCHEMA IF EXISTS test CASCADE; + CREATE SCHEMA test; + CREATE TABLE test.tb1 ( + id integer PRIMARY KEY, + num integer, + data text) + """) # Pass data to fill a query placeholders and let Psycopg perform the correct conversion cur.execute( diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index f5a596e4..42cf7560 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -48,11 +48,11 @@ const ( ReadyForQueryTransactionIndicator_FailedTransactionBlock ReadyForQueryTransactionIndicator = 'E' ) -// ConvertedQuery represents a query that has been converted from the Postgres representation to the Vitess -// representation. String may contain the string version of the converted query. AST will contain the tree -// version of the converted query, and is the recommended form to use. If AST is nil, then use the String version, +// ConvertedStatement represents a statement that has been converted from the Postgres representation to the Vitess +// representation. String may contain the string version of the converted statement. AST will contain the tree +// version of the converted statement, and is the recommended form to use. If AST is nil, then use the String version, // otherwise always prefer to AST. -type ConvertedQuery struct { +type ConvertedStatement struct { String string AST tree.Statement StatementTag string @@ -86,7 +86,7 @@ type copyFromStdinState struct { } type PortalData struct { - Query ConvertedQuery + Query ConvertedStatement IsEmptyQuery bool Fields []pgproto3.FieldDescription ResultFormatCodes []int16 @@ -96,7 +96,7 @@ type PortalData struct { } type PreparedStatementData struct { - Query ConvertedQuery + Query ConvertedStatement ReturnFields []pgproto3.FieldDescription BindVarTypes []uint32 Stmt *duckdb.Stmt diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 3b6e3164..7a4fc85e 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -470,62 +470,73 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages return true, err } - query, err := h.convertQuery(message.String) + statements, err := h.convertQuery(message.String) if err != nil { return true, err } - // A query message destroys the unnamed statement and the unnamed portal + // A statements message destroys the unnamed statement and the unnamed portal h.deletePreparedStatement("") h.deletePortal("") - // Certain statement types get handled directly by the handler instead of being passed to the engine - handled, endOfMessages, err = h.handleQueryOutsideEngine(query) - if handled { - return endOfMessages, err - } else if err != nil { - h.logger.Warnf("Failed to handle query %v outside engine: %v", query, err) + for _, statement := range statements { + // Certain statement types get handled directly by the handler instead of being passed to the engine + handled, endOfMessages, err = h.handleStatementOutsideEngine(statement) + if handled { + if err != nil { + h.logger.Warnf("Failed to handle statement %v outside engine: %v", statement, err) + return true, err + } + } else { + if err != nil { + h.logger.Warnf("Failed to handle statement %v outside engine: %v", statement, err) + } + endOfMessages, err = true, h.runStatement(statement) + if err != nil { + return true, err + } + } } - return true, h.query(query) + return endOfMessages, nil } -// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being +// handleStatementOutsideEngine handles any queries that should be handled by the handler directly, rather than being // passed to the engine. The response parameter |handled| is true if the query was handled, |endOfMessages| is true // if no more messages are expected for this query and server should send the client a READY FOR QUERY message, // and any error that occurred while handling the query. -func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (handled bool, endOfMessages bool, err error) { - switch stmt := query.AST.(type) { +func (h *ConnectionHandler) handleStatementOutsideEngine(statement ConvertedStatement) (handled bool, endOfMessages bool, err error) { + switch stmt := statement.AST.(type) { case *tree.Deallocate: // TODO: handle ALL keyword - return true, true, h.deallocatePreparedStatement(stmt.Name.String(), h.preparedStatements, query, h.Conn()) + return true, true, h.deallocatePreparedStatement(stmt.Name.String(), h.preparedStatements, statement, h.Conn()) case *tree.Discard: - return true, true, h.discardAll(query) + return true, true, h.discardAll(statement) case *tree.CopyFrom: // When copying data from STDIN, the data is sent to the server as CopyData messages // We send endOfMessages=false since the server will be in COPY DATA mode and won't // be ready for more queries util COPY DATA mode is completed. if stmt.Stdin { - return true, false, h.handleCopyFromStdinQuery(query, stmt, "") + return true, false, h.handleCopyFromStdinQuery(statement, stmt, "") } case *tree.CopyTo: - return true, true, h.handleCopyToStdout(query, stmt, "" /* unused */, stmt.Options.CopyFormat, "") + return true, true, h.handleCopyToStdout(statement, stmt, "" /* unused */, stmt.Options.CopyFormat, "") } - if query.StatementTag == "COPY" { - if target, format, options, ok := ParseCopyFrom(query.String); ok { + if statement.StatementTag == "COPY" { + if target, format, options, ok := ParseCopyFrom(statement.String); ok { stmt, err := parser.ParseOne("COPY " + target + " FROM STDIN") if err != nil { return false, true, err } copyFrom := stmt.AST.(*tree.CopyFrom) copyFrom.Options.CopyFormat = format - return true, false, h.handleCopyFromStdinQuery(query, copyFrom, options) + return true, false, h.handleCopyFromStdinQuery(statement, copyFrom, options) } - if subquery, format, options, ok := ParseCopyTo(query.String); ok { + if subquery, format, options, ok := ParseCopyTo(statement.String); ok { if strings.HasPrefix(subquery, "(") && strings.HasSuffix(subquery, ")") { // subquery may be richer than Postgres supports, so we just pass it as a string - return true, true, h.handleCopyToStdout(query, nil, subquery, format, options) + return true, true, h.handleCopyToStdout(statement, nil, subquery, format, options) } // subquery is "table [(column_list)]", so we can parse it and pass the AST stmt, err := parser.ParseOne("COPY " + subquery + " TO STDOUT") @@ -534,11 +545,11 @@ func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (hand } copyTo := stmt.AST.(*tree.CopyTo) copyTo.Options.CopyFormat = format - return true, true, h.handleCopyToStdout(query, copyTo, "", format, options) + return true, true, h.handleCopyToStdout(statement, copyTo, "", format, options) } } - handled, err = h.handlePgCatalogQueries(query) + handled, err = h.handlePgCatalogQueries(statement) if handled || err != nil { return true, true, err } @@ -551,11 +562,13 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { h.waitForSync = true // TODO: "Named prepared statements must be explicitly closed before they can be redefined by another Parse message, but this is not required for the unnamed statement" - query, err := h.convertQuery(message.Query) + statements, err := h.convertQuery(message.Query) if err != nil { return err } + // TODO(Noy): handle multiple statements + query := statements[0] if query.AST == nil { // special case: empty query h.preparedStatements[message.Name] = PreparedStatementData{ @@ -730,7 +743,7 @@ func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error { } // Certain statement types get handled directly by the handler instead of being passed to the engine - handled, _, err := h.handleQueryOutsideEngine(query) + handled, _, err := h.handleStatementOutsideEngine(query) if handled { return err } @@ -912,7 +925,7 @@ func (h *ConnectionHandler) handleCopyFail(_ *pgproto3.CopyFail) (stop bool, end return false, true, nil } -func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedQuery, conn net.Conn) error { +func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedStatement, conn net.Conn) error { _, ok := preparedStatements[name] if !ok { return fmt.Errorf("prepared statement %s does not exist", name) @@ -965,27 +978,27 @@ func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes [] return vars, nil } -// query runs the given query and sends a CommandComplete message to the client -func (h *ConnectionHandler) query(query ConvertedQuery) error { - h.logger.Tracef("running query %v", query) +// runStatement runs the given query and sends a CommandComplete message to the client +func (h *ConnectionHandler) runStatement(statement ConvertedStatement) error { + h.logger.Tracef("running statement %v", statement) // |rowsAffected| gets altered by the callback below rowsAffected := int32(0) - // Get the accurate statement tag for the query - if !query.PgParsable && !IsWellKnownStatementTag(query.StatementTag) { - tag, err := h.duckHandler.getStatementTag(h.mysqlConn, query.String) + // Get the accurate statement tag for the statement + if !statement.PgParsable && !IsWellKnownStatementTag(statement.StatementTag) { + tag, err := h.duckHandler.getStatementTag(h.mysqlConn, statement.String) if err != nil { return err } - h.logger.Tracef("getting statement tag for query %v via preparing in DuckDB: %s", query, tag) - query.StatementTag = tag + h.logger.Tracef("getting statement tag for statement %v via preparing in DuckDB: %s", statement, tag) + statement.StatementTag = tag } - if query.SubscriptionConfig != nil { - return h.executeSubscriptionSQL(query.SubscriptionConfig) - } else if query.BackupConfig != nil { - msg, err := h.executeBackup(query.BackupConfig) + if statement.SubscriptionConfig != nil { + return h.executeSubscriptionSQL(statement.SubscriptionConfig) + } else if statement.BackupConfig != nil { + msg, err := h.executeBackup(statement.BackupConfig) if err != nil { return err } @@ -994,18 +1007,18 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { }) } - callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false) + callback := h.spoolRowsCallback(statement.StatementTag, &rowsAffected, false) if err := h.duckHandler.ComQuery( context.Background(), h.mysqlConn, - query.String, - query.AST, + statement.String, + statement.AST, callback, ); err != nil { - return fmt.Errorf("fallback query execution failed: %w", err) + return fmt.Errorf("fallback statement execution failed: %w", err) } - return h.send(makeCommandComplete(query.StatementTag, rowsAffected)) + return h.send(makeCommandComplete(statement.StatementTag, rowsAffected)) } // spoolRowsCallback returns a callback function that will send RowDescription message, @@ -1076,7 +1089,7 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) if err != nil { return false, err } - return true, h.query(query) + return true, h.runStatement(query[0]) } // Command: \l on psql 16 if statement == "select\n d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n null as \"icu rules\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" { @@ -1084,25 +1097,25 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) if err != nil { return false, err } - return true, h.query(query) + return true, h.runStatement(query[0]) } // Command: \dt if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'table' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, StatementTag: "SELECT", }) } // Command: \d if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, StatementTag: "SELECT", }) } // Alternate \d for psql 14 if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 's' then 'special' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, StatementTag: "SELECT", }) @@ -1115,21 +1128,21 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) } // Command: \dn if statement == "select n.nspname as \"name\",\n pg_catalog.pg_get_userbyid(n.nspowner) as \"owner\"\nfrom pg_catalog.pg_namespace n\nwhere n.nspname !~ '^pg_' and n.nspname <> 'information_schema'\norder by 1;" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT 'public' AS "Name", 'pg_database_owner' AS "Owner";`, StatementTag: "SELECT", }) } // Command: \df if statement == "select n.nspname as \"schema\",\n p.proname as \"name\",\n pg_catalog.pg_get_function_result(p.oid) as \"result data type\",\n pg_catalog.pg_get_function_arguments(p.oid) as \"argument data types\",\n case p.prokind\n when 'a' then 'agg'\n when 'w' then 'window'\n when 'p' then 'proc'\n else 'func'\n end as \"type\"\nfrom pg_catalog.pg_proc p\n left join pg_catalog.pg_namespace n on n.oid = p.pronamespace\nwhere pg_catalog.pg_function_is_visible(p.oid)\n and n.nspname <> 'pg_catalog'\n and n.nspname <> 'information_schema'\norder by 1, 2, 4;" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT '' AS "Schema", '' AS "Name", '' AS "Result data type", '' AS "Argument data types", '' AS "Type" LIMIT 0;`, StatementTag: "SELECT", }) } // Command: \dv if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relkind in ('v','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'view' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'VIEW' ORDER BY 2;`, StatementTag: "SELECT", }) @@ -1137,7 +1150,7 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) // Command: \du if statement == "select r.rolname, r.rolsuper, r.rolinherit,\n r.rolcreaterole, r.rolcreatedb, r.rolcanlogin,\n r.rolconnlimit, r.rolvaliduntil,\n array(select b.rolname\n from pg_catalog.pg_auth_members m\n join pg_catalog.pg_roles b on (m.roleid = b.oid)\n where m.member = r.oid) as memberof\n, r.rolreplication\n, r.rolbypassrls\nfrom pg_catalog.pg_roles r\nwhere r.rolname !~ '^pg_'\norder by 1;" { // We don't support users yet, so we'll just return nothing for now - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT '' FROM dual LIMIT 0;`, StatementTag: "SELECT", }) @@ -1149,13 +1162,13 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) func (h *ConnectionHandler) handledWorkbenchCommands(statement string) (bool, error) { lower := strings.ToLower(statement) if lower == "select * from current_schema()" || lower == "select * from current_schema();" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT search_path AS "current_schema";`, StatementTag: "SELECT", }) } if lower == "select * from current_database()" || lower == "select * from current_database();" { - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: `SELECT DATABASE() AS "current_database";`, StatementTag: "SELECT", }) @@ -1193,8 +1206,8 @@ func (h *ConnectionHandler) sendError(err error) { } } -// convertQuery takes the given Postgres query, and converts it as an ast.ConvertedQuery that will work with the handler. -func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifier) (ConvertedQuery, error) { +// convertQuery takes the given Postgres query, and converts it as a list of ast.ConvertedStatement that will work with the handler. +func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifier) ([]ConvertedStatement, error) { for _, modifier := range modifiers { query = modifier(query) } @@ -1204,21 +1217,21 @@ func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifie // Check if the query is a subscription query, and if so, parse it as a subscription query. subscriptionConfig, err := parseSubscriptionSQL(query) if subscriptionConfig != nil && err == nil { - return ConvertedQuery{ + return []ConvertedStatement{{ String: query, PgParsable: true, SubscriptionConfig: subscriptionConfig, - }, nil + }}, nil } // Check if the query is a backup query, and if so, parse it as a backup query. backupConfig, err := parseBackupSQL(query) if backupConfig != nil && err == nil { - return ConvertedQuery{ + return []ConvertedStatement{{ String: query, PgParsable: true, BackupConfig: backupConfig, - }, nil + }}, nil } stmts, err := parser.Parse(query) @@ -1228,30 +1241,31 @@ func (h *ConnectionHandler) convertQuery(query string, modifiers ...QueryModifie stmts, _ = parser.Parse("SELECT 'SQL syntax is incompatible with PostgreSQL' AS error") } - if len(stmts) > 1 { - return ConvertedQuery{}, fmt.Errorf("only a single statement at a time is currently supported") - } + //if len(stmts) > 1 { + // return []ConvertedStatement{}, fmt.Errorf("only a single statement at a time is currently supported") + //} if len(stmts) == 0 { - return ConvertedQuery{String: query}, nil + return []ConvertedStatement{{String: query}}, nil } - var stmtTag string - if parsable { - stmtTag = stmts[0].AST.StatementTag() - } else { - stmtTag = GuessStatementTag(query) + convertedStmts := make([]ConvertedStatement, len(stmts)) + logrus.Warnf("String: %s", query) + for i, stmt := range stmts { + logrus.Warnf("SQL: %s", stmt.SQL) + convertedStmts[i].String = stmt.SQL + convertedStmts[i].AST = stmt.AST + if parsable { + convertedStmts[i].StatementTag = stmt.AST.StatementTag() + } else { + convertedStmts[i].StatementTag = GuessStatementTag(stmt.SQL) + } + convertedStmts[i].PgParsable = parsable } - - return ConvertedQuery{ - String: query, - AST: stmts[0].AST, - StatementTag: stmtTag, - PgParsable: parsable, - }, nil + return convertedStmts, nil } // discardAll handles the DISCARD ALL command -func (h *ConnectionHandler) discardAll(query ConvertedQuery) error { +func (h *ConnectionHandler) discardAll(query ConvertedStatement) error { h.closeBackendConn() return h.send(&pgproto3.CommandComplete{ @@ -1263,7 +1277,7 @@ func (h *ConnectionHandler) discardAll(query ConvertedQuery) error { // COPY FROM STDIN can't be handled directly by the GMS engine, since COPY FROM STDIN relies on multiple messages sent // over the wire. func (h *ConnectionHandler) handleCopyFromStdinQuery( - query ConvertedQuery, copyFrom *tree.CopyFrom, + query ConvertedStatement, copyFrom *tree.CopyFrom, rawOptions string, // For non-PG-parseable COPY FROM ) error { sqlCtx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String) @@ -1327,7 +1341,7 @@ func returnsRow(tag string) bool { } } -func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo, subquery string, format tree.CopyFormat, rawOptions string) error { +func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo *tree.CopyTo, subquery string, format tree.CopyFormat, rawOptions string) error { ctx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String) if err != nil { return err diff --git a/pgserver/pg_catalog_handler.go b/pgserver/pg_catalog_handler.go index 026b4be4..578a6ca6 100644 --- a/pgserver/pg_catalog_handler.go +++ b/pgserver/pg_catalog_handler.go @@ -130,7 +130,7 @@ func (h *ConnectionHandler) handleIsInRecovery() (bool, error) { if err != nil { return false, err } - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: fmt.Sprintf(`SELECT '%s' AS "pg_is_in_recovery";`, isInRecovery), StatementTag: "SELECT", }) @@ -142,14 +142,14 @@ func (h *ConnectionHandler) handleWALSN() (bool, error) { if err != nil { return false, err } - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: fmt.Sprintf(`SELECT '%s' AS "%s";`, lsnStr, "pg_current_wal_lsn"), StatementTag: "SELECT", }) } // handler for currentSetting -func (h *ConnectionHandler) handleCurrentSetting(query ConvertedQuery) (bool, error) { +func (h *ConnectionHandler) handleCurrentSetting(query ConvertedStatement) (bool, error) { sql := RemoveComments(query.String) matches := currentSettingRegex.FindStringSubmatch(sql) if len(matches) != 3 { @@ -159,16 +159,16 @@ func (h *ConnectionHandler) handleCurrentSetting(query ConvertedQuery) (bool, er if err != nil { return false, err } - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: fmt.Sprintf(`SELECT '%s' AS "current_setting";`, fmt.Sprintf("%v", setting)), StatementTag: "SELECT", }) } // handler for pgCatalog -func (h *ConnectionHandler) handlePgCatalog(query ConvertedQuery) (bool, error) { +func (h *ConnectionHandler) handlePgCatalog(query ConvertedStatement) (bool, error) { sql := RemoveComments(query.String) - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__.$1"), StatementTag: "SELECT", }) @@ -176,21 +176,21 @@ func (h *ConnectionHandler) handlePgCatalog(query ConvertedQuery) (bool, error) type PGCatalogHandler struct { // HandledInPlace is a function that determines if the query should be handled in place and not passed to the engine. - HandledInPlace func(ConvertedQuery) (bool, error) - Handler func(*ConnectionHandler, ConvertedQuery) (bool, error) + HandledInPlace func(ConvertedStatement) (bool, error) + Handler func(*ConnectionHandler, ConvertedStatement) (bool, error) } -func isPgIsInRecovery(query ConvertedQuery) bool { +func isPgIsInRecovery(query ConvertedStatement) bool { sql := RemoveComments(query.String) return pgIsInRecoveryRegex.MatchString(sql) } -func isPgWALSN(query ConvertedQuery) bool { +func isPgWALSN(query ConvertedStatement) bool { sql := RemoveComments(query.String) return pgWALLSNRegex.MatchString(sql) } -func isPgCurrentSetting(query ConvertedQuery) bool { +func isPgCurrentSetting(query ConvertedStatement) bool { sql := RemoveComments(query.String) if !currentSettingRegex.MatchString(sql) { return false @@ -206,7 +206,7 @@ func isPgCurrentSetting(query ConvertedQuery) bool { return true } -func isSpecialPgCatalog(query ConvertedQuery) bool { +func isSpecialPgCatalog(query ConvertedStatement) bool { sql := RemoveComments(query.String) return pgCatalogRegex.MatchString(sql) } @@ -214,7 +214,7 @@ func isSpecialPgCatalog(query ConvertedQuery) bool { // The key is the statement tag of the query. var pgCatalogHandlers = map[string]PGCatalogHandler{ "SELECT": { - HandledInPlace: func(query ConvertedQuery) (bool, error) { + HandledInPlace: func(query ConvertedStatement) (bool, error) { // TODO(sean): Evaluate the conditions by iterating over the AST. if isPgIsInRecovery(query) { return true, nil @@ -230,7 +230,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { if isPgIsInRecovery(query) { return h.handleIsInRecovery() } @@ -248,14 +248,14 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ }, }, "SHOW": { - HandledInPlace: func(query ConvertedQuery) (bool, error) { + HandledInPlace: func(query ConvertedStatement) (bool, error) { switch query.AST.(type) { case *tree.ShowVar: return true, nil } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { showVar, ok := query.AST.(*tree.ShowVar) if !ok { return false, nil @@ -266,7 +266,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ if err != nil { return false, err } - return true, h.query(ConvertedQuery{ + return true, h.runStatement(ConvertedStatement{ String: fmt.Sprintf(`SELECT '%s' AS "%s";`, fmt.Sprintf("%v", setting), key), StatementTag: "SELECT", }) @@ -281,7 +281,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ }, }, "SET": { - HandledInPlace: func(query ConvertedQuery) (bool, error) { + HandledInPlace: func(query ConvertedStatement) (bool, error) { switch stmt := query.AST.(type) { case *tree.SetVar: key := strings.ToLower(stmt.Name) @@ -301,7 +301,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { setVar, ok := query.AST.(*tree.SetVar) if !ok { return false, fmt.Errorf("error: invalid set statement: %v", query.String) @@ -337,7 +337,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ }, }, "RESET": { - HandledInPlace: func(query ConvertedQuery) (bool, error) { + HandledInPlace: func(query ConvertedStatement) (bool, error) { switch stmt := query.AST.(type) { case *tree.SetVar: if !stmt.Reset && !stmt.ResetAll { @@ -351,7 +351,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ } return false, nil }, - Handler: func(h *ConnectionHandler, query ConvertedQuery) (bool, error) { + Handler: func(h *ConnectionHandler, query ConvertedStatement) (bool, error) { resetVar, ok := query.AST.(*tree.SetVar) if !ok || (!resetVar.Reset && !resetVar.ResetAll) { return false, fmt.Errorf("error: invalid reset statement: %v", query.String) @@ -378,7 +378,7 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ // shouldQueryBeHandledInPlace determines whether a query should be handled in place, rather than being // passed to the engine. This is useful for queries that are not supported by the engine, or that require // special handling. -func shouldQueryBeHandledInPlace(sql ConvertedQuery) (bool, error) { +func shouldQueryBeHandledInPlace(sql ConvertedStatement) (bool, error) { handler, ok := pgCatalogHandlers[sql.StatementTag] if !ok { return false, nil @@ -392,7 +392,7 @@ func shouldQueryBeHandledInPlace(sql ConvertedQuery) (bool, error) { // TODO(sean): This is a temporary work around for clients that query the views from schema 'pg_catalog'. // Remove this once we add the views for 'pg_catalog'. -func (h *ConnectionHandler) handlePgCatalogQueries(sql ConvertedQuery) (bool, error) { +func (h *ConnectionHandler) handlePgCatalogQueries(sql ConvertedStatement) (bool, error) { handler, ok := pgCatalogHandlers[sql.StatementTag] if !ok { return false, nil