feat: move collation and collation env to mysql.Conn instead of connParams

Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
This commit is contained in:
Florent Poinsard 2021-11-12 11:40:31 +01:00
Родитель 303c3d1b78
Коммит a56fe1d6ff
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 87A9DEBFB0824A2D
6 изменённых файлов: 69 добавлений и 73 удалений

Просмотреть файл

@ -218,8 +218,6 @@ func (c *Conn) Ping() error {
// allows us to make informed decisions around charset's default collation
// depending on the MySQL/MariaDB version we are using.
func setCollationForConnection(c *Conn, params *ConnParams) error {
var collStr string
// Once we have done the initial handshake with MySQL, we receive the server version
// string. This string is critical as it enables the instantiation of a new collation
// environment variable.
@ -230,21 +228,18 @@ func setCollationForConnection(c *Conn, params *ConnParams) error {
return err
}
// The collation environment is stored inside the connection parameters struct.
// We will use it to verify that execution requests issued by VTGate match the
// same collation as the one used to communicate with MySQL.
params.CollationEnvironment = env
var coll collations.Collation
charset := params.Charset
// if there is no collation or charset, we default to utf8mb4
if params.Collation == "" && params.Charset == "" {
params.Charset = "utf8mb4"
if params.Collation == "" && charset == "" {
charset = "utf8mb4"
}
var coll collations.Collation
if params.Collation == "" {
// If there is no collation we will just use the charset's default collation
// otherwise we directly use the given collation.
coll = env.DefaultCollationForCharset(params.Charset)
coll = env.DefaultCollationForCharset(charset)
} else {
// Here we call the collations API to ensure the collation/charset exist
// and is supported by Vitess.
@ -254,16 +249,19 @@ func setCollationForConnection(c *Conn, params *ConnParams) error {
// The given collation is most likely unknown or unsupported, we need to fail.
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot resolve collation: '%s'", params.Collation)
}
collStr = coll.Name()
params.Collation = collStr
// We send a query to MySQL to set the connection's collation.
// See: https://dev.mysql.com/doc/refman/8.0/en/charset-connection.html
querySetCollation := fmt.Sprintf("SET collation_connection = %s;", collStr)
querySetCollation := fmt.Sprintf("SET collation_connection = %s;", coll.Name())
_, err = c.ExecuteFetch(querySetCollation, 1, false)
if err != nil {
return err
}
// The collation environment is stored inside the connection parameters struct.
// We will use it to verify that execution requests issued by VTGate match the
// same collation as the one used to communicate with MySQL.
c.CollationEnvironment = env
c.Collation = coll.ID()
return nil
}

Просмотреть файл

@ -29,6 +29,8 @@ import (
"sync"
"time"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/bucketpool"
@ -36,7 +38,7 @@ import (
"vitess.io/vitess/go/sync2"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
)
@ -186,6 +188,17 @@ type Conn struct {
// See the values in constants.go.
CharacterSet uint8
// Collation defines the collation for this connection, it has the same
// value as the collation_connection variable of MySQL.
// Its value is set after we send the initial "SET collation_connection"
// query to MySQL after the handshake is done.
Collation collations.ID
// CollationEnvironment defines the collation environment used by this
// connection. We set its value using the ServerVersion we receive from
// MySQL after the handshake.
CollationEnvironment *collations.Environment
// Packet encoding variables.
sequence uint8
}
@ -346,7 +359,7 @@ func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
sequence := uint8(c.header[3])
if sequence != c.sequence {
return 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "invalid sequence, expected %v got %v", c.sequence, sequence)
return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence, expected %v got %v", c.sequence, sequence)
}
c.sequence++
@ -364,7 +377,7 @@ func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
// it most likely will be io.EOF.
func (c *Conn) readEphemeralPacket() ([]byte, error) {
if c.currentEphemeralPolicy != ephemeralUnused {
panic(vterrors.Errorf(vtrpc.Code_INTERNAL, "readEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
}
r := c.getReader()
@ -424,7 +437,7 @@ func (c *Conn) readEphemeralPacket() ([]byte, error) {
// This function usually shouldn't be used - use readEphemeralPacket.
func (c *Conn) readEphemeralPacketDirect() ([]byte, error) {
if c.currentEphemeralPolicy != ephemeralUnused {
panic(vterrors.Errorf(vtrpc.Code_INTERNAL, "readEphemeralPacketDirect: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
}
var r io.Reader = c.conn
@ -449,7 +462,7 @@ func (c *Conn) readEphemeralPacketDirect() ([]byte, error) {
return *c.currentEphemeralBuffer, nil
}
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "readEphemeralPacketDirect doesn't support more than one packet")
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "readEphemeralPacketDirect doesn't support more than one packet")
}
// recycleReadPacket recycles the read packet. It needs to be called
@ -457,7 +470,7 @@ func (c *Conn) readEphemeralPacketDirect() ([]byte, error) {
func (c *Conn) recycleReadPacket() {
if c.currentEphemeralPolicy != ephemeralRead {
// Programming error.
panic(vterrors.Errorf(vtrpc.Code_INTERNAL, "trying to call recycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
}
if c.currentEphemeralBuffer != nil {
// We are using the pool, put the buffer back in.
@ -570,7 +583,7 @@ func (c *Conn) writePacket(data []byte) error {
if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil {
return vterrors.Wrapf(err, "Write(packet) failed")
} else if n != (toBeSent + packetHeaderSize) {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize))
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize))
}
// restore the first 4 bytes once the network send is done
@ -591,7 +604,7 @@ func (c *Conn) writePacket(data []byte) error {
if n, err := w.Write(header[:]); err != nil {
return vterrors.Wrapf(err, "Write(empty header) failed")
} else if n != packetHeaderSize {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(empty header) returned a short write: %v < 4", n)
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Write(empty header) returned a short write: %v < 4", n)
}
c.sequence++
}
@ -624,7 +637,7 @@ func (c *Conn) writeEphemeralPacket() error {
}
case ephemeralUnused, ephemeralRead:
// Programming error.
panic(vterrors.Errorf(vtrpc.Code_INTERNAL, "conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.ID(), c.currentEphemeralPolicy))
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.ID(), c.currentEphemeralPolicy))
}
return nil
@ -635,7 +648,7 @@ func (c *Conn) writeEphemeralPacket() error {
func (c *Conn) recycleWritePacket() {
if c.currentEphemeralPolicy != ephemeralWrite {
// Programming error.
panic(vterrors.Errorf(vtrpc.Code_INTERNAL, "trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
}
// Release our reference so the buffer can be gced
bufPool.Put(c.currentEphemeralBuffer)
@ -1359,7 +1372,7 @@ func parseEOFPacket(data []byte) (warnings uint16, statusFlags uint16, err error
// The status flag is in position 4 & 5
statusFlags, _, ok := readUint16(data, 3)
if !ok {
return 0, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "invalid EOF packet statusFlags: %v", data)
return 0, 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid EOF packet statusFlags: %v", data)
}
return warnings, statusFlags, nil
}
@ -1384,7 +1397,7 @@ func (c *Conn) parseOKPacket(in []byte) (*PacketOK, error) {
packetOK := &PacketOK{}
fail := func(format string, args ...interface{}) (*PacketOK, error) {
return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, format, args...)
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, format, args...)
}
// Affected rows.
@ -1511,3 +1524,24 @@ func (c *Conn) IsUnixSocket() bool {
func (c *Conn) GetRawConn() net.Conn {
return c.conn
}
// MatchCollation returns nil if the given collations.ID matches with the connection's
// collation, otherwise it returns an error explaining why it does not match.
// We do the comparison all the way down in the Connector to use mysql.ConnParams
// collations environment to achieve the collation lookup using the same server version.
func (c *Conn) MatchCollation(collationID collations.ID) error {
// The collation environment of a connection parameter should never be nil, if we fail
// to create it we already errored out when initializing the connection with MySQL.
if c.CollationEnvironment == nil {
return vterrors.New(vtrpcpb.Code_INTERNAL, "No collation environment for this connection")
}
coll := c.CollationEnvironment.LookupByID(collationID)
if coll == nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "QueryOption's Collation is unknown (collation ID: %d)", collationID)
}
if coll.ID() != c.Collation {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "QueryOption ('%v') and VTTablet ('%v') charsets do not match", coll.Name(), c.CollationEnvironment.LookupByID(c.Collation).Name())
}
return nil
}

Просмотреть файл

@ -17,23 +17,21 @@ limitations under the License.
package mysql
import (
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/vt/vttls"
)
// ConnParams contains all the parameters to use to connect to mysql.
type ConnParams struct {
Host string `json:"host"`
Port int `json:"port"`
Uname string `json:"uname"`
Pass string `json:"pass"`
DbName string `json:"dbname"`
UnixSocket string `json:"unix_socket"`
Charset string `json:"charset"`
Collation string `json:"collation"`
Flags uint64 `json:"flags"`
Flavor string `json:"flavor,omitempty"`
CollationEnvironment *collations.Environment `json:"collation_environment,omitempty"`
Host string `json:"host"`
Port int `json:"port"`
Uname string `json:"uname"`
Pass string `json:"pass"`
DbName string `json:"dbname"`
UnixSocket string `json:"unix_socket"`
Charset string `json:"charset"`
Collation string `json:"collation"`
Flags uint64 `json:"flags"`
Flavor string `json:"flavor,omitempty"`
// The following SSL flags control the SSL behavior.
//

Просмотреть файл

@ -409,7 +409,7 @@ func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) {
// when creating a connection to the database, we send an initial query to set the connection's
// collation, we want to skip the query check if we get such initial query.
// this is done to ease the test readability.
if query == "SET collation_connection = utf8mb4_general_ci" {
if strings.HasPrefix(query, "SET collation_connection =") {
return &sqltypes.Result{}, nil
}

Просмотреть файл

@ -25,7 +25,6 @@ import (
"encoding/json"
"flag"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/vt/vttls"
"vitess.io/vitess/go/mysql"
@ -186,38 +185,6 @@ func New(mcp *mysql.ConnParams) Connector {
}
}
// SetCollationInformation sets the charset, collation and collation environment of the
// connection.
func (c *Connector) SetCollationInformation(charset, collation string, env *collations.Environment) {
if c.connParams.Charset != charset && c.connParams.Collation != collation {
c.connParams.Collation = collation
c.connParams.Charset = charset
c.connParams.CollationEnvironment = env
}
}
// MatchCollation returns nil if the given collations.ID matches with the connection's
// collation, otherwise it returns an error explaining why it does not match.
// We do the comparison all the way down in the Connector to use mysql.ConnParams
// collations environment to achieve the collation lookup using the same server version.
func (c Connector) MatchCollation(collationID collations.ID) error {
// The collation environment of a connection parameter should never be nil, if we fail
// to create it we already errored out when initializing the connection with MySQL.
if c.connParams.CollationEnvironment == nil {
return vterrors.New(vtrpcpb.Code_INTERNAL, "No collation environment for this connection")
}
coll := c.connParams.CollationEnvironment.LookupByID(collationID)
if coll == nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "QueryOption's Collation is unknown (collation ID: %d)", collationID)
}
if coll.Name() != c.connParams.Collation {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "QueryOption ('%v') and VTTablet ('%v') charsets do not match", coll.Name(), c.connParams.Collation)
}
return nil
}
// Connect will invoke the mysql.connect method and return a connection
func (c *Connector) Connect(ctx context.Context) (*mysql.Conn, error) {
params, err := c.MysqlParams()
@ -228,7 +195,6 @@ func (c *Connector) Connect(ctx context.Context) (*mysql.Conn, error) {
if err != nil {
return nil, err
}
c.SetCollationInformation(params.Charset, params.Collation, params.CollationEnvironment)
return conn, nil
}

Просмотреть файл

@ -445,5 +445,5 @@ func (dbc *DBConn) setDeadline(ctx context.Context) (chan bool, *sync.WaitGroup)
// collation matches with the given collation ID.
// If it does not match an error will be returned explaining why.
func (dbc *DBConn) MatchCollation(collationID collations.ID) error {
return dbc.info.MatchCollation(collationID)
return dbc.conn.MatchCollation(collationID)
}