Fix error around breaking of multistatements

When a multistatement query is received, any errors should
abort the execution of the remaining queries.

The `execQuery` and `handleNextCommand` were returning an error,
but not actually using the error value - just checking if it was nil or not.

We need to be able to know on the outside of `execQuery` if an error occured and if it was an error we need to close the connection for or if it was a simple execution error.

Signed-off-by: Andres Taylor <andres@planetscale.com>
This commit is contained in:
Andres Taylor 2020-10-01 10:33:13 +02:00
Родитель e2d0e1a9ba
Коммит 4f0210a1b4
4 изменённых файлов: 235 добавлений и 166 удалений

Просмотреть файл

@ -758,7 +758,7 @@ func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error {
// handleNextCommand is called in the server loop to process
// incoming packets.
func (c *Conn) handleNextCommand(handler Handler) error {
func (c *Conn) handleNextCommand(handler Handler) bool {
c.sequence = 0
data, err := c.readEphemeralPacket()
if err != nil {
@ -766,63 +766,56 @@ func (c *Conn) handleNextCommand(handler Handler) error {
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
log.Errorf("Error reading packet from %s: %v", c, err)
}
return err
return false
}
switch data[0] {
case ComQuit:
c.recycleReadPacket()
return errors.New("ComQuit")
return false
case ComInitDB:
db := c.parseComInitDB(data)
c.recycleReadPacket()
if err := c.execQuery("use "+sqlescape.EscapeID(db), handler, false); err != nil {
return err
}
res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false)
return res == execSuccess // TODO: we shouldn't drop the connection if the user is asking for the wrong db
case ComQuery:
err := func() error {
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
}
}()
queryStart := time.Now()
query := c.parseComQuery(data)
c.recycleReadPacket()
var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
queries, err = sqlparser.SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Conn %v: Error writing query error: %v", c, werr)
return werr
}
}
} else {
queries = []string{query}
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
}
for index, sql := range queries {
more := false
if index != len(queries)-1 {
more = true
}
if err := c.execQuery(sql, handler, more); err != nil {
return err
}
}
timings.Record(queryTimingKey, queryStart)
return nil
}()
if err != nil {
return err
queryStart := time.Now()
query := c.parseComQuery(data)
c.recycleReadPacket()
var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
queries, err = sqlparser.SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Conn %v: Error writing query error: %v", c, werr)
return false
}
}
} else {
queries = []string{query}
}
for index, sql := range queries {
more := false
if index != len(queries)-1 {
more = true
}
res := c.execQuery(sql, handler, more)
if res != execSuccess {
return res != connErr
}
}
timings.Record(queryTimingKey, queryStart)
case ComPing:
c.recycleReadPacket()
@ -830,14 +823,15 @@ func (c *Conn) handleNextCommand(handler Handler) error {
if c.listener.isShutdown() {
if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil {
log.Errorf("Error writing ComPing error to %s: %v", c, err)
return err
return false
}
} else {
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Error writing ComPing result to %s: %v", c, err)
return err
return false
}
}
case ComSetOption:
operation, ok := c.parseComSetOption(data)
c.recycleReadPacket()
@ -851,20 +845,21 @@ func (c *Conn) handleNextCommand(handler Handler) error {
log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
return false
}
}
if err := c.writeEndResult(false, 0, 0, 0); err != nil {
log.Errorf("Error writeEndResult error %v ", err)
return err
return false
}
} else {
log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
return false
}
}
case ComPrepare:
query := c.parseComPrepare(data)
c.recycleReadPacket()
@ -877,7 +872,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Conn %v: Error writing query error: %v", c, werr)
return werr
return false
}
}
} else {
@ -885,7 +880,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
}
if len(queries) != 1 {
return fmt.Errorf("can not prepare multiple statements")
return false // TODO: do we really want to close the connection because of this?
}
// Popoulate PrepareData
@ -901,7 +896,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Conn %v: Error writing prepared statement error: %v", c, werr)
return werr
return false
}
}
@ -936,120 +931,116 @@ func (c *Conn) handleNextCommand(handler Handler) error {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
return false
}
return nil
return true
}
if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil {
return err
log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err)
return false
}
case ComStmtExecute:
err := func() error {
c.startWriterBuffering()
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
}
}()
queryStart := time.Now()
stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data)
c.recycleReadPacket()
if stmtID != uint32(0) {
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
}
// Allocate a new bindvar map every time since VTGate.Execute() mutates it.
prepare := c.PrepareData[stmtID]
prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount)
}()
queryStart := time.Now()
stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data)
c.recycleReadPacket()
}
if stmtID != uint32(0) {
defer func() {
// Allocate a new bindvar map every time since VTGate.Execute() mutates it.
prepare := c.PrepareData[stmtID]
prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount)
}()
if err != nil {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
return false
}
return true
}
fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
prepare := c.PrepareData[stmtID]
err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error {
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}
if err != nil {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
}
return nil
}
fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
prepare := c.PrepareData[stmtID]
err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error {
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}
if !fieldSent {
fieldSent = true
if len(qr.Fields) == 0 {
sendFinished = true
// We should not send any more packets after this.
return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0)
}
if err := c.writeFields(qr); err != nil {
return err
}
}
return c.writeBinaryRows(qr)
})
// If no field was sent, we expect an error.
if !fieldSent {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
fieldSent = true
if len(qr.Fields) == 0 {
sendFinished = true
// We should not send any more packets after this.
return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0)
}
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
if err := c.writeFields(qr); err != nil {
return err
}
// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
}
}
timings.Record(queryTimingKey, queryStart)
return nil
}()
if err != nil {
return err
return c.writeBinaryRows(qr)
})
// If no field was sent, we expect an error.
if !fieldSent {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
}
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return false
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return false
}
// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return false
}
}
}
timings.Record(queryTimingKey, queryStart)
case ComStmtSendLongData:
stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data)
c.recycleReadPacket()
if !ok {
err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data)
log.Error(err.Error())
return err
return false // TODO: really break here?
}
prepare, ok := c.PrepareData[stmtID]
if !ok {
err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID)
log.Error(err.Error())
return err
return false // TODO: really break here?
}
if prepare.BindVars == nil ||
@ -1057,7 +1048,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
paramID >= prepare.ParamsCount {
err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt)
log.Error(err.Error())
return err
return false // TODO: really break here?
}
chunk := make([]byte, len(chunkData))
@ -1082,7 +1073,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Error("Error writing error packet to client: %v", err)
return err
return false
}
}
@ -1091,7 +1082,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); err != nil {
log.Error("Error writing error packet to client: %v", err)
return err
return false
}
}
@ -1103,7 +1094,7 @@ func (c *Conn) handleNextCommand(handler Handler) error {
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
return err
return false
}
case ComResetConnection:
@ -1122,14 +1113,22 @@ func (c *Conn) handleNextCommand(handler Handler) error {
c.recycleReadPacket()
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]); err != nil {
log.Errorf("Error writing error packet to %s: %s", c, err)
return err
return false
}
}
return nil
return true
}
func (c *Conn) execQuery(query string, handler Handler, more bool) error {
type execResult byte
const (
execSuccess execResult = iota
execErr
connErr
)
func (c *Conn) execQuery(query string, handler Handler, more bool) execResult {
fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
@ -1175,28 +1174,28 @@ func (c *Conn) execQuery(query string, handler Handler, more bool) error {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return err
return connErr
}
return execErr
}
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return connErr
}
// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return connErr
}
}
return nil
return execSuccess
}
//

