From 1f08352a6c847ed729923bcccb1eaeee1f76d345 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Mon, 18 Nov 2024 20:24:10 +0800 Subject: [PATCH] fix: always use text format for simple query mode (#168) --- pgserver/connection_handler.go | 17 +++++++---- pgserver/duck_handler.go | 53 ++++++++++++++++++++++++---------- pgtypes/pgtypes.go | 35 ++++++++++++++++------ 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 3e764ffb..385ce188 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -25,6 +25,7 @@ import ( "net" "os" "runtime/debug" + "slices" "strings" "unicode" @@ -503,7 +504,14 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { return fmt.Errorf("prepared statement %s does not exist", message.Name) } - fields = preparedStatementData.ReturnFields + // https://www.postgresql.org/docs/current/protocol-flow.html + // > Note that since Bind has not yet been issued, the formats to be used for returned columns are not yet known to the backend; + // > the format code fields in the RowDescription message will be zeroes in this case. + fields = slices.Clone(preparedStatementData.ReturnFields) + for i := range fields { + fields[i].Format = 0 + } + bindvarTypes = preparedStatementData.BindVarTypes tag = preparedStatementData.Query.StatementTag } else { @@ -826,9 +834,6 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false) err := h.duckHandler.ComQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback) if err != nil { - if strings.HasPrefix(err.Error(), "syntax error at position") { - return fmt.Errorf("This statement is not yet supported") - } return err } @@ -841,10 +846,11 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute // IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query. isIUD := tag == "INSERT" || tag == "UPDATE" || tag == "DELETE" return func(res *Result) error { - logrus.Tracef("spooling %d rows for tag %s", res.RowsAffected, tag) + logrus.Tracef("spooling %d rows for tag %s (execute = %v)", res.RowsAffected, tag, isExecute) if returnsRow(tag) { // EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it if !isExecute { + logrus.Tracef("sending RowDescription %+v for tag %s", res.Fields, tag) if err := h.send(&pgproto3.RowDescription{ Fields: res.Fields, }); err != nil { @@ -852,6 +858,7 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute } } + logrus.Tracef("sending Rows %+v for tag %s", res.Rows, tag) for _, row := range res.Rows { if err := h.send(&pgproto3.DataRow{ Values: row.val, diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index a4032262..ce932dd8 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -69,6 +69,13 @@ type Row struct { const rowsBatch = 128 +type QueryMode bool + +const ( + SimpleQueryMode QueryMode = false + ExtendedQueryMode QueryMode = true +) + // DuckHandler is a handler uses DuckDB and the SQLe engine directly // running Postgres specific queries. type DuckHandler struct { @@ -101,7 +108,7 @@ func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared Prepa // ComExecuteBound implements the Handler interface. func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error { - err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, h.executeBoundPlan, callback) + err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, ExtendedQueryMode, h.executeBoundPlan, callback) if err != nil { err = sql.CastSQLError(err) } @@ -193,7 +200,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query if err != nil { break } - fields = schemaToFieldDescriptions(sqlCtx, schema) + fields = schemaToFieldDescriptions(sqlCtx, schema, ExtendedQueryMode) default: // For other statements, we just return the "affected rows" field. fields = []pgproto3.FieldDescription{ @@ -214,7 +221,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query // ComQuery implements the Handler interface. func (h *DuckHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, callback func(*Result) error) error { - err := h.doQuery(ctx, c, query, parsed, nil, nil, h.executeQuery, callback) + err := h.doQuery(ctx, c, query, parsed, nil, nil, SimpleQueryMode, h.executeQuery, callback) if err != nil { err = sql.CastSQLError(err) } @@ -291,7 +298,7 @@ func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (stri var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`) -func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, queryExec QueryExecutor, callback func(*Result) error) error { +func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, mode QueryMode, queryExec QueryExecutor, callback func(*Result) error) error { sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) if err != nil { return err @@ -351,10 +358,10 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, } else if schema == nil { r, err = resultForEmptyIter(sqlCtx, rowIter) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { - resultFields := schemaToFieldDescriptions(sqlCtx, schema) + resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode) r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields) } else { - resultFields := schemaToFieldDescriptions(sqlCtx, schema) + resultFields := schemaToFieldDescriptions(sqlCtx, schema, mode) r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields) } if err != nil { @@ -418,7 +425,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S })) default: - rows, err = adapter.Query(ctx, query) + rows, err = adapter.QueryCatalog(ctx, query) if err != nil { break } @@ -562,7 +569,7 @@ func (h *DuckHandler) maybeReleaseAllLocks(c *mysql.Conn) { } } -func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldDescription { +func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, mode QueryMode) []pgproto3.FieldDescription { fields := make([]pgproto3.FieldDescription, len(s)) for i, c := range s { var oid uint32 @@ -571,7 +578,13 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldD var err error if pgType, ok := c.Type.(pgtypes.PostgresType); ok { oid = pgType.PG.OID - format = pgType.PG.Codec.PreferredFormat() + if mode == SimpleQueryMode { + // https://www.postgresql.org/docs/current/protocol-flow.html + // > In simple Query mode, the format of retrieved values is always text, except ... + format = pgtype.TextFormatCode + } else { + format = pgType.PG.Codec.PreferredFormat() + } size = int16(pgType.Size) } else { oid, err = VitessTypeToObjectID(c.Type.Type()) @@ -651,7 +664,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, return nil, err } - outputRow, err := rowToBytes(ctx, schema, row) + outputRow, err := rowToBytes(ctx, schema, resultFields, row) if err != nil { return nil, err } @@ -752,12 +765,12 @@ func (h *DuckHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, continue } - outputRow, err := rowToBytes(ctx, schema, row) + outputRow, err := rowToBytes(ctx, schema, resultFields, row) if err != nil { return err } - ctx.GetLogger().Tracef("spooling result row %s", outputRow) + ctx.GetLogger().Tracef("spooling result row %+v", outputRow) r.Rows = append(r.Rows, Row{outputRow}) r.RowsAffected++ case <-timer.C: @@ -791,7 +804,17 @@ func (h *DuckHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, return } -func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) { +func rowToBytes(ctx *sql.Context, s sql.Schema, fields []pgproto3.FieldDescription, row sql.Row) ([][]byte, error) { + if logger := ctx.GetLogger(); logger.Logger.Level >= logrus.TraceLevel { + logger = logger.WithField("func", rowToBytes) + logger.Tracef("row: %+v\n", row) + types := make([]sql.Type, len(s)) + for i, c := range s { + types[i] = c.Type + } + logger.Tracef("types: %+v\n", types) + logger.Tracef("fields: %+v\n", fields) + } if len(row) == 0 { return nil, nil } @@ -807,8 +830,8 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) { } // TODO(fan): Preallocate the buffer - if pgType, ok := s[i].Type.(pgtypes.PostgresType); ok { - bytes, err := pgType.Encode(v, []byte{}) + if _, ok := s[i].Type.(pgtypes.PostgresType); ok { + bytes, err := pgtypes.DefaultTypeMap.Encode(fields[i].DataTypeOID, fields[i].Format, v, nil) if err != nil { return nil, err } diff --git a/pgtypes/pgtypes.go b/pgtypes/pgtypes.go index 250298e5..c39ad561 100644 --- a/pgtypes/pgtypes.go +++ b/pgtypes/pgtypes.go @@ -88,22 +88,23 @@ var DuckdbTypeToPostgresOID = map[duckdb.Type]uint32{ var PostgresTypeSizes = map[uint32]int32{ pgtype.BoolOID: 1, // bool pgtype.ByteaOID: -1, // bytea - pgtype.NameOID: -1, // name + pgtype.NameOID: 64, // name pgtype.Int8OID: 8, // int8 pgtype.Int2OID: 2, // int2 pgtype.Int4OID: 4, // int4 pgtype.TextOID: -1, // text pgtype.OIDOID: 4, // oid - pgtype.TIDOID: 8, // tid - pgtype.XIDOID: -1, // xid - pgtype.CIDOID: -1, // cid + pgtype.TIDOID: 6, // tid + pgtype.XIDOID: 4, // xid + pgtype.CIDOID: 4, // cid pgtype.JSONOID: -1, // json pgtype.XMLOID: -1, // xml pgtype.PointOID: 8, // point pgtype.Float4OID: 4, // float4 pgtype.Float8OID: 8, // float8 - pgtype.UnknownOID: -1, // unknown - pgtype.MacaddrOID: -1, // macaddr + pgtype.UnknownOID: -2, // unknown + pgtype.MacaddrOID: 6, // macaddr + pgtype.Macaddr8OID: 8, // macaddr8 pgtype.InetOID: -1, // inet pgtype.BoolArrayOID: -1, // bool[] pgtype.ByteaArrayOID: -1, // bytea[] @@ -115,8 +116,10 @@ var PostgresTypeSizes = map[uint32]int32{ pgtype.VarcharOID: -1, // varchar pgtype.DateOID: 4, // date pgtype.TimeOID: 8, // time + pgtype.TimetzOID: 12, // timetz pgtype.TimestampOID: 8, // timestamp pgtype.TimestamptzOID: 8, // timestamptz + pgtype.IntervalOID: 16, // interval pgtype.NumericOID: -1, // numeric pgtype.UUIDOID: 16, // uuid } @@ -182,11 +185,16 @@ func InferSchema(rows *stdsql.Rows) (sql.Schema, error) { } nullable, _ := t.Nullable() + size := int32(-1) + if s, ok := PostgresTypeSizes[pgType.OID]; ok { + size = s + } + schema[i] = &sql.Column{ Name: t.Name(), Type: PostgresType{ PG: pgType, - Size: PostgresTypeSizes[pgType.OID], + Size: size, }, Nullable: nullable, } @@ -216,11 +224,16 @@ func InferDriverSchema(rows driver.Rows) (sql.Schema, error) { nullable, _ = colNullable.ColumnTypeNullable(i) } + size := int32(-1) + if s, ok := PostgresTypeSizes[pgType.OID]; ok { + size = s + } + schema[i] = &sql.Column{ Name: colName, Type: PostgresType{ PG: pgType, - Size: PostgresTypeSizes[pgType.OID], + Size: size, }, Nullable: nullable, } @@ -239,9 +252,13 @@ func NewPostgresType(oid uint32) (PostgresType, error) { if !ok { return PostgresType{}, fmt.Errorf("unsupported type OID %d", oid) } + size := int32(-1) + if s, ok := PostgresTypeSizes[oid]; ok { + size = s + } return PostgresType{ PG: t, - Size: PostgresTypeSizes[oid], + Size: size, }, nil }