зеркало из https://github.com/github/vitess-gh.git
384 строки
7.9 KiB
Go
384 строки
7.9 KiB
Go
//go:build gofuzz
|
|
// +build gofuzz
|
|
|
|
/*
|
|
Copyright 2021 The Vitess Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path"
|
|
"sync"
|
|
"time"
|
|
|
|
gofuzzheaders "github.com/AdaLogics/go-fuzz-headers"
|
|
|
|
"vitess.io/vitess/go/sqltypes"
|
|
querypb "vitess.io/vitess/go/vt/proto/query"
|
|
"vitess.io/vitess/go/vt/tlstest"
|
|
"vitess.io/vitess/go/vt/vttls"
|
|
)
|
|
|
|
func createFuzzingSocketPair() (net.Listener, *Conn, *Conn) {
|
|
// Create a listener.
|
|
listener, err := net.Listen("tcp", "127.0.0.1:")
|
|
if err != nil {
|
|
fmt.Println("We got an error early on")
|
|
return nil, nil, nil
|
|
}
|
|
addr := listener.Addr().String()
|
|
listener.(*net.TCPListener).SetDeadline(time.Now().Add(10 * time.Second))
|
|
|
|
// Dial a client, Accept a server.
|
|
wg := sync.WaitGroup{}
|
|
|
|
var clientConn net.Conn
|
|
var clientErr error
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
clientConn, clientErr = net.DialTimeout("tcp", addr, 10*time.Second)
|
|
}()
|
|
|
|
var serverConn net.Conn
|
|
var serverErr error
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
serverConn, serverErr = listener.Accept()
|
|
}()
|
|
|
|
wg.Wait()
|
|
|
|
if clientErr != nil {
|
|
return nil, nil, nil
|
|
}
|
|
if serverErr != nil {
|
|
return nil, nil, nil
|
|
}
|
|
|
|
// Create a Conn on both sides.
|
|
cConn := newConn(clientConn)
|
|
sConn := newConn(serverConn)
|
|
|
|
return listener, sConn, cConn
|
|
}
|
|
|
|
type fuzztestRun struct {
|
|
UnimplementedHandler
|
|
}
|
|
|
|
func (t fuzztestRun) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
|
|
return nil
|
|
}
|
|
|
|
func (t fuzztestRun) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (t fuzztestRun) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error {
|
|
return nil
|
|
}
|
|
|
|
func (t fuzztestRun) WarningCount(c *Conn) uint16 {
|
|
return 0
|
|
}
|
|
|
|
var _ Handler = (*fuzztestRun)(nil)
|
|
|
|
type fuzztestConn struct {
|
|
writeToPass []bool
|
|
pos int
|
|
queryPacket []byte
|
|
}
|
|
|
|
func (t fuzztestConn) Read(b []byte) (n int, err error) {
|
|
for i := 0; i < len(b) && i < len(t.queryPacket); i++ {
|
|
b[i] = t.queryPacket[i]
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
func (t fuzztestConn) Write(b []byte) (n int, err error) {
|
|
t.pos = t.pos + 1
|
|
if t.writeToPass[t.pos] {
|
|
return 0, nil
|
|
}
|
|
return 0, fmt.Errorf("error in writing to connection")
|
|
}
|
|
|
|
func (t fuzztestConn) Close() error {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (t fuzztestConn) LocalAddr() net.Addr {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (t fuzztestConn) RemoteAddr() net.Addr {
|
|
return fuzzmockAddress{s: "a"}
|
|
}
|
|
|
|
func (t fuzztestConn) SetDeadline(t1 time.Time) error {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (t fuzztestConn) SetReadDeadline(t1 time.Time) error {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (t fuzztestConn) SetWriteDeadline(t1 time.Time) error {
|
|
panic("implement me")
|
|
}
|
|
|
|
var _ net.Conn = (*fuzztestConn)(nil)
|
|
|
|
type fuzzmockAddress struct {
|
|
s string
|
|
}
|
|
|
|
func (m fuzzmockAddress) Network() string {
|
|
return m.s
|
|
}
|
|
|
|
func (m fuzzmockAddress) String() string {
|
|
return m.s
|
|
}
|
|
|
|
var _ net.Addr = (*fuzzmockAddress)(nil)
|
|
|
|
// Fuzzers begin here:
|
|
func FuzzWritePacket(data []byte) int {
|
|
if len(data) < 10 {
|
|
return -1
|
|
}
|
|
listener, sConn, cConn := createFuzzingSocketPair()
|
|
defer func() {
|
|
listener.Close()
|
|
sConn.Close()
|
|
cConn.Close()
|
|
}()
|
|
|
|
err := cConn.writePacket(data)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
_, err = sConn.ReadPacket()
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func FuzzHandleNextCommand(data []byte) int {
|
|
if len(data) < 10 {
|
|
return -1
|
|
}
|
|
sConn := newConn(fuzztestConn{
|
|
writeToPass: []bool{false},
|
|
pos: -1,
|
|
queryPacket: data,
|
|
})
|
|
sConn.PrepareData = map[uint32]*PrepareData{}
|
|
|
|
handler := &fuzztestRun{}
|
|
_ = sConn.handleNextCommand(handler)
|
|
return 1
|
|
}
|
|
|
|
func FuzzReadQueryResults(data []byte) int {
|
|
listener, sConn, cConn := createFuzzingSocketPair()
|
|
defer func() {
|
|
listener.Close()
|
|
sConn.Close()
|
|
cConn.Close()
|
|
}()
|
|
err := cConn.WriteComQuery(string(data))
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
handler := &fuzztestRun{}
|
|
_ = sConn.handleNextCommand(handler)
|
|
_, _, _, err = cConn.ReadQueryResult(100, true)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return 1
|
|
}
|
|
|
|
type fuzzTestHandler struct {
|
|
UnimplementedHandler
|
|
|
|
mu sync.Mutex
|
|
lastConn *Conn
|
|
result *sqltypes.Result
|
|
err error
|
|
warnings uint16
|
|
}
|
|
|
|
func (th *fuzzTestHandler) LastConn() *Conn {
|
|
th.mu.Lock()
|
|
defer th.mu.Unlock()
|
|
return th.lastConn
|
|
}
|
|
|
|
func (th *fuzzTestHandler) Result() *sqltypes.Result {
|
|
th.mu.Lock()
|
|
defer th.mu.Unlock()
|
|
return th.result
|
|
}
|
|
|
|
func (th *fuzzTestHandler) SetErr(err error) {
|
|
th.mu.Lock()
|
|
defer th.mu.Unlock()
|
|
th.err = err
|
|
}
|
|
|
|
func (th *fuzzTestHandler) Err() error {
|
|
th.mu.Lock()
|
|
defer th.mu.Unlock()
|
|
return th.err
|
|
}
|
|
|
|
func (th *fuzzTestHandler) SetWarnings(count uint16) {
|
|
th.mu.Lock()
|
|
defer th.mu.Unlock()
|
|
th.warnings = count
|
|
}
|
|
|
|
func (th *fuzzTestHandler) NewConnection(c *Conn) {
|
|
th.mu.Lock()
|
|
defer th.mu.Unlock()
|
|
th.lastConn = c
|
|
}
|
|
|
|
func (th *fuzzTestHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
|
|
|
|
return nil
|
|
}
|
|
|
|
func (th *fuzzTestHandler) ComPrepare(c *Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (th *fuzzTestHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error {
|
|
return nil
|
|
}
|
|
|
|
func (th *fuzzTestHandler) ComResetConnection(c *Conn) {
|
|
|
|
}
|
|
|
|
func (th *fuzzTestHandler) WarningCount(c *Conn) uint16 {
|
|
th.mu.Lock()
|
|
defer th.mu.Unlock()
|
|
return th.warnings
|
|
}
|
|
|
|
func (c *Conn) writeFuzzedPacket(packet []byte) {
|
|
c.sequence = 0
|
|
data, pos := c.startEphemeralPacketWithHeader(len(packet) + 1)
|
|
copy(data[pos:], packet)
|
|
_ = c.writeEphemeralPacket()
|
|
}
|
|
|
|
func FuzzTLSServer(data []byte) int {
|
|
if len(data) < 40 {
|
|
return -1
|
|
}
|
|
// totalQueries is the number of queries the fuzzer
|
|
// makes in each fuzz iteration
|
|
totalQueries := 20
|
|
var queries [][]byte
|
|
c := gofuzzheaders.NewConsumer(data)
|
|
for i := 0; i < totalQueries; i++ {
|
|
query, err := c.GetBytes()
|
|
if err != nil {
|
|
return -1
|
|
}
|
|
if len(query) < 40 {
|
|
continue
|
|
}
|
|
queries = append(queries, query)
|
|
}
|
|
|
|
th := &fuzzTestHandler{}
|
|
|
|
authServer := NewAuthServerStatic("", "", 0)
|
|
authServer.entries["user1"] = []*AuthServerStaticEntry{{
|
|
Password: "password1",
|
|
}}
|
|
defer authServer.close()
|
|
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false)
|
|
if err != nil {
|
|
return -1
|
|
}
|
|
defer l.Close()
|
|
|
|
host := l.Addr().(*net.TCPAddr).IP.String()
|
|
port := l.Addr().(*net.TCPAddr).Port
|
|
root, err := os.MkdirTemp("", "TestTLSServer")
|
|
if err != nil {
|
|
return -1
|
|
}
|
|
defer os.RemoveAll(root)
|
|
tlstest.CreateCA(root)
|
|
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
|
|
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
|
|
|
|
serverConfig, err := vttls.ServerConfig(
|
|
path.Join(root, "server-cert.pem"),
|
|
path.Join(root, "server-key.pem"),
|
|
path.Join(root, "ca-cert.pem"),
|
|
"",
|
|
"",
|
|
tls.VersionTLS12)
|
|
if err != nil {
|
|
return -1
|
|
}
|
|
l.TLSConfig.Store(serverConfig)
|
|
go l.Accept()
|
|
|
|
connCountByTLSVer.ResetAll()
|
|
// Setup the right parameters.
|
|
params := &ConnParams{
|
|
Host: host,
|
|
Port: port,
|
|
Uname: "user1",
|
|
Pass: "password1",
|
|
// SSL flags.
|
|
SslMode: vttls.VerifyIdentity,
|
|
SslCa: path.Join(root, "ca-cert.pem"),
|
|
SslCert: path.Join(root, "client-cert.pem"),
|
|
SslKey: path.Join(root, "client-key.pem"),
|
|
ServerName: "server.example.com",
|
|
}
|
|
conn, err := Connect(context.Background(), params)
|
|
if err != nil {
|
|
return -1
|
|
}
|
|
|
|
for i := 0; i < len(queries); i++ {
|
|
conn.writeFuzzedPacket(queries[i])
|
|
}
|
|
return 1
|
|
}
|