зеркало из https://github.com/github/vitess-gh.git
Merge pull request #5149 from adsr/imm_handshake_err
Handle case where mysqld replies to Initial Handshake Packet with an ERR packet
This commit is contained in:
Коммит
3ce4e0584d
|
@ -373,6 +373,16 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error)
|
|||
if !ok {
|
||||
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no protocol version")
|
||||
}
|
||||
|
||||
// Server is allowed to immediately send ERR packet
|
||||
if pver == ErrPacket {
|
||||
errorCode, pos, _ := readUint16(data, pos)
|
||||
// Normally there would be a 1-byte sql_state_marker field and a 5-byte
|
||||
// sql_state field here, but docs say these will not be present in this case.
|
||||
errorMsg, pos, _ := readEOFString(data, pos)
|
||||
return 0, nil, NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "immediate error from server errorCode=%v errorMsg=%v", errorCode, errorMsg)
|
||||
}
|
||||
|
||||
if pver != protocolVersion {
|
||||
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "bad protocol version: %v", pver)
|
||||
}
|
||||
|
|
|
@ -80,6 +80,10 @@ func lenNullString(value string) int {
|
|||
return len(value) + 1
|
||||
}
|
||||
|
||||
func lenEOFString(value string) int {
|
||||
return len(value)
|
||||
}
|
||||
|
||||
func writeNullString(data []byte, pos int, value string) int {
|
||||
pos += copy(data[pos:], value)
|
||||
data[pos] = 0
|
||||
|
@ -180,6 +184,10 @@ func readNullString(data []byte, pos int) (string, int, bool) {
|
|||
return string(data[pos : pos+end]), pos + end + 1, true
|
||||
}
|
||||
|
||||
func readEOFString(data []byte, pos int) (string, int, bool) {
|
||||
return string(data[pos:]), len(data) - pos, true
|
||||
}
|
||||
|
||||
func readUint16(data []byte, pos int) (uint16, int, bool) {
|
||||
if pos+1 >= len(data) {
|
||||
return 0, 0, false
|
||||
|
|
|
@ -190,21 +190,25 @@ func TestEncString(t *testing.T) {
|
|||
value string
|
||||
lenEncoded []byte
|
||||
nullEncoded []byte
|
||||
eofEncoded []byte
|
||||
}{
|
||||
{
|
||||
"",
|
||||
[]byte{0x00},
|
||||
[]byte{0x00},
|
||||
[]byte{},
|
||||
},
|
||||
{
|
||||
"a",
|
||||
[]byte{0x01, 'a'},
|
||||
[]byte{'a', 0x00},
|
||||
[]byte{'a'},
|
||||
},
|
||||
{
|
||||
"0123456789",
|
||||
[]byte{0x0a, '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'},
|
||||
[]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 0x00},
|
||||
[]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
|
@ -220,6 +224,11 @@ func TestEncString(t *testing.T) {
|
|||
t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.nullEncoded), test.value)
|
||||
}
|
||||
|
||||
// Check lenEOFString
|
||||
if got := lenEOFString(test.value); got != len(test.eofEncoded) {
|
||||
t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.eofEncoded), test.value)
|
||||
}
|
||||
|
||||
// Check successful encoding.
|
||||
data := make([]byte, len(test.lenEncoded))
|
||||
pos := writeLenEncString(data, 0, test.value)
|
||||
|
@ -319,16 +328,21 @@ func TestEncString(t *testing.T) {
|
|||
}
|
||||
|
||||
// EOF encoded tests.
|
||||
// We use the nullEncoded value, removing the 0 at the end.
|
||||
|
||||
// Check successful encoding.
|
||||
data = make([]byte, len(test.nullEncoded)-1)
|
||||
data = make([]byte, len(test.eofEncoded))
|
||||
pos = writeEOFString(data, 0, test.value)
|
||||
if pos != len(test.nullEncoded)-1 {
|
||||
t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.nullEncoded)-1)
|
||||
if pos != len(test.eofEncoded) {
|
||||
t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.eofEncoded))
|
||||
}
|
||||
if !bytes.Equal(data, test.nullEncoded[:len(test.nullEncoded)-1]) {
|
||||
t.Errorf("unexpected nullEncoded value for %v, got %v expected %v", test.value, data, test.nullEncoded)
|
||||
if !bytes.Equal(data, test.eofEncoded[:len(test.eofEncoded)]) {
|
||||
t.Errorf("unexpected eofEncoded value for %v, got %v expected %v", test.value, data, test.eofEncoded)
|
||||
}
|
||||
|
||||
// Check successful decoding.
|
||||
got, pos, ok = readEOFString(test.eofEncoded, 0)
|
||||
if !ok || got != test.value || pos != len(test.eofEncoded) {
|
||||
t.Errorf("readEOFString returned %v/%v/%v but expected %v/%v/%v", got, pos, ok, test.value, len(test.eofEncoded), true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче