Skip to content

Commit

Permalink
fix: always use text format for simple query mode (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 authored Nov 18, 2024
1 parent f1448a9 commit 1f08352
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 29 deletions.
17 changes: 12 additions & 5 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net"
"os"
"runtime/debug"
"slices"
"strings"
"unicode"

Expand Down Expand Up @@ -503,7 +504,14 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error {
return fmt.Errorf("prepared statement %s does not exist", message.Name)
}

fields = preparedStatementData.ReturnFields
// https://www.postgresql.org/docs/current/protocol-flow.html
// > Note that since Bind has not yet been issued, the formats to be used for returned columns are not yet known to the backend;
// > the format code fields in the RowDescription message will be zeroes in this case.
fields = slices.Clone(preparedStatementData.ReturnFields)
for i := range fields {
fields[i].Format = 0
}

bindvarTypes = preparedStatementData.BindVarTypes
tag = preparedStatementData.Query.StatementTag
} else {
Expand Down Expand Up @@ -826,9 +834,6 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error {
callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false)
err := h.duckHandler.ComQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback)
if err != nil {
if strings.HasPrefix(err.Error(), "syntax error at position") {
return fmt.Errorf("This statement is not yet supported")
}
return err
}

Expand All @@ -841,17 +846,19 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute
// IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query.
isIUD := tag == "INSERT" || tag == "UPDATE" || tag == "DELETE"
return func(res *Result) error {
logrus.Tracef("spooling %d rows for tag %s", res.RowsAffected, tag)
logrus.Tracef("spooling %d rows for tag %s (execute = %v)", res.RowsAffected, tag, isExecute)
if returnsRow(tag) {
// EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it
if !isExecute {
logrus.Tracef("sending RowDescription %+v for tag %s", res.Fields, tag)
if err := h.send(&pgproto3.RowDescription{
Fields: res.Fields,
}); err != nil {
return err
}
}

logrus.Tracef("sending Rows %+v for tag %s", res.Rows, tag)
for _, row := range res.Rows {
if err := h.send(&pgproto3.DataRow{
Values: row.val,
Expand Down
53 changes: 38 additions & 15 deletions pgserver/duck_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ type Row struct {

const rowsBatch = 128

type QueryMode bool

const (
SimpleQueryMode QueryMode = false
ExtendedQueryMode QueryMode = true
)

// DuckHandler is a handler uses DuckDB and the SQLe engine directly
// running Postgres specific queries.
type DuckHandler struct {
Expand Down Expand Up @@ -101,7 +108,7 @@ func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared Prepa

// ComExecuteBound implements the Handler interface.
func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error {
err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, h.executeBoundPlan, callback)
err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, ExtendedQueryMode, h.executeBoundPlan, callback)
if err != nil {
err = sql.CastSQLError(err)
}
Expand Down Expand Up @@ -193,7 +200,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query
if err != nil {
break
}
fields = schemaToFieldDescriptions(sqlCtx, schema)
fields = schemaToFieldDescriptions(sqlCtx, schema, ExtendedQueryMode)
default:
// For other statements, we just return the "affected rows" field.
fields = []pgproto3.FieldDescription{
Expand All @@ -214,7 +221,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query

// ComQuery implements the Handler interface.
func (h *DuckHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, callback func(*Result) error) error {
err := h.doQuery(ctx, c, query, parsed, nil, nil, h.executeQuery, callback)
err := h.doQuery(ctx, c, query, parsed, nil, nil, SimpleQueryMode, h.executeQuery, callback)
if err != nil {
err = sql.CastSQLError(err)
}
Expand Down Expand Up @@ -291,7 +298,7 @@ func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (stri

var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`)

func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, queryExec QueryExecutor, callback func(*Result) error) error {
func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, mode QueryMode, queryExec QueryExecutor, callback func(*Result) error) error {
sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query)
if err != nil {
return err
Expand Down Expand Up @@ -351,10 +358,10 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string,
} else if schema == nil {
r, err = resultForEmptyIter(sqlCtx, rowIter)
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
resultFields := schemaToFieldDescriptions(sqlCtx, schema)
resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode)
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields)
} else {
resultFields := schemaToFieldDescriptions(sqlCtx, schema)
resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode)
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields)
}
if err != nil {
Expand Down Expand Up @@ -418,7 +425,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S
}))

default:
rows, err = adapter.Query(ctx, query)
rows, err = adapter.QueryCatalog(ctx, query)
if err != nil {
break
}
Expand Down Expand Up @@ -562,7 +569,7 @@ func (h *DuckHandler) maybeReleaseAllLocks(c *mysql.Conn) {
}
}

func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldDescription {
func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) []pgproto3.FieldDescription {
fields := make([]pgproto3.FieldDescription, len(s))
for i, c := range s {
var oid uint32
Expand All @@ -571,7 +578,13 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldD
var err error
if pgType, ok := c.Type.(pgtypes.PostgresType); ok {
oid = pgType.PG.OID
format = pgType.PG.Codec.PreferredFormat()
if mode == SimpleQueryMode {
// https://www.postgresql.org/docs/current/protocol-flow.html
// > In simple Query mode, the format of retrieved values is always text, except ...
format = pgtype.TextFormatCode
} else {
format = pgType.PG.Codec.PreferredFormat()
}
size = int16(pgType.Size)
} else {
oid, err = VitessTypeToObjectID(c.Type.Type())
Expand Down Expand Up @@ -651,7 +664,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
return nil, err
}

outputRow, err := rowToBytes(ctx, schema, row)
outputRow, err := rowToBytes(ctx, schema, resultFields, row)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -752,12 +765,12 @@ func (h *DuckHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema,
continue
}

outputRow, err := rowToBytes(ctx, schema, row)
outputRow, err := rowToBytes(ctx, schema, resultFields, row)
if err != nil {
return err
}

ctx.GetLogger().Tracef("spooling result row %s", outputRow)
ctx.GetLogger().Tracef("spooling result row %+v", outputRow)
r.Rows = append(r.Rows, Row{outputRow})
r.RowsAffected++
case <-timer.C:
Expand Down Expand Up @@ -791,7 +804,17 @@ func (h *DuckHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema,
return
}

func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) {
func rowToBytes(ctx *sql.Context, s sql.Schema, fields []pgproto3.FieldDescription, row sql.Row) ([][]byte, error) {
if logger := ctx.GetLogger(); logger.Logger.Level >= logrus.TraceLevel {
logger = logger.WithField("func", rowToBytes)
logger.Tracef("row: %+v\n", row)
types := make([]sql.Type, len(s))
for i, c := range s {
types[i] = c.Type
}
logger.Tracef("types: %+v\n", types)
logger.Tracef("fields: %+v\n", fields)
}
if len(row) == 0 {
return nil, nil
}
Expand All @@ -807,8 +830,8 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) {
}

// TODO(fan): Preallocate the buffer
if pgType, ok := s[i].Type.(pgtypes.PostgresType); ok {
bytes, err := pgType.Encode(v, []byte{})
if _, ok := s[i].Type.(pgtypes.PostgresType); ok {
bytes, err := pgtypes.DefaultTypeMap.Encode(fields[i].DataTypeOID, fields[i].Format, v, nil)
if err != nil {
return nil, err
}
Expand Down
35 changes: 26 additions & 9 deletions pgtypes/pgtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,23 @@ var DuckdbTypeToPostgresOID = map[duckdb.Type]uint32{
var PostgresTypeSizes = map[uint32]int32{
pgtype.BoolOID: 1, // bool
pgtype.ByteaOID: -1, // bytea
pgtype.NameOID: -1, // name
pgtype.NameOID: 64, // name
pgtype.Int8OID: 8, // int8
pgtype.Int2OID: 2, // int2
pgtype.Int4OID: 4, // int4
pgtype.TextOID: -1, // text
pgtype.OIDOID: 4, // oid
pgtype.TIDOID: 8, // tid
pgtype.XIDOID: -1, // xid
pgtype.CIDOID: -1, // cid
pgtype.TIDOID: 6, // tid
pgtype.XIDOID: 4, // xid
pgtype.CIDOID: 4, // cid
pgtype.JSONOID: -1, // json
pgtype.XMLOID: -1, // xml
pgtype.PointOID: 8, // point
pgtype.Float4OID: 4, // float4
pgtype.Float8OID: 8, // float8
pgtype.UnknownOID: -1, // unknown
pgtype.MacaddrOID: -1, // macaddr
pgtype.UnknownOID: -2, // unknown
pgtype.MacaddrOID: 6, // macaddr
pgtype.Macaddr8OID: 8, // macaddr8
pgtype.InetOID: -1, // inet
pgtype.BoolArrayOID: -1, // bool[]
pgtype.ByteaArrayOID: -1, // bytea[]
Expand All @@ -115,8 +116,10 @@ var PostgresTypeSizes = map[uint32]int32{
pgtype.VarcharOID: -1, // varchar
pgtype.DateOID: 4, // date
pgtype.TimeOID: 8, // time
pgtype.TimetzOID: 12, // timetz
pgtype.TimestampOID: 8, // timestamp
pgtype.TimestamptzOID: 8, // timestamptz
pgtype.IntervalOID: 16, // interval
pgtype.NumericOID: -1, // numeric
pgtype.UUIDOID: 16, // uuid
}
Expand Down Expand Up @@ -182,11 +185,16 @@ func InferSchema(rows *stdsql.Rows) (sql.Schema, error) {
}
nullable, _ := t.Nullable()

size := int32(-1)
if s, ok := PostgresTypeSizes[pgType.OID]; ok {
size = s
}

schema[i] = &sql.Column{
Name: t.Name(),
Type: PostgresType{
PG: pgType,
Size: PostgresTypeSizes[pgType.OID],
Size: size,
},
Nullable: nullable,
}
Expand Down Expand Up @@ -216,11 +224,16 @@ func InferDriverSchema(rows driver.Rows) (sql.Schema, error) {
nullable, _ = colNullable.ColumnTypeNullable(i)
}

size := int32(-1)
if s, ok := PostgresTypeSizes[pgType.OID]; ok {
size = s
}

schema[i] = &sql.Column{
Name: colName,
Type: PostgresType{
PG: pgType,
Size: PostgresTypeSizes[pgType.OID],
Size: size,
},
Nullable: nullable,
}
Expand All @@ -239,9 +252,13 @@ func NewPostgresType(oid uint32) (PostgresType, error) {
if !ok {
return PostgresType{}, fmt.Errorf("unsupported type OID %d", oid)
}
size := int32(-1)
if s, ok := PostgresTypeSizes[oid]; ok {
size = s
}
return PostgresType{
PG: t,
Size: PostgresTypeSizes[oid],
Size: size,
}, nil
}

Expand Down

0 comments on commit 1f08352

Please sign in to comment.