vitess-gh/go/mysql/client.go

789 строки
26 KiB
Go

/*
Copyright 2019 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mysql
import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"time"
"vitess.io/vitess/go/mysql/collations"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vttls"
)
// connectResult is used by Connect.
type connectResult struct {
c *Conn
err error
}
// Connect creates a connection to a server.
// It then handles the initial handshake.
//
// If context is canceled before the end of the process, this function
// will return nil, ctx.Err().
//
// FIXME(alainjobart) once we have more of a server side, add test cases
// to cover all failure scenarios.
func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
if params.ConnectTimeoutMs != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(params.ConnectTimeoutMs)*time.Millisecond)
defer cancel()
}
netProto := "tcp"
addr := ""
if params.UnixSocket != "" {
netProto = "unix"
addr = params.UnixSocket
} else {
addr = net.JoinHostPort(params.Host, fmt.Sprintf("%v", params.Port))
}
// Start a background connection routine. It first
// establishes a network connection, returns it on the channel,
// then starts the negotiation, and returns the result on the channel.
// It can send on the channel, before closing it:
// - a connectResult with an error and nothing else (when dial fails).
// - a connectResult with a *Conn and no error, then another one
// with possibly an error.
status := make(chan connectResult)
go func() {
defer close(status)
var err error
var conn net.Conn
// Cap the Dial time with the context deadline, plus a
// few seconds. We want to reclaim resources quickly
// and not let this go routine stuck in Dial forever.
//
// We add a few seconds so we detect the context is
// Done() before timing out the Dial. That way we'll
// return the right error to the client (ctx.Err(), vs
// DialTimeout() error).
if deadline, ok := ctx.Deadline(); ok {
timeout := time.Until(deadline) + 5*time.Second
conn, err = net.DialTimeout(netProto, addr, timeout)
} else {
conn, err = net.Dial(netProto, addr)
}
if err != nil {
// If we get an error, the connection to a Unix socket
// should return a 2002, but for a TCP socket it
// should return a 2003.
if netProto == "tcp" {
status <- connectResult{
err: NewSQLError(CRConnHostError, SSUnknownSQLState, "net.Dial(%v) failed: %v", addr, err),
}
} else {
status <- connectResult{
err: NewSQLError(CRConnectionError, SSUnknownSQLState, "net.Dial(%v) to local server failed: %v", addr, err),
}
}
return
}
// Send the connection back, so the other side can close it.
c := newConn(conn)
status <- connectResult{
c: c,
}
// During the handshake, and if the context is
// canceled, the connection will be closed. That will
// make any read or write just return with an error
// right away.
status <- connectResult{
err: c.clientHandshake(params),
}
}()
// Wait on the context and the status, for the connection to happen.
var c *Conn
select {
case <-ctx.Done():
// The background routine may send us a few things,
// wait for them and terminate them properly in the
// background.
go func() {
dialCR := <-status // This one can take a while.
if dialCR.err != nil {
// Dial failed, nothing else to do.
return
}
// Dial worked, close the connection, wait for the end.
// We wait as not to leave a channel with an unread value.
dialCR.c.Close()
<-status
}()
return nil, ctx.Err()
case cr := <-status:
if cr.err != nil {
// Dial failed, no connection was ever established.
return nil, cr.err
}
// Dial worked, we have a connection. Keep going.
c = cr.c
}
// Wait for the end of the handshake.
select {
case <-ctx.Done():
// We are interrupted. Close the connection, wait for
// the handshake to finish in the background.
c.Close()
go func() {
// Since we closed the connection, this one should be fast.
// We wait as not to leave a channel with an unread value.
<-status
}()
return nil, ctx.Err()
case cr := <-status:
if cr.err != nil {
c.Close()
return nil, cr.err
}
}
return c, nil
}
// Ping implements mysql ping command.
func (c *Conn) Ping() error {
// This is a new command, need to reset the sequence.
c.sequence = 0
data, pos := c.startEphemeralPacketWithHeader(1)
data[pos] = ComPing
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, "%v", err)
}
data, err := c.readEphemeralPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
switch data[0] {
case OKPacket:
return nil
case ErrPacket:
return ParseErrorPacket(data)
}
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected packet type: %d", data[0])
}
// clientHandshake handles the client side of the handshake.
// Note the connection can be closed while this is running.
// Returns a SQLError.
func (c *Conn) clientHandshake(params *ConnParams) error {
// if EnableQueryInfo is set, make sure that all queries starting with the handshake
// will actually process the INFO fields in QUERY_OK packets
if params.EnableQueryInfo {
c.enableQueryInfo = true
}
// Wait for the server initial handshake packet, and parse it.
data, err := c.readPacket()
if err != nil {
return NewSQLError(CRServerLost, "", "initial packet read failed: %v", err)
}
capabilities, salt, err := c.parseInitialHandshakePacket(data)
if err != nil {
return err
}
c.fillFlavor(params)
c.salt = salt
// Sanity check.
if capabilities&CapabilityClientProtocol41 == 0 {
return NewSQLError(CRVersionError, SSUnknownSQLState, "cannot connect to servers earlier than 4.1")
}
// Remember a subset of the capabilities, so we can use them
// later in the protocol.
c.Capabilities = 0
if !params.DisableClientDeprecateEOF {
c.Capabilities = capabilities & (CapabilityClientDeprecateEOF)
}
charset, err := collations.Local().ParseConnectionCharset(params.Charset)
if err != nil {
return err
}
// Handle switch to SSL if necessary.
if params.SslEnabled() {
// If client asked for SSL, but server doesn't support it,
// stop right here.
if params.SslRequired() && capabilities&CapabilityClientSSL == 0 {
return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "server doesn't support SSL but client asked for it")
}
// The ServerName to verify depends on what the hostname is.
// We use the params's ServerName if specified. Otherwise:
// - If using a socket, we use "localhost".
// - If it is an IP address, we need to prefix it with 'IP:'.
// - If not, we can just use it as is.
serverName := "localhost"
if params.ServerName != "" {
serverName = params.ServerName
} else if params.Host != "" {
if net.ParseIP(params.Host) != nil {
serverName = "IP:" + params.Host
} else {
serverName = params.Host
}
}
tlsVersion, err := vttls.TLSVersionToNumber(params.TLSMinVersion)
if err != nil {
return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error parsing minimal TLS version: %v", err)
}
// Build the TLS config.
clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, params.SslCrl, serverName, tlsVersion)
if err != nil {
return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error loading client cert and ca: %v", err)
}
// Send the SSLRequest packet.
if err := c.writeSSLRequest(capabilities, charset, params); err != nil {
return err
}
// Switch to SSL.
conn := tls.Client(c.conn, clientConfig)
c.conn = conn
c.bufferedReader.Reset(conn)
c.Capabilities |= CapabilityClientSSL
}
// Password encryption.
var scrambledPassword []byte
if c.authPluginName == CachingSha2Password {
scrambledPassword = ScrambleCachingSha2Password(salt, []byte(params.Pass))
} else {
scrambledPassword = ScrambleMysqlNativePassword(salt, []byte(params.Pass))
}
// Client Session Tracking Capability.
if capabilities&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
// If the server also supports it, we will have enabled
// it so we also add it to our capabilities.
c.Capabilities |= CapabilityClientSessionTrack
} else if params.Flags&CapabilityClientSessionTrack == CapabilityClientSessionTrack {
// If client asked for ClientSessionTrack, but server doesn't support it,
// stop right here.
return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "server doesn't support ClientSessionTrack but client asked for it")
}
// Build and send our handshake response 41.
// Note this one will never have SSL flag on.
if err := c.writeHandshakeResponse41(capabilities, scrambledPassword, charset, params); err != nil {
return err
}
// Read the server response.
if err := c.handleAuthResponse(params); err != nil {
return err
}
// If the server didn't support DbName in its handshake, set
// it now. This is what the 'mysql' client does.
if capabilities&CapabilityClientConnectWithDB == 0 && params.DbName != "" {
// Write the packet.
if err := c.writeComInitDB(params.DbName); err != nil {
return err
}
// Wait for response, should be OK.
response, err := c.readPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
switch response[0] {
case OKPacket:
// OK packet, we are authenticated.
return nil
case ErrPacket:
return ParseErrorPacket(response)
default:
// FIXME(alainjobart) handle extra auth cases and so on.
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response is asking for more information, not implemented yet: %v", response)
}
}
return nil
}
// parseInitialHandshakePacket parses the initial handshake from the server.
// It returns a SQLError with the right code.
func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error) {
pos := 0
// Protocol version.
pver, pos, ok := readByte(data, pos)
if !ok {
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no protocol version")
}
// Server is allowed to immediately send ERR packet
if pver == ErrPacket {
errorCode, pos, _ := readUint16(data, pos)
// Normally there would be a 1-byte sql_state_marker field and a 5-byte
// sql_state field here, but docs say these will not be present in this case.
errorMsg, _, _ := readEOFString(data, pos)
return 0, nil, NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "immediate error from server errorCode=%v errorMsg=%v", errorCode, errorMsg)
}
if pver != protocolVersion {
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "bad protocol version: %v", pver)
}
// Read the server version.
c.ServerVersion, pos, ok = readNullString(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no server version")
}
// Read the connection id.
c.ConnectionID, pos, ok = readUint32(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no connection id")
}
// Read the first part of the auth-plugin-data
authPluginData, pos, ok := readBytes(data, pos, 8)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no auth-plugin-data-part-1")
}
// One byte filler, 0. We don't really care about the value.
_, pos, ok = readByte(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no filler")
}
// Lower 2 bytes of the capability flags.
capLower, pos, ok := readUint16(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (lower 2 bytes)")
}
var capabilities = uint32(capLower)
// The packet can end here.
if pos == len(data) {
return capabilities, authPluginData, nil
}
// Character set.
characterSet, pos, ok := readByte(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no character set")
}
c.CharacterSet = collations.ID(characterSet)
// Status flags. Ignored.
_, pos, ok = readUint16(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no status flags")
}
// Upper 2 bytes of the capability flags.
capUpper, pos, ok := readUint16(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no capability flags (upper 2 bytes)")
}
capabilities += uint32(capUpper) << 16
// Length of auth-plugin-data, or 0.
// Only with CLIENT_PLUGIN_AUTH capability.
var authPluginDataLength byte
if capabilities&CapabilityClientPluginAuth != 0 {
authPluginDataLength, pos, ok = readByte(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no length of auth-plugin-data")
}
} else {
// One byte filler, 0. We don't really care about the value.
_, pos, ok = readByte(data, pos)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no length of auth-plugin-data filler")
}
}
// 10 reserved 0 bytes.
pos += 10
if capabilities&CapabilityClientSecureConnection != 0 {
// The next part of the auth-plugin-data.
// The length is max(13, length of auth-plugin-data - 8).
l := int(authPluginDataLength) - 8
if l > 13 {
l = 13
}
var authPluginDataPart2 []byte
authPluginDataPart2, pos, ok = readBytes(data, pos, l)
if !ok {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no auth-plugin-data-part-2")
}
// The last byte has to be 0, and is not part of the data.
if authPluginDataPart2[l-1] != 0 {
return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "parseInitialHandshakePacket: auth-plugin-data-part-2 is not 0 terminated")
}
authPluginData = append(authPluginData, authPluginDataPart2[0:l-1]...)
}
// Auth-plugin name.
if capabilities&CapabilityClientPluginAuth != 0 {
authPluginName, _, ok := readNullString(data, pos)
if !ok {
// Fallback for versions prior to 5.5.10 and
// 5.6.2 that don't have a null terminated string.
authPluginName = string(data[pos : len(data)-1])
}
c.authPluginName = AuthMethodDescription(authPluginName)
}
return capabilities, authPluginData, nil
}
// writeSSLRequest writes the SSLRequest packet. It's just a truncated
// HandshakeResponse41.
func (c *Conn) writeSSLRequest(capabilities uint32, characterSet uint8, params *ConnParams) error {
// Build our flags, with CapabilityClientSSL.
capabilityFlags := CapabilityFlagsSsl |
// If the server supported
// CapabilityClientDeprecateEOF, we also support it.
c.Capabilities&CapabilityClientDeprecateEOF |
// If the server supported
// CapabilityClientSessionTrack, we also support it.
c.Capabilities&CapabilityClientSessionTrack |
// Pass-through ClientFoundRows flag.
CapabilityClientFoundRows&uint32(params.Flags)
length :=
4 + // Client capability flags.
4 + // Max-packet size.
1 + // Character set.
23 // Reserved.
// Add the DB name if the server supports it.
if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) {
capabilityFlags |= CapabilityClientConnectWithDB
}
data, pos := c.startEphemeralPacketWithHeader(length)
// Client capability flags.
pos = writeUint32(data, pos, capabilityFlags)
// Max-packet size, always 0. See doc.go.
pos = writeZeroes(data, pos, 4)
// Character set.
_ = writeByte(data, pos, characterSet)
// And send it as is.
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "cannot send SSLRequest: %v", err)
}
return nil
}
// CapabilityFlags are client capability flag sent to mysql on connect
const CapabilityFlags uint32 = CapabilityClientLongPassword |
CapabilityClientLongFlag |
CapabilityClientProtocol41 |
CapabilityClientTransactions |
CapabilityClientSecureConnection |
CapabilityClientMultiStatements |
CapabilityClientMultiResults |
CapabilityClientPluginAuth |
CapabilityClientPluginAuthLenencClientData
// CapabilityFlagsSsl signals that we can handle SSL as well
const CapabilityFlagsSsl = CapabilityFlags |
CapabilityClientSSL
// writeHandshakeResponse41 writes the handshake response.
// Returns a SQLError.
func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword []byte, characterSet uint8, params *ConnParams) error {
// Build our flags.
capabilityFlags := CapabilityFlags |
// If the server supported
// CapabilityClientDeprecateEOF, we also support it.
c.Capabilities&CapabilityClientDeprecateEOF |
// Pass-through ClientFoundRows flag.
CapabilityClientFoundRows&uint32(params.Flags) |
// If the server supported
// CapabilityClientSessionTrack, we also support it.
c.Capabilities&CapabilityClientSessionTrack
// FIXME(alainjobart) add multi statement.
length :=
4 + // Client capability flags.
4 + // Max-packet size.
1 + // Character set.
23 + // Reserved.
lenNullString(params.Uname) +
// length of scrambled password is handled below.
len(scrambledPassword) +
len(c.authPluginName) +
1 // terminating zero.
// Add the DB name if the server supports it.
if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) {
capabilityFlags |= CapabilityClientConnectWithDB
length += lenNullString(params.DbName)
}
if capabilities&CapabilityClientPluginAuthLenencClientData != 0 {
length += lenEncIntSize(uint64(len(scrambledPassword)))
} else {
length++
}
data, pos := c.startEphemeralPacketWithHeader(length)
// Client capability flags.
pos = writeUint32(data, pos, capabilityFlags)
// Max-packet size, always 0. See doc.go.
pos = writeZeroes(data, pos, 4)
// Character set.
pos = writeByte(data, pos, characterSet)
// 23 reserved bytes, all 0.
pos = writeZeroes(data, pos, 23)
// Username
pos = writeNullString(data, pos, params.Uname)
// Scrambled password. The length is encoded as variable length if
// CapabilityClientPluginAuthLenencClientData is set.
if capabilities&CapabilityClientPluginAuthLenencClientData != 0 {
pos = writeLenEncInt(data, pos, uint64(len(scrambledPassword)))
} else {
data[pos] = byte(len(scrambledPassword))
pos++
}
pos += copy(data[pos:], scrambledPassword)
// DbName, only if server supports it.
if params.DbName != "" && (capabilities&CapabilityClientConnectWithDB != 0) {
pos = writeNullString(data, pos, params.DbName)
c.schemaName = params.DbName
}
// Assume native client during response
pos = writeNullString(data, pos, string(c.authPluginName))
// Sanity-check the length.
if pos != len(data) {
return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "writeHandshakeResponse41: only packed %v bytes, out of %v allocated", pos, len(data))
}
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "cannot send HandshakeResponse41: %v", err)
}
return nil
}
// handleAuthResponse parses server's response after client sends the password for authentication
// and handles next steps for AuthSwitchRequestPacket and AuthMoreDataPacket.
func (c *Conn) handleAuthResponse(params *ConnParams) error {
response, err := c.readPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
switch response[0] {
case OKPacket:
// OK packet, we are authenticated. Save the user, keep going.
c.User = params.Uname
case AuthSwitchRequestPacket:
// Server is asking to use a different auth method
if err = c.handleAuthSwitchPacket(params, response); err != nil {
return err
}
case AuthMoreDataPacket:
// Server is requesting more data - maybe un-scrambled password
if err := c.handleAuthMoreDataPacket(response[1], params); err != nil {
return err
}
case ErrPacket:
return ParseErrorPacket(response)
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "initial server response cannot be parsed: %v", response)
}
return nil
}
// handleAuthSwitchPacket scrambles password for the plugin requested by the server and retries authentication
func (c *Conn) handleAuthSwitchPacket(params *ConnParams, response []byte) error {
var err error
var salt []byte
c.authPluginName, salt, err = parseAuthSwitchRequest(response)
if err != nil {
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err)
}
if salt != nil {
c.salt = salt
}
switch c.authPluginName {
case MysqlClearPassword:
if err := c.writeClearTextPassword(params); err != nil {
return err
}
case MysqlNativePassword:
scrambledPassword := ScrambleMysqlNativePassword(c.salt, []byte(params.Pass))
if err := c.writeScrambledPassword(scrambledPassword); err != nil {
return err
}
case CachingSha2Password:
scrambledPassword := ScrambleCachingSha2Password(c.salt, []byte(params.Pass))
if err := c.writeScrambledPassword(scrambledPassword); err != nil {
return err
}
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", c.authPluginName)
}
// The response could be an OKPacket, AuthMoreDataPacket or ErrPacket
return c.handleAuthResponse(params)
}
// handleAuthMoreDataPacket handles response of CachingSha2Password authentication and sends full password to the
// server if requested
func (c *Conn) handleAuthMoreDataPacket(data byte, params *ConnParams) error {
switch data {
case CachingSha2FastAuth:
// User credentials are verified using the cache ("Fast" path).
// Next packet should be an OKPacket
return c.handleAuthResponse(params)
case CachingSha2FullAuth:
// User credentials are not cached, we have to exchange full password.
if c.Capabilities&CapabilityClientSSL > 0 || params.UnixSocket != "" {
// If we are using an SSL connection or Unix socket, write clear text password
if err := c.writeClearTextPassword(params); err != nil {
return err
}
} else {
// If we are not using an SSL connection or Unix socket, we have to fetch a public key
// from the server to encrypt password
pub, err := c.requestPublicKey()
if err != nil {
return err
}
// Encrypt password with public key
enc, err := EncryptPasswordWithPublicKey(c.salt, []byte(params.Pass), pub)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error encrypting password with public key: %v", err)
}
// Write encrypted password
if err := c.writeScrambledPassword(enc); err != nil {
return err
}
}
// Next packet should either be an OKPacket or ErrPacket
return c.handleAuthResponse(params)
default:
return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse AuthMoreDataPacket: %v", data)
}
}
func parseAuthSwitchRequest(data []byte) (AuthMethodDescription, []byte, error) {
pos := 1
pluginName, pos, ok := readNullString(data, pos)
if !ok {
return "", nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot get plugin name from AuthSwitchRequest: %v", data)
}
// If this was a request with a salt in it, max 20 bytes
salt := data[pos:]
if len(salt) > 20 {
salt = salt[:20]
}
return AuthMethodDescription(pluginName), salt, nil
}
// requestPublicKey requests a public key from the server
func (c *Conn) requestPublicKey() (rsaKey *rsa.PublicKey, err error) {
// get public key from server
data, pos := c.startEphemeralPacketWithHeader(1)
data[pos] = 0x02
if err := c.writeEphemeralPacket(); err != nil {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error sending public key request packet: %v", err)
}
response, err := c.readPacket()
if err != nil {
return nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
// Server should respond with a AuthMoreDataPacket containing the public key
if response[0] != AuthMoreDataPacket {
return nil, ParseErrorPacket(response)
}
block, _ := pem.Decode(response[1:])
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to parse public key from server: %v", err)
}
return pub.(*rsa.PublicKey), nil
}
// writeClearTextPassword writes the clear text password.
// Returns a SQLError.
func (c *Conn) writeClearTextPassword(params *ConnParams) error {
length := len(params.Pass) + 1
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeNullString(data, pos, params.Pass)
// Sanity check.
if pos != len(data) {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building ClearTextPassword packet: got %v bytes expected %v", pos, len(data))
}
return c.writeEphemeralPacket()
}
// writeScrambledPassword writes the encrypted mysql_native_password format
// Returns a SQLError.
func (c *Conn) writeScrambledPassword(scrambledPassword []byte) error {
data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword))
pos += copy(data[pos:], scrambledPassword)
// Sanity check.
if pos != len(data) {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error building %v packet: got %v bytes expected %v", c.authPluginName, pos, len(data))
}
return c.writeEphemeralPacket()
}