Merge pull request #5149 from adsr/imm_handshake_err

Handle case where mysqld replies to Initial Handshake Packet with an ERR packet
This commit is contained in:
Sugu Sougoumarane 2019-08-29 19:16:43 -07:00 коммит произвёл GitHub
Родитель f2ff0ae813 7504082368
Коммит 3ce4e0584d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 38 добавлений и 6 удалений

Просмотреть файл

@ -373,6 +373,16 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error)
if !ok {
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no protocol version")
}
// Server is allowed to immediately send ERR packet
if pver == ErrPacket {
errorCode, pos, _ := readUint16(data, pos)
// Normally there would be a 1-byte sql_state_marker field and a 5-byte
// sql_state field here, but docs say these will not be present in this case.
errorMsg, pos, _ := readEOFString(data, pos)
return 0, nil, NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "immediate error from server errorCode=%v errorMsg=%v", errorCode, errorMsg)
}
if pver != protocolVersion {
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "bad protocol version: %v", pver)
}

Просмотреть файл

@ -80,6 +80,10 @@ func lenNullString(value string) int {
return len(value) + 1
}
func lenEOFString(value string) int {
return len(value)
}
func writeNullString(data []byte, pos int, value string) int {
pos += copy(data[pos:], value)
data[pos] = 0
@ -180,6 +184,10 @@ func readNullString(data []byte, pos int) (string, int, bool) {
return string(data[pos : pos+end]), pos + end + 1, true
}
func readEOFString(data []byte, pos int) (string, int, bool) {
return string(data[pos:]), len(data) - pos, true
}
func readUint16(data []byte, pos int) (uint16, int, bool) {
if pos+1 >= len(data) {
return 0, 0, false

Просмотреть файл

@ -190,21 +190,25 @@ func TestEncString(t *testing.T) {
value string
lenEncoded []byte
nullEncoded []byte
eofEncoded []byte
}{
{
"",
[]byte{0x00},
[]byte{0x00},
[]byte{},
},
{
"a",
[]byte{0x01, 'a'},
[]byte{'a', 0x00},
[]byte{'a'},
},
{
"0123456789",
[]byte{0x0a, '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'},
[]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 0x00},
[]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'},
},
}
for _, test := range tests {
@ -220,6 +224,11 @@ func TestEncString(t *testing.T) {
t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.nullEncoded), test.value)
}
// Check lenEOFString
if got := lenEOFString(test.value); got != len(test.eofEncoded) {
t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.eofEncoded), test.value)
}
// Check successful encoding.
data := make([]byte, len(test.lenEncoded))
pos := writeLenEncString(data, 0, test.value)
@ -319,16 +328,21 @@ func TestEncString(t *testing.T) {
}
// EOF encoded tests.
// We use the nullEncoded value, removing the 0 at the end.
// Check successful encoding.
data = make([]byte, len(test.nullEncoded)-1)
data = make([]byte, len(test.eofEncoded))
pos = writeEOFString(data, 0, test.value)
if pos != len(test.nullEncoded)-1 {
t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.nullEncoded)-1)
if pos != len(test.eofEncoded) {
t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.eofEncoded))
}
if !bytes.Equal(data, test.nullEncoded[:len(test.nullEncoded)-1]) {
t.Errorf("unexpected nullEncoded value for %v, got %v expected %v", test.value, data, test.nullEncoded)
if !bytes.Equal(data, test.eofEncoded[:len(test.eofEncoded)]) {
t.Errorf("unexpected eofEncoded value for %v, got %v expected %v", test.value, data, test.eofEncoded)
}
// Check successful decoding.
got, pos, ok = readEOFString(test.eofEncoded, 0)
if !ok || got != test.value || pos != len(test.eofEncoded) {
t.Errorf("readEOFString returned %v/%v/%v but expected %v/%v/%v", got, pos, ok, test.value, len(test.eofEncoded), true)
}
}
}