Skip to content

Commit

Permalink
fix: adopt CR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
VWagen1989 committed Nov 28, 2024
1 parent 24dbc91 commit f06d6d8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 87 deletions.
20 changes: 9 additions & 11 deletions pgserver/connection_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,18 @@ type copyFromStdinState struct {
}

type PortalData struct {
Query ConvertedQuery
IsEmptyQuery bool
Fields []pgproto3.FieldDescription
Stmt *duckdb.Stmt
Vars []any
HandledOutsideEngine bool
Query ConvertedQuery
IsEmptyQuery bool
Fields []pgproto3.FieldDescription
Stmt *duckdb.Stmt
Vars []any
}

type PreparedStatementData struct {
Query ConvertedQuery
ReturnFields []pgproto3.FieldDescription
BindVarTypes []uint32
Stmt *duckdb.Stmt
HandledOutsideEngine bool
Query ConvertedQuery
ReturnFields []pgproto3.FieldDescription
BindVarTypes []uint32
Stmt *duckdb.Stmt
}

// VitessTypeToObjectID returns a type, as defined by Vitess, into a type as defined by Postgres.
Expand Down
139 changes: 68 additions & 71 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,50 +564,48 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error {
if err != nil {
return err
}

if !handledOutsideEngine {
stmt, params, fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST)
if err != nil {
return err
if handledOutsideEngine {
h.preparedStatements[message.Name] = PreparedStatementData{
Query: query,
ReturnFields: nil,
BindVarTypes: nil,
Stmt: nil,
}
return h.send(&pgproto3.ParseComplete{})
}

if !query.PgParsable {
query.StatementTag = GetStatementTag(stmt)
}
stmt, params, fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST)
if err != nil {
return err
}

// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
// > A parameter data type can be left unspecified by setting it to zero,
// > or by making the array of parameter type OIDs shorter than the number of
// > parameter symbols ($n)used in the query string.
// > ...
// > Parameter data types can be specified by OID;
// > if not given, the parser attempts to infer the data types in the same way
// > as it would do for untyped literal string constants.
bindVarTypes := message.ParameterOIDs
if len(bindVarTypes) < len(params) {
bindVarTypes = append(bindVarTypes, params[len(bindVarTypes):]...)
}
for i := range params {
if bindVarTypes[i] == 0 {
bindVarTypes[i] = params[i]
}
}
h.preparedStatements[message.Name] = PreparedStatementData{
Query: query,
ReturnFields: fields,
BindVarTypes: bindVarTypes,
Stmt: stmt,
HandledOutsideEngine: false,
}
} else {
h.preparedStatements[message.Name] = PreparedStatementData{
Query: query,
ReturnFields: nil,
BindVarTypes: nil,
Stmt: nil,
HandledOutsideEngine: true,
if !query.PgParsable {
query.StatementTag = GetStatementTag(stmt)
}

// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
// > A parameter data type can be left unspecified by setting it to zero,
// > or by making the array of parameter type OIDs shorter than the number of
// > parameter symbols ($n)used in the query string.
// > ...
// > Parameter data types can be specified by OID;
// > if not given, the parser attempts to infer the data types in the same way
// > as it would do for untyped literal string constants.
bindVarTypes := message.ParameterOIDs
if len(bindVarTypes) < len(params) {
bindVarTypes = append(bindVarTypes, params[len(bindVarTypes):]...)
}
for i := range params {
if bindVarTypes[i] == 0 {
bindVarTypes[i] = params[i]
}
}
h.preparedStatements[message.Name] = PreparedStatementData{
Query: query,
ReturnFields: fields,
BindVarTypes: bindVarTypes,
Stmt: stmt,
}

return h.send(&pgproto3.ParseComplete{})
}
Expand All @@ -628,7 +626,7 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error {
// https://www.postgresql.org/docs/current/protocol-flow.html
// > Note that since Bind has not yet been issued, the formats to be used for returned columns are not yet known to the backend;
// > the format code fields in the RowDescription message will be zeroes in this case.
if !preparedStatementData.HandledOutsideEngine {
if preparedStatementData.Stmt != nil {
fields = slices.Clone(preparedStatementData.ReturnFields)
for i := range fields {
fields[i].Format = 0
Expand All @@ -643,7 +641,7 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error {
return fmt.Errorf("portal %s does not exist", message.Name)
}

if !portalData.HandledOutsideEngine {
if portalData.Stmt != nil {
fields = portalData.Fields
tag = portalData.Query.StatementTag
}
Expand All @@ -664,41 +662,40 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error {
return fmt.Errorf("prepared statement %s does not exist", message.PreparedStatement)
}

if !preparedData.HandledOutsideEngine {
if preparedData.Query.AST == nil {
// special case: empty query
h.portals[message.DestinationPortal] = PortalData{
Query: preparedData.Query,
IsEmptyQuery: true,
}
return h.send(&pgproto3.BindComplete{})
if preparedData.Stmt == nil {
h.portals[message.DestinationPortal] = PortalData{
Query: preparedData.Query,
Fields: nil,
Stmt: nil,
Vars: nil,
}
return h.send(&pgproto3.BindComplete{})
}

bindVars, err := h.convertBindParameters(preparedData.BindVarTypes, message.ParameterFormatCodes, message.Parameters)
if err != nil {
return err
if preparedData.Query.AST == nil {
// special case: empty query
h.portals[message.DestinationPortal] = PortalData{
Query: preparedData.Query,
IsEmptyQuery: true,
}
return h.send(&pgproto3.BindComplete{})
}

fields, err := h.duckHandler.ComBind(context.Background(), h.mysqlConn, preparedData, bindVars)
if err != nil {
return err
}
bindVars, err := h.convertBindParameters(preparedData.BindVarTypes, message.ParameterFormatCodes, message.Parameters)
if err != nil {
return err
}

h.portals[message.DestinationPortal] = PortalData{
Query: preparedData.Query,
Fields: fields,
Stmt: preparedData.Stmt,
Vars: bindVars,
HandledOutsideEngine: false,
}
} else {
h.portals[message.DestinationPortal] = PortalData{
Query: preparedData.Query,
Fields: nil,
Stmt: nil,
Vars: nil,
HandledOutsideEngine: true,
}
fields, err := h.duckHandler.ComBind(context.Background(), h.mysqlConn, preparedData, bindVars)
if err != nil {
return err
}

h.portals[message.DestinationPortal] = PortalData{
Query: preparedData.Query,
Fields: fields,
Stmt: preparedData.Stmt,
Vars: bindVars,
}
return h.send(&pgproto3.BindComplete{})
}
Expand Down
3 changes: 1 addition & 2 deletions pgserver/duck_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"os"
"regexp"
"runtime/trace"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -189,7 +188,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query
if stmtType == duckdb.DUCKDB_STATEMENT_TYPE_SELECT ||
stmtType == duckdb.DUCKDB_STATEMENT_TYPE_RELATION {
// Add LIMIT 0 to avoid executing the actual query.
query = "SELECT * FROM (" + strings.Trim(query, "; \t\r\n") + ") LIMIT 0"
query = "SELECT * FROM (" + sql.RemoveSpaceAndDelimiter(query, ';') + ") LIMIT 0"
}
params := make([]any, len(paramTypes)) // all nil
rows, err = conn.QueryContext(sqlCtx, query, params...)
Expand Down
7 changes: 4 additions & 3 deletions pgserver/pg_catalog_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ var pgWALLSNRegex = regexp.MustCompile(`(?i)^\s*select\s+pg_catalog\.(pg_current
// precompile a regex to match "select pg_catalog.current_setting('xxx');".
var currentSettingRegex = regexp.MustCompile(`(?i)^\s*select\s+(pg_catalog.)?current_setting\(\s*'([^']+)'\s*\)\s*;?\s*$`)

// precompile a regex to match any "from pg_catalog.xxx" in the query.
var pgCatalogRegex = regexp.MustCompile(`(?i)\s+from\s+pg_catalog\.`)
// precompile a regex to match any "from pg_catalog.pg_stat_replication" in the query.
var pgCatalogRegex = regexp.MustCompile(`(?i)\s+from\s+pg_catalog\.(pg_stat_replication)`)

// isInRecovery will get the count of
func (h *ConnectionHandler) isInRecovery() (string, error) {
Expand Down Expand Up @@ -166,7 +166,7 @@ func (h *ConnectionHandler) handleCurrentSetting(query ConvertedQuery) (bool, er
func (h *ConnectionHandler) handlePgCatalog(query ConvertedQuery) (bool, error) {
sql := RemoveComments(query.String)
return true, h.query(ConvertedQuery{
String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__."),
String: pgCatalogRegex.ReplaceAllString(sql, " FROM __sys__.$1"),
StatementTag: "SELECT",
})
}
Expand Down Expand Up @@ -212,6 +212,7 @@ func isSpecialPgCatalog(query ConvertedQuery) bool {
var pgCatalogHandlers = map[string]PGCatalogHandler{
"SELECT": {
HandledInPlace: func(query ConvertedQuery) (bool, error) {
// TODO(sean): Evaluate the conditions by iterating over the AST.
if isPgIsInRecovery(query) {
return true, nil
}
Expand Down

0 comments on commit f06d6d8

Please sign in to comment.