Merge pull request #1383 from michael-berlin/go_sql_driver_shard

vitessdriver: Add helper methods and refactor tests.
This commit is contained in:
Michael Berlin 2015-12-09 21:48:24 -08:00
Родитель 1a427288f3 50b9021a2c
Коммит 45544ca7c5
6 изменённых файлов: 496 добавлений и 253 удалений

Просмотреть файл

@ -10,14 +10,12 @@
package main
import (
"database/sql"
"flag"
"fmt"
"os"
"time"
// import the 'vitess' sql driver
_ "github.com/youtube/vitess/go/vt/vitessdriver"
"github.com/youtube/vitess/go/vt/vitessdriver"
)
var (
@ -25,16 +23,14 @@ var (
)
func main() {
keyspace := "test_keyspace"
timeout := (10 * time.Second).Nanoseconds()
shard := "0"
flag.Parse()
keyspace := "test_keyspace"
shard := "0"
timeout := 10 * time.Second
// Connect to vtgate.
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "keyspace": "%s", "shard": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`,
*server, keyspace, shard, "master", false, timeout)
db, err := sql.Open("vitess", connStr)
db, err := vitessdriver.OpenShard(*server, keyspace, shard, "master", timeout)
if err != nil {
fmt.Printf("client error: %v\n", err)
os.Exit(1)
@ -82,9 +78,7 @@ func main() {
// Note that this may be behind master due to replication lag.
fmt.Println("Reading from replica...")
connStr = fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "keyspace": "%s", "shard": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`,
*server, keyspace, shard, "replica", false, timeout)
dbr, err := sql.Open("vitess", connStr)
dbr, err := vitessdriver.OpenShard(*server, keyspace, shard, "replica", timeout)
if err != nil {
fmt.Printf("client error: %v\n", err)
os.Exit(1)

Просмотреть файл

@ -5,7 +5,6 @@
package main
import (
"database/sql"
"encoding/json"
"flag"
"fmt"
@ -17,8 +16,7 @@ import (
"github.com/youtube/vitess/go/exit"
"github.com/youtube/vitess/go/vt/logutil"
// import the 'vitess' sql driver
_ "github.com/youtube/vitess/go/vt/vitessdriver"
"github.com/youtube/vitess/go/vt/vitessdriver"
)
var (
@ -106,8 +104,15 @@ func main() {
exit.Return(1)
}
connStr := fmt.Sprintf(`{"address": "%s", "keyspace": "%s", "shard": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`, *server, *keyspace, *shard, *tabletType, *streaming, int64(30*(*timeout)))
db, err := sql.Open("vitess", connStr)
c := vitessdriver.Configuration{
Address: *server,
Keyspace: *keyspace,
Shard: *shard,
TabletType: *tabletType,
Timeout: *timeout,
Streaming: *streaming,
}
db, err := vitessdriver.OpenWithConfiguration(c)
if err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)

Просмотреть файл

@ -25,32 +25,91 @@ func init() {
sql.Register("vitess", drv{})
}
// TODO(mberlin): Add helper methods.
// Open is a Vitess helper function for sql.Open().
//
// It opens a database connection to vtgate running at "address".
//
// Note that this is the vtgate v3 mode and requires a loaded VSchema.
func Open(address, tabletType string, timeout time.Duration) (*sql.DB, error) {
return OpenShard(address, "" /* keyspace */, "" /* shard */, tabletType, timeout)
}
// OpenShard connects to vtgate running at "address".
//
// Unlike Open(), all queries will target a specific shard in a given keyspace
// ("fallback" mode to vtgate v1).
//
// This mode is recommended when you want to try out Vitess initially because it
// does not require defining a VSchema. Just replace the MySQL/MariaDB driver
// invocation in your application with the Vitess driver.
func OpenShard(address, keyspace, shard, tabletType string, timeout time.Duration) (*sql.DB, error) {
c := newDefaultConfiguration()
c.Address = address
c.Keyspace = keyspace
c.Shard = shard
c.TabletType = tabletType
c.Timeout = timeout
return OpenWithConfiguration(c)
}
// OpenForStreaming is the same as Open() but uses streaming RPCs to retrieve
// the results.
//
// The streaming mode is recommended for large results.
func OpenForStreaming(address, tabletType string, timeout time.Duration) (*sql.DB, error) {
return OpenShardForStreaming(address, "" /* keyspace */, "" /* shard */, tabletType, timeout)
}
// OpenShardForStreaming is the same as OpenShard() but uses streaming RPCs to
// retrieve the results.
//
// The streaming mode is recommended for large results.
func OpenShardForStreaming(address, keyspace, shard, tabletType string, timeout time.Duration) (*sql.DB, error) {
c := newDefaultConfiguration()
c.Address = address
c.Keyspace = keyspace
c.Shard = shard
c.TabletType = tabletType
c.Timeout = timeout
c.Streaming = true
return OpenWithConfiguration(c)
}
// OpenWithConfiguration is the generic Vitess helper function for sql.Open().
//
// It allows to pass in a Configuration struct to control all possible
// settings of the Vitess Go SQL driver.
func OpenWithConfiguration(c Configuration) (*sql.DB, error) {
jsonBytes, err := json.Marshal(c)
if err != nil {
return nil, err
}
return sql.Open("vitess", string(jsonBytes))
}
type drv struct {
}
// Open must be called with a JSON string that looks like this:
// Open implements the database/sql/driver.Driver interface.
//
// For "name", the Vitess driver requires that a JSON object is passed in.
//
// Instead of using this call and passing in a hand-crafted JSON string, it's
// recommended to use the public Vitess helper functions like
// Open(), OpenShard() or OpenWithConfiguration() instead. These will generate
// the required JSON string behind the scenes for you.
//
// Example for a JSON string:
//
// {"protocol": "gorpc", "address": "localhost:1111", "tablet_type": "master", "timeout": 1000000000}
//
// protocol specifies the rpc protocol to use.
// address specifies the address for the VTGate to connect to.
// tablet_type represents the consistency level of your operations.
// For example "replica" means eventually consistent reads, while
// "master" supports transactions and gives you read-after-write consistency.
// timeout is specified in nanoseconds. It applies for all operations.
//
// If you want to execute queries which are not supported by vtgate v3, you can
// run queries against a specific keyspace and shard.
// Therefore, add the fields "keyspace" and "shard" to the JSON string. Example:
//
// {"protocol": "gorpc", "address": "localhost:1111", "keyspace": "ks1", "shard": "0", "tablet_type": "master", "timeout": 1000000000}
// For a description of the available fields, see the Configuration struct.
// Note: In the JSON string, timeout has to be specified in nanoseconds.
//
// Note that this function will always create a connection to vtgate i.e. there
// is no need to call DB.Ping() to verify the connection.
func (d drv) Open(name string) (driver.Conn, error) {
c := &conn{TabletType: "master"}
c := &conn{Configuration: newDefaultConfiguration()}
err := json.Unmarshal([]byte(name), c)
if err != nil {
return nil, err
@ -61,7 +120,7 @@ func (d drv) Open(name string) (driver.Conn, error) {
if c.useExecuteShards() {
log.Infof("Sending queries only to keyspace/shard: %v/%v", c.Keyspace, c.Shard)
}
c.tabletType, err = topoproto.ParseTabletType(c.TabletType)
c.tabletTypeProto, err = topoproto.ParseTabletType(c.TabletType)
if err != nil {
return nil, err
}
@ -72,23 +131,63 @@ func (d drv) Open(name string) (driver.Conn, error) {
return c, nil
}
type conn struct {
// Configuration holds all Vitess driver settings.
//
// Fields with documented default values do not have to be set explicitly.
type Configuration struct {
// Protocol is the name of the vtgate RPC client implementation.
// Note: In open-source "grpc" is the recommended implementation.
//
// Default: "grpc"
Protocol string
Address string
// Keyspace of a specific keyspace/shard to target. Disables vtgate v3.
// If Keyspace and Shard are not empty, vtgate v2 instead of v3 will be used
// Address must point to a vtgate instance.
//
// Format: hostname:port
Address string
// Keyspace of a specific keyspace and shard to target. Disables vtgate v3.
//
// If Keyspace and Shard are not empty, vtgate v1 instead of v3 will be used
// and all requests will be sent only to that particular shard.
// This functionality is meant for initial migrations from MySQL/MariaDB to Vitess.
Keyspace string
// Shard of a specific keyspace/shard to target. Disables vtgate v3.
Shard string
TabletType string `json:"tablet_type"`
Streaming bool
Timeout time.Duration
// Shard of a specific keyspace and shard to target. Disables vtgate v3.
Shard string
tabletType topodatapb.TabletType
vtgateConn *vtgateconn.VTGateConn
tx *vtgateconn.VTGateTx
// TabletType is the type of tablet you want to access and affects the
// freshness of read data.
//
// For example, "replica" means eventually consistent reads, while
// "master" supports transactions and gives you read-after-write consistency.
//
// Default: "master"
// Allowed values: "master", "replica", "rdonly"
TabletType string `json:"tablet_type"`
// Streaming is true when streaming RPCs are used.
// Recommended for large results.
// Default: false
Streaming bool
// Timeout after which a pending query will be aborted.
Timeout time.Duration
}
func newDefaultConfiguration() Configuration {
return Configuration{
Protocol: "grpc",
TabletType: "master",
Streaming: false,
}
}
type conn struct {
Configuration
// tabletTypeProto is the protobof enum value of the string Configuration.TabletType.
tabletTypeProto topodatapb.TabletType
vtgateConn *vtgateconn.VTGateConn
tx *vtgateconn.VTGateTx
}
func (c *conn) dial() error {
@ -162,6 +261,7 @@ func (s *stmt) Close() error {
}
func (s *stmt) NumInput() int {
// -1 = Golang sql won't sanity check argument counts before Exec or Query.
return -1
}
@ -188,9 +288,9 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
var errFunc vtgateconn.ErrFunc
var err error
if s.c.useExecuteShards() {
qrc, errFunc, err = s.c.vtgateConn.StreamExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletType)
qrc, errFunc, err = s.c.vtgateConn.StreamExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletTypeProto)
} else {
qrc, errFunc, err = s.c.vtgateConn.StreamExecute(ctx, s.query, makeBindVars(args), s.c.tabletType)
qrc, errFunc, err = s.c.vtgateConn.StreamExecute(ctx, s.query, makeBindVars(args), s.c.tabletTypeProto)
}
if err != nil {
return nil, err
@ -210,16 +310,16 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
func (s *stmt) executeVitess(ctx context.Context, args []driver.Value) (*sqltypes.Result, error) {
if s.c.tx != nil {
if s.c.useExecuteShards() {
return s.c.tx.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletType, false /* notInTransaction */)
return s.c.tx.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletTypeProto, false /* notInTransaction */)
}
return s.c.tx.Execute(ctx, s.query, makeBindVars(args), s.c.tabletType, false /* notInTransaction */)
return s.c.tx.Execute(ctx, s.query, makeBindVars(args), s.c.tabletTypeProto, false /* notInTransaction */)
}
// Non-transactional case.
if s.c.useExecuteShards() {
return s.c.vtgateConn.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletType)
return s.c.vtgateConn.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletTypeProto)
}
return s.c.vtgateConn.Execute(ctx, s.query, makeBindVars(args), s.c.tabletType)
return s.c.vtgateConn.Execute(ctx, s.query, makeBindVars(args), s.c.tabletTypeProto)
}
func makeBindVars(args []driver.Value) map[string]interface{} {

Просмотреть файл

@ -5,10 +5,7 @@
package vitessdriver
import (
"database/sql"
"database/sql/driver"
"fmt"
"io"
"net"
"os"
"reflect"
@ -27,17 +24,21 @@ var (
testAddress string
)
// TestMain tests the Vitess Go SQL driver.
//
// Note that the queries used in the test are not valid SQL queries and don't
// have to be. The main point here is to test the interactions against a
// vtgate implementation (here: fakeVTGateService from fakeserver_test.go).
func TestMain(m *testing.M) {
// fake service
service := CreateFakeServer()
// listen on a random port
// listen on a random port.
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(fmt.Sprintf("Cannot listen: %v", err))
}
// Create a gRPC server and listen on the port
// Create a gRPC server and listen on the port.
server := grpc.NewServer()
grpcvtgateservice.RegisterForTest(server, service)
go server.Serve(listener)
@ -46,38 +47,22 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func TestDriver(t *testing.T) {
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`, testAddress, int64(30*time.Second))
db, err := sql.Open("vitess", connStr)
if err != nil {
t.Fatal(err)
}
r, err := db.Query("request1", int64(0))
if err != nil {
t.Fatal(err)
}
count := 0
for r.Next() {
count++
}
if count != 2 {
t.Errorf("count: %d, want 2", count)
}
_ = db.Close()
}
func TestDial(t *testing.T) {
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "replica", "timeout": %d}`, testAddress, int64(30*time.Second))
func TestOpen(t *testing.T) {
connStr := fmt.Sprintf(`{"address": "%s", "tablet_type": "replica", "timeout": %d}`, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
if err != nil {
t.Fatal(err)
}
defer c.Close()
wantc := &conn{
Protocol: "grpc",
TabletType: "replica",
Streaming: false,
Timeout: 30 * time.Second,
tabletType: topodatapb.TabletType_REPLICA,
Configuration: Configuration{
Protocol: "grpc",
TabletType: "replica",
Streaming: false,
Timeout: 30 * time.Second,
},
tabletTypeProto: topodatapb.TabletType_REPLICA,
}
newc := *(c.(*conn))
newc.Address = ""
@ -85,16 +70,67 @@ func TestDial(t *testing.T) {
if !reflect.DeepEqual(&newc, wantc) {
t.Errorf("conn: %+v, want %+v", &newc, wantc)
}
_ = c.Close()
}
_, err = drv{}.Open(`{"protocol": "none"}`)
func TestOpenShard(t *testing.T) {
connStr := fmt.Sprintf(`{"address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "replica", "timeout": %d}`, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
if err != nil {
t.Fatal(err)
}
defer c.Close()
wantc := &conn{
Configuration: Configuration{
Protocol: "grpc",
Keyspace: "ks1",
Shard: "0",
TabletType: "replica",
Streaming: false,
Timeout: 30 * time.Second,
},
tabletTypeProto: topodatapb.TabletType_REPLICA,
}
newc := *(c.(*conn))
newc.Address = ""
newc.vtgateConn = nil
if !reflect.DeepEqual(&newc, wantc) {
t.Errorf("conn: %+v, want %+v", &newc, wantc)
}
}
func TestOpen_UnregisteredProtocol(t *testing.T) {
_, err := drv{}.Open(`{"protocol": "none"}`)
want := "no dialer registered for VTGate protocol none"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, want %s", err, want)
}
}
_, err = drv{}.Open(`{`)
want = "unexpected end of JSON input"
func TestOpen_InvalidJson(t *testing.T) {
_, err := drv{}.Open(`{`)
want := "unexpected end of JSON input"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, want %s", err, want)
}
}
func TestOpen_KeyspaceAndShardBelongTogether(t *testing.T) {
_, err := drv{}.Open(`{"keyspace": "ks1"}`)
want := "Always set both keyspace and shard or leave both empty."
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, want %s", err, want)
}
_, err = drv{}.Open(`{"shard": "0"}`)
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, want %s", err, want)
}
}
func TestOpen_ValidTabletTypeRequired(t *testing.T) {
_, err := drv{}.Open(`{"tablet_type": "foobar"}`)
want := "unknown TabletType foobar"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, want %s", err, want)
}
@ -102,59 +138,85 @@ func TestDial(t *testing.T) {
func TestExec(t *testing.T) {
var testcases = []struct {
dataSourceName string
requestName string
desc string
config Configuration
requestName string
}{
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`,
requestName: "request1",
desc: "vtgate v3",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
TabletType: "rdonly",
Timeout: 30 * time.Second,
},
requestName: "request1",
},
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "rdonly", "timeout": %d}`,
requestName: "request1SpecificShard",
desc: "vtgate v1",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
Keyspace: "ks1",
Shard: "0",
TabletType: "rdonly",
Timeout: 30 * time.Second,
},
requestName: "request1SpecificShard",
},
}
for _, tc := range testcases {
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
db, err := OpenWithConfiguration(tc.config)
if err != nil {
t.Fatal(err)
t.Fatalf("%v: %v", tc.desc, err)
}
s, _ := c.Prepare(tc.requestName)
if ni := s.NumInput(); ni != -1 {
t.Errorf("got %d, want -1", ni)
}
r, err := s.Exec([]driver.Value{int64(0)})
defer db.Close()
s, err := db.Prepare(tc.requestName)
if err != nil {
t.Error(err)
t.Fatalf("%v: %v", tc.desc, err)
}
defer s.Close()
r, err := s.Exec(int64(0))
if err != nil {
t.Fatalf("%v: %v", tc.desc, err)
}
if v, _ := r.LastInsertId(); v != 72 {
t.Errorf("insert id: %d, want 72", v)
t.Fatalf("%v: insert id: %d, want 72", tc.desc, v)
}
if v, _ := r.RowsAffected(); v != 123 {
t.Errorf("rows affected: %d, want 123", v)
t.Fatalf("%v: rows affected: %d, want 123", tc.desc, v)
}
_ = s.Close()
s, _ = c.Prepare("none")
_, err = s.Exec(nil)
s2, err := db.Prepare("none")
if err != nil {
t.Fatalf("%v: %v", tc.desc, err)
}
defer s2.Close()
_, err = s2.Exec(nil)
want := "no match for: none"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, does not contain %s", err, want)
t.Errorf("%v: err: %v, does not contain %s", tc.desc, err, want)
}
_ = c.Close()
}
}
func TestExecStreamingNotAllowed(t *testing.T) {
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "streaming": true, "timeout": %d}`, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
db, err := OpenForStreaming(testAddress, "rdonly", 30*time.Second)
if err != nil {
t.Fatal(err)
}
s, _ := c.Prepare("request1")
_, err = s.Exec([]driver.Value{int64(0)})
s, err := db.Prepare("request1")
if err != nil {
t.Fatal(err)
}
defer s.Close()
_, err = s.Exec(int64(0))
want := "Exec not allowed for streaming connections"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, does not contain %s", err, want)
@ -163,195 +225,244 @@ func TestExecStreamingNotAllowed(t *testing.T) {
func TestQuery(t *testing.T) {
var testcases = []struct {
dataSourceName string
requestName string
desc string
config Configuration
requestName string
}{
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`,
requestName: "request1",
desc: "non-streaming, vtgate v3",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
TabletType: "rdonly",
Timeout: 30 * time.Second,
},
requestName: "request1",
},
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "rdonly", "timeout": %d}`,
requestName: "request1SpecificShard",
desc: "non-streaming, vtgate v1",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
Keyspace: "ks1",
Shard: "0",
TabletType: "rdonly",
Timeout: 30 * time.Second,
},
requestName: "request1SpecificShard",
},
{
desc: "streaming, vtgate v3",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
TabletType: "rdonly",
Timeout: 30 * time.Second,
Streaming: true,
},
requestName: "request1",
},
{
desc: "streaming, vtgate v1",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
Keyspace: "ks1",
Shard: "0",
TabletType: "rdonly",
Timeout: 30 * time.Second,
Streaming: true,
},
requestName: "request1SpecificShard",
},
}
for _, tc := range testcases {
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
db, err := OpenWithConfiguration(tc.config)
if err != nil {
t.Fatal(err)
t.Fatalf("%v: %v", tc.desc, err)
}
s, _ := c.Prepare(tc.requestName)
r, err := s.Query([]driver.Value{int64(0)})
defer db.Close()
s, err := db.Prepare(tc.requestName)
if err != nil {
t.Fatal(err)
t.Fatalf("%v: %v", tc.desc, err)
}
defer s.Close()
r, err := s.Query(int64(0))
if err != nil {
t.Fatalf("%v: %v", tc.desc, err)
}
cols, err := r.Columns()
if err != nil {
t.Fatalf("%v: %v", tc.desc, err)
}
cols := r.Columns()
wantCols := []string{
"field1",
"field2",
}
if !reflect.DeepEqual(cols, wantCols) {
t.Fatalf("cols: %v, want %v", cols, wantCols)
t.Fatalf("%v: cols: %v, want %v", tc.desc, cols, wantCols)
}
row := make([]driver.Value, 2)
count := 0
for {
err = r.Next(row)
wantValues := []struct {
field1 int16
field2 string
}{{1, "value1"}, {2, "value2"}}
for r.Next() {
var field1 int16
var field2 string
err := r.Scan(&field1, &field2)
if err != nil {
if err == io.EOF {
break
}
t.Error(err)
t.Fatalf("%v: %v", tc.desc, err)
}
if want := wantValues[count].field1; field1 != want {
t.Fatalf("%v: wrong value for field1: got: %v want: %v", tc.desc, field1, want)
}
if want := wantValues[count].field2; field2 != want {
t.Fatalf("%v: wrong value for field2: got: %v want: %v", tc.desc, field2, want)
}
count++
}
if count != 2 {
t.Errorf("count: %d, want 2", count)
if count != len(wantValues) {
t.Errorf("%v: count: %d, want %d", tc.desc, count, len(wantValues))
}
_ = s.Close()
s, _ = c.Prepare("none")
_, err = s.Query(nil)
s2, err := db.Prepare("none")
if err != nil {
t.Fatalf("%v: %v", tc.desc, err)
}
defer s2.Close()
rows, err := s2.Query(nil)
want := "no match for: none"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, does not contain %s", err, want)
}
_ = c.Close()
}
}
func TestQueryStreaming(t *testing.T) {
var testcases = []struct {
dataSourceName string
requestName string
}{
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`,
requestName: "request1",
},
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "rdonly", "timeout": %d}`,
requestName: "request1SpecificShard",
},
}
for _, tc := range testcases {
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
if err != nil {
t.Fatal(err)
}
s, _ := c.Prepare(tc.requestName)
r, err := s.Query([]driver.Value{int64(0)})
if err != nil {
t.Fatal(err)
}
cols := r.Columns()
wantCols := []string{
"field1",
"field2",
}
if !reflect.DeepEqual(cols, wantCols) {
t.Fatalf("cols: %v, want %v", cols, wantCols)
}
row := make([]driver.Value, 2)
count := 0
for {
err = r.Next(row)
if err != nil {
if err == io.EOF {
break
}
t.Fatal(err)
if tc.config.Streaming && err == nil {
// gRPC requires to consume the stream first before the error becomes visible.
if rows.Next() {
t.Fatalf("%v: query should not have returned anything but did.", tc.desc)
}
count++
err = rows.Err()
}
if count != 2 {
t.Errorf("count: %d, want 2", count)
if err == nil || !strings.Contains(err.Error(), want) {
t.Fatalf("%v: err: %v, does not contain %s", tc.desc, err, want)
}
_ = s.Close()
_ = c.Close()
}
}
func TestTx(t *testing.T) {
var testcases = []struct {
dataSourceName string
requestName string
desc string
config Configuration
requestName string
}{
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "master", "timeout": %d}`,
requestName: "txRequest",
desc: "vtgate v3",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
TabletType: "master",
Timeout: 30 * time.Second,
},
requestName: "txRequest",
},
{
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "master", "timeout": %d}`,
requestName: "txRequestSpecificShard",
desc: "vtgate v1",
config: Configuration{
Protocol: "grpc",
Address: testAddress,
Keyspace: "ks1",
Shard: "0",
TabletType: "master",
Timeout: 30 * time.Second,
},
requestName: "txRequestSpecificShard",
},
}
for _, tc := range testcases {
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
if err != nil {
t.Fatalf("%v: %v", tc.requestName, err)
}
tx, err := c.Begin()
if err != nil {
t.Errorf("%v: %v", tc.requestName, err)
}
s, _ := c.Prepare(tc.requestName)
_, err = s.Exec([]driver.Value{int64(0)})
if err != nil {
t.Errorf("%v: %v", tc.requestName, err)
}
err = tx.Commit()
if err != nil {
t.Errorf("%v: %v", tc.requestName, err)
}
err = tx.Commit()
want := "commit: not in transaction"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("case: %v err: %v, does not contain %s", tc.requestName, err, want)
}
_ = c.Close()
testTxCommit(t, tc.config, tc.desc, tc.requestName)
c, err = drv{}.Open(connStr)
if err != nil {
t.Fatalf("%v: %v", tc.requestName, err)
}
tx, err = c.Begin()
if err != nil {
t.Errorf("%v: %v", tc.requestName, err)
}
s, _ = c.Prepare(tc.requestName)
_, err = s.Query([]driver.Value{int64(0)})
if err != nil {
t.Errorf("%v: %v", tc.requestName, err)
}
err = tx.Rollback()
if err != nil {
t.Errorf("%v: %v", tc.requestName, err)
}
err = tx.Rollback()
if err != nil {
t.Errorf("%v: %v", tc.requestName, err)
}
_ = c.Close()
testTxRollback(t, tc.config, tc.desc, tc.requestName)
}
}
func testTxCommit(t *testing.T, c Configuration, desc, requestName string) {
db, err := OpenWithConfiguration(c)
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
s, err := tx.Prepare(requestName)
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
_, err = s.Exec(int64(0))
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
// Commit on committed transaction is caught by Golang sql package.
// We actually don't have to cover this in our code.
err = tx.Commit()
want := "sql: Transaction has already been committed or rolled back"
if err == nil || !strings.Contains(err.Error(), want) {
t.Fatalf("%v: err: %v, does not contain %s", desc, err, want)
}
}
func testTxRollback(t *testing.T, c Configuration, desc, requestName string) {
db, err := OpenWithConfiguration(c)
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
s, err := tx.Prepare(requestName)
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
_, err = s.Query(int64(0))
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
err = tx.Rollback()
if err != nil {
t.Fatalf("%v: %v", desc, err)
}
// Rollback on rolled back transaction is caught by Golang sql package.
// We actually don't have to cover this in our code.
err = tx.Rollback()
want := "sql: Transaction has already been committed or rolled back"
if err == nil || !strings.Contains(err.Error(), want) {
t.Fatalf("%v: err: %v, does not contain %s", desc, err, want)
}
}
func TestTxExecStreamingNotAllowed(t *testing.T) {
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "streaming": true, "timeout": %d}`, testAddress, int64(30*time.Second))
c, err := drv{}.Open(connStr)
db, err := OpenForStreaming(testAddress, "rdonly", 30*time.Second)
if err != nil {
t.Fatal(err)
}
_, err = c.Begin()
defer db.Close()
_, err = db.Begin()
want := "transaction not allowed for streaming connection"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, does not contain %s", err, want)
}
_ = c.Close()
}

