From 333f10a93d31562cb5a1efd1a95977a19de0b35e Mon Sep 17 00:00:00 2001 From: Sean Wu <111744549+VWagen1989@users.noreply.github.com> Date: Thu, 28 Nov 2024 18:55:10 +0800 Subject: [PATCH] fix: adopt CR feedback --- pgserver/connection_data.go | 20 +++-- pgserver/connection_handler.go | 139 ++++++++++++++++----------------- pgserver/duck_handler.go | 3 +- pgserver/pg_catalog_handler.go | 7 +- 4 files changed, 82 insertions(+), 87 deletions(-) diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index e16cf5ad..eda6a774 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -85,20 +85,18 @@ type copyFromStdinState struct { } type PortalData struct { - Query ConvertedQuery - IsEmptyQuery bool - Fields []pgproto3.FieldDescription - Stmt *duckdb.Stmt - Vars []any - HandledOutsideEngine bool + Query ConvertedQuery + IsEmptyQuery bool + Fields []pgproto3.FieldDescription + Stmt *duckdb.Stmt + Vars []any } type PreparedStatementData struct { - Query ConvertedQuery - ReturnFields []pgproto3.FieldDescription - BindVarTypes []uint32 - Stmt *duckdb.Stmt - HandledOutsideEngine bool + Query ConvertedQuery + ReturnFields []pgproto3.FieldDescription + BindVarTypes []uint32 + Stmt *duckdb.Stmt } // VitessTypeToObjectID returns a type, as defined by Vitess, into a type as defined by Postgres. diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index cb354632..6d49f23b 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -567,50 +567,48 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { if err != nil { return err } - - if !handledOutsideEngine { - stmt, params, fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST) - if err != nil { - return err + if handledOutsideEngine { + h.preparedStatements[message.Name] = PreparedStatementData{ + Query: query, + ReturnFields: nil, + BindVarTypes: nil, + Stmt: nil, } + return h.send(&pgproto3.ParseComplete{}) + } - if !query.PgParsable { - query.StatementTag = GetStatementTag(stmt) - } + stmt, params, fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST) + if err != nil { + return err + } - // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY - // > A parameter data type can be left unspecified by setting it to zero, - // > or by making the array of parameter type OIDs shorter than the number of - // > parameter symbols ($n)used in the query string. - // > ... - // > Parameter data types can be specified by OID; - // > if not given, the parser attempts to infer the data types in the same way - // > as it would do for untyped literal string constants. - bindVarTypes := message.ParameterOIDs - if len(bindVarTypes) < len(params) { - bindVarTypes = append(bindVarTypes, params[len(bindVarTypes):]...) - } - for i := range params { - if bindVarTypes[i] == 0 { - bindVarTypes[i] = params[i] - } - } - h.preparedStatements[message.Name] = PreparedStatementData{ - Query: query, - ReturnFields: fields, - BindVarTypes: bindVarTypes, - Stmt: stmt, - HandledOutsideEngine: false, - } - } else { - h.preparedStatements[message.Name] = PreparedStatementData{ - Query: query, - ReturnFields: nil, - BindVarTypes: nil, - Stmt: nil, - HandledOutsideEngine: true, + if !query.PgParsable { + query.StatementTag = GetStatementTag(stmt) + } + + // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY + // > A parameter data type can be left unspecified by setting it to zero, + // > or by making the array of parameter type OIDs shorter than the number of + // > parameter symbols ($n)used in the query string. + // > ... + // > Parameter data types can be specified by OID; + // > if not given, the parser attempts to infer the data types in the same way + // > as it would do for untyped literal string constants. + bindVarTypes := message.ParameterOIDs + if len(bindVarTypes) < len(params) { + bindVarTypes = append(bindVarTypes, params[len(bindVarTypes):]...) + } + for i := range params { + if bindVarTypes[i] == 0 { + bindVarTypes[i] = params[i] } } + h.preparedStatements[message.Name] = PreparedStatementData{ + Query: query, + ReturnFields: fields, + BindVarTypes: bindVarTypes, + Stmt: stmt, + } return h.send(&pgproto3.ParseComplete{}) } @@ -631,7 +629,7 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { // https://www.postgresql.org/docs/current/protocol-flow.html // > Note that since Bind has not yet been issued, the formats to be used for returned columns are not yet known to the backend; // > the format code fields in the RowDescription message will be zeroes in this case. - if !preparedStatementData.HandledOutsideEngine { + if preparedStatementData.Stmt != nil { fields = slices.Clone(preparedStatementData.ReturnFields) for i := range fields { fields[i].Format = 0 @@ -646,7 +644,7 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { return fmt.Errorf("portal %s does not exist", message.Name) } - if !portalData.HandledOutsideEngine { + if portalData.Stmt != nil { fields = portalData.Fields tag = portalData.Query.StatementTag } @@ -667,41 +665,40 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error { return fmt.Errorf("prepared statement %s does not exist", message.PreparedStatement) } - if !preparedData.HandledOutsideEngine { - if preparedData.Query.AST == nil { - // special case: empty query - h.portals[message.DestinationPortal] = PortalData{ - Query: preparedData.Query, - IsEmptyQuery: true, - } - return h.send(&pgproto3.BindComplete{}) + if preparedData.Stmt == nil { + h.portals[message.DestinationPortal] = PortalData{ + Query: preparedData.Query, + Fields: nil, + Stmt: nil, + Vars: nil, } + return h.send(&pgproto3.BindComplete{}) + } - bindVars, err := h.convertBindParameters(preparedData.BindVarTypes, message.ParameterFormatCodes, message.Parameters) - if err != nil { - return err + if preparedData.Query.AST == nil { + // special case: empty query + h.portals[message.DestinationPortal] = PortalData{ + Query: preparedData.Query, + IsEmptyQuery: true, } + return h.send(&pgproto3.BindComplete{}) + } - fields, err := h.duckHandler.ComBind(context.Background(), h.mysqlConn, preparedData, bindVars) - if err != nil { - return err - } + bindVars, err := h.convertBindParameters(preparedData.BindVarTypes, message.ParameterFormatCodes, message.Parameters) + if err != nil { + return err + } - h.portals[message.DestinationPortal] = PortalData{ - Query: preparedData.Query, - Fields: fields, - Stmt: preparedData.Stmt, - Vars: bindVars, - HandledOutsideEngine: false, - } - } else { - h.portals[message.DestinationPortal] = PortalData{ - Query: preparedData.Query, - Fields: nil, - Stmt: nil, - Vars: nil, - HandledOutsideEngine: true, - } + fields, err := h.duckHandler.ComBind(context.Background(), h.mysqlConn, preparedData, bindVars) + if err != nil { + return err + } + + h.portals[message.DestinationPortal] = PortalData{ + Query: preparedData.Query, + Fields: fields, + Stmt: preparedData.Stmt, + Vars: bindVars, } return h.send(&pgproto3.BindComplete{}) } diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index f794a908..974eba5d 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -24,7 +24,6 @@ import ( "os" "regexp" "runtime/trace" - "strings" "sync" "time" @@ -189,7 +188,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query if stmtType == duckdb.DUCKDB_STATEMENT_TYPE_SELECT || stmtType == duckdb.DUCKDB_STATEMENT_TYPE_RELATION { // Add LIMIT 0 to avoid executing the actual query. - query = "SELECT * FROM (" + strings.Trim(query, "; \t\r\n") + ") LIMIT 0" + query = "SELECT * FROM (" + sql.RemoveSpaceAndDelimiter(query, ';') + ") LIMIT 0" } params := make([]any, len(paramTypes)) // all nil rows, err = conn.QueryContext(sqlCtx, query, params...) diff --git a/pgserver/pg_catalog_handler.go b/pgserver/pg_catalog_handler.go index 13606115..00159fd2 100644 --- a/pgserver/pg_catalog_handler.go +++ b/pgserver/pg_catalog_handler.go @@ -25,8 +25,8 @@ var pgWALLSNRegex = regexp.MustCompile(`(?i)^\s*select\s+pg_catalog\.(pg_current // precompile a regex to match "select pg_catalog.current_setting('xxx');". var currentSettingRegex = regexp.MustCompile(`(?i)^\s*select\s+(pg_catalog.)?current_setting\(\s*'([^']+)'\s*\)\s*;?\s*$`) -// precompile a regex to match any "from pg_catalog.xxx" in the query. -var pgCatalogRegex = regexp.MustCompile(`(?i)\s+from\s+pg_catalog\.`) +// precompile a regex to match any "from pg_catalog.pg_stat_replication" in the query. +var pgCatalogRegex = regexp.MustCompile(`(?i)\s+from\s+pg_catalog\.(pg_stat_replication)`) // isInRecovery will get the count of func (h *ConnectionHandler) isInRecovery() (string, error) { @@ -166,7 +166,7 @@ func (h *ConnectionHandler) handleCurrentSetting(query ConvertedQuery) (bool, er func (h *ConnectionHandler) handlePgCatalog(query ConvertedQuery) (bool, error) { sql := RemoveComments(query.String) return true, h.query(ConvertedQuery{ - String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__."), + String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__.$1"), StatementTag: "SELECT", }) } @@ -212,6 +212,7 @@ func isSpecialPgCatalog(query ConvertedQuery) bool { var pgCatalogHandlers = map[string]PGCatalogHandler{ "SELECT": { HandledInPlace: func(query ConvertedQuery) (bool, error) { + // TODO(sean): Evaluate the conditions by iterating over the AST. if isPgIsInRecovery(query) { return true, nil }