add gtid in ok packet as client session tracking information

Signed-off-by: Harshit Gangal <harshit@planetscale.com>
This commit is contained in:
Harshit Gangal 2020-10-12 18:09:19 +05:30
Родитель 84324e6cbc
Коммит 22fb28eb49
4 изменённых файлов: 146 добавлений и 183 удалений

Просмотреть файл

@ -685,6 +685,24 @@ func (c *Conn) IsClosed() bool {
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeOKPacket(packetOk *PacketOK) error {
return c.writeOKPacketWithHeader(packetOk, OKPacket)
}
// writeOKPacketWithEOFHeader writes an OK packet with an EOF header.
// This is used at the end of a result set if
// CapabilityClientDeprecateEOF is set.
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeOKPacketWithEOFHeader(packetOk *PacketOK) error {
return c.writeOKPacketWithHeader(packetOk, EOFPacket)
}
// writeOKPacketWithEOFHeader writes an OK packet with an EOF header.
// This is used at the end of a result set if
// CapabilityClientDeprecateEOF is set.
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) error {
length := 1 + // OKPacket
lenEncIntSize(packetOk.affectedRows) +
lenEncIntSize(packetOk.lastInsertID)
@ -694,127 +712,74 @@ func (c *Conn) writeOKPacket(packetOk *PacketOK) error {
length += 2 // status_flags
}
var gtidData []byte
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
length += lenEncStringSize(packetOk.info) // info
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
length += 1 + // total length
1 // type
gtidData = getLenEncString([]byte(packetOk.sessionStateData))
gtidData = append([]byte{0x00}, gtidData...)
gtidData = getLenEncString(gtidData)
gtidData = append([]byte{0x03}, gtidData...)
gtidData = append(getLenEncInt(uint64(len(gtidData))), gtidData...)
length += len(gtidData)
}
} else {
length += len(packetOk.info) // info
}
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, OKPacket) //header - OK or EOF
pos = writeLenEncInt(data, pos, packetOk.affectedRows)
pos = writeLenEncInt(data, pos, packetOk.lastInsertID)
bytes, pos := c.startEphemeralPacketWithHeader(length)
data := &coder{data: bytes, pos: pos}
data.writeByte(headerType) //header - OK or EOF
data.writeLenEncInt(packetOk.affectedRows)
data.writeLenEncInt(packetOk.lastInsertID)
if c.Capabilities&CapabilityClientProtocol41 == CapabilityClientProtocol41 {
pos = writeUint16(data, pos, packetOk.statusFlags)
pos = writeUint16(data, pos, packetOk.warnings)
data.writeUint16(packetOk.statusFlags)
data.writeUint16(packetOk.warnings)
} else if c.Capabilities&CapabilityClientTransactions == CapabilityClientTransactions {
pos = writeUint16(data, pos, packetOk.statusFlags)
data.writeUint16(packetOk.statusFlags)
}
if c.Capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
pos = writeLenEncString(data, pos, packetOk.info)
data.writeLenEncString(packetOk.info)
if packetOk.statusFlags&ServerSessionStateChanged == ServerSessionStateChanged {
pos = writeByte(data, pos, 0)
_ = writeByte(data, pos, packetOk.sessionStateChangeType)
data.writeEOFString(string(gtidData))
}
} else {
_ = writeEOFString(data, pos, packetOk.info)
data.writeEOFString(packetOk.info)
}
return c.writeEphemeralPacket()
}
/* writeOKPacketWithGTIDs writes an OK packet according to https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
+--------------------+-----------------------+-----------------------------------+
| Type | Name | Description |
+--------------------+-----------------------+-----------------------------------+
| int<1> | header | [00] or [fe] the OK packet header |
+--------------------+-----------------------+-----------------------------------+
| int<lenenc> | affected_rows | affected rows |
+--------------------+-----------------------+-----------------------------------+
| int<lenenc> | last_insert_id | last insert-id |
+--------------------+-----------------------+-----------------------------------+
| int<2> | status_flags | see possible status_flags here |
+--------------------+-----------------------+-----------------------------------+
| int<2> | warnings | number of warnings |
+--------------------+-----------------------+-----------------------------------+
| if capabilities & CLIENT_SESSION_TRACK { |
+--------------------+-----------------------+-----------------------------------+
| string<lenenc> | info | human readable information |
+--------------------+-----------------------+-----------------------------------+
| if status_flags & SERVER_SESSION_STATE_CHANGED { |
+--------------------+-----------------------+-----------------------------------+
| string<lenenc> | session_state_changes | session state info - GTIDs here |
+--------------------+-----------------------+-----------------------------------+
| } |
+--------------------------------------------------------------------------------+
| } else { |
+--------------------+-----------------------+-----------------------------------+
| string<lenenc> | info | human readable information |
+--------------------+-----------------------+-----------------------------------+
| } |
+--------------------------------------------------------------------------------+
*/
func (c *Conn) writeOKPacketWithGTIDs(affectedRows, lastInsertID uint64, flags uint16, warnings uint16, gtids string) error {
sessionStateSize := 3 + //- size and encoding spec
lenEncStringSize(gtids) // the actual gtids
sessionStateSizePlusSize := lenEncIntSize(uint64(sessionStateSize))
length := 1 + // OKPacket
lenEncIntSize(affectedRows) +
lenEncIntSize(lastInsertID) +
2 + // flags
2 + // warnings
lenEncStringSize("") + // human readable status information
sessionStateSizePlusSize + sessionStateSize //session state info
bytes, pos := c.startEphemeralPacketWithHeader(length)
data := &coder{
data: bytes,
pos: pos,
}
data.writeByte(OKPacket)
data.writeLenEncInt(affectedRows)
data.writeLenEncInt(lastInsertID)
if gtids != "" {
flags |= ServerSessionStateChanged
}
data.writeUint16(flags)
data.writeUint16(warnings)
// add session state change tracking
data.writeLenEncString("") // human readable info
data.writeLenEncInt(uint64(sessionStateSizePlusSize)) // total length of session state change info
data.writeByte(SessionTrackGtids)
data.writeByte(byte(sessionStateSizePlusSize + sessionStateSize)) // total length of session state change info
data.writeByte(byte(sessionStateSize)) // gtid encoding spec - only text available today
data.writeLenEncString(gtids)
return c.writeEphemeralPacket()
func getLenEncString(value []byte) []byte {
data := getLenEncInt(uint64(len(value)))
return append(data, value...)
}
// writeOKPacketWithEOFHeader writes an OK packet with an EOF header.
// This is used at the end of a result set if
// CapabilityClientDeprecateEOF is set.
// Server -> Client.
// This method returns a generic error, not a SQLError.
func (c *Conn) writeOKPacketWithEOFHeader(affectedRows, lastInsertID uint64, flags uint16, warnings uint16) error {
length := 1 + // EOFPacket
lenEncIntSize(affectedRows) +
lenEncIntSize(lastInsertID) +
2 + // flags
2 // warnings
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, EOFPacket)
pos = writeLenEncInt(data, pos, affectedRows)
pos = writeLenEncInt(data, pos, lastInsertID)
pos = writeUint16(data, pos, flags)
_ = writeUint16(data, pos, warnings)
return c.writeEphemeralPacket()
func getLenEncInt(i uint64) []byte {
var data []byte
switch {
case i < 251:
data = append(data, byte(i))
case i < 1<<16:
data = append(data, 0xfc)
data = append(data, byte(i))
data = append(data, byte(i>>8))
case i < 1<<24:
data = append(data, 0xfd)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
default:
data = append(data, 0xfe)
data = append(data, byte(i))
data = append(data, byte(i>>8))
data = append(data, byte(i>>16))
data = append(data, byte(i>>24))
data = append(data, byte(i>>32))
data = append(data, byte(i>>40))
data = append(data, byte(i>>48))
data = append(data, byte(i>>56))
}
return data
}
func (c *Conn) writeErrorAndLog(errorCode uint16, sqlState string, format string, args ...interface{}) bool {
@ -1051,7 +1016,15 @@ func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool
if len(qr.Fields) == 0 {
sendFinished = true
// We should not send any more packets after this.
return c.writeOKPacketWithGTIDs(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0, qr.SessionStateChanges)
ok := PacketOK{
affectedRows: qr.RowsAffected,
lastInsertID: qr.InsertID,
statusFlags: c.StatusFlags,
warnings: 0,
info: "",
sessionStateData: qr.SessionStateChanges,
}
return c.writeOKPacket(&ok)
}
if err := c.writeFields(qr); err != nil {
return err
@ -1275,7 +1248,15 @@ func (c *Conn) execQuery(query string, handler Handler, more bool) execResult {
// We should not send any more packets after this, but make sure
// to extract the affected rows and last insert id from the result
// struct here since clients expect it.
return c.writeOKPacketWithGTIDs(qr.RowsAffected, qr.InsertID, flag, handler.WarningCount(c), qr.SessionStateChanges)
ok := PacketOK{
affectedRows: qr.RowsAffected,
lastInsertID: qr.InsertID,
statusFlags: flag,
warnings: handler.WarningCount(c),
info: "",
sessionStateData: qr.SessionStateChanges,
}
return c.writeOKPacket(&ok)
}
if err := c.writeFields(qr); err != nil {
return err
@ -1365,14 +1346,8 @@ type PacketOK struct {
warnings uint16
info string
sessionStateChangeType uint8
sessionStateChangeValue interface{}
}
// TrackSystemVariable contains the name and values of system variables
type TrackSystemVariable struct {
names string
values string
// at the moment, we only store GTID information in this field
sessionStateData string
}
func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
@ -1427,62 +1402,25 @@ func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
return fail("invalid OK packet session state change length: %v", data)
}
sscType, ok := data.readByte()
if !ok || sscType != SessionTrackGtids {
return fail("invalid OK packet session state change type: %v", sscType)
}
// Move past the total length of the changed entity: 1 byte
_, ok = data.readByte()
if !ok {
return fail("invalid OK packet session state change type: %v", data)
return fail("invalid OK packet gtids length: %v", data)
}
packetOK.sessionStateChangeType = sscType
switch sscType {
case SessionTrackSystemVariables:
// Move past the total length of the changed entity: 1 byte
_, ok := data.readByte()
if !ok {
return fail("invalid OK packet system variables length: %v", data)
}
names, ok := data.readLenEncString()
if !ok {
return fail("invalid OK packet system variables names: %v", data)
}
values, ok := data.readLenEncString()
if !ok {
return fail("invalid OK packet system variables values: %v", data)
}
packetOK.sessionStateChangeValue = &TrackSystemVariable{names: names, values: values}
case SessionTrackSchema:
// Move past the total length of the changed entity: 1 byte
_, ok := data.readByte()
if !ok {
return fail("invalid OK packet schema length: %v", data)
}
schema, ok := data.readLenEncString()
if !ok {
return fail("invalid OK packet schema: %v", data)
}
packetOK.sessionStateChangeValue = schema
case SessionTrackStateChange:
tracked, ok := data.readLenEncString()
if !ok {
return fail("invalid OK packet tracked: %v", data)
}
packetOK.sessionStateChangeValue = tracked
case SessionTrackGtids:
// Move past the total length of the changed entity: 1 byte
_, ok := data.readByte()
if !ok {
return fail("invalid OK packet gtids length: %v", data)
}
// read (and ignore for now) the GTIDS encoding specification code: 1 byte
_, ok = data.readByte()
if !ok {
return fail("invalid OK packet gtids type: %v", data)
}
gtids, ok := data.readLenEncString()
if !ok {
return fail("invalid OK packet gtids: %v", data)
}
packetOK.sessionStateChangeValue = gtids
default:
fail("invalid OK packet session state change: %v", data)
// read (and ignore for now) the GTIDS encoding specification code: 1 byte
_, ok = data.readByte()
if !ok {
return fail("invalid OK packet gtids type: %v", data)
}
gtids, ok := data.readLenEncString()
if !ok {
return fail("invalid OK packet gtids: %v", data)
}
packetOK.sessionStateData = gtids
}
} else {
// info

Просмотреть файл

@ -246,8 +246,17 @@ func TestBasicPackets(t *testing.T) {
assert.EqualValues(78, packetOk.warnings)
// Write OK packet with affected GTIDs, read it, compare.
gtids := "foo-bar"
err = sConn.writeOKPacketWithGTIDs(23, 45, 67, 89, gtids)
sConn.Capabilities |= CapabilityClientSessionTrack
cConn.Capabilities |= CapabilityClientSessionTrack
ok := PacketOK{
affectedRows: 23,
lastInsertID: 45,
statusFlags: 67 | ServerSessionStateChanged,
warnings: 89,
info: "",
sessionStateData: "foo-bar",
}
err = sConn.writeOKPacket(&ok)
require.NoError(err)
data, err = cConn.ReadPacket()
@ -255,18 +264,22 @@ func TestBasicPackets(t *testing.T) {
require.NotEmpty(data)
assert.EqualValues(data[0], OKPacket, "OKPacket")
cConn.Capabilities = CapabilityFlags
packetOk, err = cConn.parseOKPacket(data)
require.NoError(err)
assert.EqualValues(23, packetOk.affectedRows)
assert.EqualValues(45, packetOk.lastInsertID)
assert.EqualValues(67|ServerSessionStateChanged, packetOk.statusFlags)
assert.EqualValues(ServerSessionStateChanged, packetOk.statusFlags&ServerSessionStateChanged)
assert.EqualValues(89, packetOk.warnings)
assert.EqualValues(SessionTrackGtids, packetOk.sessionStateChangeType)
// TODO harshit: fix this assert.EqualValues("SessionTrackGtids", packetOk.sessionStateChangeValue)
assert.EqualValues("foo-bar", packetOk.sessionStateData)
// Write OK packet with EOF header, read it, compare.
err = sConn.writeOKPacketWithEOFHeader(12, 34, 56, 78)
ok = PacketOK{
affectedRows: 12,
lastInsertID: 34,
statusFlags: 56,
warnings: 78,
}
err = sConn.writeOKPacketWithEOFHeader(&ok)
require.NoError(err)
data, err = cConn.ReadPacket()
@ -323,8 +336,9 @@ func TestOkPackets(t *testing.T) {
}()
testCases := []struct {
data string
cc uint32
data string
cc uint32
expectedErr string
}{{
data: `
00000000 00 00 00 02 00 00 00 |.......|`,
@ -341,18 +355,21 @@ func TestOkPackets(t *testing.T) {
00000030 61 3a 32 |a:2|`,
cc: CapabilityClientProtocol41 | CapabilityClientTransactions | CapabilityClientSessionTrack,
}, {
data: `00000000 00 00 00 02 40 00 00 00 07 01 05 04 74 65 73 74 |....@.......test|`,
cc: CapabilityClientProtocol41 | CapabilityClientTransactions | CapabilityClientSessionTrack,
data: `00000000 00 00 00 02 40 00 00 00 07 01 05 04 74 65 73 74 |....@.......test|`,
cc: CapabilityClientProtocol41 | CapabilityClientTransactions | CapabilityClientSessionTrack,
expectedErr: "invalid OK packet session state change type: 1",
}, {
data: `
00000000 00 00 00 00 40 00 00 00 14 00 0f 0a 61 75 74 6f |....@.......auto|
00000010 63 6f 6d 6d 69 74 03 4f 46 46 02 01 31 |commit.OFF..1|`,
cc: CapabilityClientProtocol41 | CapabilityClientTransactions | CapabilityClientSessionTrack,
cc: CapabilityClientProtocol41 | CapabilityClientTransactions | CapabilityClientSessionTrack,
expectedErr: "invalid OK packet session state change type: 0",
}, {
data: `
00000000 00 00 00 00 40 00 00 00 0a 01 05 04 74 65 73 74 |....@.......test|
00000010 02 01 31 |..1|`,
cc: CapabilityClientProtocol41 | CapabilityClientTransactions | CapabilityClientSessionTrack,
cc: CapabilityClientProtocol41 | CapabilityClientTransactions | CapabilityClientSessionTrack,
expectedErr: "invalid OK packet session state change type: 1",
}}
for i, testCase := range testCases {
@ -363,6 +380,11 @@ func TestOkPackets(t *testing.T) {
sConn.Capabilities = testCase.cc
// parse the packet
packetOk, err := cConn.parseOKPacket(data)
if testCase.expectedErr != "" {
require.Error(t, err)
require.Equal(t, testCase.expectedErr, err.Error())
return
}
require.NoError(t, err, "failed to parse OK packet")
// write the ok packet from server

Просмотреть файл

@ -319,12 +319,6 @@ func (d *coder) readByte() (byte, bool) {
return res, ok
}
//func (d *coder) skipLenEncString() bool {
// newPos, ok := skipLenEncString(d.data, d.pos)
// d.pos = newPos
// return ok
//}
func (d *coder) readLenEncString() (string, bool) {
res, newPos, ok := readLenEncString(d.data, d.pos)
d.pos = newPos
@ -358,3 +352,7 @@ func (d *coder) writeLenEncString(value string) {
newPos := writeLenEncString(d.data, d.pos, value)
d.pos = newPos
}
func (d *coder) writeEOFString(value string) {
d.pos += copy(d.data[d.pos:], value)
}

Просмотреть файл

@ -996,7 +996,12 @@ func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warn
}
} else {
// This will flush too.
if err := c.writeOKPacketWithEOFHeader(affectedRows, lastInsertID, flags, warnings); err != nil {
if err := c.writeOKPacketWithEOFHeader(&PacketOK{
affectedRows: affectedRows,
lastInsertID: lastInsertID,
statusFlags: flags,
warnings: warnings,
}); err != nil {
return err
}
}