Просмотреть файл

@ -143,6 +143,38 @@ func (f *fakeVTGateService) StreamExecute(ctx context.Context, sql string, bindV
// StreamExecuteShards is part of the VTGateService interface
func (f *fakeVTGateService) StreamExecuteShards(ctx context.Context, sql string, bindVariables map[string]interface{}, keyspace string, shards []string, tabletType topodatapb.TabletType, sendReply func(*sqltypes.Result) error) error {
execCase, ok := execSpecificShardMap[sql]
if !ok {
return fmt.Errorf("no match for: %s", sql)
}
query := &queryExecuteSpecificShard{
queryExecute: queryExecute{
SQL: sql,
BindVariables: bindVariables,
TabletType: tabletType,
},
Keyspace: keyspace,
Shard: shards[0],
}
if !reflect.DeepEqual(query, execCase.execQuery) {
return fmt.Errorf("request mismatch: got %+v, want %+v", query, execCase.execQuery)
}
if execCase.result != nil {
result := &sqltypes.Result{
Fields: execCase.result.Fields,
}
if err := sendReply(result); err != nil {
return err
}
for _, row := range execCase.result.Rows {
result := &sqltypes.Result{
Rows: [][]sqltypes.Value{row},
}
if err := sendReply(result); err != nil {
return err
}
}
}
return nil
}

Просмотреть файл

@ -228,7 +228,8 @@ def tearDownModule():
if utils.options.skip_teardown:
return
logging.debug('Tearing down the servers and setup')
keyspace_env.teardown()
if keyspace_env:
keyspace_env.teardown()
environment.topo_server().teardown()