diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 95016bf21d7..ade148a2dbe 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -154,6 +154,12 @@ type Conn struct { // PrepareData is the map to use a prepared statement. PrepareData map[uint32]*PrepareData + // cursorStates represent queries which are running as a result of a + // COM_STMT_EXECUTE. Rows will be fetched with COM_STMT_FETCH, and + // possibly a query will need to be terminated as a result of what + // happens in the protocol, etc. + cursorStates map[uint32]*CursorState + // protects the bufferedWriter and bufferedReader bufMu sync.Mutex @@ -234,6 +240,31 @@ type PrepareData struct { ParamsCount uint16 } +// CursorSate +type CursorState struct { + // The goroutine which generates the results delivers |Result|s with a + // batch of rows one at a time on the |next| channel. As long as the + // query has not delivered EOF, we will have pending rows, because we + // prefetch them as soon as they are exhausted. That is, if we have 10 + // pending rows and a client fetches only 10 rows, we will block on + // fetching more rows before ew return the currently cached 10 rows. + // This allows us to detect EOF and correctly return CursorExhausted. + pending *sqltypes.Result + + // The channel on which the running query is sending batched results. + next chan *sqltypes.Result + + // The channel on which the running query will return its final error + // status, either `nil` or an error. + done chan error + + // A control channel on which the running query is listening. If the + // server needs to cancel the inflight query based on what is happening + // in the wire protocol handling, it will send on this channel and then + // block on `done` being sent to. + quit chan error +} + // execResult is an enum signifying the result of executing a query type execResult byte @@ -285,6 +316,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn { conn: conn, listener: listener, PrepareData: make(map[uint32]*PrepareData), + cursorStates: make(map[uint32]*CursorState), keepAliveOn: enabledKeepAlive, flushDelay: listener.flushDelay, truncateErrLen: listener.truncateErrLen, @@ -941,6 +973,9 @@ func (c *Conn) handleNextCommand(handler Handler) bool { if ok { delete(c.PrepareData, stmtID) } + c.discardCursor(stmtID) + case ComStmtFetch: + return c.handleComStmtFetch(handler, data) case ComStmtReset: return c.handleComStmtReset(data) case ComResetConnection: @@ -1037,6 +1072,7 @@ func (c *Conn) handleComBinlogDumpGTID(handler Handler, data []byte) (kontinue b func (c *Conn) handleComResetConnection(handler Handler) { // Clean up and reset the connection c.recycleReadPacket() + c.discardAllCursors() handler.ComResetConnection(c) // Reset prepared statements c.PrepareData = make(map[uint32]*PrepareData) @@ -1070,6 +1106,8 @@ func (c *Conn) handleComStmtReset(data []byte) bool { } } + c.discardCursor(stmtID) + if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil { log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) return false @@ -1077,6 +1115,81 @@ func (c *Conn) handleComStmtReset(data []byte) bool { return true } +func (c *Conn) handleComStmtFetch(handler Handler, data []byte) (kontinue bool) { + c.startWriterBuffering() + + stmtID, numRows, ok := c.parseComStmtFetch(data) + c.recycleReadPacket() + if !ok { + log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(sqlerror.ERUnknownComError, sqlerror.SSNetError, "error handling packet: %v", data) { + return false + } + return c.endWriterBuffering() == nil + } + + // fetching from wrong statement + cs, ok := c.cursorStates[stmtID] + if !ok || cs == nil { + log.Errorf("Requested stmtID does not have an open cursor. Client %v, returning error: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(sqlerror.ERUnknownComError, sqlerror.SSNetError, "error handling packet: %v", data) { + return false + } + return c.endWriterBuffering() == nil + } + + // There is always a pending result set, because we prefetch it to detect EOF. + // When we detect EOF, we set c.cs = nil. + + for cs != nil && numRows != 0 { + toSend := uint32(len(cs.pending.Rows)) + if toSend > numRows { + toSend = numRows + } + nextRows := cs.pending.Rows[toSend:] + cs.pending.Rows = cs.pending.Rows[:toSend] + + if err := c.writeBinaryRows(cs.pending); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false + } + cs.pending.Rows = nextRows + numRows -= toSend + if len(cs.pending.Rows) == 0 { + var ok bool + cs.pending, ok = <-cs.next + if !ok { + // Query has terminated. Check for an error. + err := <-cs.done + cs = nil + delete(c.cursorStates, stmtID) + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return false + } + } + } + } + + if cs == nil { + c.StatusFlags |= uint16(ServerStatusLastRowSent) + } + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false + } + if cs == nil { + c.StatusFlags &= ^uint16(ServerStatusLastRowSent) + } + if err := c.endWriterBuffering(); err != nil { + log.Errorf("Conn %v: endWriterBuffering() failed: %v", c.ID(), err) + return false + } + return true +} + func (c *Conn) handleComStmtSendLongData(data []byte) bool { stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data) c.recycleReadPacket() @@ -1115,10 +1228,19 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool kontinue = false } }() - queryStart := time.Now() - stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) + + // flush is called at the end of this block. + // To simplify error handling, we do not + // encapsulate it with a defer'd func() + c.startWriterBuffering() + + stmtID, cursorType, err := c.parseComStmtExecute(c.PrepareData, data) c.recycleReadPacket() + if err != nil { + return c.writeErrorPacketFromErrorAndLog(err) + } + if stmtID != uint32(0) { defer func() { // Allocate a new bindvar map every time since VTGate.Execute() mutates it. @@ -1127,73 +1249,144 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool }() } - if err != nil { - return c.writeErrorPacketFromErrorAndLog(err) - } - - receivedResult := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false + queryStart := time.Now() prepare := c.PrepareData[stmtID] - err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } + if cursorType == CursorTypeNoCursor { + receivedResult := false + sendFinished := false // sendFinished is set if the response should just be an OK packet. + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !receivedResult { + receivedResult = true + + if len(qr.Fields) == 0 { + sendFinished = true + // We should not send any more packets after this. + ok := PacketOK{ + affectedRows: qr.RowsAffected, + lastInsertID: qr.InsertID, + statusFlags: c.StatusFlags, + warnings: 0, + info: "", + sessionStateData: qr.SessionStateChanges, + } + return c.writeOKPacket(&ok) + } + if err := c.writeFields(qr); err != nil { + return err + } + } + + return c.writeBinaryRows(qr) + }) + // If no field was sent, we expect an error. if !receivedResult { - receivedResult = true + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = sqlerror.NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if !c.writeErrorPacketFromErrorAndLog(err) { + return false + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return false + } - if len(qr.Fields) == 0 { - sendFinished = true - // We should not send any more packets after this. - ok := PacketOK{ - affectedRows: qr.RowsAffected, - lastInsertID: qr.InsertID, - statusFlags: c.StatusFlags, - warnings: 0, - info: "", - sessionStateData: qr.SessionStateChanges, + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false } - return c.writeOKPacket(&ok) - } - if err := c.writeFields(qr); err != nil { - return err } } + return true + } - return c.writeBinaryRows(qr) - }) + next := make(chan *sqltypes.Result) + done, quit := make(chan error), make(chan error) - // If no field was sent, we expect an error. - if !receivedResult { - // This is just a failsafe. Should never happen. - if err == nil || err == io.EOF { - err = sqlerror.NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) - } - if !c.writeErrorPacketFromErrorAndLog(err) { - return false - } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) + go func() { + var err error + defer func() { + // pass along error, even if there's a panic + if r := recover(); r != nil { + err = fmt.Errorf("panic while running query for server-side cursor: %v", r) + } + close(next) + done <- err + close(done) + }() + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + // block until query results are sent or receive signal to quit + var qerr error + select { + case next <- qr: + case qerr = <-quit: + } + return qerr + }) + }() + + // Immediately receive the very first query result to write the fields + qr, ok := <-next + if !ok { + <-done + if !c.writeErrorAndLog(sqlerror.ERUnknownError, sqlerror.SSUnknownSQLState, "error handling packet: %v", "missing result set") { + log.Errorf("Error writing query error to %s", c) return false } + return false + } - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return false - } - } + defer timings.Record(queryTimingKey, queryStart) + + if len(qr.Fields) == 0 { + // DML or something without a result set. We do not open a cursor here. + <-done + return c.writeOKPacket(&PacketOK{ + affectedRows: qr.RowsAffected, + lastInsertID: qr.InsertID, + statusFlags: c.StatusFlags, + warnings: 0, + }) == nil } - timings.Record(queryTimingKey, queryStart) + // Open the cursor and write the fields. + c.StatusFlags |= uint16(ServerStatusCursorExists) + if err := c.writeFieldsWithoutEOF(qr); err != nil { + log.Errorf("Error writing fields to %s: %v", c, err) + return false + } + + // TODO: Look into whether accessing WarningCount + // here after passing `c` to ComStmtExecute in the + // goroutine above races. + if werr := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); werr != nil { + log.Errorf("Error writing result to %s: %v", c, werr) + return false + } + + // After writing the EOF_Packet/OK_Packet above, we + // have told the client the cursor is open. + c.StatusFlags &= ^uint16(ServerStatusCursorExists) + c.cursorStates[stmtID] = &CursorState{ + next: next, + done: done, + quit: quit, + pending: qr, + } return true } @@ -1469,6 +1662,28 @@ func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) { return true } +// discardCursor stops the statement execute goroutine and clears the cursor state, if it exists +func (c *Conn) discardCursor(stmtID uint32) { + // close cursor if open with unread results + cs, ok := c.cursorStates[stmtID] + if ok { + select { + case cs.quit <- errors.New("cancel cursor query"): + <-cs.done + case <-cs.done: + } + delete(c.cursorStates, stmtID) + } +} + +// discardAllCursors discards all open cursors on the connection +func (c *Conn) discardAllCursors() { + for stmtID := range c.cursorStates { + c.discardCursor(stmtID) + } + c.cursorStates = make(map[uint32]*CursorState) +} + func (c *Conn) execQuery(query string, handler Handler, more bool) execResult { callbackCalled := false // sendFinished is set if the response should just be an OK packet. diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index aab26763fcd..97a1b01b748 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -79,6 +79,7 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) { cConn := newConn(clientConn, DefaultFlushDelay, 0) sConn := newConn(serverConn, DefaultFlushDelay, 0) sConn.PrepareData = map[uint32]*PrepareData{} + sConn.cursorStates = map[uint32]*CursorState{} return listener, sConn, cConn } @@ -856,9 +857,9 @@ func TestEmptyQuery(t *testing.T) { // The queries run will be an empty query; Even with the empty error, the connection should be fine require.True(t, res, "we should not break the connection in case of no errors") // Read the result and assert that we indeed see the error for empty query. - data, more, _, err := cConn.ReadQueryResult(100, true) + data, status, _, err := cConn.ReadQueryResult(100, true) require.EqualError(t, err, "Query was empty (errno 1065) (sqlstate 42000)") - require.False(t, more) + require.False(t, IsMoreResultsExists(status)) require.Nil(t, data) }) } @@ -886,17 +887,17 @@ func TestMultiStatement(t *testing.T) { // The queries run will be select 1; and select 2; These queries do not return any errors, so the connection should still be open require.True(t, res, "we should not break the connection in case of no errors") // Read the result of the query and assert that it is indeed what we want. This will contain the result of the first query. - data, more, _, err := cConn.ReadQueryResult(100, true) + data, status, _, err := cConn.ReadQueryResult(100, true) require.NoError(t, err) // Since we executed 2 queries, there should be more results to be read - require.True(t, more) + require.True(t, IsMoreResultsExists(status)) require.True(t, data.Equal(selectRowsResult)) // Read the results for the second query and verify the correctness - data, more, _, err = cConn.ReadQueryResult(100, true) + data, status, _, err = cConn.ReadQueryResult(100, true) require.NoError(t, err) // This was the final query run, so we expect that more should be false as there are no more queries. - require.False(t, more) + require.False(t, IsMoreResultsExists(status)) require.True(t, data.Equal(selectRowsResult)) // This time we run two queries fist of which will return an error @@ -908,10 +909,10 @@ func TestMultiStatement(t *testing.T) { require.True(t, res, "we should not break the connection because of execution errors") // Read the result and assert that we indeed see the error that testRun throws. - data, more, _, err = cConn.ReadQueryResult(100, true) + data, status, _, err = cConn.ReadQueryResult(100, true) require.EqualError(t, err, "cannot get column number (errno 2027) (sqlstate HY000)") // In case of errors in a multi-statement, the following statements are not executed, therefore we want that more should be false - require.False(t, more) + require.False(t, IsMoreResultsExists(status)) require.Nil(t, data) }) } diff --git a/go/mysql/constants.go b/go/mysql/constants.go index b1e95a549bf..f19dc4bada6 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -156,7 +156,8 @@ const ( // Status flags. They are returned by the server in a few cases. // Originally found in include/mysql/mysql_com.h -// See http://dev.mysql.com/doc/internals/en/status-flags.html +// See https://dev.mysql.com/doc/dev/mysql-server/latest/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad +// and https://mariadb.com/kb/en/ok_packet/#server-status-flag const ( // a transaction is active ServerStatusInTrans uint16 = 0x0001 @@ -183,6 +184,21 @@ const ( ServerSessionStateChanged uint16 = 0x4000 ) +// Helper for checking if SERVER_MORE_RESULTS_EXISTS is set +func IsMoreResultsExists(status uint16) bool { + return status&ServerMoreResultsExists == ServerMoreResultsExists +} + +// Cursor Types. They are received on COM_STMT_EXECUTE() +// See https://dev.mysql.com/doc/dev/mysql-server/latest/mysql__com_8h.html#a3e5e9e744ff6f7b989a604fd669977da +// and https://mariadb.com/kb/en/com_stmt_execute/#flag +const ( + CursorTypeNoCursor = iota + CursorTypeReadOnly + CursorTypeCursorForUpdate + CursorTypeScrollableCursor +) + // State Change Information const ( // one or more system variables changed. diff --git a/go/mysql/endtoend/client_test.go b/go/mysql/endtoend/client_test.go index c1105e668ba..7b79d9c410a 100644 --- a/go/mysql/endtoend/client_test.go +++ b/go/mysql/endtoend/client_test.go @@ -161,29 +161,29 @@ func doTestMultiResult(t *testing.T, disableClientDeprecateEOF bool) { expectNoError(t, err) defer conn.Close() - qr, more, err := conn.ExecuteFetchMulti("select 1 from dual; set autocommit=1; select 1 from dual", 10, true) + qr, status, err := conn.ExecuteFetchMulti("select 1 from dual; set autocommit=1; select 1 from dual", 10, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(multi result)", more, true) + expectFlag(t, "ExecuteMultiFetch(multi result)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 1, len(qr.Rows)) - qr, more, _, err = conn.ReadQueryResult(10, true) + qr, status, _, err = conn.ReadQueryResult(10, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(1)", more, true) + expectFlag(t, "ReadQueryResult(1)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 0, len(qr.Rows)) - qr, more, _, err = conn.ReadQueryResult(10, true) + qr, status, _, err = conn.ReadQueryResult(10, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(2)", more, false) + expectFlag(t, "ReadQueryResult(2)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 1, len(qr.Rows)) - qr, more, err = conn.ExecuteFetchMulti("select 1 from dual", 10, true) + qr, status, err = conn.ExecuteFetchMulti("select 1 from dual", 10, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(single result)", more, false) + expectFlag(t, "ExecuteMultiFetch(single result)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 1, len(qr.Rows)) - qr, more, err = conn.ExecuteFetchMulti("set autocommit=1", 10, true) + qr, status, err = conn.ExecuteFetchMulti("set autocommit=1", 10, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(no result)", more, false) + expectFlag(t, "ExecuteMultiFetch(no result)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 0, len(qr.Rows)) // The ClientDeprecateEOF protocol change has a subtle twist in which an EOF or OK @@ -212,19 +212,19 @@ func doTestMultiResult(t *testing.T, disableClientDeprecateEOF bool) { err = conn.ExecuteFetchMultiDrain("update a set name = concat(name, ', multi drain 1'); select * from a; select count(*) from a") expectNoError(t, err) // If the previous command leaves packet in invalid state, this will fail. - qr, more, err = conn.ExecuteFetchMulti("update a set name = concat(name, ', fetch multi'); select * from a; select count(*) from a", 300, true) + qr, status, err = conn.ExecuteFetchMulti("update a set name = concat(name, ', fetch multi'); select * from a; select count(*) from a", 300, true) expectNoError(t, err) - expectFlag(t, "ExecuteMultiFetch(multi result)", more, true) + expectFlag(t, "ExecuteMultiFetch(multi result)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 255, qr.RowsAffected) - qr, more, _, err = conn.ReadQueryResult(300, true) + qr, status, _, err = conn.ReadQueryResult(300, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(1)", more, true) + expectFlag(t, "ReadQueryResult(1)", mysql.IsMoreResultsExists(status), true) assert.EqualValues(t, 255, len(qr.Rows), "ReadQueryResult(1)") - qr, more, _, err = conn.ReadQueryResult(300, true) + qr, status, _, err = conn.ReadQueryResult(300, true) expectNoError(t, err) - expectFlag(t, "ReadQueryResult(2)", more, false) + expectFlag(t, "ReadQueryResult(2)", mysql.IsMoreResultsExists(status), false) assert.EqualValues(t, 1, len(qr.Rows), "ReadQueryResult(1)") // Verify that a ExecuteFetchMultiDrain is happy to operate again after all the above. diff --git a/go/mysql/mysql_fuzzer.go b/go/mysql/mysql_fuzzer.go index 5ce82cd56c0..99e5007115f 100644 --- a/go/mysql/mysql_fuzzer.go +++ b/go/mysql/mysql_fuzzer.go @@ -199,6 +199,7 @@ func FuzzHandleNextCommand(data []byte) int { queryPacket: data, }, DefaultFlushDelay, 0, false) sConn.PrepareData = map[uint32]*PrepareData{} + sConn.cursorStates = map[uint32]*CursorState{} handler := &fuzztestRun{} _ = sConn.handleNextCommand(handler) diff --git a/go/mysql/query.go b/go/mysql/query.go index c20102c8c91..0e191bdf909 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -318,13 +318,13 @@ func (c *Conn) parseRow(data []byte, fields []*querypb.Field, reader func([]byte // 2. if the server closes the connection when a command is in flight, // readComQueryResponse will fail, and we'll return CRServerLost(2013). func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result *sqltypes.Result, err error) { - result, more, err := c.ExecuteFetchMulti(query, maxrows, wantfields) - if more { + result, status, err := c.ExecuteFetchMulti(query, maxrows, wantfields) + if IsMoreResultsExists(status) { // Multiple results are unexpected. Prioritize this "unexpected" error over whatever error we got from the first result. err = errors.Join(ErrExecuteFetchMultipleResults, err) } // draining to make the connection clean. - err = c.drainMoreResults(more, err) + err = c.drainMoreResults(status, err) return result, err } @@ -332,16 +332,16 @@ func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result // caring for any results. The function returns an error if any of the statements fail. // The function drains the query results of all statements, even if there's an error. func (c *Conn) ExecuteFetchMultiDrain(query string) (err error) { - _, more, err := c.ExecuteFetchMulti(query, FETCH_NO_ROWS, false) - return c.drainMoreResults(more, err) + _, status, err := c.ExecuteFetchMulti(query, FETCH_NO_ROWS, false) + return c.drainMoreResults(status, err) } // drainMoreResults ensures to drain all query results, even if there's an error. // We collect all errors until we consume all results. -func (c *Conn) drainMoreResults(more bool, err error) error { - for more { +func (c *Conn) drainMoreResults(status uint16, err error) error { + for IsMoreResultsExists(status) { var moreErr error - _, more, _, moreErr = c.ReadQueryResult(FETCH_NO_ROWS, false) + _, status, _, moreErr = c.ReadQueryResult(FETCH_NO_ROWS, false) if moreErr != nil { err = errors.Join(err, moreErr) } @@ -352,7 +352,7 @@ func (c *Conn) drainMoreResults(more bool, err error) error { // ExecuteFetchMulti is for fetching multiple results from a multi-statement result. // It returns an additional 'more' flag. If it is set, you must fetch the additional // results using ReadQueryResult. -func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (result *sqltypes.Result, more bool, err error) { +func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (result *sqltypes.Result, status uint16, err error) { defer func() { if err != nil { if sqlerr, ok := err.(*sqlerror.SQLError); ok { @@ -363,14 +363,14 @@ func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (re // Send the query as a COM_QUERY packet. if err = c.WriteComQuery(query); err != nil { - return nil, false, err + return nil, 0, err } - res, more, _, err := c.ReadQueryResult(maxrows, wantfields) + res, status, _, err := c.ReadQueryResult(maxrows, wantfields) if err != nil { - return nil, false, err + return nil, 0, err } - return res, more, err + return res, status, err } // ExecuteFetchWithWarningCount is for fetching results and a warning count @@ -395,14 +395,14 @@ func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfield } // ReadQueryResult gets the result from the last written query. -func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, bool, uint16, error) { +func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, uint16, uint16, error) { var packetOk PacketOK // Get the result. colNumber, err := c.readComQueryResponse(&packetOk) if err != nil { - return nil, false, 0, err + return nil, 0, 0, err } - more := packetOk.statusFlags&ServerMoreResultsExists != 0 + status := packetOk.statusFlags & ServerMoreResultsExists warnings := packetOk.warnings if colNumber == 0 { // OK packet, means no results. Just use the numbers. @@ -413,7 +413,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, SessionStateChanges: packetOk.sessionStateData, StatusFlags: packetOk.statusFlags, Info: packetOk.info, - }, more, warnings, nil + }, status, warnings, nil } fields := make([]querypb.Field, colNumber) @@ -428,11 +428,11 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, if wantfields { if err := c.readColumnDefinition(result.Fields[i], i); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } } else { if err := c.readColumnDefinitionType(result.Fields[i], i); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } } } @@ -441,19 +441,28 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, // EOF is only present here if it's not deprecated. data, err := c.readEphemeralPacket() if err != nil { - return nil, false, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) + return nil, 0, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) } if c.isEOFPacket(data) { // This is what we expect. - // Warnings and status flags are ignored. + _, status, err = parseEOFPacket(data) + if err != nil { + return nil, 0, 0, err + } c.recycleReadPacket() + + if status&ServerStatusCursorExists != 0 { + // if we are using cursors, do not go into the read row loop below + return result, status, 0, nil + } + // goto: read row loop } else if isErrorPacket(data) { defer c.recycleReadPacket() - return nil, false, 0, ParseErrorPacket(data) + return nil, 0, 0, ParseErrorPacket(data) } else { defer c.recycleReadPacket() - return nil, false, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data) + return nil, 0, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data) } } @@ -461,7 +470,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, for { data, err := c.readEphemeralPacket() if err != nil { - return nil, false, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) + return nil, 0, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) } if c.isEOFPacket(data) { @@ -475,29 +484,27 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, // The deprecated EOF packets change means that this is either an // EOF packet or an OK packet with the EOF type code. if c.Capabilities&CapabilityClientDeprecateEOF == 0 { - var statusFlags uint16 - warnings, statusFlags, err = parseEOFPacket(data) + warnings, status, err = parseEOFPacket(data) if err != nil { - return nil, false, 0, err + return nil, 0, 0, err } - more = (statusFlags & ServerMoreResultsExists) != 0 - result.StatusFlags = statusFlags + result.StatusFlags = status } else { var packetEof PacketOK if err := c.parseOKPacket(&packetEof, data); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } warnings = packetEof.warnings - more = (packetEof.statusFlags & ServerMoreResultsExists) != 0 + status = packetEof.statusFlags result.SessionStateChanges = packetEof.sessionStateData result.StatusFlags = packetEof.statusFlags result.Info = packetEof.info } - return result, more, warnings, nil + return result, status, warnings, nil } else if isErrorPacket(data) { defer c.recycleReadPacket() // Error packet. - return nil, false, 0, ParseErrorPacket(data) + return nil, 0, 0, ParseErrorPacket(data) } if maxrows == FETCH_NO_ROWS { @@ -509,16 +516,73 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, if maxrows != FETCH_ALL_ROWS && len(result.Rows) == maxrows { c.recycleReadPacket() if err := c.drainResults(); err != nil { - return nil, false, 0, err + return nil, 0, 0, err } - return nil, false, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows) + return nil, 0, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows) } // Regular row. row, err := c.parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil) if err != nil { c.recycleReadPacket() - return nil, false, 0, err + return nil, 0, 0, err + } + result.Rows = append(result.Rows, row) + c.recycleReadPacket() + } +} + +// FetchQueryResult gets the reset set from the last executed query. +func (c *Conn) FetchQueryResult(maxrows int, fields []*querypb.Field) (result *sqltypes.Result, status uint16, warnings uint16, err error) { + result = &sqltypes.Result{} + + // read each row until EOF or OK packet. + for { + data, err := c.readEphemeralPacket() + if err != nil { + return nil, 0, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err) + } + + if c.isEOFPacket(data) { + defer c.recycleReadPacket() + + result.RowsAffected = uint64(len(result.Rows)) + + // The deprecated EOF packets change means that this is either an + // EOF packet or an OK packet with the EOF type code. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + warnings, status, err = parseEOFPacket(data) + if err != nil { + return nil, 0, 0, err + } + } else { + var packetEof PacketOK + if err := c.parseOKPacket(&packetEof, data); err != nil { + return nil, 0, 0, err + } + warnings = packetEof.warnings + status = packetEof.statusFlags + } + return result, status, warnings, nil + } else if isErrorPacket(data) { + defer c.recycleReadPacket() + return nil, 0, 0, ParseErrorPacket(data) + } + + // Check we're not over the limit before we add more. + if len(result.Rows) == maxrows { + defer c.recycleReadPacket() + if err := c.drainResults(); err != nil { + return nil, 0, 0, err + } + return nil, 0, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows) + } + + // Regular row. + row, err := c.parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil) + if err != nil { + c.recycleReadPacket() + return nil, 0, 0, err } result.Rows = append(result.Rows, row) c.recycleReadPacket() @@ -954,6 +1018,15 @@ func (c *Conn) parseComStmtReset(data []byte) (uint32, bool) { return val, ok } +func (c *Conn) parseComStmtFetch(data []byte) (uint32, uint32, bool) { + stmtId, pos, ok := readUint32(data, 1) + if !ok { + return 0, 0, false + } + numRows, _, ok := readUint32(data, pos) + return stmtId, numRows, ok +} + func (c *Conn) parseComInitDB(data []byte) string { return string(data[1:]) } @@ -1039,6 +1112,28 @@ func (c *Conn) writeRow(row []sqltypes.Value) error { // writeFields writes the fields of a Result. It should be called only // if there are valid columns in the result. func (c *Conn) writeFields(result *sqltypes.Result) error { + err := c.writeFieldsWithoutEOF(result) + if err != nil { + return err + } + + // Now send an EOF packet. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + // With CapabilityClientDeprecateEOF, we do not send this EOF. + if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { + return err + } + } + + return nil +} + +// Writes the fields for a Result, but never adds the EOF_Packet on the end, even +// if ClientDeprecateEOF == 0. This is used when returning fields in a +// COM_STMT_EXECUTE that opens a cursor, since we immediately follow the fields +// up with a writeEndResult, which appropriately adds an EOF or an OK_Packet +// depending on the client capabilities. +func (c *Conn) writeFieldsWithoutEOF(result *sqltypes.Result) error { // Send the number of fields first. if err := c.sendColumnCount(uint64(len(result.Fields))); err != nil { return err @@ -1050,14 +1145,6 @@ func (c *Conn) writeFields(result *sqltypes.Result) error { return err } } - - // Now send an EOF packet. - if c.Capabilities&CapabilityClientDeprecateEOF == 0 { - // With CapabilityClientDeprecateEOF, we do not send this EOF. - if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { - return err - } - } return nil } diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 37501f9329f..98ff8341b4b 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -18,6 +18,7 @@ package mysql import ( "fmt" + "io" "reflect" "sync" "testing" @@ -142,6 +143,8 @@ func TestComStmtPrepare(t *testing.T) { sConn.PrepareData = make(map[uint32]*PrepareData) sConn.PrepareData[prepare.StatementID] = prepare + sConn.cursorStates = make(map[uint32]*CursorState) + // write the response to the client if err := sConn.writePrepare(result.Fields, prepare); err != nil { t.Fatalf("sConn.writePrepare failed: %v", err) @@ -181,6 +184,8 @@ func TestComStmtPrepareUpdStmt(t *testing.T) { sConn.PrepareData = make(map[uint32]*PrepareData) sConn.PrepareData[prepare.StatementID] = prepare + sConn.cursorStates = make(map[uint32]*CursorState) + // write the response to the client err = sConn.writePrepare(nil, prepare) require.NoError(t, err, "sConn.writePrepare failed") @@ -211,6 +216,8 @@ func TestComStmtSendLongData(t *testing.T) { t.Fatalf("writePrepare failed: %v", err) } + cConn.cursorStates = make(map[uint32]*CursorState) + // Since there's no writeComStmtSendLongData, we'll write a prepareStmt and check if we can read the StatementID data, err := sConn.ReadPacket() if err != nil || len(data) == 0 { @@ -237,8 +244,10 @@ func TestComStmtExecute(t *testing.T) { cConn.PrepareData = make(map[uint32]*PrepareData) cConn.PrepareData[prepare.StatementID] = prepare + cConn.cursorStates = make(map[uint32]*CursorState) + // This is simulated packets for `select * from test_table where id = ?` - data := []byte{23, 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1} + data := []byte{23, 18, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 128, 1} stmtID, _, err := sConn.parseComStmtExecute(cConn.PrepareData, data) require.NoError(t, err, "parseComStmtExeute failed: %v", err) @@ -319,6 +328,33 @@ func TestComStmtExecuteUpdStmt(t *testing.T) { assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[28], "got: %s", querypb.Type(prepData.ParamsType[28])) } +func TestComStmtFetch(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + prepare, _ := MockPrepareData(t) + cConn.PrepareData = make(map[uint32]*PrepareData) + cConn.PrepareData[prepare.StatementID] = prepare + + cConn.cursorStates = make(map[uint32]*CursorState) + + // This is simulated packets for `select * from test_table where id = ?` + data := []byte{23, 18, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 128, 1} + + stmtID, cursorType, err := sConn.parseComStmtExecute(cConn.PrepareData, data) + require.NoError(t, err, "parseComStmtExeute failed: %v", err) + if stmtID != 18 { + t.Fatalf("Parsed incorrect values") + } + if cursorType != CursorTypeReadOnly { + t.Fatalf("Expected read-only cursor") + } +} + func TestComStmtClose(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { @@ -334,6 +370,8 @@ func TestComStmtClose(t *testing.T) { t.Fatalf("writePrepare failed: %v", err) } + cConn.cursorStates = make(map[uint32]*CursorState) + // Since there's no writeComStmtClose, we'll write a prepareStmt and check if we can read the StatementID data, err := sConn.ReadPacket() if err != nil || len(data) == 0 { @@ -767,3 +805,209 @@ func RowString(row []sqltypes.Value) string { } return result } + +func writeRawPacketToConn(c *Conn, packet []byte) error { + c.sequence = 0 + data, pos := c.startEphemeralPacketWithHeader(len(packet)) + copy(data[pos:], packet) + return c.writeEphemeralPacket() +} + +type testExec struct { + query string + useCursor byte + expectedNumFields int + expectedNumRows int + maxRows int +} + +func clientExecute(t *testing.T, cConn *Conn, useCursor byte, maxRows int) (*sqltypes.Result, uint16) { + if useCursor != 0 { + cConn.StatusFlags |= uint16(ServerStatusCursorExists) + } else { + cConn.StatusFlags &= ^uint16(ServerStatusCursorExists) + } + + // Write a COM_STMT_EXECUTE packet + mockPacket := []byte{ComStmtExecute, 0, 0, 0, 0, useCursor, 1, 0, 0, 0, 0, 1, 1, 128, 1} + + if err := writeRawPacketToConn(cConn, mockPacket); err != nil { + t.Fatalf("WriteMockExecuteToConn failed with error: %v", err) + } + + qr, status, _, err := cConn.ReadQueryResult(maxRows, true) + require.NoError(t, err, "ReadQueryResult failed with error: %v", err) + + return qr, status +} + +func clientFetch(t *testing.T, cConn *Conn, useCursor byte, maxRows int, fields []*querypb.Field) (*sqltypes.Result, uint16) { + // Write a COM_STMT_FETCH packet + mockPacket := []byte{ComStmtFetch, 0, 0, 0, 0, useCursor, 1, 0, 0, 0, 0, 1, 1, 128, 1} + + if err := writeRawPacketToConn(cConn, mockPacket); err != nil { + t.Fatalf("WriteMockDataToConn failed with error: %v", err) + } + + qr, status, _, err := cConn.FetchQueryResult(maxRows, fields) + if err != nil && err != io.EOF { + t.Fatalf("FetchQueryResult failed with error: %v", err) + } + + return qr, status +} + +func checkExecute(t *testing.T, sConn, cConn *Conn, test testExec) { + // Pretend a successful COM_PREPARE was sent to server + prepare := &PrepareData{ + StatementID: 0, + PrepareStmt: test.query, + } + sConn.PrepareData = make(map[uint32]*PrepareData) + sConn.PrepareData[prepare.StatementID] = prepare + + sConn.cursorStates = make(map[uint32]*CursorState) + + // use go routine to emulate client calls + var qr *sqltypes.Result + var status uint16 + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + qr, status = clientExecute(t, cConn, test.useCursor, test.maxRows) + }() + + // handle a single client command + if !sConn.handleNextCommand(&testHandler{}) { + t.Fatalf("handleNextComamnd failed") + } + + // wait until client receives the query result back + wg.Wait() + + if qr == nil { + t.Fatalf("execute result is nil") + } + var fields = qr.Fields + if test.expectedNumFields != len(fields) { + t.Fatalf("Expected %d fields, Received %d", test.expectedNumFields, len(fields)) + } + + // if not using cursor, we should have results without fetching + if test.useCursor == 0 { + if (status & ServerStatusCursorExists) != 0 { + t.Fatalf("Server StatusFlag should indicate that Cursor does not exist") + } + if test.expectedNumRows != len(qr.Rows) { + t.Fatalf("Expected %d rows, Received %d", test.expectedNumRows, len(qr.Rows)) + } + return + } + + // using cursor, use client to fetch results + if (status & ServerStatusCursorExists) == 0 { + t.Fatalf("Server StatusFlag should indicate that Cursor exists, status flags were: %d", status) + } + + qr = nil + wg.Add(1) + go func() { + defer wg.Done() + qr, status = clientFetch(t, cConn, test.useCursor, test.maxRows, fields) + }() + + // handle a single client command + if !sConn.handleNextCommand(&testHandler{}) { + t.Fatalf("handleNextComamnd failed") + } + + // wait until client fetches the rows + wg.Wait() + if qr == nil { + t.Fatalf("fetch result is nil") + } + + if (status & ServerStatusCursorExists) != 0 { + t.Fatalf("Server StatusFlag should not indicate a new Cursor, status flags were: %d", status) + } + if (status & ServerStatusLastRowSent) == 0 { + t.Fatalf("Server StatusFlag should indicate that we exhausted the cursor, status flags were: %d", status) + } + + if test.expectedNumRows != len(qr.Rows) { + t.Fatalf("Expected %d rows, Received %d", test.expectedNumRows, len(qr.Rows)) + } +} + +func TestExecuteQueries(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + tests := []testExec{ + { + query: "empty result", + useCursor: 0, + expectedNumFields: 3, + expectedNumRows: 0, + maxRows: 100, + }, + { + query: "select rows", + useCursor: 0, + expectedNumFields: 2, + expectedNumRows: 2, + maxRows: 100, + }, + { + query: "large batch", + useCursor: 0, + expectedNumFields: 2, + expectedNumRows: 256, + maxRows: 1000, + }, + { + query: "empty result", + useCursor: 1, + expectedNumFields: 3, + expectedNumRows: 0, + maxRows: 100, + }, + { + query: "select rows", + useCursor: 1, + expectedNumFields: 2, + expectedNumRows: 2, + maxRows: 100, + }, + { + query: "large batch", + useCursor: 1, + expectedNumFields: 2, + expectedNumRows: 256, + maxRows: 1000, + }, + } + + t.Run("WithoutDeprecateEOF", func(t *testing.T) { + for i, test := range tests { + t.Run(fmt.Sprintf("%d %s", i, test.query), func(t *testing.T) { + checkExecute(t, sConn, cConn, test) + }) + } + }) + + sConn.Capabilities = CapabilityClientDeprecateEOF + cConn.Capabilities = CapabilityClientDeprecateEOF + t.Run("WithDeprecateEOF", func(t *testing.T) { + for i, test := range tests { + t.Run(fmt.Sprintf("%d %s", i, test.query), func(t *testing.T) { + checkExecute(t, sConn, cConn, test) + }) + } + }) +} diff --git a/go/mysql/server.go b/go/mysql/server.go index 17c113248e6..fbd4c1cffcb 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -400,6 +400,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti // Adjust the count of open connections defer connCount.Add(-1) + defer c.discardAllCursors() + // First build and send the server handshake packet. serverAuthPluginData, err := c.writeHandshakeV10(l.ServerVersion, l.authServer, uint8(l.charset), l.TLSConfig.Load() != nil) if err != nil { diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index ee30dded978..6d158d4fe8d 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -256,7 +256,76 @@ func (th *testHandler) ComPrepare(*Conn, string) ([]*querypb.Field, uint16, erro } func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { - return nil + switch prepare.PrepareStmt { + case "empty result": + callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: querypb.Type_INT32, + }, + { + Name: "name", + Type: querypb.Type_VARCHAR, + }, + { + Name: "name2", + Type: querypb.Type_VARCHAR, + }, + }, + Rows: [][]sqltypes.Value{}, + RowsAffected: 0, + }) + case "select rows": + callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: querypb.Type_INT32, + }, + { + Name: "name", + Type: querypb.Type_VARCHAR, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), + }, + { + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nicer name")), + }, + }, + RowsAffected: 2, + }) + case "large batch": + n := 256 + rows := make([][]sqltypes.Value, n) + for i := 0; i < n; i++ { + rows[i] = []sqltypes.Value{ + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), + } + } + callback(&sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: querypb.Type_INT32, + }, + { + Name: "name", + Type: querypb.Type_VARCHAR, + }, + }, + Rows: rows, + RowsAffected: uint64(n), + }) + default: + return fmt.Errorf("unrecorgnized test command") + } } func (th *testHandler) ComRegisterReplica(c *Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error { diff --git a/go/test/endtoend/topotest/consul/main_test.go b/go/test/endtoend/topotest/consul/main_test.go index 6b61aa4c633..4b61001dbd8 100644 --- a/go/test/endtoend/topotest/consul/main_test.go +++ b/go/test/endtoend/topotest/consul/main_test.go @@ -277,11 +277,11 @@ func execute(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/test/endtoend/topotest/etcd2/main_test.go b/go/test/endtoend/topotest/etcd2/main_test.go index f17532e65f5..fae2bf222a8 100644 --- a/go/test/endtoend/topotest/etcd2/main_test.go +++ b/go/test/endtoend/topotest/etcd2/main_test.go @@ -271,11 +271,11 @@ func TestNamedLocking(t *testing.T) { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/test/endtoend/topotest/zk2/main_test.go b/go/test/endtoend/topotest/zk2/main_test.go index 5f95b78c507..c237d93ff3e 100644 --- a/go/test/endtoend/topotest/zk2/main_test.go +++ b/go/test/endtoend/topotest/zk2/main_test.go @@ -246,11 +246,11 @@ func TestKeyspaceLocking(t *testing.T) { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/test/endtoend/utils/cmp.go b/go/test/endtoend/utils/cmp.go index 05e65cc7fda..e7fbb1a978f 100644 --- a/go/test/endtoend/utils/cmp.go +++ b/go/test/endtoend/utils/cmp.go @@ -225,11 +225,13 @@ func (mcmp *MySQLCompare) ExecMulti(sql string) []*sqltypes.Result { mcmp.t.Helper() stmts, err := sqlparser.NewTestParser().SplitStatementToPieces(sql) require.NoError(mcmp.t, err) - vtQr, vtMore, err := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) + vtQr, vtStatus, err := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) require.NoError(mcmp.t, err, "[Vitess Error] for sql: "+sql) + vtMore := mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, err := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) + mysqlQr, mysqlStatus, err := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) require.NoError(mcmp.t, err, "[MySQL Error] for sql: "+sql) + mysqlMore := mysql.IsMoreResultsExists(mysqlStatus) sql = stmts[0] CompareVitessAndMySQLResults(mcmp.t, sql, mcmp.VtConn, vtQr, mysqlQr, CompareOptions{}) if vtMore != mysqlMore { @@ -241,11 +243,13 @@ func (mcmp *MySQLCompare) ExecMulti(sql string) []*sqltypes.Result { for vtMore { sql = stmts[idx] idx++ - vtQr, vtMore, _, err = mcmp.VtConn.ReadQueryResult(1000, true) + vtQr, vtStatus, _, err = mcmp.VtConn.ReadQueryResult(1000, true) require.NoError(mcmp.t, err, "[Vitess Error] for sql: "+sql) + vtMore = mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, _, err = mcmp.MySQLConn.ReadQueryResult(1000, true) + mysqlQr, mysqlStatus, _, err = mcmp.MySQLConn.ReadQueryResult(1000, true) require.NoError(mcmp.t, err, "[MySQL Error] for sql: "+sql) + mysqlMore = mysql.IsMoreResultsExists(mysqlStatus) CompareVitessAndMySQLResults(mcmp.t, sql, mcmp.VtConn, vtQr, mysqlQr, CompareOptions{}) if vtMore != mysqlMore { mcmp.AsT().Errorf("Vitess and MySQL have different More flags: %v vs %v", vtMore, mysqlMore) @@ -262,9 +266,11 @@ func (mcmp *MySQLCompare) ExecMultiAllowError(sql string) { mcmp.t.Helper() stmts, err := sqlparser.NewTestParser().SplitStatementToPieces(sql) require.NoError(mcmp.t, err) - vtQr, vtMore, vtErr := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) + vtQr, vtStatus, vtErr := mcmp.VtConn.ExecuteFetchMulti(sql, 1000, true) + vtMore := mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, mysqlErr := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) + mysqlQr, mysqlStatus, mysqlErr := mcmp.MySQLConn.ExecuteFetchMulti(sql, 1000, true) + mysqlMore := mysql.IsMoreResultsExists(mysqlStatus) sql = stmts[0] compareVitessAndMySQLErrors(mcmp.t, vtErr, mysqlErr) if vtErr == nil && mysqlErr == nil { @@ -278,9 +284,11 @@ func (mcmp *MySQLCompare) ExecMultiAllowError(sql string) { for vtMore { sql = stmts[idx] idx++ - vtQr, vtMore, _, vtErr = mcmp.VtConn.ReadQueryResult(1000, true) + vtQr, vtStatus, _, vtErr = mcmp.VtConn.ReadQueryResult(1000, true) + vtMore = mysql.IsMoreResultsExists(vtStatus) - mysqlQr, mysqlMore, _, mysqlErr = mcmp.MySQLConn.ReadQueryResult(1000, true) + mysqlQr, mysqlStatus, _, mysqlErr = mcmp.MySQLConn.ReadQueryResult(1000, true) + mysqlMore = mysql.IsMoreResultsExists(mysqlStatus) compareVitessAndMySQLErrors(mcmp.t, vtErr, mysqlErr) if vtErr == nil && mysqlErr == nil { CompareVitessAndMySQLResults(mcmp.t, sql, mcmp.VtConn, vtQr, mysqlQr, CompareOptions{}) diff --git a/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go b/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go index 2a898b61d07..2d8f7c2b491 100644 --- a/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go +++ b/go/test/endtoend/vtgate/queries/multi_query/multi_query_test.go @@ -170,12 +170,12 @@ func TestMultiQuery(t *testing.T) { // results obtained from the gRPC connection. func getMySqlResults(conn *mysql.Conn, sql string) ([]*sqltypes.Result, error) { var results []*sqltypes.Result - mysqlQr, mysqlMore, mysqlErr := conn.ExecuteFetchMulti(sql, 1000, true) + mysqlQr, mysqlStatus, mysqlErr := conn.ExecuteFetchMulti(sql, 1000, true) if mysqlQr != nil { results = append(results, mysqlQr) } - for mysqlMore { - mysqlQr, mysqlMore, _, mysqlErr = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(mysqlStatus) { + mysqlQr, mysqlStatus, _, mysqlErr = conn.ReadQueryResult(1000, true) if mysqlQr != nil { results = append(results, mysqlQr) } diff --git a/go/test/endtoend/vtgate/reservedconn/udv_test.go b/go/test/endtoend/vtgate/reservedconn/udv_test.go index c5eee88ae8b..16bbce5be4a 100644 --- a/go/test/endtoend/vtgate/reservedconn/udv_test.go +++ b/go/test/endtoend/vtgate/reservedconn/udv_test.go @@ -146,9 +146,9 @@ func TestMysqlDumpInitialLog(t *testing.T) { for _, query := range queries { t.Run(query, func(t *testing.T) { - _, more, err := conn.ExecuteFetchMulti(query, 1000, true) + _, status, err := conn.ExecuteFetchMulti(query, 1000, true) require.NoError(t, err) - require.False(t, more) + require.False(t, mysql.IsMoreResultsExists(status)) }) } } diff --git a/go/test/endtoend/vtgate/unsharded/main_test.go b/go/test/endtoend/vtgate/unsharded/main_test.go index acb9a7be4a1..00a7e2e066c 100644 --- a/go/test/endtoend/vtgate/unsharded/main_test.go +++ b/go/test/endtoend/vtgate/unsharded/main_test.go @@ -449,11 +449,11 @@ func TestFloatValueDefault(t *testing.T) { func execMulti(t *testing.T, conn *mysql.Conn, query string) []*sqltypes.Result { t.Helper() var res []*sqltypes.Result - qr, more, err := conn.ExecuteFetchMulti(query, 1000, true) + qr, status, err := conn.ExecuteFetchMulti(query, 1000, true) res = append(res, qr) require.NoError(t, err) - for more == true { - qr, more, _, err = conn.ReadQueryResult(1000, true) + for mysql.IsMoreResultsExists(status) { + qr, status, _, err = conn.ReadQueryResult(1000, true) require.NoError(t, err) res = append(res, qr) } diff --git a/go/vt/binlog/binlogplayer/dbclient.go b/go/vt/binlog/binlogplayer/dbclient.go index 4cbfd962528..f24da0e29e6 100644 --- a/go/vt/binlog/binlogplayer/dbclient.go +++ b/go/vt/binlog/binlogplayer/dbclient.go @@ -162,14 +162,14 @@ func (dc *dbClientImpl) ExecuteFetch(query string, maxrows int) (*sqltypes.Resul func (dc *dbClientImpl) ExecuteFetchMulti(query string, maxrows int) ([]*sqltypes.Result, error) { results := make([]*sqltypes.Result, 0) - mqr, more, err := dc.dbConn.ExecuteFetchMulti(query, maxrows, true) + mqr, status, err := dc.dbConn.ExecuteFetchMulti(query, maxrows, true) if err != nil { dc.handleError(err) return nil, err } results = append(results, mqr) - for more { - mqr, more, _, err = dc.dbConn.ReadQueryResult(maxrows, false) + for mysql.IsMoreResultsExists(status) { + mqr, status, _, err = dc.dbConn.ReadQueryResult(maxrows, false) if err != nil { dc.handleError(err) return nil, err diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index b255fc32f4c..b3ebcb70997 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -1201,12 +1201,12 @@ func (mysqld *Mysqld) executeMysqlScript(ctx context.Context, connParams *mysql. } defer conn.Close() - _, more, err := conn.ExecuteFetchMulti(sql, -1, false) + _, status, err := conn.ExecuteFetchMulti(sql, -1, false) if err != nil { return err } - for more { - _, more, _, err = conn.ReadQueryResult(0, false) + for mysql.IsMoreResultsExists(status) { + _, status, _, err = conn.ReadQueryResult(0, false) if err != nil { return err } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index c1931514dc4..77f2b533c20 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -238,12 +238,12 @@ func initTimezoneData(t *testing.T, conn *mysql.Conn) { t.Fatalf("failed to retrieve timezone info: %v", err) } - _, more, err := conn.ExecuteFetchMulti(fmt.Sprintf("USE mysql; %s\n", string(out)), -1, false) + _, status, err := conn.ExecuteFetchMulti(fmt.Sprintf("USE mysql; %s\n", string(out)), -1, false) if err != nil { t.Fatalf("failed to insert timezone info: %v", err) } - for more { - _, more, _, err = conn.ReadQueryResult(-1, false) + for mysql.IsMoreResultsExists(status) { + _, status, _, err = conn.ReadQueryResult(-1, false) if err != nil { t.Fatalf("failed to insert timezone info: %v", err) } diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index bb3075175db..4dab1dff871 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -129,12 +129,12 @@ func (tm *TabletManager) executeMultiFetchAsDba( return nil, err } results := make([]*querypb.QueryResult, 0, len(queries)) - result, more, err := conn.ExecuteFetchMulti(uq, maxRows, true /*wantFields*/) + result, status, err := conn.ExecuteFetchMulti(uq, maxRows, true /*wantFields*/) if err == nil { results = append(results, sqltypes.ResultToProto3(result)) } - for more { - result, more, _, err = conn.ReadQueryResult(maxRows, true /*wantFields*/) + for (status & sqltypes.ServerMoreResultsExists) != 0 { + result, status, _, err = conn.ReadQueryResult(maxRows, true /*wantFields*/) if err != nil { return nil, err }