Просмотреть файл

@ -19,12 +19,18 @@ package mysql
import (
"bytes"
crypto_rand "crypto/rand"
"fmt"
"math/rand"
"net"
"reflect"
"runtime/debug"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)
func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) {
@ -288,3 +294,67 @@ func TestEOFOrLengthEncodedIntFuzz(t *testing.T) {
}
}
}
func TestMultiStatementStopsOnError(t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
sConn.Capabilities |= CapabilityClientMultiStatements
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()
err := cConn.WriteComQuery("select 1;select 2")
require.NoError(t, err)
// this handler will return an error on the first run, and fail the test if it's run more times
handler := &singleRun{t: t, err: fmt.Errorf("execution failed")}
res := sConn.handleNextCommand(handler)
require.True(t, res, res, "we should not break the connection because of execution errors")
data, err := cConn.ReadPacket()
require.NoError(t, err)
require.NotEmpty(t, data)
require.EqualValues(t, data[0], ErrPacket) // we should see the error here
}
type singleRun struct {
hasRun bool
t *testing.T
err error
}
func (h *singleRun) NewConnection(*Conn) {
panic("implement me")
}
func (h *singleRun) ConnectionClosed(*Conn) {
panic("implement me")
}
func (h *singleRun) ComQuery(*Conn, string, func(*sqltypes.Result) error) error {
if h.hasRun {
debug.PrintStack()
h.t.Fatal("don't do this!")
}
h.hasRun = true
return h.err
}
func (h *singleRun) ComPrepare(*Conn, string, map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
panic("implement me")
}
func (h *singleRun) ComStmtExecute(*Conn, *PrepareData, func(*sqltypes.Result) error) error {
panic("implement me")
}
func (h *singleRun) WarningCount(*Conn) uint16 {
return 0
}
func (h *singleRun) ComResetConnection(*Conn) {
panic("implement me")
}
var _ Handler = (*singleRun)(nil)

Просмотреть файл

@ -628,9 +628,9 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *
}
for i := 0; i < count; i++ {
err := sConn.handleNextCommand(&handler)
if err != nil {
t.Fatalf("error handling command: %v", err)
kontinue := sConn.handleNextCommand(&handler)
if !kontinue {
t.Fatalf("error handling command: %d", i)
}
}

Просмотреть файл

@ -470,8 +470,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}
for {
err := c.handleNextCommand(l.handler)
if err != nil {
kontinue := c.handleNextCommand(l.handler)
if !kontinue {
return
}
}