vitess-gh/go/mysql/query.go

1520 строки
42 KiB
Go

/*
Copyright 2019 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mysql
import (
"fmt"
"math"
"strconv"
"strings"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
querypb "vitess.io/vitess/go/vt/proto/query"
)
// This file contains the methods related to queries.
//
// Client side methods.
//
// WriteComQuery writes a query for the server to execute.
// Client -> Server.
// Returns SQLError(CRServerGone) if it can't.
func (c *Conn) WriteComQuery(query string) error {
// This is a new command, need to reset the sequence.
c.sequence = 0
data, pos := c.startEphemeralPacketWithHeader(len(query) + 1)
data[pos] = ComQuery
pos++
copy(data[pos:], query)
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
return nil
}
// writeComInitDB changes the default database to use.
// Client -> Server.
// Returns SQLError(CRServerGone) if it can't.
func (c *Conn) writeComInitDB(db string) error {
data, pos := c.startEphemeralPacketWithHeader(len(db) + 1)
data[pos] = ComInitDB
pos++
copy(data[pos:], db)
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
return nil
}
// writeComSetOption changes the connection's capability of executing multi statements.
// Returns SQLError(CRServerGone) if it can't.
func (c *Conn) writeComSetOption(operation uint16) error {
data, pos := c.startEphemeralPacketWithHeader(16 + 1)
data[pos] = ComSetOption
pos++
writeUint16(data, pos, operation)
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
return nil
}
// readColumnDefinition reads the next Column Definition packet.
// Returns a SQLError.
func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error {
colDef, err := c.readEphemeralPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
// Catalog is ignored, always set to "def"
pos, ok := skipLenEncString(colDef, 0)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v catalog failed", index)
}
// schema, table, orgTable, name and OrgName are strings.
field.Database, pos, ok = readLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v schema failed", index)
}
field.Table, pos, ok = readLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v table failed", index)
}
field.OrgTable, pos, ok = readLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v org_table failed", index)
}
field.Name, pos, ok = readLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v name failed", index)
}
field.OrgName, pos, ok = readLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v org_name failed", index)
}
// Skip length of fixed-length fields.
pos++
// characterSet is a uint16.
characterSet, pos, ok := readUint16(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v characterSet failed", index)
}
field.Charset = uint32(characterSet)
// columnLength is a uint32.
field.ColumnLength, pos, ok = readUint32(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v columnLength failed", index)
}
// type is one byte.
t, pos, ok := readByte(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v type failed", index)
}
// flags is 2 bytes.
flags, pos, ok := readUint16(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v flags failed", index)
}
// Convert MySQL type to Vitess type.
field.Type, err = sqltypes.MySQLToType(int64(t), int64(flags))
if err != nil {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err)
}
// Decimals is a byte.
decimals, _, ok := readByte(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v decimals failed", index)
}
field.Decimals = uint32(decimals)
// If we didn't get column length or character set,
// we assume the original row on the other side was encoded from
// a Field without that data, so we don't return the flags.
if field.ColumnLength != 0 || field.Charset != 0 {
field.Flags = uint32(flags)
// FIXME(alainjobart): This is something the MySQL
// client library does: If the type is numerical, it
// adds a NUM_FLAG to the flags. We're doing it here
// only to be compatible with the C library. Once
// we're not using that library any more, we'll remove this.
// See doc.go.
if IsNum(t) {
field.Flags |= uint32(querypb.MySqlFlag_NUM_FLAG)
}
}
return nil
}
// readColumnDefinitionType is a faster version of
// readColumnDefinition that only fills in the Type.
// Returns a SQLError.
func (c *Conn) readColumnDefinitionType(field *querypb.Field, index int) error {
colDef, err := c.readEphemeralPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
// catalog, schema, table, orgTable, name and orgName are
// strings, all skipped.
pos, ok := skipLenEncString(colDef, 0)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v catalog failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v schema failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v table failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v org_table failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v name failed", index)
}
pos, ok = skipLenEncString(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v org_name failed", index)
}
// Skip length of fixed-length fields.
pos++
// characterSet is a uint16.
_, pos, ok = readUint16(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v characterSet failed", index)
}
// columnLength is a uint32.
_, pos, ok = readUint32(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v columnLength failed", index)
}
// type is one byte
t, pos, ok := readByte(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v type failed", index)
}
// flags is 2 bytes
flags, _, ok := readUint16(colDef, pos)
if !ok {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v flags failed", index)
}
// Convert MySQL type to Vitess type.
field.Type, err = sqltypes.MySQLToType(int64(t), int64(flags))
if err != nil {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err)
}
// skip decimals
return nil
}
// parseRow parses an individual row.
// Returns a SQLError.
func (c *Conn) parseRow(data []byte, fields []*querypb.Field, reader func([]byte, int) ([]byte, int, bool), result []sqltypes.Value) ([]sqltypes.Value, error) {
colNumber := len(fields)
if result == nil {
result = make([]sqltypes.Value, 0, colNumber)
}
pos := 0
for i := 0; i < colNumber; i++ {
if data[pos] == NullValue {
result = append(result, sqltypes.Value{})
pos++
continue
}
var s []byte
var ok bool
s, pos, ok = reader(data, pos)
if !ok {
return nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "decoding string failed")
}
result = append(result, sqltypes.MakeTrusted(fields[i].Type, s))
}
return result, nil
}
// ExecuteFetch executes a query and returns the result.
// Returns a SQLError. Depending on the transport used, the error
// returned might be different for the same condition:
//
// 1. if the server closes the connection when no command is in flight:
//
// 1.1 unix: WriteComQuery will fail with a 'broken pipe', and we'll
// return CRServerGone(2006).
//
// 1.2 tcp: WriteComQuery will most likely work, but readComQueryResponse
// will fail, and we'll return CRServerLost(2013).
//
// This is because closing a TCP socket on the server side sends
// a FIN to the client (telling the client the server is done
// writing), but on most platforms doesn't send a RST. So the
// client has no idea it can't write. So it succeeds writing data, which
// *then* triggers the server to send a RST back, received a bit
// later. By then, the client has already started waiting for
// the response, and will just return a CRServerLost(2013).
// So CRServerGone(2006) will almost never be seen with TCP.
//
// 2. if the server closes the connection when a command is in flight,
// readComQueryResponse will fail, and we'll return CRServerLost(2013).
func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result *sqltypes.Result, err error) {
result, _, err = c.ExecuteFetchMulti(query, maxrows, wantfields)
return result, err
}
// ExecuteFetchMulti is for fetching multiple results from a multi-statement result.
// It returns an additional 'more' flag. If it is set, you must fetch the additional
// results using ReadQueryResult.
func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (result *sqltypes.Result, more bool, err error) {
defer func() {
if err != nil {
if sqlerr, ok := err.(*SQLError); ok {
sqlerr.Query = query
}
}
}()
// Send the query as a COM_QUERY packet.
if err = c.WriteComQuery(query); err != nil {
return nil, false, err
}
res, more, _, err := c.ReadQueryResult(maxrows, wantfields)
if err != nil {
return nil, false, err
}
return res, more, err
}
// ExecuteFetchWithWarningCount is for fetching results and a warning count
// Note: In a future iteration this should be abolished and merged into the
// ExecuteFetch API.
func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfields bool) (result *sqltypes.Result, warnings uint16, err error) {
defer func() {
if err != nil {
if sqlerr, ok := err.(*SQLError); ok {
sqlerr.Query = query
}
}
}()
// Send the query as a COM_QUERY packet.
if err = c.WriteComQuery(query); err != nil {
return nil, 0, err
}
res, _, warnings, err := c.ReadQueryResult(maxrows, wantfields)
return res, warnings, err
}
// ReadQueryResult gets the result from the last written query.
func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, bool, uint16, error) {
// Get the result.
colNumber, packetOk, err := c.readComQueryResponse()
if err != nil {
return nil, false, 0, err
}
more := packetOk.statusFlags&ServerMoreResultsExists != 0
warnings := packetOk.warnings
if colNumber == 0 {
// OK packet, means no results. Just use the numbers.
return &sqltypes.Result{
RowsAffected: packetOk.affectedRows,
InsertID: packetOk.lastInsertID,
SessionStateChanges: packetOk.sessionStateData,
StatusFlags: packetOk.statusFlags,
Info: packetOk.info,
}, more, warnings, nil
}
fields := make([]querypb.Field, colNumber)
result := &sqltypes.Result{
Fields: make([]*querypb.Field, colNumber),
}
// Read column headers. One packet per column.
// Build the fields.
for i := 0; i < colNumber; i++ {
result.Fields[i] = &fields[i]
if wantfields {
if err := c.readColumnDefinition(result.Fields[i], i); err != nil {
return nil, false, 0, err
}
} else {
if err := c.readColumnDefinitionType(result.Fields[i], i); err != nil {
return nil, false, 0, err
}
}
}
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
// EOF is only present here if it's not deprecated.
data, err := c.readEphemeralPacket()
if err != nil {
return nil, false, 0, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
if c.isEOFPacket(data) {
// This is what we expect.
// Warnings and status flags are ignored.
c.recycleReadPacket()
// goto: read row loop
} else if isErrorPacket(data) {
defer c.recycleReadPacket()
return nil, false, 0, ParseErrorPacket(data)
} else {
defer c.recycleReadPacket()
return nil, false, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data)
}
}
// read each row until EOF or OK packet.
for {
data, err := c.readEphemeralPacket()
if err != nil {
return nil, false, 0, err
}
if c.isEOFPacket(data) {
defer c.recycleReadPacket()
// Strip the partial Fields before returning.
if !wantfields {
result.Fields = nil
}
// The deprecated EOF packets change means that this is either an
// EOF packet or an OK packet with the EOF type code.
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
var statusFlags uint16
warnings, statusFlags, err = parseEOFPacket(data)
if err != nil {
return nil, false, 0, err
}
more = (statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = statusFlags
} else {
packetOk, err := c.parseOKPacket(data)
if err != nil {
return nil, false, 0, err
}
warnings = packetOk.warnings
more = (packetOk.statusFlags & ServerMoreResultsExists) != 0
result.SessionStateChanges = packetOk.sessionStateData
result.StatusFlags = packetOk.statusFlags
result.Info = packetOk.info
}
return result, more, warnings, nil
} else if isErrorPacket(data) {
defer c.recycleReadPacket()
// Error packet.
return nil, false, 0, ParseErrorPacket(data)
}
// Check we're not over the limit before we add more.
if len(result.Rows) == maxrows {
c.recycleReadPacket()
if err := c.drainResults(); err != nil {
return nil, false, 0, err
}
return nil, false, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows)
}
// Regular row.
row, err := c.parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil)
if err != nil {
c.recycleReadPacket()
return nil, false, 0, err
}
result.Rows = append(result.Rows, row)
c.recycleReadPacket()
}
}
// drainResults will read all packets for a result set and ignore them.
func (c *Conn) drainResults() error {
for {
data, err := c.readEphemeralPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
if c.isEOFPacket(data) {
c.recycleReadPacket()
return nil
} else if isErrorPacket(data) {
defer c.recycleReadPacket()
return ParseErrorPacket(data)
}
c.recycleReadPacket()
}
}
func (c *Conn) readComQueryResponse() (int, *PacketOK, error) {
data, err := c.readEphemeralPacket()
if err != nil {
return 0, nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
if len(data) == 0 {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "invalid empty COM_QUERY response packet")
}
switch data[0] {
case OKPacket:
packetOk, err := c.parseOKPacket(data)
return 0, packetOk, err
case ErrPacket:
// Error
return 0, nil, ParseErrorPacket(data)
case 0xfb:
// Local infile
return 0, nil, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented")
}
n, pos, ok := readLenEncInt(data, 0)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "cannot get column number")
}
if pos != len(data) {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extra data in COM_QUERY response")
}
return int(n), &PacketOK{}, nil
}
//
// Server side methods.
//
func (c *Conn) parseComQuery(data []byte) string {
return string(data[1:])
}
func (c *Conn) parseComSetOption(data []byte) (uint16, bool) {
val, _, ok := readUint16(data, 1)
return val, ok
}
func (c *Conn) parseComPrepare(data []byte) string {
return string(data[1:])
}
func (c *Conn) parseComStmtExecute(prepareData map[uint32]*PrepareData, data []byte) (uint32, byte, error) {
pos := 0
payload := data[1:]
bitMap := make([]byte, 0)
// statement ID
stmtID, pos, ok := readUint32(payload, 0)
if !ok {
return 0, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading statement ID failed")
}
prepare, ok := prepareData[stmtID]
if !ok {
return 0, 0, NewSQLError(CRCommandsOutOfSync, SSUnknownSQLState, "statement ID is not found from record")
}
// cursor type flags
cursorType, pos, ok := readByte(payload, pos)
if !ok {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading cursor type flags failed")
}
// iteration count
iterCount, pos, ok := readUint32(payload, pos)
if !ok {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading iteration count failed")
}
if iterCount != uint32(1) {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "iteration count is not equal to 1")
}
if prepare.ParamsCount > 0 {
bitMap, pos, ok = readBytes(payload, pos, int((prepare.ParamsCount+7)/8))
if !ok {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading NULL-bitmap failed")
}
}
newParamsBoundFlag, pos, ok := readByte(payload, pos)
if ok && newParamsBoundFlag == 0x01 {
var mysqlType, flags byte
for i := uint16(0); i < prepare.ParamsCount; i++ {
mysqlType, pos, ok = readByte(payload, pos)
if !ok {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter type failed")
}
flags, pos, ok = readByte(payload, pos)
if !ok {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter flags failed")
}
// convert MySQL type to internal type.
valType, err := sqltypes.MySQLToType(int64(mysqlType), int64(flags))
if err != nil {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed: %v", mysqlType, flags, err)
}
prepare.ParamsType[i] = int32(valType)
}
}
for i := 0; i < len(prepare.ParamsType); i++ {
var val sqltypes.Value
parameterID := fmt.Sprintf("v%d", i+1)
if v, ok := prepare.BindVars[parameterID]; ok {
if v != nil {
continue
}
}
if (bitMap[i/8] & (1 << uint(i%8))) > 0 {
val, pos, ok = c.parseStmtArgs(nil, sqltypes.Null, pos)
} else {
val, pos, ok = c.parseStmtArgs(payload, querypb.Type(prepare.ParamsType[i]), pos)
}
if !ok {
return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "decoding parameter value failed: %v", prepare.ParamsType[i])
}
prepare.BindVars[parameterID] = sqltypes.ValueBindVariable(val)
}
return stmtID, cursorType, nil
}
func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.Value, int, bool) {
switch typ {
case sqltypes.Null:
return sqltypes.NULL, pos, true
case sqltypes.Int8:
val, pos, ok := readByte(data, pos)
return sqltypes.NewInt64(int64(int8(val))), pos, ok
case sqltypes.Uint8:
val, pos, ok := readByte(data, pos)
return sqltypes.NewUint64(uint64(val)), pos, ok
case sqltypes.Uint16:
val, pos, ok := readUint16(data, pos)
return sqltypes.NewUint64(uint64(val)), pos, ok
case sqltypes.Int16, sqltypes.Year:
val, pos, ok := readUint16(data, pos)
return sqltypes.NewInt64(int64(int16(val))), pos, ok
case sqltypes.Uint24, sqltypes.Uint32:
val, pos, ok := readUint32(data, pos)
return sqltypes.NewUint64(uint64(val)), pos, ok
case sqltypes.Int24, sqltypes.Int32:
val, pos, ok := readUint32(data, pos)
return sqltypes.NewInt64(int64(int32(val))), pos, ok
case sqltypes.Float32:
val, pos, ok := readUint32(data, pos)
return sqltypes.NewFloat64(float64(math.Float32frombits(uint32(val)))), pos, ok
case sqltypes.Uint64:
val, pos, ok := readUint64(data, pos)
return sqltypes.NewUint64(val), pos, ok
case sqltypes.Int64:
val, pos, ok := readUint64(data, pos)
return sqltypes.NewInt64(int64(val)), pos, ok
case sqltypes.Float64:
val, pos, ok := readUint64(data, pos)
return sqltypes.NewFloat64(math.Float64frombits(val)), pos, ok
case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime:
size, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
switch size {
case 0x00:
return sqltypes.NewVarChar(" "), pos, ok
case 0x0b:
year, pos, ok := readUint16(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
month, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
day, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
hour, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
minute, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
second, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
microSecond, pos, ok := readUint32(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
val := strconv.Itoa(int(year)) + "-" +
strconv.Itoa(int(month)) + "-" +
strconv.Itoa(int(day)) + " " +
strconv.Itoa(int(hour)) + ":" +
strconv.Itoa(int(minute)) + ":" +
strconv.Itoa(int(second)) + "." +
fmt.Sprintf("%06d", microSecond)
return sqltypes.NewVarChar(val), pos, ok
case 0x07:
year, pos, ok := readUint16(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
month, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
day, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
hour, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
minute, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
second, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
val := strconv.Itoa(int(year)) + "-" +
strconv.Itoa(int(month)) + "-" +
strconv.Itoa(int(day)) + " " +
strconv.Itoa(int(hour)) + ":" +
strconv.Itoa(int(minute)) + ":" +
strconv.Itoa(int(second))
return sqltypes.NewVarChar(val), pos, ok
case 0x04:
year, pos, ok := readUint16(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
month, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
day, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
val := strconv.Itoa(int(year)) + "-" +
strconv.Itoa(int(month)) + "-" +
strconv.Itoa(int(day))
return sqltypes.NewVarChar(val), pos, ok
default:
return sqltypes.NULL, 0, false
}
case sqltypes.Time:
size, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
switch size {
case 0x00:
return sqltypes.NewVarChar("00:00:00"), pos, ok
case 0x0c:
isNegative, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
days, pos, ok := readUint32(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
hour, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
hours := uint32(hour) + days*uint32(24)
minute, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
second, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
microSecond, pos, ok := readUint32(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
val := ""
if isNegative == 0x01 {
val += "-"
}
val += strconv.Itoa(int(hours)) + ":" +
strconv.Itoa(int(minute)) + ":" +
strconv.Itoa(int(second)) + "." +
fmt.Sprintf("%06d", microSecond)
return sqltypes.NewVarChar(val), pos, ok
case 0x08:
isNegative, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
days, pos, ok := readUint32(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
hour, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
hours := uint32(hour) + days*uint32(24)
minute, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
second, pos, ok := readByte(data, pos)
if !ok {
return sqltypes.NULL, 0, false
}
val := ""
if isNegative == 0x01 {
val += "-"
}
val += strconv.Itoa(int(hours)) + ":" +
strconv.Itoa(int(minute)) + ":" +
strconv.Itoa(int(second))
return sqltypes.NewVarChar(val), pos, ok
default:
return sqltypes.NULL, 0, false
}
case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, sqltypes.VarBinary, sqltypes.Char,
sqltypes.Bit, sqltypes.Enum, sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON:
val, pos, ok := readLenEncStringAsBytesCopy(data, pos)
return sqltypes.MakeTrusted(sqltypes.VarBinary, val), pos, ok
default:
return sqltypes.NULL, pos, false
}
}
func (c *Conn) parseComStmtSendLongData(data []byte) (uint32, uint16, []byte, bool) {
pos := 1
statementID, pos, ok := readUint32(data, pos)
if !ok {
return 0, 0, nil, false
}
paramID, pos, ok := readUint16(data, pos)
if !ok {
return 0, 0, nil, false
}
chunkData := data[pos:]
chunk := make([]byte, len(chunkData))
copy(chunk, chunkData)
return statementID, paramID, chunk, true
}
func (c *Conn) parseComStmtClose(data []byte) (uint32, bool) {
val, _, ok := readUint32(data, 1)
return val, ok
}
func (c *Conn) parseComStmtReset(data []byte) (uint32, bool) {
val, _, ok := readUint32(data, 1)
return val, ok
}
func (c *Conn) parseComInitDB(data []byte) string {
return string(data[1:])
}
func (c *Conn) sendColumnCount(count uint64) error {
length := lenEncIntSize(count)
data, pos := c.startEphemeralPacketWithHeader(length)
writeLenEncInt(data, pos, count)
return c.writeEphemeralPacket()
}
func (c *Conn) writeColumnDefinition(field *querypb.Field) error {
length := 4 + // lenEncStringSize("def")
lenEncStringSize(field.Database) +
lenEncStringSize(field.Table) +
lenEncStringSize(field.OrgTable) +
lenEncStringSize(field.Name) +
lenEncStringSize(field.OrgName) +
1 + // length of fixed length fields
2 + // character set
4 + // column length
1 + // type
2 + // flags
1 + // decimals
2 // filler
// Get the type and the flags back. If the Field contains
// non-zero flags, we use them. Otherwise use the flags we
// derive from the type.
typ, flags := sqltypes.TypeToMySQL(field.Type)
if field.Flags != 0 {
flags = int64(field.Flags)
}
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeLenEncString(data, pos, "def") // Always the same.
pos = writeLenEncString(data, pos, field.Database)
pos = writeLenEncString(data, pos, field.Table)
pos = writeLenEncString(data, pos, field.OrgTable)
pos = writeLenEncString(data, pos, field.Name)
pos = writeLenEncString(data, pos, field.OrgName)
pos = writeByte(data, pos, 0x0c)
pos = writeUint16(data, pos, uint16(field.Charset))
pos = writeUint32(data, pos, field.ColumnLength)
pos = writeByte(data, pos, byte(typ))
pos = writeUint16(data, pos, uint16(flags))
pos = writeByte(data, pos, byte(field.Decimals))
pos = writeUint16(data, pos, uint16(0x0000))
if pos != len(data) {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "packing of column definition used %v bytes instead of %v", pos, len(data))
}
return c.writeEphemeralPacket()
}
func (c *Conn) writeRow(row []sqltypes.Value) error {
length := 0
for _, val := range row {
if val.IsNull() {
length++
} else {
l := len(val.Raw())
length += lenEncIntSize(uint64(l)) + l
}
}
data, pos := c.startEphemeralPacketWithHeader(length)
for _, val := range row {
if val.IsNull() {
pos = writeByte(data, pos, NullValue)
} else {
l := len(val.Raw())
pos = writeLenEncInt(data, pos, uint64(l))
pos += copy(data[pos:], val.Raw())
}
}
return c.writeEphemeralPacket()
}
// writeFields writes the fields of a Result. It should be called only
// if there are valid columns in the result.
func (c *Conn) writeFields(result *sqltypes.Result) error {
// Send the number of fields first.
if err := c.sendColumnCount(uint64(len(result.Fields))); err != nil {
return err
}
// Now send each Field.
for _, field := range result.Fields {
if err := c.writeColumnDefinition(field); err != nil {
return err
}
}
// Now send an EOF packet.
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
// With CapabilityClientDeprecateEOF, we do not send this EOF.
if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil {
return err
}
}
return nil
}
// writeRows sends the rows of a Result.
func (c *Conn) writeRows(result *sqltypes.Result) error {
for _, row := range result.Rows {
if err := c.writeRow(row); err != nil {
return err
}
}
return nil
}
// writeEndResult concludes the sending of a Result.
// if more is set to true, then it means there are more results afterwords
func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warnings uint16) error {
// Send either an EOF, or an OK packet.
// See doc.go.
flags := c.StatusFlags
if more {
flags |= ServerMoreResultsExists
}
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
if err := c.writeEOFPacket(flags, warnings); err != nil {
return err
}
} else {
// This will flush too.
if err := c.writeOKPacketWithEOFHeader(&PacketOK{
affectedRows: affectedRows,
lastInsertID: lastInsertID,
statusFlags: flags,
warnings: warnings,
}); err != nil {
return err
}
}
return nil
}
// PacketComStmtPrepareOK contains the COM_STMT_PREPARE_OK packet details
type PacketComStmtPrepareOK struct {
status uint8
stmtID uint32
numCols uint16
numParams uint16
warningCount uint16
}
// writePrepare writes a prepare query response to the wire.
func (c *Conn) writePrepare(fld []*querypb.Field, prepare *PrepareData) error {
paramsCount := prepare.ParamsCount
columnCount := 0
if len(fld) != 0 {
columnCount = len(fld)
}
if columnCount > 0 {
prepare.ColumnNames = make([]string, columnCount)
}
ok := PacketComStmtPrepareOK{
status: OKPacket,
stmtID: prepare.StatementID,
numCols: (uint16)(columnCount),
numParams: paramsCount,
warningCount: 0,
}
bytes, pos := c.startEphemeralPacketWithHeader(12)
data := &coder{data: bytes, pos: pos}
data.writeByte(ok.status)
data.writeUint32(ok.stmtID)
data.writeUint16(ok.numCols)
data.writeUint16(ok.numParams)
data.writeByte(0x00) // reserved 1 byte
data.writeUint16(ok.warningCount)
if err := c.writeEphemeralPacket(); err != nil {
return err
}
if paramsCount > 0 {
for i := uint16(0); i < paramsCount; i++ {
if err := c.writeColumnDefinition(&querypb.Field{
Name: "?",
Type: sqltypes.VarBinary,
Charset: 63}); err != nil {
return err
}
}
// Now send an EOF packet.
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
// With CapabilityClientDeprecateEOF, we do not send this EOF.
if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil {
return err
}
}
}
for i, field := range fld {
field.Name = strings.Replace(field.Name, "'?'", "?", -1)
prepare.ColumnNames[i] = field.Name
if err := c.writeColumnDefinition(field); err != nil {
return err
}
}
if columnCount > 0 {
// Now send an EOF packet.
if c.Capabilities&CapabilityClientDeprecateEOF == 0 {
// With CapabilityClientDeprecateEOF, we do not send this EOF.
if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil {
return err
}
}
}
return nil
}
func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) error {
length := 0
nullBitMapLen := (len(fields) + 7 + 2) / 8
for _, val := range row {
if !val.IsNull() {
l, err := val2MySQLLen(val)
if err != nil {
return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err)
}
length += l
}
}
length += nullBitMapLen + 1
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, 0x00)
for i := 0; i < nullBitMapLen; i++ {
pos = writeByte(data, pos, 0x00)
}
for i, val := range row {
if val.IsNull() {
bytePos := (i+2)/8 + 1 + packetHeaderSize
bitPos := (i + 2) % 8
data[bytePos] |= 1 << uint(bitPos)
} else {
v, err := val2MySQL(val)
if err != nil {
c.recycleWritePacket()
return fmt.Errorf("internal value %v to MySQL value error: %v", val, err)
}
pos += copy(data[pos:], v)
}
}
return c.writeEphemeralPacket()
}
// writeBinaryRows sends the rows of a Result with binary form.
func (c *Conn) writeBinaryRows(result *sqltypes.Result) error {
for _, row := range result.Rows {
if err := c.writeBinaryRow(result.Fields, row); err != nil {
return err
}
}
return nil
}
func val2MySQL(v sqltypes.Value) ([]byte, error) {
var out []byte
pos := 0
switch v.Type() {
case sqltypes.Null:
// no-op
case sqltypes.Int8:
val, err := strconv.ParseInt(v.ToString(), 10, 8)
if err != nil {
return []byte{}, err
}
out = make([]byte, 1)
writeByte(out, pos, uint8(val))
case sqltypes.Uint8:
val, err := strconv.ParseUint(v.ToString(), 10, 8)
if err != nil {
return []byte{}, err
}
out = make([]byte, 1)
writeByte(out, pos, uint8(val))
case sqltypes.Uint16:
val, err := strconv.ParseUint(v.ToString(), 10, 16)
if err != nil {
return []byte{}, err
}
out = make([]byte, 2)
writeUint16(out, pos, uint16(val))
case sqltypes.Int16, sqltypes.Year:
val, err := strconv.ParseInt(v.ToString(), 10, 16)
if err != nil {
return []byte{}, err
}
out = make([]byte, 2)
writeUint16(out, pos, uint16(val))
case sqltypes.Uint24, sqltypes.Uint32:
val, err := strconv.ParseUint(v.ToString(), 10, 32)
if err != nil {
return []byte{}, err
}
out = make([]byte, 4)
writeUint32(out, pos, uint32(val))
case sqltypes.Int24, sqltypes.Int32:
val, err := strconv.ParseInt(v.ToString(), 10, 32)
if err != nil {
return []byte{}, err
}
out = make([]byte, 4)
writeUint32(out, pos, uint32(val))
case sqltypes.Float32:
val, err := strconv.ParseFloat(v.ToString(), 32)
if err != nil {
return []byte{}, err
}
bits := math.Float32bits(float32(val))
out = make([]byte, 4)
writeUint32(out, pos, bits)
case sqltypes.Uint64:
val, err := strconv.ParseUint(v.ToString(), 10, 64)
if err != nil {
return []byte{}, err
}
out = make([]byte, 8)
writeUint64(out, pos, uint64(val))
case sqltypes.Int64:
val, err := strconv.ParseInt(v.ToString(), 10, 64)
if err != nil {
return []byte{}, err
}
out = make([]byte, 8)
writeUint64(out, pos, uint64(val))
case sqltypes.Float64:
val, err := strconv.ParseFloat(v.ToString(), 64)
if err != nil {
return []byte{}, err
}
bits := math.Float64bits(val)
out = make([]byte, 8)
writeUint64(out, pos, bits)
case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime:
if len(v.Raw()) > 19 {
out = make([]byte, 1+11)
out[pos] = 0x0b
pos++
year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16)
if err != nil {
return []byte{}, err
}
month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8)
if err != nil {
return []byte{}, err
}
day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8)
if err != nil {
return []byte{}, err
}
hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8)
if err != nil {
return []byte{}, err
}
minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8)
if err != nil {
return []byte{}, err
}
second, err := strconv.ParseUint(string(v.Raw()[17:19]), 10, 8)
if err != nil {
return []byte{}, err
}
val := make([]byte, 6)
count := copy(val, v.Raw()[20:])
for i := 0; i < (6 - count); i++ {
val[count+i] = 0x30
}
microSecond, err := strconv.ParseUint(string(val), 10, 32)
if err != nil {
return []byte{}, err
}
pos = writeUint16(out, pos, uint16(year))
pos = writeByte(out, pos, byte(month))
pos = writeByte(out, pos, byte(day))
pos = writeByte(out, pos, byte(hour))
pos = writeByte(out, pos, byte(minute))
pos = writeByte(out, pos, byte(second))
writeUint32(out, pos, uint32(microSecond))
} else if len(v.Raw()) > 10 {
out = make([]byte, 1+7)
out[pos] = 0x07
pos++
year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16)
if err != nil {
return []byte{}, err
}
month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8)
if err != nil {
return []byte{}, err
}
day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8)
if err != nil {
return []byte{}, err
}
hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8)
if err != nil {
return []byte{}, err
}
minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8)
if err != nil {
return []byte{}, err
}
second, err := strconv.ParseUint(string(v.Raw()[17:]), 10, 8)
if err != nil {
return []byte{}, err
}
pos = writeUint16(out, pos, uint16(year))
pos = writeByte(out, pos, byte(month))
pos = writeByte(out, pos, byte(day))
pos = writeByte(out, pos, byte(hour))
pos = writeByte(out, pos, byte(minute))
writeByte(out, pos, byte(second))
} else if len(v.Raw()) > 0 {
out = make([]byte, 1+4)
out[pos] = 0x04
pos++
year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16)
if err != nil {
return []byte{}, err
}
month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8)
if err != nil {
return []byte{}, err
}
day, err := strconv.ParseUint(string(v.Raw()[8:]), 10, 8)
if err != nil {
return []byte{}, err
}
pos = writeUint16(out, pos, uint16(year))
pos = writeByte(out, pos, byte(month))
writeByte(out, pos, byte(day))
} else {
out = make([]byte, 1)
out[pos] = 0x00
}
case sqltypes.Time:
if string(v.Raw()) == "00:00:00" {
out = make([]byte, 1)
out[pos] = 0x00
} else if strings.Contains(string(v.Raw()), ".") {
out = make([]byte, 1+12)
out[pos] = 0x0c
pos++
sub1 := strings.Split(string(v.Raw()), ":")
if len(sub1) != 3 {
err := fmt.Errorf("incorrect time value, ':' is not found")
return []byte{}, err
}
sub2 := strings.Split(sub1[2], ".")
if len(sub2) != 2 {
err := fmt.Errorf("incorrect time value, '.' is not found")
return []byte{}, err
}
var total []byte
if strings.HasPrefix(sub1[0], "-") {
out[pos] = 0x01
total = []byte(sub1[0])
total = total[1:]
} else {
out[pos] = 0x00
total = []byte(sub1[0])
}
pos++
h, err := strconv.ParseUint(string(total), 10, 32)
if err != nil {
return []byte{}, err
}
days := uint32(h) / 24
hours := uint32(h) % 24
minute := sub1[1]
second := sub2[0]
microSecond := sub2[1]
minutes, err := strconv.ParseUint(minute, 10, 8)
if err != nil {
return []byte{}, err
}
seconds, err := strconv.ParseUint(second, 10, 8)
if err != nil {
return []byte{}, err
}
pos = writeUint32(out, pos, uint32(days))
pos = writeByte(out, pos, byte(hours))
pos = writeByte(out, pos, byte(minutes))
pos = writeByte(out, pos, byte(seconds))
val := make([]byte, 6)
count := copy(val, microSecond)
for i := 0; i < (6 - count); i++ {
val[count+i] = 0x30
}
microSeconds, err := strconv.ParseUint(string(val), 10, 32)
if err != nil {
return []byte{}, err
}
writeUint32(out, pos, uint32(microSeconds))
} else if len(v.Raw()) > 0 {
out = make([]byte, 1+8)
out[pos] = 0x08
pos++
sub1 := strings.Split(string(v.Raw()), ":")
if len(sub1) != 3 {
err := fmt.Errorf("incorrect time value, ':' is not found")
return []byte{}, err
}
var total []byte
if strings.HasPrefix(sub1[0], "-") {
out[pos] = 0x01
total = []byte(sub1[0])
total = total[1:]
} else {
out[pos] = 0x00
total = []byte(sub1[0])
}
pos++
h, err := strconv.ParseUint(string(total), 10, 32)
if err != nil {
return []byte{}, err
}
days := uint32(h) / 24
hours := uint32(h) % 24
minute := sub1[1]
second := sub1[2]
minutes, err := strconv.ParseUint(minute, 10, 8)
if err != nil {
return []byte{}, err
}
seconds, err := strconv.ParseUint(second, 10, 8)
if err != nil {
return []byte{}, err
}
pos = writeUint32(out, pos, uint32(days))
pos = writeByte(out, pos, byte(hours))
pos = writeByte(out, pos, byte(minutes))
writeByte(out, pos, byte(seconds))
} else {
err := fmt.Errorf("incorrect time value")
return []byte{}, err
}
case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar,
sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum,
sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON:
l := len(v.Raw())
length := lenEncIntSize(uint64(l)) + l
out = make([]byte, length)
pos = writeLenEncInt(out, pos, uint64(l))
copy(out[pos:], v.Raw())
default:
out = make([]byte, len(v.Raw()))
copy(out, v.Raw())
}
return out, nil
}
func val2MySQLLen(v sqltypes.Value) (int, error) {
var length int
var err error
switch v.Type() {
case sqltypes.Null:
length = 0
case sqltypes.Int8, sqltypes.Uint8:
length = 1
case sqltypes.Uint16, sqltypes.Int16, sqltypes.Year:
length = 2
case sqltypes.Uint24, sqltypes.Uint32, sqltypes.Int24, sqltypes.Int32, sqltypes.Float32:
length = 4
case sqltypes.Uint64, sqltypes.Int64, sqltypes.Float64:
length = 8
case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime:
if len(v.Raw()) > 19 {
length = 12
} else if len(v.Raw()) > 10 {
length = 8
} else if len(v.Raw()) > 0 {
length = 5
} else {
length = 1
}
case sqltypes.Time:
if string(v.Raw()) == "00:00:00" {
length = 1
} else if strings.Contains(string(v.Raw()), ".") {
length = 13
} else if len(v.Raw()) > 0 {
length = 9
} else {
err = fmt.Errorf("incorrect time value")
}
case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar,
sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum,
sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON:
l := len(v.Raw())
length = lenEncIntSize(uint64(l)) + l
default:
length = len(v.Raw())
}
if err != nil {
return 0, err
}
return length, nil
}