vitess-gh/go/mysql/conn.go

1587 строки
47 KiB
Go

/*
Copyright 2019 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mysql
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/bucketpool"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/sync2"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
)
const (
// connBufferSize is how much we buffer for reading and
// writing. It is also how much we allocate for ephemeral buffers.
connBufferSize = 16 * 1024
// packetHeaderSize is the 4 bytes of header per MySQL packet
// sent over
packetHeaderSize = 4
)
// Constants for how ephemeral buffers were used for reading / writing.
const (
// ephemeralUnused means the ephemeral buffer is not in use at this
// moment. This is the default value, and is checked so we don't
// read or write a packet while one is already used.
ephemeralUnused = iota
// ephemeralWrite means we currently in process of writing from currentEphemeralBuffer
ephemeralWrite
// ephemeralRead means we currently in process of reading into currentEphemeralBuffer
ephemeralRead
)
// A Getter has a Get()
type Getter interface {
Get() *querypb.VTGateCallerID
}
// Conn is a connection between a client and a server, using the MySQL
// binary protocol. It is built on top of an existing net.Conn, that
// has already been established.
//
// Use Connect on the client side to create a connection.
// Use NewListener to create a server side and listen for connections.
type Conn struct {
// fields contains the fields definitions for an on-going
// streaming query. It is set by ExecuteStreamFetch, and
// cleared by the last FetchNext(). It is nil if no streaming
// query is in progress. If the streaming query returned no
// fields, this is set to an empty array (but not nil).
fields []*querypb.Field
// salt is sent by the server during initial handshake to be used for authentication
salt []byte
// authPluginName is the name of server's authentication plugin.
// It is set during the initial handshake.
authPluginName AuthMethodDescription
// schemaName is the default database name to use. It is set
// during handshake, and by ComInitDb packets. Both client and
// servers maintain it. This member is private because it's
// non-authoritative: the client can change the schema name
// through the 'USE' statement, which will bypass this variable.
schemaName string
// ClientData is a place where an application can store any
// connection-related data. Mostly used on the server side, to
// avoid maps indexed by ConnectionID for instance.
ClientData any
// conn is the underlying network connection.
// Calling Close() on the Conn will close this connection.
// If there are any ongoing reads or writes, they may get interrupted.
conn net.Conn
// flavor contains the auto-detected flavor for this client
// connection. It is unused for server-side connections.
flavor flavor
// ServerVersion is set during Connect with the server
// version. It is not changed afterwards. It is unused for
// server-side connections.
ServerVersion string
// User is the name used by the client to connect.
// It is set during the initial handshake.
User string // For server-side connections, listener points to the server object.
// UserData is custom data returned by the AuthServer module.
// It is set during the initial handshake.
UserData Getter
bufferedReader *bufio.Reader
flushTimer *time.Timer
header [packetHeaderSize]byte
// Keep track of how and of the buffer we allocated for an
// ephemeral packet on the read and write sides.
// These fields are used by:
// - startEphemeralPacketWithHeader / writeEphemeralPacket methods for writes.
// - readEphemeralPacket / recycleReadPacket methods for reads.
currentEphemeralPolicy int
// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
// It can be allocated from bufPool or heap and should be recycled in the same manner.
currentEphemeralBuffer *[]byte
listener *Listener
// Buffered writing has a timer which flushes on inactivity.
bufferedWriter *bufio.Writer
// PrepareData is the map to use a prepared statement.
PrepareData map[uint32]*PrepareData
// protects the bufferedWriter and bufferedReader
bufMu sync.Mutex
// Capabilities is the current set of features this connection
// is using. It is the features that are both supported by
// the client and the server, and currently in use.
// It is set during the initial handshake.
//
// It is only used for CapabilityClientDeprecateEOF
// and CapabilityClientFoundRows.
Capabilities uint32
// closed is set to true when Close() is called on the connection.
closed sync2.AtomicBool
// ConnectionID is set:
// - at Connect() time for clients, with the value returned by
// the server.
// - at accept time for the server.
ConnectionID uint32
// StatementID is the prepared statement ID.
StatementID uint32
// StatusFlags are the status flags we will base our returned flags on.
// This is a bit field, with values documented in constants.go.
// An interesting value here would be ServerStatusAutocommit.
// It is only used by the server. These flags can be changed
// by Handler methods.
StatusFlags uint16
// CharacterSet is the charset for this connection, as negotiated
// in our handshake with the server. Note that although the MySQL protocol lists this
// as a "character set", the returned byte value is actually a Collation ID,
// and hence it's casted as such here.
// If the user has specified a custom Collation in the ConnParams for this
// connection, once the CharacterSet has been negotiated, we will override
// it via SQL and update this field accordingly.
CharacterSet collations.ID
// Packet encoding variables.
sequence uint8
// ExpectSemiSyncIndicator is applicable when the connection is used for replication (ComBinlogDump).
// When 'true', events are assumed to be padded with 2-byte semi-sync information
// See https://dev.mysql.com/doc/internals/en/semi-sync-binlog-event.html
ExpectSemiSyncIndicator bool
// enableQueryInfo controls whether we parse the INFO field in QUERY_OK packets
// See: ConnParams.EnableQueryInfo
enableQueryInfo bool
}
// splitStatementFunciton is the function that is used to split the statement in case of a multi-statement query.
var splitStatementFunction = sqlparser.SplitStatementToPieces
// PrepareData is a buffer used for store prepare statement meta data
type PrepareData struct {
ParamsType []int32
ColumnNames []string
PrepareStmt string
BindVars map[string]*querypb.BindVariable
StatementID uint32
ParamsCount uint16
}
// execResult is an enum signifying the result of executing a query
type execResult byte
const (
execSuccess execResult = iota
execErr
connErr
)
// bufPool is used to allocate and free buffers in an efficient way.
var bufPool = bucketpool.New(connBufferSize, MaxPacketSize)
// writersPool is used for pooling bufio.Writer objects.
var writersPool = sync.Pool{New: func() any { return bufio.NewWriterSize(nil, connBufferSize) }}
var readersPool = sync.Pool{New: func() any { return bufio.NewReaderSize(nil, connBufferSize) }}
// newConn is an internal method to create a Conn. Used by client and server
// side for common creation code.
func newConn(conn net.Conn) *Conn {
return &Conn{
conn: conn,
closed: sync2.NewAtomicBool(false),
bufferedReader: bufio.NewReaderSize(conn, connBufferSize),
}
}
// newServerConn should be used to create server connections.
//
// It stashes a reference to the listener to be able to determine if
// the server is shutting down, and has the ability to control buffer
// size for reads.
func newServerConn(conn net.Conn, listener *Listener) *Conn {
c := &Conn{
conn: conn,
listener: listener,
closed: sync2.NewAtomicBool(false),
PrepareData: make(map[uint32]*PrepareData),
}
if listener.connReadBufferSize > 0 {
var buf *bufio.Reader
if listener.connBufferPooling {
buf = readersPool.Get().(*bufio.Reader)
buf.Reset(conn)
} else {
buf = bufio.NewReaderSize(conn, listener.connReadBufferSize)
}
c.bufferedReader = buf
}
return c
}
// startWriterBuffering starts using buffered writes. This should
// be terminated by a call to endWriteBuffering.
func (c *Conn) startWriterBuffering() {
c.bufMu.Lock()
defer c.bufMu.Unlock()
c.bufferedWriter = writersPool.Get().(*bufio.Writer)
c.bufferedWriter.Reset(c.conn)
}
// endWriterBuffering must be called to terminate startWriteBuffering.
func (c *Conn) endWriterBuffering() error {
c.bufMu.Lock()
defer c.bufMu.Unlock()
if c.bufferedWriter == nil {
return nil
}
defer func() {
c.bufferedWriter.Reset(nil)
writersPool.Put(c.bufferedWriter)
c.bufferedWriter = nil
}()
c.stopFlushTimer()
return c.bufferedWriter.Flush()
}
func (c *Conn) returnReader() {
if c.bufferedReader == nil {
return
}
c.bufferedReader.Reset(nil)
readersPool.Put(c.bufferedReader)
}
// getWriter returns the current writer. It may be either
// the original connection or a wrapper. The returned unget
// function must be invoked after the writing is finished.
// In buffered mode, the unget starts a timer to flush any
// buffered data.
func (c *Conn) getWriter() (w io.Writer, unget func()) {
c.bufMu.Lock()
if c.bufferedWriter != nil {
return c.bufferedWriter, func() {
c.startFlushTimer()
c.bufMu.Unlock()
}
}
c.bufMu.Unlock()
return c.conn, func() {}
}
// startFlushTimer must be called while holding lock on bufMu.
func (c *Conn) startFlushTimer() {
c.stopFlushTimer()
c.flushTimer = time.AfterFunc(mysqlServerFlushDelay, func() {
c.bufMu.Lock()
defer c.bufMu.Unlock()
if c.bufferedWriter == nil {
return
}
c.stopFlushTimer()
c.bufferedWriter.Flush()
})
}
// stopFlushTimer must be called while holding lock on bufMu.
func (c *Conn) stopFlushTimer() {
if c.flushTimer != nil {
c.flushTimer.Stop()
c.flushTimer = nil
}
}
// getReader returns reader for connection. It can be *bufio.Reader or net.Conn
// depending on which buffer size was passed to newServerConn.
func (c *Conn) getReader() io.Reader {
if c.bufferedReader != nil {
return c.bufferedReader
}
return c.conn
}
func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
// Note io.ReadFull will return two different types of errors:
// 1. if the socket is already closed, and the go runtime knows it,
// then ReadFull will return an error (different than EOF),
// something like 'read: connection reset by peer'.
// 2. if the socket is not closed while we start the read,
// but gets closed after the read is started, we'll get io.EOF.
if _, err := io.ReadFull(r, c.header[:]); err != nil {
// The special casing of propagating io.EOF up
// is used by the server side only, to suppress an error
// message if a client just disconnects.
if err == io.EOF {
return 0, err
}
if strings.HasSuffix(err.Error(), "read: connection reset by peer") {
return 0, io.EOF
}
return 0, vterrors.Wrapf(err, "io.ReadFull(header size) failed")
}
sequence := uint8(c.header[3])
if sequence != c.sequence {
return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence, expected %v got %v", c.sequence, sequence)
}
c.sequence++
return int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16), nil
}
// readEphemeralPacket attempts to read a packet into buffer from sync.Pool. Do
// not use this method if the contents of the packet needs to be kept
// after the next readEphemeralPacket.
//
// Note if the connection is closed already, an error will be
// returned, and it may not be io.EOF. If the connection closes while
// we are stuck waiting for data, an error will also be returned, and
// it most likely will be io.EOF.
func (c *Conn) readEphemeralPacket() ([]byte, error) {
if c.currentEphemeralPolicy != ephemeralUnused {
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
}
r := c.getReader()
length, err := c.readHeaderFrom(r)
if err != nil {
return nil, err
}
c.currentEphemeralPolicy = ephemeralRead
if length == 0 {
// This can be caused by the packet after a packet of
// exactly size MaxPacketSize.
return nil, nil
}
// Use the bufPool.
if length < MaxPacketSize {
c.currentEphemeralBuffer = bufPool.Get(length)
if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil {
return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
}
return *c.currentEphemeralBuffer, nil
}
// Much slower path, revert to allocating everything from scratch.
// We're going to concatenate a lot of data anyway, can't really
// optimize this code path easily.
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
}
for {
next, err := c.readOnePacket()
if err != nil {
return nil, err
}
if len(next) == 0 {
// Again, the packet after a packet of exactly size MaxPacketSize.
break
}
data = append(data, next...)
if len(next) < MaxPacketSize {
break
}
}
return data, nil
}
// readEphemeralPacketDirect attempts to read a packet from the socket directly.
// It needs to be used for the first handshake packet the server receives,
// so we do't buffer the SSL negotiation packet. As a shortcut, only
// packets smaller than MaxPacketSize can be read here.
// This function usually shouldn't be used - use readEphemeralPacket.
func (c *Conn) readEphemeralPacketDirect() ([]byte, error) {
if c.currentEphemeralPolicy != ephemeralUnused {
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
}
var r io.Reader = c.conn
length, err := c.readHeaderFrom(r)
if err != nil {
return nil, err
}
c.currentEphemeralPolicy = ephemeralRead
if length == 0 {
// This can be caused by the packet after a packet of
// exactly size MaxPacketSize.
return nil, nil
}
if length < MaxPacketSize {
c.currentEphemeralBuffer = bufPool.Get(length)
if _, err := io.ReadFull(r, *c.currentEphemeralBuffer); err != nil {
return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
}
return *c.currentEphemeralBuffer, nil
}
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect doesn't support more than one packet")
}
// recycleReadPacket recycles the read packet. It needs to be called
// after readEphemeralPacket was called.
func (c *Conn) recycleReadPacket() {
if c.currentEphemeralPolicy != ephemeralRead {
// Programming error.
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
}
if c.currentEphemeralBuffer != nil {
// We are using the pool, put the buffer back in.
bufPool.Put(c.currentEphemeralBuffer)
c.currentEphemeralBuffer = nil
}
c.currentEphemeralPolicy = ephemeralUnused
}
// readOnePacket reads a single packet into a newly allocated buffer.
func (c *Conn) readOnePacket() ([]byte, error) {
r := c.getReader()
length, err := c.readHeaderFrom(r)
if err != nil {
return nil, err
}
if length == 0 {
// This can be caused by the packet after a packet of
// exactly size MaxPacketSize.
return nil, nil
}
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return nil, vterrors.Wrapf(err, "io.ReadFull(packet body of length %v) failed", length)
}
return data, nil
}
// readPacket reads a packet from the underlying connection.
// It re-assembles packets that span more than one message.
// This method returns a generic error, not a SQLError.
func (c *Conn) readPacket() ([]byte, error) {
// Optimize for a single packet case.
data, err := c.readOnePacket()
if err != nil {
return nil, err
}
// This is a single packet.
if len(data) < MaxPacketSize {
return data, nil
}
// There is more than one packet, read them all.
for {
next, err := c.readOnePacket()
if err != nil {
return nil, err
}
if len(next) == 0 {
// Again, the packet after a packet of exactly size MaxPacketSize.
break
}
data = append(data, next...)
if len(next) < MaxPacketSize {
break
}
}
return data, nil
}
// ReadPacket reads a packet from the underlying connection.
// it is the public API version, that returns a SQLError.
// The memory for the packet is always allocated, and it is owned by the caller
// after this function returns.
func (c *Conn) ReadPacket() ([]byte, error) {
result, err := c.readPacket()
if err != nil {
return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
return result, err
}
// writePacket writes a packet, possibly cutting it into multiple
// chunks. Note this is not very efficient, as the client probably
// has to build the []byte and that makes a memory copy.
// Try to use startEphemeralPacketWithHeader/writeEphemeralPacket instead.
//
// This method returns a generic error, not a SQLError.
func (c *Conn) writePacket(data []byte) error {
index := 0
dataLength := len(data) - packetHeaderSize
w, unget := c.getWriter()
defer unget()
var header [packetHeaderSize]byte
for {
// toBeSent is capped to MaxPacketSize.
toBeSent := dataLength
if toBeSent > MaxPacketSize {
toBeSent = MaxPacketSize
}
// save the first 4 bytes of the payload, we will overwrite them with the
// header below
copy(header[0:packetHeaderSize], data[index:index+packetHeaderSize])
// Compute and write the header.
data[index] = byte(toBeSent)
data[index+1] = byte(toBeSent >> 8)
data[index+2] = byte(toBeSent >> 16)
data[index+3] = c.sequence
// Write the body.
if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil {
return vterrors.Wrapf(err, "Write(packet) failed")
} else if n != (toBeSent + packetHeaderSize) {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize))
}
// restore the first 4 bytes once the network send is done
copy(data[index:index+packetHeaderSize], header[0:packetHeaderSize])
// Update our state.
c.sequence++
dataLength -= toBeSent
if dataLength == 0 {
if toBeSent == MaxPacketSize {
// The packet we just sent had exactly
// MaxPacketSize size, we need to
// sent a zero-size packet too.
header[0] = 0
header[1] = 0
header[2] = 0
header[3] = c.sequence
if n, err := w.Write(header[:]); err != nil {
return vterrors.Wrapf(err, "Write(empty header) failed")
} else if n != packetHeaderSize {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(empty header) returned a short write: %v < 4", n)
}
c.sequence++
}
return nil
}
index += toBeSent
}
}
func (c *Conn) startEphemeralPacketWithHeader(length int) ([]byte, int) {
if c.currentEphemeralPolicy != ephemeralUnused {
panic("startEphemeralPacketWithHeader cannot be used while a packet is already started.")
}
c.currentEphemeralPolicy = ephemeralWrite
// get buffer from pool or it'll be allocated if length is too big
c.currentEphemeralBuffer = bufPool.Get(length + packetHeaderSize)
return *c.currentEphemeralBuffer, packetHeaderSize
}
// writeEphemeralPacket writes the packet that was allocated by
// startEphemeralPacketWithHeader.
func (c *Conn) writeEphemeralPacket() error {
defer c.recycleWritePacket()
switch c.currentEphemeralPolicy {
case ephemeralWrite:
if err := c.writePacket(*c.currentEphemeralBuffer); err != nil {
return vterrors.Wrapf(err, "conn %v", c.ID())
}
case ephemeralUnused, ephemeralRead:
// Programming error.
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.ID(), c.currentEphemeralPolicy))
}
return nil
}
// recycleWritePacket recycles the write packet. It needs to be called
// after writeEphemeralPacket was called.
func (c *Conn) recycleWritePacket() {
if c.currentEphemeralPolicy != ephemeralWrite {
// Programming error.
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
}
// Release our reference so the buffer can be gced
bufPool.Put(c.currentEphemeralBuffer)
c.currentEphemeralBuffer = nil
c.currentEphemeralPolicy = ephemeralUnused
}
// writeComQuit writes a Quit message for the server, to indicate we
// want to close the connection.
// Client -> Server.
// Returns SQLError(CRServerGone) if it can't.
func (c *Conn) writeComQuit() error {
// This is a new command, need to reset the sequence.
c.sequence = 0
data, pos := c.startEphemeralPacketWithHeader(1)
data[pos] = ComQuit
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
return nil
}
// RemoteAddr returns the underlying socket RemoteAddr().
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// ID returns the MySQL connection ID for this connection.
func (c *Conn) ID() int64 {
return int64(c.ConnectionID)
}
// Ident returns a useful identification string for error logging
func (c *Conn) String() string {
return fmt.Sprintf("client %v (%s)", c.ConnectionID, c.RemoteAddr().String())
}
// Close closes the connection. It can be called from a different go
// routine to interrupt the current connection.
func (c *Conn) Close() {
if c.closed.CompareAndSwap(false, true) {
c.conn.Close()
}
}
// IsClosed returns true if this connection was ever closed by the
// Close() method. Note if the other side closes the connection, but
// Close() wasn't called, this will return false.
func (c *Conn) IsClosed() bool {
return c.closed.Get()
}
//
// Packet writing methods, for generic packets.
//
// writeOKPacket writes an OK packet.
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeOKPacket(packetOk *PacketOK) error {
return c.writeOKPacketWithHeader(packetOk, OKPacket)
}
// writeOKPacketWithEOFHeader writes an OK packet with an EOF header.
// This is used at the end of a result set if
// CapabilityClientDeprecateEOF is set.
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeOKPacketWithEOFHeader(packetOk *PacketOK) error {
return c.writeOKPacketWithHeader(packetOk, EOFPacket)
}
// writeOKPacketWithEOFHeader writes an OK packet with an EOF header.
// This is used at the end of a result set if
// CapabilityClientDeprecateEOF is set.
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) error {
length := 1 + // OKPacket
lenEncIntSize(packetOk.affectedRows) +
lenEncIntSize(packetOk.lastInsertID)
// assuming CapabilityClientProtocol41
length += 4 // status_flags + warnings
var gtidData []byte
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
length += lenEncStringSize(packetOk.info) // info
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
gtidData = getLenEncString([]byte(packetOk.sessionStateData))
gtidData = append([]byte{0x00}, gtidData...)
gtidData = getLenEncString(gtidData)
gtidData = append([]byte{0x03}, gtidData...)
gtidData = append(getLenEncInt(uint64(len(gtidData))), gtidData...)
length += len(gtidData)
}
} else {
length += len(packetOk.info) // info
}
bytes, pos := c.startEphemeralPacketWithHeader(length)
data := &coder{data: bytes, pos: pos}
data.writeByte(headerType) //header - OK or EOF
data.writeLenEncInt(packetOk.affectedRows)
data.writeLenEncInt(packetOk.lastInsertID)
data.writeUint16(packetOk.statusFlags)
data.writeUint16(packetOk.warnings)
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
data.writeLenEncString(packetOk.info)
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
data.writeEOFString(string(gtidData))
}
} else {
data.writeEOFString(packetOk.info)
}
return c.writeEphemeralPacket()
}
func getLenEncString(value []byte) []byte {
data := getLenEncInt(uint64(len(value)))
return append(data, value...)
}
func getLenEncInt(i uint64) []byte {
var data []byte
switch {
case i < 251:
data = append(data, byte(i))
case i < 1<<16:
data = append(data, 0xfc)
data = append(data, byte(i))
data = append(data, byte(i>>8))
case i < 1<<24:
data = append(data, 0xfd)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
default:
data = append(data, 0xfe)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
data = append(data, byte(i>>24))
data = append(data, byte(i>>32))
data = append(data, byte(i>>40))
data = append(data, byte(i>>48))
data = append(data, byte(i>>56))
}
return data
}
func (c *Conn) writeErrorAndLog(errorCode uint16, sqlState string, format string, args ...any) bool {
if err := c.writeErrorPacket(errorCode, sqlState, format, args...); err != nil {
log.Errorf("Error writing error to %s: %v", c, err)
return false
}
return true
}
func (c *Conn) writeErrorPacketFromErrorAndLog(err error) bool {
werr := c.writeErrorPacketFromError(err)
if werr != nil {
log.Errorf("Error writing error to %s: %v", c, werr)
return false
}
return true
}
// writeErrorPacket writes an error packet.
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeErrorPacket(errorCode uint16, sqlState string, format string, args ...any) error {
errorMessage := fmt.Sprintf(format, args...)
length := 1 + 2 + 1 + 5 + len(errorMessage)
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, ErrPacket)
pos = writeUint16(data, pos, errorCode)
pos = writeByte(data, pos, '#')
if sqlState == "" {
sqlState = SSUnknownSQLState
}
if len(sqlState) != 5 {
panic("sqlState has to be 5 characters long")
}
pos = writeEOFString(data, pos, sqlState)
_ = writeEOFString(data, pos, errorMessage)
return c.writeEphemeralPacket()
}
// writeErrorPacketFromError writes an error packet, from a regular error.
// See writeErrorPacket for other info.
func (c *Conn) writeErrorPacketFromError(err error) error {
if se, ok := err.(*SQLError); ok {
return c.writeErrorPacket(uint16(se.Num), se.State, "%v", se.Message)
}
return c.writeErrorPacket(ERUnknownError, SSUnknownSQLState, "unknown error: %v", err)
}
// writeEOFPacket writes an EOF packet, through the buffer, and
// doesn't flush (as it is used as part of a query result).
func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error {
length := 5
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, EOFPacket)
pos = writeUint16(data, pos, warnings)
_ = writeUint16(data, pos, flags)
return c.writeEphemeralPacket()
}
// handleNextCommand is called in the server loop to process
// incoming packets.
func (c *Conn) handleNextCommand(handler Handler) bool {
c.sequence = 0
data, err := c.readEphemeralPacket()
if err != nil {
// Don't log EOF errors. They cause too much spam.
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
log.Errorf("Error reading packet from %s: %v", c, err)
}
return false
}
if len(data) == 0 {
return false
}
switch data[0] {
case ComQuit:
c.recycleReadPacket()
return false
case ComInitDB:
db := c.parseComInitDB(data)
c.recycleReadPacket()
res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false)
return res != connErr
case ComQuery:
return c.handleComQuery(handler, data)
case ComPing:
return c.handleComPing()
case ComSetOption:
return c.handleComSetOption(data)
case ComPrepare:
return c.handleComPrepare(handler, data)
case ComStmtExecute:
return c.handleComStmtExecute(handler, data)
case ComStmtSendLongData:
return c.handleComStmtSendLongData(data)
case ComStmtClose:
stmtID, ok := c.parseComStmtClose(data)
c.recycleReadPacket()
if ok {
delete(c.PrepareData, stmtID)
}
case ComStmtReset:
return c.handleComStmtReset(data)
case ComResetConnection:
c.handleComResetConnection(handler)
return true
case ComFieldList:
c.recycleReadPacket()
if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "command handling not implemented yet: %v", data[0]) {
return false
}
case ComBinlogDumpGTID:
return c.handleComBinlogDumpGTID(handler, data)
default:
log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data)
c.recycleReadPacket()
if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "command handling not implemented yet: %v", data[0]) {
return false
}
}
return true
}
func (c *Conn) handleComBinlogDumpGTID(handler Handler, data []byte) (kontinue bool) {
defer c.recycleReadPacket()
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
kontinue = false
}
}()
_, _, position, err := c.parseComBinlogDumpGTID(data)
if err != nil {
log.Errorf("conn %v: parseComBinlogDumpGTID failed: %v", c.ID(), err)
kontinue = false
}
handler.ComBinlogDumpGTID(c, position.GTIDSet)
return true
}
func (c *Conn) handleComResetConnection(handler Handler) {
// Clean up and reset the connection
c.recycleReadPacket()
handler.ComResetConnection(c)
// Reset prepared statements
c.PrepareData = make(map[uint32]*PrepareData)
err := c.writeOKPacket(&PacketOK{})
if err != nil {
c.writeErrorPacketFromError(err)
}
}
func (c *Conn) handleComStmtReset(data []byte) bool {
stmtID, ok := c.parseComStmtReset(data)
c.recycleReadPacket()
if !ok {
log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) {
return false
}
}
prepare, ok := c.PrepareData[stmtID]
if !ok {
log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data)
if !c.writeErrorAndLog(CRCommandsOutOfSync, SSNetError, "commands were executed in an improper order: %v", data) {
return false
}
}
if prepare.BindVars != nil {
for k := range prepare.BindVars {
prepare.BindVars[k] = nil
}
}
if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil {
log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
return false
}
return true
}
func (c *Conn) handleComStmtSendLongData(data []byte) bool {
stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data)
c.recycleReadPacket()
if !ok {
err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data)
return c.writeErrorPacketFromErrorAndLog(err)
}
prepare, ok := c.PrepareData[stmtID]
if !ok {
err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID)
return c.writeErrorPacketFromErrorAndLog(err)
}
if prepare.BindVars == nil ||
prepare.ParamsCount == uint16(0) ||
paramID >= prepare.ParamsCount {
err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt)
return c.writeErrorPacketFromErrorAndLog(err)
}
key := fmt.Sprintf("v%d", paramID+1)
if val, ok := prepare.BindVars[key]; ok {
val.Value = append(val.Value, chunk...)
} else {
prepare.BindVars[key] = sqltypes.BytesBindVariable(chunk)
}
return true
}
func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool) {
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
kontinue = false
}
}()
queryStart := time.Now()
stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data)
c.recycleReadPacket()
if stmtID != uint32(0) {
defer func() {
// Allocate a new bindvar map every time since VTGate.Execute() mutates it.
prepare := c.PrepareData[stmtID]
prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount)
}()
}
if err != nil {
return c.writeErrorPacketFromErrorAndLog(err)
}
fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
prepare := c.PrepareData[stmtID]
err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error {
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}
if !fieldSent {
fieldSent = true
if len(qr.Fields) == 0 {
sendFinished = true
// We should not send any more packets after this.
ok := PacketOK{
affectedRows: qr.RowsAffected,
lastInsertID: qr.InsertID,
statusFlags: c.StatusFlags,
warnings: 0,
info: "",
sessionStateData: qr.SessionStateChanges,
}
return c.writeOKPacket(&ok)
}
if err := c.writeFields(qr); err != nil {
return err
}
}
return c.writeBinaryRows(qr)
})
// If no field was sent, we expect an error.
if !fieldSent {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
}
if !c.writeErrorPacketFromErrorAndLog(err) {
return false
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return false
}
// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return false
}
}
}
timings.Record(queryTimingKey, queryStart)
return true
}
func (c *Conn) handleComPrepare(handler Handler, data []byte) (kontinue bool) {
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
kontinue = false
}
}()
query := c.parseComPrepare(data)
c.recycleReadPacket()
var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
var err error
queries, err = splitStatementFunction(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
return c.writeErrorPacketFromErrorAndLog(err)
}
if len(queries) != 1 {
log.Errorf("Conn %v: can not prepare multiple statements", c)
return c.writeErrorPacketFromErrorAndLog(err)
}
} else {
queries = []string{query}
}
// Popoulate PrepareData
c.StatementID++
prepare := &PrepareData{
StatementID: c.StatementID,
PrepareStmt: queries[0],
}
statement, err := sqlparser.ParseStrictDDL(query)
if err != nil {
log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err)
if !c.writeErrorPacketFromErrorAndLog(err) {
return false
}
}
paramsCount := uint16(0)
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
case sqlparser.Argument:
if strings.HasPrefix(string(node), "v") {
paramsCount++
}
}
return true, nil
}, statement)
if paramsCount > 0 {
prepare.ParamsCount = paramsCount
prepare.ParamsType = make([]int32, paramsCount)
prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount)
}
bindVars := make(map[string]*querypb.BindVariable, paramsCount)
for i := uint16(0); i < paramsCount; i++ {
parameterID := fmt.Sprintf("v%d", i+1)
bindVars[parameterID] = &querypb.BindVariable{}
}
c.PrepareData[c.StatementID] = prepare
fld, err := handler.ComPrepare(c, queries[0], bindVars)
if err != nil {
return c.writeErrorPacketFromErrorAndLog(err)
}
if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil {
log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err)
return false
}
return true
}
func (c *Conn) handleComSetOption(data []byte) bool {
operation, ok := c.parseComSetOption(data)
c.recycleReadPacket()
if ok {
switch operation {
case 0:
c.Capabilities |= CapabilityClientMultiStatements
case 1:
c.Capabilities &^= CapabilityClientMultiStatements
default:
log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data)
if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) {
return false
}
}
if err := c.writeEndResult(false, 0, 0, 0); err != nil {
log.Errorf("Error writeEndResult error %v ", err)
return false
}
} else {
log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data)
if !c.writeErrorAndLog(ERUnknownComError, SSNetError, "error handling packet: %v", data) {
return false
}
}
return true
}
func (c *Conn) handleComPing() bool {
c.recycleReadPacket()
// Return error if listener was shut down and OK otherwise
if c.listener.isShutdown() {
if !c.writeErrorAndLog(ERServerShutdown, SSNetError, "Server shutdown in progress") {
return false
}
} else {
if err := c.writeOKPacket(&PacketOK{statusFlags: c.StatusFlags}); err != nil {
log.Errorf("Error writing ComPing result to %s: %v", c, err)
return false
}
}
return true
}
var errEmptyStatement = NewSQLError(EREmptyQuery, SSClientError, "Query was empty")
func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) {
c.startWriterBuffering()
defer func() {
if err := c.endWriterBuffering(); err != nil {
log.Errorf("conn %v: flush() failed: %v", c.ID(), err)
kontinue = false
}
}()
queryStart := time.Now()
query := c.parseComQuery(data)
c.recycleReadPacket()
var queries []string
var err error
if c.Capabilities&CapabilityClientMultiStatements != 0 {
queries, err = splitStatementFunction(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
return c.writeErrorPacketFromErrorAndLog(err)
}
} else {
queries = []string{query}
}
if len(queries) == 0 {
return c.writeErrorPacketFromErrorAndLog(errEmptyStatement)
}
for index, sql := range queries {
more := false
if index != len(queries)-1 {
more = true
}
res := c.execQuery(sql, handler, more)
if res != execSuccess {
return res != connErr
}
}
timings.Record(queryTimingKey, queryStart)
return true
}
func (c *Conn) execQuery(query string, handler Handler, more bool) execResult {
callbackCalled := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error {
flag := c.StatusFlags
if more {
flag |= ServerMoreResultsExists
}
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}
if !callbackCalled {
callbackCalled = true
if len(qr.Fields) == 0 {
sendFinished = true
// A successful callback with no fields means that this was a
// DML or other write-only operation.
//
// We should not send any more packets after this, but make sure
// to extract the affected rows and last insert id from the result
// struct here since clients expect it.
ok := PacketOK{
affectedRows: qr.RowsAffected,
lastInsertID: qr.InsertID,
statusFlags: flag,
warnings: handler.WarningCount(c),
info: "",
sessionStateData: qr.SessionStateChanges,
}
return c.writeOKPacket(&ok)
}
if err := c.writeFields(qr); err != nil {
return err
}
}
return c.writeRows(qr)
})
// If callback was not called, we expect an error.
if !callbackCalled {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
}
if !c.writeErrorPacketFromErrorAndLog(err) {
return connErr
}
return execErr
}
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return connErr
}
// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return connErr
}
}
return execSuccess
}
//
// Packet parsing methods, for generic packets.
//
// isEOFPacket determines whether a data packet is an EOF. In case the client capabilities
// do not have DEPRECATE_EOF set, DO NOT blindly compare the first byte of a packet to EOFPacket
// as you might do for other packet types, as 0xfe is overloaded as a first byte.
// In case that DEPRECATE_EOF is set, we have really an OK packet which is always maximum a single
// packet and not multiple, but otherwise 0xfe definitely indicates it is an EOF.
//
// Per https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html, a packet starting with 0xfe
// but having length >= 9 (on top of 4 byte header) without DEPRECATE_EOF set is not a true EOF but
// a LengthEncodedInteger (typically preceding a LengthEncodedString). Thus, all EOF checks without
// DEPRECATE_EOF must validate the payload size before exiting.
//
// More docs here:
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_response_packets.html
func (c *Conn) isEOFPacket(data []byte) bool {
if data[0] != EOFPacket {
return false
}
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
return len(data) < 9
}
return len(data) < MaxPacketSize
}
// parseEOFPacket returns the warning count and a boolean to indicate if there
// are more results to receive.
//
// Note: This is only valid on actual EOF packets and not on OK packets with the EOF
// type code set, i.e. should not be used if ClientDeprecateEOF is set.
func parseEOFPacket(data []byte) (warnings uint16, statusFlags uint16, err error) {
// The warning count is in position 2 & 3
warnings, _, _ = readUint16(data, 1)
// The status flag is in position 4 & 5
statusFlags, _, ok := readUint16(data, 3)
if !ok {
return 0, 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid EOF packet statusFlags: %v", data)
}
return warnings, statusFlags, nil
}
// PacketOK contains the ok packet details
type PacketOK struct {
affectedRows uint64
lastInsertID uint64
statusFlags uint16
warnings uint16
info string
// at the moment, we only store GTID information in this field
sessionStateData string
}
func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
data := &coder{
data: in,
pos: 1, // We already read the type.
}
packetOK := &PacketOK{}
fail := func(format string, args ...any) (*PacketOK, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...)
}
// Affected rows.
affectedRows, ok := data.readLenEncInt()
if !ok {
return fail("invalid OK packet affectedRows: %v", data)
}
packetOK.affectedRows = affectedRows
// Last Insert ID.
lastInsertID, ok := data.readLenEncInt()
if !ok {
return fail("invalid OK packet lastInsertID: %v", data)
}
packetOK.lastInsertID = lastInsertID
// Status flags.
statusFlags, ok := data.readUint16()
if !ok {
return fail("invalid OK packet statusFlags: %v", data)
}
packetOK.statusFlags = statusFlags
// assuming CapabilityClientProtocol41
// Warnings.
warnings, ok := data.readUint16()
if !ok {
return fail("invalid OK packet warnings: %v", data)
}
packetOK.warnings = warnings
// info
info, _ := data.readLenEncInfo()
if c.enableQueryInfo {
packetOK.info = info
}
if c.Capabilities&uint32(CapabilityClientSessionTrack) == CapabilityClientSessionTrack {
// session tracking
if statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
length, ok := data.readLenEncInt()
if !ok {
return fail("invalid OK packet session state change length: %v", data)
}
// In case we have a zero length string, there's no additional information so
// we can return the packet.
if length == 0 {
return packetOK, nil
}
// Alright, now we need to read each sub packet from the session state change.
for {
sscType, ok := data.readByte()
if !ok {
// We're done, there's no more session state parts in the packet.
break
}
sessionLen, ok := data.readLenEncInt()
if !ok {
return fail("invalid OK packet session state change length for type %v", sscType)
}
if sscType != SessionTrackGtids {
// Still need to increase the pointer here to indicate we're consuming
// but otherwise ignoring the rest of this packet
data.pos = data.pos + int(sessionLen)
continue
}
// read (and ignore for now) the GTIDS encoding specification code: 1 byte
_, ok = data.readByte()
if !ok {
return fail("invalid OK packet gtids type: %v", data)
}
gtids, ok := data.readLenEncString()
if !ok {
return fail("invalid OK packet gtids: %v", data)
}
packetOK.sessionStateData = gtids
}
}
}
return packetOK, nil
}
// isErrorPacket determines whether or not the packet is an error packet. Mostly here for
// consistency with isEOFPacket
func isErrorPacket(data []byte) bool {
return data[0] == ErrPacket
}
// ParseErrorPacket parses the error packet and returns a SQLError.
func ParseErrorPacket(data []byte) error {
// We already read the type.
pos := 1
// Error code is 2 bytes.
code, pos, ok := readUint16(data, pos)
if !ok {
return NewSQLError(CRUnknownError, SSUnknownSQLState, "invalid error packet code: %v", data)
}
// '#' marker of the SQL state is 1 byte. Ignored.
pos++
// SQL state is 5 bytes
sqlState, pos, ok := readBytes(data, pos, 5)
if !ok {
return NewSQLError(CRUnknownError, SSUnknownSQLState, "invalid error packet sqlState: %v", data)
}
// Human readable error message is the rest.
msg := string(data[pos:])
return NewSQLError(int(code), string(sqlState), "%v", msg)
}
// GetTLSClientCerts gets TLS certificates.
func (c *Conn) GetTLSClientCerts() []*x509.Certificate {
if tlsConn, ok := c.conn.(*tls.Conn); ok {
return tlsConn.ConnectionState().PeerCertificates
}
return nil
}
// TLSEnabled returns true if this connection is using TLS.
func (c *Conn) TLSEnabled() bool {
return c.Capabilities&CapabilityClientSSL > 0
}
// IsUnixSocket returns true if this connection is over a Unix socket.
func (c *Conn) IsUnixSocket() bool {
_, ok := c.listener.listener.(*net.UnixListener)
return ok
}
// GetRawConn returns the raw net.Conn for nefarious purposes.
func (c *Conn) GetRawConn() net.Conn {
return c.conn
}