From 3ee9422a265ce907bf8f168ac5e083b9ba514715 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20de=20Giessen?= Date: Mon, 24 Mar 2025 11:07:38 +0100 Subject: [PATCH 1/2] Implement server side cursors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To be able to use the MySQL foreign data wrapper for Postgres, we have to support server side cursors. To that end each connection now keeps track of a list of open cursors, and implements the COM_STMT_FETCH command for retrieving data from an open cursor. Based on the implementation in the Dolthub fork of Vitess: https://github.com/dolthub/vitess/pull/228 https://github.com/dolthub/vitess/pull/232 Co-Authored-By: James Cor Co-Authored-By: Aaron Son Signed-off-by: Daniël van de Giessen --- go/mysql/conn.go | 325 ++++++++++++++++++++++++++++++++------- go/mysql/conn_test.go | 13 +- go/mysql/constants.go | 13 +- go/mysql/mysql_fuzzer.go | 1 + go/mysql/query.go | 175 +++++++++++++++------ go/mysql/query_test.go | 246 ++++++++++++++++++++++++++++- go/mysql/server.go | 2 + go/mysql/server_test.go | 71 ++++++++- 8 files changed, 738 insertions(+), 108 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 95016bf21d7..ade148a2dbe 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -154,6 +154,12 @@ type Conn struct { // PrepareData is the map to use a prepared statement. PrepareData map[uint32]*PrepareData + // cursorStates represent queries which are running as a result of a + // COM_STMT_EXECUTE. Rows will be fetched with COM_STMT_FETCH, and + // possibly a query will need to be terminated as a result of what + // happens in the protocol, etc. + cursorStates map[uint32]*CursorState + // protects the bufferedWriter and bufferedReader bufMu sync.Mutex @@ -234,6 +240,31 @@ type PrepareData struct { ParamsCount uint16 } +// CursorSate +type CursorState struct { + // The goroutine which generates the results delivers |Result|s with a + // batch of rows one at a time on the |next| channel. As long as the + // query has not delivered EOF, we will have pending rows, because we + // prefetch them as soon as they are exhausted. That is, if we have 10 + // pending rows and a client fetches only 10 rows, we will block on + // fetching more rows before ew return the currently cached 10 rows. + // This allows us to detect EOF and correctly return CursorExhausted. + pending *sqltypes.Result + + // The channel on which the running query is sending batched results. + next chan *sqltypes.Result + + // The channel on which the running query will return its final error + // status, either `nil` or an error. + done chan error + + // A control channel on which the running query is listening. If the + // server needs to cancel the inflight query based on what is happening + // in the wire protocol handling, it will send on this channel and then + // block on `done` being sent to. + quit chan error +} + // execResult is an enum signifying the result of executing a query type execResult byte @@ -285,6 +316,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn { conn: conn, listener: listener, PrepareData: make(map[uint32]*PrepareData), + cursorStates: make(map[uint32]*CursorState), keepAliveOn: enabledKeepAlive, flushDelay: listener.flushDelay, truncateErrLen: listener.truncateErrLen, @@ -941,6 +973,9 @@ func (c *Conn) handleNextCommand(handler Handler) bool { if ok { delete(c.PrepareData, stmtID) } + c.discardCursor(stmtID) + case ComStmtFetch: + return c.handleComStmtFetch(handler, data) case ComStmtReset: return c.handleComStmtReset(data) case ComResetConnection: @@ -1037,6 +1072,7 @@ func (c *Conn) handleComBinlogDumpGTID(handler Handler, data []byte) (kontinue b func (c *Conn) handleComResetConnection(handler Handler) { // Clean up and reset the connection c.recycleReadPacket() + c.discardAllCursors() handler.ComResetConnection(c) // Reset prepared statements c.PrepareData = make(map[uint32]*PrepareData) @@ -1070,6 +1106,8 @@ func (c *Conn) handleComStmtReset(data []byte) bool { } } + c.discardCursor(stmtID) + if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil { log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) return false @@ -1077,6 +1115,81 @@ func (c *Conn) handleComStmtReset(data []byte) bool { return true } +func (c *Conn) handleComStmtFetch(handler Handler, data []byte) (kontinue bool) { + c.startWriterBuffering() + + stmtID, numRows, ok := c.parseComStmtFetch(data) + c.recycleReadPacket() + if !ok { + log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(sqlerror.ERUnknownComError, sqlerror.SSNetError, "error handling packet: %v", data) { + return false + } + return c.endWriterBuffering() == nil + } + + // fetching from wrong statement + cs, ok := c.cursorStates[stmtID] + if !ok || cs == nil { + log.Errorf("Requested stmtID does not have an open cursor. Client %v, returning error: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(sqlerror.ERUnknownComError, sqlerror.SSNetError, "error handling packet: %v", data) { + return false + } + return c.endWriterBuffering() == nil + } + + // There is always a pending result set, because we prefetch it to detect EOF. + // When we detect EOF, we set c.cs = nil. + + for cs != nil && numRows != 0 { + toSend := uint32(len(cs.pending.Rows)) + if toSend > numRows { + toSend = numRows + } + nextRows := cs.pending.Rows[toSend:] + cs.pending.Rows = cs.pending.Rows[:toSend] + + if err := c.writeBinaryRows(cs.pending); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false + } + cs.pending.Rows = nextRows + numRows -= toSend + if len(cs.pending.Rows) == 0 { + var ok bool + cs.pending, ok = <-cs.next + if !ok { + // Query has terminated. Check for an error. + err := <-cs.done + cs = nil + delete(c.cursorStates, stmtID) + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return false + } + } + } + } + + if cs == nil { + c.StatusFlags |= uint16(ServerStatusLastRowSent) + } + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false + } + if cs == nil { + c.StatusFlags &= ^uint16(ServerStatusLastRowSent) + } + if err := c.endWriterBuffering(); err != nil { + log.Errorf("Conn %v: endWriterBuffering() failed: %v", c.ID(), err) + return false + } + return true +} + func (c *Conn) handleComStmtSendLongData(data []byte) bool { stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data) c.recycleReadPacket() @@ -1115,10 +1228,19 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool kontinue = false } }() - queryStart := time.Now() - stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) + + // flush is called at the end of this block. + // To simplify error handling, we do not + // encapsulate it with a defer'd func() + c.startWriterBuffering() + + stmtID, cursorType, err := c.parseComStmtExecute(c.PrepareData, data) c.recycleReadPacket() + if err != nil { + return c.writeErrorPacketFromErrorAndLog(err) + } + if stmtID != uint32(0) { defer func() { // Allocate a new bindvar map every time since VTGate.Execute() mutates it. @@ -1127,73 +1249,144 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool }() } - if err != nil { - return c.writeErrorPacketFromErrorAndLog(err) - } - - receivedResult := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false + queryStart := time.Now() prepare := c.PrepareData[stmtID] - err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } + if cursorType == CursorTypeNoCursor { + receivedResult := false + sendFinished := false // sendFinished is set if the response should just be an OK packet. + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !receivedResult { + receivedResult = true + + if len(qr.Fields) == 0 { + sendFinished = true + // We should not send any more packets after this. + ok := PacketOK{ + affectedRows: qr.RowsAffected, + lastInsertID: qr.InsertID, + statusFlags: c.StatusFlags, + warnings: 0, + info: "", + sessionStateData: qr.SessionStateChanges, + } + return c.writeOKPacket(&ok) + } + if err := c.writeFields(qr); err != nil { + return err + } + } + + return c.writeBinaryRows(qr) + }) + // If no field was sent, we expect an error. if !receivedResult { - receivedResult = true + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = sqlerror.NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if !c.writeErrorPacketFromErrorAndLog(err) { + return false + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return false + } - if len(qr.Fields) == 0 { - sendFinished = true - // We should not send any more packets after this. - ok := PacketOK{ - affectedRows: qr.RowsAffected, - lastInsertID: qr.InsertID, - statusFlags: c.StatusFlags, - warnings: 0, - info: "", - sessionStateData: qr.SessionStateChanges, + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false } - return c.writeOKPacket(&ok) - } - if err := c.writeFields(qr); err != nil { - return err } } + return true + } - return c.writeBinaryRows(qr) - }) + next := make(chan *sqltypes.Result) + done, quit := make(chan error), make(chan error) - // If no field was sent, we expect an error. - if !receivedResult { - // This is just a failsafe. Should never happen. - if err == nil || err == io.EOF { - err = sqlerror.NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) - } - if !c.writeErrorPacketFromErrorAndLog(err) { - return false - } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) + go func() { + var err error + defer func() { + // pass along error, even if there's a panic + if r := recover(); r != nil { + err = fmt.Errorf("panic while running query for server-side cursor: %v", r) + } + close(next) + done <- err + close(done) + }() + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + // block until query results are sent or receive signal to quit + var qerr error + select { + case next <- qr: + case qerr = <-quit: + } + return qerr + }) + }() + + // Immediately receive the very first query result to write the fields + qr, ok := <-next + if !ok { + <-done + if !c.writeErrorAndLog(sqlerror.ERUnknownError, sqlerror.SSUnknownSQLState, "error handling packet: %v", "missing result set") { + log.Errorf("Error writing query error to %s", c) return false } + return false + } - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return false - } - } + defer timings.Record(queryTimingKey, queryStart) + + if len(qr.Fields) == 0 { + // DML or something without a result set. We do not open a cursor here. + <-done + return c.writeOKPacket(&PacketOK{ + affectedRows: qr.RowsAffected, + lastInsertID: qr.InsertID, + statusFlags: c.StatusFlags, + warnings: 0, + }) == nil } - timings.Record(queryTimingKey, queryStart) + // Open the cursor and write the fields. + c.StatusFlags |= uint16(ServerStatusCursorExists) + if err := c.writeFieldsWithoutEOF(qr); err != nil { + log.Errorf("Error writing fields to %s: %v", c, err) + return false + } + + // TODO: Look into whether accessing WarningCount + // here after passing `c` to ComStmtExecute in the + // goroutine above races. + if werr := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); werr != nil { + log.Errorf("Error writing result to %s: %v", c, werr) + return false + } + + // After writing the EOF_Packet/OK_Packet above, we + // have told the client the cursor is open. + c.StatusFlags &= ^uint16(ServerStatusCursorExists) + c.cursorStates[stmtID] = &CursorState{ + next: next, + done: done, + quit: quit, + pending: qr, + } return true } @@ -1469,6 +1662,28 @@ func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) { return true } +// discardCursor stops the statement execute goroutine and clears the cursor state, if it exists +func (c *Conn) discardCursor(stmtID uint32) { + // close cursor if open with unread results + cs, ok := c.cursorStates[stmtID] + if ok { + select { + case cs.quit <- errors.New("cancel cursor query"): + <-cs.done + case <-cs.done: + } + delete(c.cursorStates, stmtID) + } +} + +// discardAllCursors discards all open cursors on the connection +func (c *Conn) discardAllCursors() { + for stmtID := range c.cursorStates { + c.discardCursor(stmtID) + } + c.cursorStates = make(map[uint32]*CursorState) +} + func (c *Conn) execQuery(query string, handler Handler, more bool) execResult { callbackCalled := false // sendFinished is set if the response should just be an OK packet. diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index aab26763fcd..0ca3f89be45 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -79,6 +79,7 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) { cConn := newConn(clientConn, DefaultFlushDelay, 0) sConn := newConn(serverConn, DefaultFlushDelay, 0) sConn.PrepareData = map[uint32]*PrepareData{} + sConn.cursorStates = map[uint32]*CursorState{} return listener, sConn, cConn } @@ -886,17 +887,17 @@ func TestMultiStatement(t *testing.T) { // The queries run will be select 1; and select 2; These queries do not return any errors, so the connection should still be open require.True(t, res, "we should not break the connection in case of no errors") // Read the result of the query and assert that it is indeed what we want. This will contain the result of the first query. - data, more, _, err := cConn.ReadQueryResult(100, true) + data, status, _, err := cConn.ReadQueryResult(100, true) require.NoError(t, err) // Since we executed 2 queries, there should be more results to be read - require.True(t, more) + require.True(t, (status&ServerMoreResultsExists) != 0) require.True(t, data.Equal(selectRowsResult)) // Read the results for the second query and verify the correctness - data, more, _, err = cConn.ReadQueryResult(100, true) + data, status, _, err = cConn.ReadQueryResult(100, true) require.NoError(t, err) // This was the final query run, so we expect that more should be false as there are no more queries. - require.False(t, more) + require.False(t, (status&ServerMoreResultsExists) != 0) require.True(t, data.Equal(selectRowsResult)) // This time we run two queries fist of which will return an error @@ -908,10 +909,10 @@ func TestMultiStatement(t *testing.T) { require.True(t, res, "we should not break the connection because of execution errors") // Read the result and assert that we indeed see the error that testRun throws. - data, more, _, err = cConn.ReadQueryResult(100, true) + data, status, _, err = cConn.ReadQueryResult(100, true) require.EqualError(t, err, "cannot get column number (errno 2027) (sqlstate HY000)") // In case of errors in a multi-statement, the following statements are not executed, therefore we want that more should be false - require.False(t, more) + require.False(t, (status&ServerMoreResultsExists) != 0) require.Nil(t, data) }) } diff --git a/go/mysql/constants.go b/go/mysql/constants.go index b1e95a549bf..49afb7cc225 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -156,7 +156,8 @@ const ( // Status flags. They are returned by the server in a few cases. // Originally found in include/mysql/mysql_com.h -// See http://dev.mysql.com/doc/internals/en/status-flags.html +// See https://dev.mysql.com/doc/dev/mysql-server/latest/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad +// and https://mariadb.com/kb/en/ok_packet/#server-status-flag const ( // a transaction is active ServerStatusInTrans uint16 = 0x0001 @@ -183,6 +184,16 @@ const ( ServerSessionStateChanged uint16 = 0x4000 ) +// Cursor Types. They are received on COM_STMT_EXECUTE() +// See https://dev.mysql.com/doc/dev/mysql-server/latest/mysql__com_8h.html#a3e5e9e744ff6f7b989a604fd669977da +// and https://mariadb.com/kb/en/com_stmt_execute/#flag +const ( + CursorTypeNoCursor = iota + CursorTypeReadOnly + CursorTypeCursorForUpdate + CursorTypeScrollableCursor +) + // State Change Information const ( // one or more system variables changed. diff --git a/go/mysql/mysql_fuzzer.go b/go/mysql/mysql_fuzzer.go index 5ce82cd56c0..99e5007115f 100644 --- a/go/mysql/mysql_fuzzer.go +++ b/go/mysql/mysql_fuzzer.go @@ -199,6 +199,7 @@ func FuzzHandleNextCommand(data []byte) int { queryPacket: data, }, DefaultFlushDelay, 0, false) sConn.PrepareData = map[uint32]*PrepareData{} + sConn.cursorStates = map[uint32]*CursorState{} handler := &fuzztestRun{} _ = sConn.handleNextCommand(handler) diff --git a/go/mysql/query.go b/go/mysql/query.go index c20102c8c91..60d90c452cf 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -318,13 +318,13 @@ func (c *Conn) parseRow(data []byte, fields []*querypb.Field, reader func([]byte // 2. if the server closes the connection when a command is in flight, // readComQueryResponse will fail, and we'll return CRServerLost(2013). func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result *sqltypes.Result, err error) { - result, more, err := c.ExecuteFetchMulti(query, maxrows, wantfields) - if more { + result, status, err := c.ExecuteFetchMulti(query, maxrows, wantfields) + if (status & ServerMoreResultsExists) != 0 { // Multiple results are unexpected. Prioritize this "unexpected" error over whatever error we got from the first result. err = errors.Join(ErrExecuteFetchMultipleResults, err) } // draining to make the connection clean. - err = c.drainMoreResults(more, err) + err = c.drainMoreResults(status, err) return result, err } @@ -332,16 +332,16 @@ func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result // caring for any results. The function returns an error if any of the statements fail. // The function drains the query results of all statements, even if there's an error. func (c *Conn) ExecuteFetchMultiDrain(query string) (err error) { - _, more, err := c.ExecuteFetchMulti(query, FETCH_NO_ROWS, false) - return c.drainMoreResults(more, err) + _, status, err := c.ExecuteFetchMulti(query, FETCH_NO_ROWS, false) + return c.drainMoreResults(status, err) } // drainMoreResults ensures to drain all query results, even if there's an error. // We collect all errors until we consume all results. -func (c *Conn) drainMoreResults(more bool, err error) error { - for more { +func (c *Conn) drainMoreResults(status uint16, err error) error { + for (status & ServerMoreResultsExists) != 0 { var moreErr error - _, more, _, moreErr = c.ReadQueryResult(FETCH_NO_ROWS, false) + _, status, _, moreErr = c.ReadQueryResult(FETCH_NO_ROWS, false) if moreErr != nil { err = errors.Join(err, moreErr) } @@ -352,7 +352,7 @@ func (c *Conn) drainMoreResults(more bool, err error) error { // ExecuteFetchMulti is for fetching multiple results from a multi-statement result. // It returns an additional 'more' flag. If it is set, you must fetch the additional // results using ReadQueryResult. -func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (result *sqltypes.Result, more bool, err error) { +func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (result *sqltypes.Result, status uint16, err error) { defer func() { if err != nil { if sqlerr, ok := err.(*sqlerror.SQLError); ok { @@ -363,14 +363,14 @@ func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (re // Send the query as a COM_QUERY packet. if err = c.WriteComQuery(query); err != nil { - return nil, false, err + return nil, 0, err } - res, more, _, err := c.ReadQueryResult(maxrows, wantfields) + res, status, _, err := c.ReadQueryResult(maxrows, wantfields) if err != nil { - return nil, false, err + return nil, 0, err } - return res, more, err + return res, status, err } // ExecuteFetchWithWarningCount is for fetching results and a warning count @@ -395,14 +395,14 @@ func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfield } // ReadQueryResult gets the result from the last written query. -func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, bool, uint16, error) { +func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, uint16, uint16, error) { var packetOk PacketOK // Get the result. colNumber, err := c.readComQueryResponse(&packetOk) if err != nil { - return nil, false, 0, err + return nil, 0, 0, err } - more := packetOk.statusFlags&ServerMoreResultsExists != 0 + status := packetOk.statusFlags & ServerMoreResultsExists warnings := packetOk.warnings if colNumber == 0 { // OK packet, means no results. Just use the numbers. @@ -413,7 +413,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, SessionStateChanges: packetOk.sessionStateData, StatusFlags: packetOk.statusFlags, Info: packetOk.info, - }, more, warnings, nil + }, status, warnings, nil } fields := make([]querypb.Field, colNumber) @@ -428,11 +428,11 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, if wantfields { if err := c.readColumnDefinition(result.Fields[i], i); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } } else { if err := c.readColumnDefinitionType(result.Fields[i], i); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } } } @@ -441,19 +441,28 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, // EOF is only present here if it's not deprecated. data, err := c.readEphemeralPacket() if err != nil { - return nil, false, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) + return nil, 0, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) } if c.isEOFPacket(data) { // This is what we expect. - // Warnings and status flags are ignored. + _, status, err = parseEOFPacket(data) + if err != nil { + return nil, 0, 0, err + } c.recycleReadPacket() + + if status&ServerStatusCursorExists != 0 { + // if we are using cursors, do not go into the read row loop below + return result, status, 0, nil + } + // goto: read row loop } else if isErrorPacket(data) { defer c.recycleReadPacket() - return nil, false, 0, ParseErrorPacket(data) + return nil, 0, 0, ParseErrorPacket(data) } else { defer c.recycleReadPacket() - return nil, false, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data) + return nil, 0, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data) } } @@ -461,7 +470,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, for { data, err := c.readEphemeralPacket() if err != nil { - return nil, false, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) + return nil, 0, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) } if c.isEOFPacket(data) { @@ -475,29 +484,27 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, // The deprecated EOF packets change means that this is either an // EOF packet or an OK packet with the EOF type code. if c.Capabilities&CapabilityClientDeprecateEOF == 0 { - var statusFlags uint16 - warnings, statusFlags, err = parseEOFPacket(data) + warnings, status, err = parseEOFPacket(data) if err != nil { - return nil, false, 0, err + return nil, 0, 0, err } - more = (statusFlags & ServerMoreResultsExists) != 0 - result.StatusFlags = statusFlags + result.StatusFlags = status } else { var packetEof PacketOK if err := c.parseOKPacket(&packetEof, data); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } warnings = packetEof.warnings - more = (packetEof.statusFlags & ServerMoreResultsExists) != 0 + status = packetEof.statusFlags result.SessionStateChanges = packetEof.sessionStateData result.StatusFlags = packetEof.statusFlags result.Info = packetEof.info } - return result, more, warnings, nil + return result, status, warnings, nil } else if isErrorPacket(data) { defer c.recycleReadPacket() // Error packet. - return nil, false, 0, ParseErrorPacket(data) + return nil, 0, 0, ParseErrorPacket(data) } if maxrows == FETCH_NO_ROWS { @@ -509,16 +516,73 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, if maxrows != FETCH_ALL_ROWS && len(result.Rows) == maxrows { c.recycleReadPacket() if err := c.drainResults(); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } - return nil, false, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows) + return nil, 0, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows) } // Regular row. row, err := c.parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil) if err != nil { c.recycleReadPacket() - return nil, false, 0, err + return nil, 0, 0, err + } + result.Rows = append(result.Rows, row) + c.recycleReadPacket() + } +} + +// FetchQueryResult gets the reset set from the last executed query. +func (c *Conn) FetchQueryResult(maxrows int, fields []*querypb.Field) (result *sqltypes.Result, status uint16, warnings uint16, err error) { + result = &sqltypes.Result{} + + // read each row until EOF or OK packet. + for { + data, err := c.readEphemeralPacket() + if err != nil { + return nil, 0, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) + } + + if c.isEOFPacket(data) { + defer c.recycleReadPacket() + + result.RowsAffected = uint64(len(result.Rows)) + + // The deprecated EOF packets change means that this is either an + // EOF packet or an OK packet with the EOF type code. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + warnings, status, err = parseEOFPacket(data) + if err != nil { + return nil, 0, 0, err + } + } else { + var packetEof PacketOK + if err := c.parseOKPacket(&packetEof, data); err != nil { + return nil, 0, 0, err + } + warnings = packetEof.warnings + status = packetEof.statusFlags + } + return result, status, warnings, nil + } else if isErrorPacket(data) { + defer c.recycleReadPacket() + return nil, 0, 0, ParseErrorPacket(data) + } + + // Check we're not over the limit before we add more. + if len(result.Rows) == maxrows { + defer c.recycleReadPacket() + if err := c.drainResults(); err != nil { + return nil, 0, 0, err + } + return nil, 0, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows) + } + + // Regular row. + row, err := c.parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil) + if err != nil { + c.recycleReadPacket() + return nil, 0, 0, err } result.Rows = append(result.Rows, row) c.recycleReadPacket() @@ -954,6 +1018,15 @@ func (c *Conn) parseComStmtReset(data []byte) (uint32, bool) { return val, ok } +func (c *Conn) parseComStmtFetch(data []byte) (uint32, uint32, bool) { + stmtId, pos, ok := readUint32(data, 1) + if !ok { + return 0, 0, false + } + numRows, _, ok := readUint32(data, pos) + return stmtId, numRows, ok +} + func (c *Conn) parseComInitDB(data []byte) string { return string(data[1:]) } @@ -1039,6 +1112,28 @@ func (c *Conn) writeRow(row []sqltypes.Value) error { // writeFields writes the fields of a Result. It should be called only // if there are valid columns in the result. func (c *Conn) writeFields(result *sqltypes.Result) error { + err := c.writeFieldsWithoutEOF(result) + if err != nil { + return err + } + + // Now send an EOF packet. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + // With CapabilityClientDeprecateEOF, we do not send this EOF. + if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { + return err + } + } + + return nil +} + +// Writes the fields for a Result, but never adds the EOF_Packet on the end, even +// if ClientDeprecateEOF == 0. This is used when returning fields in a +// COM_STMT_EXECUTE that opens a cursor, since we immediately follow the fields +// up with a writeEndResult, which appropriately adds an EOF or an OK_Packet +// depending on the client capabilities. +func (c *Conn) writeFieldsWithoutEOF(result *sqltypes.Result) error { // Send the number of fields first. if err := c.sendColumnCount(uint64(len(result.Fields))); err != nil { return err @@ -1050,14 +1145,6 @@ func (c *Conn) writeFields(result *sqltypes.Result) error { return err } } - - // Now send an EOF packet. - if c.Capabilities&CapabilityClientDeprecateEOF == 0 { - // With CapabilityClientDeprecateEOF, we do not send this EOF. - if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { - return err - } - } return nil } diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 37501f9329f..98ff8341b4b 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -18,6 +18,7 @@ package mysql import ( "fmt" + "io" "reflect" "sync" "testing" @@ -142,6 +143,8 @@ func TestComStmtPrepare(t *testing.T) { sConn.PrepareData = make(map[uint32]*PrepareData) sConn.PrepareData[prepare.StatementID] = prepare + sConn.cursorStates = make(map[uint32]*CursorState) + // write the response to the client if err := sConn.writePrepare(result.Fields, prepare); err != nil { t.Fatalf("sConn.writePrepare failed: %v", err) @@ -181,6 +184,8 @@ func TestComStmtPrepareUpdStmt(t *testing.T) { sConn.PrepareData = make(map[uint32]*PrepareData) sConn.PrepareData[prepare.StatementID] = prepare + sConn.cursorStates = make(map[uint32]*CursorState) + // write the response to the client err = sConn.writePrepare(nil, prepare) require.NoError(t, err, "sConn.writePrepare failed") @@ -211,6 +216,8 @@ func TestComStmtSendLongData(t *testing.T) { t.Fatalf("writePrepare failed: %v", err) } + cConn.cursorStates = make(map[uint32]*CursorState) + // Since there's no writeComStmtSendLongData, we'll write a prepareStmt and check if we can read the StatementID data, err := sConn.ReadPacket() if err != nil || len(data) == 0 { @@ -237,8 +244,10 @@ func TestComStmtExecute(t *testing.T) { cConn.PrepareData = make(map[uint32]*PrepareData) cConn.PrepareData[prepare.StatementID] = prepare + cConn.cursorStates = make(map[uint32]*CursorState) + // This is simulated packets for `select * from test_table where id = ?` - data := []byte{23, 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1} + data := []byte{23, 18, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 128, 1} stmtID, _, err := sConn.parseComStmtExecute(cConn.PrepareData, data) require.NoError(t, err, "parseComStmtExeute failed: %v", err) @@ -319,6 +328,33 @@ func TestComStmtExecuteUpdStmt(t *testing.T) { assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[28], "got: %s", querypb.Type(prepData.ParamsType[28])) } +func TestComStmtFetch(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + prepare, _ := MockPrepareData(t) + cConn.PrepareData = make(map[uint32]*PrepareData) + cConn.PrepareData[prepare.StatementID] = prepare + + cConn.cursorStates = make(map[uint32]*CursorState) + + // This is simulated packets for `select * from test_table where id = ?` + data := []byte{23, 18, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 128, 1} + + stmtID, cursorType, err := sConn.parseComStmtExecute(cConn.PrepareData, data) + require.NoError(t, err, "parseComStmtExeute failed: %v", err) + if stmtID != 18 { + t.Fatalf("Parsed incorrect values") + } + if cursorType != CursorTypeReadOnly { + t.Fatalf("Expected read-only cursor") + } +} + func TestComStmtClose(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { @@ -334,6 +370,8 @@ func TestComStmtClose(t *testing.T) { t.Fatalf("writePrepare failed: %v", err) } + cConn.cursorStates = make(map[uint32]*CursorState) + // Since there's no writeComStmtClose, we'll write a prepareStmt and check if we can read the StatementID data, err := sConn.ReadPacket() if err != nil || len(data) == 0 { @@ -767,3 +805,209 @@ func RowString(row []sqltypes.Value) string { } return result } + +func writeRawPacketToConn(c *Conn, packet []byte) error { + c.sequence = 0 + data, pos := c.startEphemeralPacketWithHeader(len(packet)) + copy(data[pos:], packet) + return c.writeEphemeralPacket() +} + +type testExec struct { + query string + useCursor byte + expectedNumFields int + expectedNumRows int + maxRows int +} + +func clientExecute(t *testing.T, cConn *Conn, useCursor byte, maxRows int) (*sqltypes.Result, uint16) { + if useCursor != 0 { + cConn.StatusFlags |= uint16(ServerStatusCursorExists) + } else { + cConn.StatusFlags &= ^uint16(ServerStatusCursorExists) + } + + // Write a COM_STMT_EXECUTE packet + mockPacket := []byte{ComStmtExecute, 0, 0, 0, 0, useCursor, 1, 0, 0, 0, 0, 1, 1, 128, 1} + + if err := writeRawPacketToConn(cConn, mockPacket); err != nil { + t.Fatalf("WriteMockExecuteToConn failed with error: %v", err) + } + + qr, status, _, err := cConn.ReadQueryResult(maxRows, true) + require.NoError(t, err, "ReadQueryResult failed with error: %v", err) + + return qr, status +} + +func clientFetch(t *testing.T, cConn *Conn, useCursor byte, maxRows int, fields []*querypb.Field) (*sqltypes.Result, uint16) { + // Write a COM_STMT_FETCH packet + mockPacket := []byte{ComStmtFetch, 0, 0, 0, 0, useCursor, 1, 0, 0, 0, 0, 1, 1, 128, 1} + + if err := writeRawPacketToConn(cConn, mockPacket); err != nil { + t.Fatalf("WriteMockDataToConn failed with error: %v", err) + } + + qr, status, _, err := cConn.FetchQueryResult(maxRows, fields) + if err != nil && err != io.EOF { + t.Fatalf("FetchQueryResult failed with error: %v", err) + } + + return qr, status +} + +func checkExecute(t *testing.T, sConn, cConn *Conn, test testExec) { + // Pretend a successful COM_PREPARE was sent to server + prepare := &PrepareData{ + StatementID: 0, + PrepareStmt: test.query, + } + sConn.PrepareData = make(map[uint32]*PrepareData) + sConn.PrepareData[prepare.StatementID] = prepare + + sConn.cursorStates = make(map[uint32]*CursorState) + + // use go routine to emulate client calls + var qr *sqltypes.Result + var status uint16 + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + qr, status = clientExecute(t, cConn, test.useCursor, test.maxRows) + }() + + // handle a single client command + if !sConn.handleNextCommand(&testHandler{}) { + t.Fatalf("handleNextComamnd failed") + } + + // wait until client receives the query result back + wg.Wait() + + if qr == nil { + t.Fatalf("execute result is nil") + } + var fields = qr.Fields + if test.expectedNumFields != len(fields) { + t.Fatalf("Expected %d fields, Received %d", test.expectedNumFields, len(fields)) + } + + // if not using cursor, we should have results without fetching + if test.useCursor == 0 { + if (status & ServerStatusCursorExists) != 0 { + t.Fatalf("Server StatusFlag should indicate that Cursor does not exist") + } + if test.expectedNumRows != len(qr.Rows) { + t.Fatalf("Expected %d rows, Received %d", test.expectedNumRows, len(qr.Rows)) + } + return + } + + // using cursor, use client to fetch results + if (status & ServerStatusCursorExists) == 0 { + t.Fatalf("Server StatusFlag should indicate that Cursor exists, status flags were: %d", status) + } + + qr = nil + wg.Add(1) + go func() { + defer wg.Done() + qr, status = clientFetch(t, cConn, test.useCursor, test.maxRows, fields) + }() + + // handle a single client command + if !sConn.handleNextCommand(&testHandler{}) { + t.Fatalf("handleNextComamnd failed") + } + + // wait until client fetches the rows + wg.Wait() + if qr == nil { + t.Fatalf("fetch result is nil") + } + + if (status & ServerStatusCursorExists) != 0 { + t.Fatalf("Server StatusFlag should not indicate a new Cursor, status flags were: %d", status) + } + if (status & ServerStatusLastRowSent) == 0 { + t.Fatalf("Server StatusFlag should indicate that we exhausted the cursor, status flags were: %d", status) + } + + if test.expectedNumRows != len(qr.Rows) { + t.Fatalf("Expected %d rows, Received %d", test.expectedNumRows, len(qr.Rows)) + } +} + +func TestExecuteQueries(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + tests := []testExec{ + { + query: "empty result", + useCursor: 0, + expectedNumFields: 3, + expectedNumRows: 0, + maxRows: 100, + }, + { + query: "select rows", + useCursor: 0, + expectedNumFields: 2, + expectedNumRows: 2, + maxRows: 100, + }, + { + query: "large batch", + useCursor: 0, + expectedNumFields: 2, + expectedNumRows: 256, + maxRows: 1000, + }, + { + query: "empty result", + useCursor: 1, + expectedNumFields: 3, + expectedNumRows: 0, + maxRows: 100, + }, + { + query: "select rows", + useCursor: 1, + expectedNumFields: 2, + expectedNumRows: 2, + maxRows: 100, + }, + { + query: "large batch", + useCursor: 1, + expectedNumFields: 2, + expectedNumRows: 256, + maxRows: 1000, + }, + } + + t.Run("WithoutDeprecateEOF", func(t *testing.T) { + for i, test := range tests { + t.Run(fmt.Sprintf("%d %s", i, test.query), func(t *testing.T) { + checkExecute(t, sConn, cConn, test) + }) + } + }) + + sConn.Capabilities = CapabilityClientDeprecateEOF + cConn.Capabilities = CapabilityClientDeprecateEOF + t.Run("WithDeprecateEOF", func(t *testing.T) { + for i, test := range tests { + t.Run(fmt.Sprintf("%d %s", i, test.query), func(t *testing.T) { + checkExecute(t, sConn, cConn, test) + }) + } + }) +} diff --git a/go/mysql/server.go b/go/mysql/server.go index 17c113248e6..fbd4c1cffcb 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -400,6 +400,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti // Adjust the count of open connections defer connCount.Add(-1) + defer c.discardAllCursors() + // First build and send the server handshake packet. serverAuthPluginData, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, uint8(l.charset), l.TLSConfig.Load() != nil) if err != nil { diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index ee30dded978..6d158d4fe8d 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -256,7 +256,76 @@ func (th *testHandler) ComPrepare(*Conn, string) ([]*querypb.Field, uint16, erro } func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { - return nil + switch prepare.PrepareStmt { + case "empty result": + callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: querypb.Type_INT32, + }, + { + Name: "name", + Type: querypb.Type_VARCHAR, + }, + { + Name: "name2", + Type: querypb.Type_VARCHAR, + }, + }, + Rows: [][]sqltypes.Value{}, + RowsAffected: 0, + }) + case "select rows": + callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: querypb.Type_INT32, + }, + { + Name: "name", + Type: querypb.Type_VARCHAR, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), + }, + { + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nicer name")), + }, + }, + RowsAffected: 2, + }) + case "large batch": + n := 256 + rows := make([][]sqltypes.Value, n) + for i := 0; i < n; i++ { + rows[i] = []sqltypes.Value{ + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), + } + } + callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: querypb.Type_INT32, + }, + { + Name: "name", + Type: querypb.Type_VARCHAR, + }, + }, + Rows: rows, + RowsAffected: uint64(n), + }) + default: + return fmt.Errorf("unrecorgnized test command") + } } func (th *testHandler) ComRegisterReplica(c *Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error { From 0c5f127519db92f9d7bcea4383060c01b87cce23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20de=20Giessen?= Date: Fri, 16 May 2025 11:17:09 +0200 Subject: [PATCH 2/2] Use "status" integer instead of "more" boolean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous commit which implemented support for server side cursors also updated the function signature of the MySQL query functions to return a integer "status" value instead of a boolean "more" value. One of the flags in that status integer represents the "more" flag. This commit updates all places where the MySQL query functions are called to use this new status integer. Signed-off-by: Daniël van de Giessen --- go/mysql/conn_test.go | 10 +++--- go/mysql/constants.go | 5 +++ go/mysql/endtoend/client_test.go | 32 +++++++++---------- go/mysql/query.go | 4 +-- go/test/endtoend/topotest/consul/main_test.go | 6 ++-- go/test/endtoend/topotest/etcd2/main_test.go | 6 ++-- go/test/endtoend/topotest/zk2/main_test.go | 6 ++-- go/test/endtoend/utils/cmp.go | 24 +++++++++----- .../queries/multi_query/multi_query_test.go | 6 ++-- .../endtoend/vtgate/reservedconn/udv_test.go | 4 +-- .../endtoend/vtgate/unsharded/main_test.go | 6 ++-- go/vt/binlog/binlogplayer/dbclient.go | 6 ++-- go/vt/mysqlctl/mysqld.go | 6 ++-- .../evalengine/integration/comparison_test.go | 6 ++-- go/vt/vttablet/tabletmanager/rpc_query.go | 6 ++-- 15 files changed, 73 insertions(+), 60 deletions(-) diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index 0ca3f89be45..97a1b01b748 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -857,9 +857,9 @@ func TestEmptyQuery(t *testing.T) { // The queries run will be an empty query; Even with the empty error, the connection should be fine require.True(t, res, "we should not break the connection in case of no errors") // Read the result and assert that we indeed see the error for empty query. - data, more, _, err := cConn.ReadQueryResult(100, true) + data, status, _, err := cConn.ReadQueryResult(100, true) require.EqualError(t, err, "Query was empty (errno 1065) (sqlstate 42000)") - require.False(t, more) + require.False(t, IsMoreResultsExists(status)) require.Nil(t, data) }) } @@ -890,14 +890,14 @@ func TestMultiStatement(t *testing.T) { data, status, _, err := cConn.ReadQueryResult(100, true) require.NoError(t, err) // Since we executed 2 queries, there should be more results to be read - require.True(t, (status&ServerMoreResultsExists) != 0) + require.True(t, IsMoreResultsExists(status)) require.True(t, data.Equal(selectRowsResult)) // Read the results for the second query and verify the correctness data, status, _, err = cConn.ReadQueryResult(100, true) require.NoError(t, err) // This was the final query run, so we expect that more should be false as there are no more queries. - require.False(t, (status&ServerMoreResultsExists) != 0) + require.False(t, IsMoreResultsExists(status)) require.True(t, data.Equal(selectRowsResult)) // This time we run two queries fist of which will return an error @@ -912,7 +912,7 @@ func TestMultiStatement(t *testing.T) { data, status, _, err = cConn.ReadQueryResult(100, true) require.EqualError(t, err, "cannot get column number (errno 2027) (sqlstate HY000)") // In case of errors in a multi-statement, the following statements are not executed, therefore we want that more should be false - require.False(t, (status&ServerMoreResultsExists) != 0) + require.False(t, IsMoreResultsExists(status)) require.Nil(t, data) }) } diff --git a/go/mysql/constants.go b/go/mysql/constants.go index 49afb7cc225..f19dc4bada6 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -184,6 +184,11 @@ const ( ServerSessionStateChanged uint16 = 0x4000 ) +// Helper for checking if SERVER_MORE_RESULTS_EXISTS is set +func IsMoreResultsExists(status uint16) bool { + return status&ServerMoreResultsExists == ServerMoreResultsExists +} + // Cursor Types. They are received on COM_STMT_EXECUTE() // See https://dev.mysql.com/doc/dev/mysql-server/latest/mysql__com_8h.html#a3e5e9e744ff6f7b989a604fd669977da // and https://mariadb.com/kb/en/com_stmt_execute/#flag diff --git a/go/mysql/endtoend/client_test.go b/go/mysql/endtoend/client_test.go index c1105e668ba..7b79d9c410a 100644 --- a/go/mysql/endtoend/client_test.go +++ b/go/mysql/endtoend/client_test.go @@ -161,29 +161,29 @@ func doTestMultiResult(t *testing.T, disableClientDeprecateEOF bool) { expectNoError(t, err) defer conn.Close() - qr, more, err := conn.ExecuteFetchMulti("select 1 from dual; set autocommit=1; select 1 from dual", 10, true) + qr, status, err := conn.ExecuteFetchMulti("select 1 from dual; set autocommit=1; select 1 from dual", 10, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(multi result)", more, true) + expectFlag(t, "ExecuteMultiFetch(multi result)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 1, len(qr.Rows)) - qr, more, _, err = conn.ReadQueryResult(10, true) + qr, status, _, err = conn.ReadQueryResult(10, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(1)", more, true) + expectFlag(t, "ReadQueryResult(1)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 0, len(qr.Rows)) - qr, more, _, err = conn.ReadQueryResult(10, true) + qr, status, _, err = conn.ReadQueryResult(10, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(2)", more, false) + expectFlag(t, "ReadQueryResult(2)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 1, len(qr.Rows)) - qr, more, err = conn.ExecuteFetchMulti("select 1 from dual", 10, true) + qr, status, err = conn.ExecuteFetchMulti("select 1 from dual", 10, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(single result)", more, false) + expectFlag(t, "ExecuteMultiFetch(single result)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 1, len(qr.Rows)) - qr, more, err = conn.ExecuteFetchMulti("set autocommit=1", 10, true) + qr, status, err = conn.ExecuteFetchMulti("set autocommit=1", 10, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(no result)", more, false) + expectFlag(t, "ExecuteMultiFetch(no result)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 0, len(qr.Rows)) // The ClientDeprecateEOF protocol change has a subtle twist in which an EOF or OK @@ -212,19 +212,19 @@ func doTestMultiResult(t *testing.T, disableClientDeprecateEOF bool) { err = conn.ExecuteFetchMultiDrain("update a set name = concat(name, ', multi drain 1'); select * from a; select count(*) from a") expectNoError(t, err) // If the previous command leaves packet in invalid state, this will fail. - qr, more, err = conn.ExecuteFetchMulti("update a set name = concat(name, ', fetch multi'); select * from a; select count(*) from a", 300, true) + qr, status, err = conn.ExecuteFetchMulti("update a set name = concat(name, ', fetch multi'); select * from a; select count(*) from a", 300, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(multi result)", more, true) + expectFlag(t, "ExecuteMultiFetch(multi result)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 255, qr.RowsAffected) - qr, more, _, err = conn.ReadQueryResult(300, true) + qr, status, _, err = conn.ReadQueryResult(300, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(1)", more, true) + expectFlag(t, "ReadQueryResult(1)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 255, len(qr.Rows), "ReadQueryResult(1)") - qr, more, _, err = conn.ReadQueryResult(300, true) + qr, status, _, err = conn.ReadQueryResult(300, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(2)", more, false) + expectFlag(t, "ReadQueryResult(2)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 1, len(qr.Rows), "ReadQueryResult(1)") // Verify that a ExecuteFetchMultiDrain is happy to operate again after all the above. diff --git a/go/mysql/query.go b/go/mysql/query.go index 60d90c452cf..0e191bdf909 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -319,7 +319,7 @@ func (c *Conn) parseRow(data []byte, fields []*querypb.Field, reader func([]byte // readComQueryResponse will fail, and we'll return CRServerLost(2013). func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result *sqltypes.Result, err error) { result, status, err := c.ExecuteFetchMulti(query, maxrows, wantfields) - if (status & ServerMoreResultsExists) != 0 { + if IsMoreResultsExists(status) { // Multiple results are unexpected. Prioritize this "unexpected" error over whatever error we got from the first result. err = errors.Join(ErrExecuteFetchMultipleResults, err) } @@ -339,7 +339,7 @@ func (c *Conn) ExecuteFetchMultiDrain(query string) (err error) { // drainMoreResults ensures to drain all query results, even if there's an error. // We collect all errors until we consume all results. func (c *Conn) drainMoreResults(status uint16, err error) error { - for (status & ServerMoreResultsExists) != 0 { + for IsMoreResultsExists(status) { var moreErr error _, status, _, moreErr = c.ReadQueryResult(FETCH_NO_ROWS, false) if moreErr != nil { diff --git a/go/test/endtoend/topotest/consul/main_test.go b/go/test/endtoend/topotest/consul/main_test.go index 33f7677f857..7694c1b13bb 100644 --- a/go/test/endtoend/topotest/consul/main_test.go +++ b/go/test/endtoend/topotest/consul/main_test.go @@ -277,11 +277,11 @@ func execute(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/test/endtoend/topotest/etcd2/main_test.go b/go/test/endtoend/topotest/etcd2/main_test.go index b274219b41a..cda087b38d8 100644 --- a/go/test/endtoend/topotest/etcd2/main_test.go +++ b/go/test/endtoend/topotest/etcd2/main_test.go @@ -271,11 +271,11 @@ func TestNamedLocking(t *testing.T) { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/test/endtoend/topotest/zk2/main_test.go b/go/test/endtoend/topotest/zk2/main_test.go index 29c5cb89406..94be097344d 100644 --- a/go/test/endtoend/topotest/zk2/main_test.go +++ b/go/test/endtoend/topotest/zk2/main_test.go @@ -246,11 +246,11 @@ func TestKeyspaceLocking(t *testing.T) { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/test/endtoend/utils/cmp.go b/go/test/endtoend/utils/cmp.go index 05e65cc7fda..e7fbb1a978f 100644 --- a/go/test/endtoend/utils/cmp.go +++ b/go/test/endtoend/utils/cmp.go @@ -225,11 +225,13 @@ func (mcmp *MySQLCompare) ExecMulti(sql string) []*sqltypes.Result { mcmp.t.Helper() stmts, err := sqlparser.NewTestParser().SplitStatementToPieces(sql) require.NoError(mcmp.t, err) - vtQr, vtMore, err := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) + vtQr, vtStatus, err := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) require.NoError(mcmp.t, err, "[Vitess Error] for sql: "+sql) + vtMore := mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, err := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) + mysqlQr, mysqlStatus, err := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) require.NoError(mcmp.t, err, "[MySQL Error] for sql: "+sql) + mysqlMore := mysql.IsMoreResultsExists(mysqlStatus) sql = stmts[0] CompareVitessAndMySQLResults(mcmp.t, sql, mcmp.VtConn, vtQr, mysqlQr, CompareOptions{}) if vtMore != mysqlMore { @@ -241,11 +243,13 @@ func (mcmp *MySQLCompare) ExecMulti(sql string) []*sqltypes.Result { for vtMore { sql = stmts[idx] idx++ - vtQr, vtMore, _, err = mcmp.VtConn.ReadQueryResult(1000, true) + vtQr, vtStatus, _, err = mcmp.VtConn.ReadQueryResult(1000, true) require.NoError(mcmp.t, err, "[Vitess Error] for sql: "+sql) + vtMore = mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, _, err = mcmp.MySQLConn.ReadQueryResult(1000, true) + mysqlQr, mysqlStatus, _, err = mcmp.MySQLConn.ReadQueryResult(1000, true) require.NoError(mcmp.t, err, "[MySQL Error] for sql: "+sql) + mysqlMore = mysql.IsMoreResultsExists(mysqlStatus) CompareVitessAndMySQLResults(mcmp.t, sql, mcmp.VtConn, vtQr, mysqlQr, CompareOptions{}) if vtMore != mysqlMore { mcmp.AsT().Errorf("Vitess and MySQL have different More flags: %v vs %v", vtMore, mysqlMore) @@ -262,9 +266,11 @@ func (mcmp *MySQLCompare) ExecMultiAllowError(sql string) { mcmp.t.Helper() stmts, err := sqlparser.NewTestParser().SplitStatementToPieces(sql) require.NoError(mcmp.t, err) - vtQr, vtMore, vtErr := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) + vtQr, vtStatus, vtErr := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) + vtMore := mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, mysqlErr := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) + mysqlQr, mysqlStatus, mysqlErr := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) + mysqlMore := mysql.IsMoreResultsExists(mysqlStatus) sql = stmts[0] compareVitessAndMySQLErrors(mcmp.t, vtErr, mysqlErr) if vtErr == nil && mysqlErr == nil { @@ -278,9 +284,11 @@ func (mcmp *MySQLCompare) ExecMultiAllowError(sql string) { for vtMore { sql = stmts[idx] idx++ - vtQr, vtMore, _, vtErr = mcmp.VtConn.ReadQueryResult(1000, true) + vtQr, vtStatus, _, vtErr = mcmp.VtConn.ReadQueryResult(1000, true) + vtMore = mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, _, mysqlErr = mcmp.MySQLConn.ReadQueryResult(1000, true) + mysqlQr, mysqlStatus, _, mysqlErr = mcmp.MySQLConn.ReadQueryResult(1000, true) + mysqlMore = mysql.IsMoreResultsExists(mysqlStatus) compareVitessAndMySQLErrors(mcmp.t, vtErr, mysqlErr) if vtErr == nil && mysqlErr == nil { CompareVitessAndMySQLResults(mcmp.t, sql, mcmp.VtConn, vtQr, mysqlQr, CompareOptions{}) diff --git a/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go b/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go index 2da16d78138..68ba794e029 100644 --- a/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go +++ b/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go @@ -171,12 +171,12 @@ func TestMultiQuery(t *testing.T) { // results obtained from the gRPC connection. func getMySqlResults(conn *mysql.Conn, sql string) ([]*sqltypes.Result, error) { var results []*sqltypes.Result - mysqlQr, mysqlMore, mysqlErr := conn.ExecuteFetchMulti(sql, 1000, true) + mysqlQr, mysqlStatus, mysqlErr := conn.ExecuteFetchMulti(sql, 1000, true) if mysqlQr != nil { results = append(results, mysqlQr) } - for mysqlMore { - mysqlQr, mysqlMore, _, mysqlErr = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(mysqlStatus) { + mysqlQr, mysqlStatus, _, mysqlErr = conn.ReadQueryResult(1000, true) if mysqlQr != nil { results = append(results, mysqlQr) } diff --git a/go/test/endtoend/vtgate/reservedconn/udv_test.go b/go/test/endtoend/vtgate/reservedconn/udv_test.go index 14b65dbcd35..e34402b5a47 100644 --- a/go/test/endtoend/vtgate/reservedconn/udv_test.go +++ b/go/test/endtoend/vtgate/reservedconn/udv_test.go @@ -150,9 +150,9 @@ func TestMysqlDumpInitialLog(t *testing.T) { for _, query := range queries { t.Run(query, func(t *testing.T) { - _, more, err := conn.ExecuteFetchMulti(query, 1000, true) + _, status, err := conn.ExecuteFetchMulti(query, 1000, true) require.NoError(t, err) - require.False(t, more) + require.False(t, mysql.IsMoreResultsExists(status)) }) } } diff --git a/go/test/endtoend/vtgate/unsharded/main_test.go b/go/test/endtoend/vtgate/unsharded/main_test.go index 8d072907873..3f71569fb4a 100644 --- a/go/test/endtoend/vtgate/unsharded/main_test.go +++ b/go/test/endtoend/vtgate/unsharded/main_test.go @@ -449,11 +449,11 @@ func TestFloatValueDefault(t *testing.T) { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/vt/binlog/binlogplayer/dbclient.go b/go/vt/binlog/binlogplayer/dbclient.go index 4cbfd962528..f24da0e29e6 100644 --- a/go/vt/binlog/binlogplayer/dbclient.go +++ b/go/vt/binlog/binlogplayer/dbclient.go @@ -162,14 +162,14 @@ func (dc *dbClientImpl) ExecuteFetch(query string, maxrows int) (*sqltypes.Resul func (dc *dbClientImpl) ExecuteFetchMulti(query string, maxrows int) ([]*sqltypes.Result, error) { results := make([]*sqltypes.Result, 0) - mqr, more, err := dc.dbConn.ExecuteFetchMulti(query, maxrows, true) + mqr, status, err := dc.dbConn.ExecuteFetchMulti(query, maxrows, true) if err != nil { dc.handleError(err) return nil, err } results = append(results, mqr) - for more { - mqr, more, _, err = dc.dbConn.ReadQueryResult(maxrows, false) + for mysql.IsMoreResultsExists(status) { + mqr, status, _, err = dc.dbConn.ReadQueryResult(maxrows, false) if err != nil { dc.handleError(err) return nil, err diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index 3bbad31f1c2..e7d6594d5e7 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -1171,12 +1171,12 @@ func (mysqld *Mysqld) executeMysqlScript(ctx context.Context, connParams *mysql. } defer conn.Close() - _, more, err := conn.ExecuteFetchMulti(sql, -1, false) + _, status, err := conn.ExecuteFetchMulti(sql, -1, false) if err != nil { return err } - for more { - _, more, _, err = conn.ReadQueryResult(0, false) + for mysql.IsMoreResultsExists(status) { + _, status, _, err = conn.ReadQueryResult(0, false) if err != nil { return err } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index c1931514dc4..77f2b533c20 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -238,12 +238,12 @@ func initTimezoneData(t *testing.T, conn *mysql.Conn) { t.Fatalf("failed to retrieve timezone info: %v", err) } - _, more, err := conn.ExecuteFetchMulti(fmt.Sprintf("USE mysql; %s\n", string(out)), -1, false) + _, status, err := conn.ExecuteFetchMulti(fmt.Sprintf("USE mysql; %s\n", string(out)), -1, false) if err != nil { t.Fatalf("failed to insert timezone info: %v", err) } - for more { - _, more, _, err = conn.ReadQueryResult(-1, false) + for mysql.IsMoreResultsExists(status) { + _, status, _, err = conn.ReadQueryResult(-1, false) if err != nil { t.Fatalf("failed to insert timezone info: %v", err) } diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index 3a3fdc39b2b..45912b1bee8 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -129,12 +129,12 @@ func (tm *TabletManager) executeMultiFetchAsDba( return nil, err } results := make([]*querypb.QueryResult, 0, len(queries)) - result, more, err := conn.ExecuteFetchMulti(uq, maxRows, true /*wantFields*/) + result, status, err := conn.ExecuteFetchMulti(uq, maxRows, true /*wantFields*/) if err == nil { results = append(results, sqltypes.ResultToProto3(result)) } - for more { - result, more, _, err = conn.ReadQueryResult(maxRows, true /*wantFields*/) + for (status & sqltypes.ServerMoreResultsExists) != 0 { + result, status, _, err = conn.ReadQueryResult(maxRows, true /*wantFields*/) if err != nil { return nil, err }