зеркало из https://github.com/github/vitess-gh.git
Merge pull request #1383 from michael-berlin/go_sql_driver_shard
vitessdriver: Add helper methods and refactor tests.
This commit is contained in:
Коммит
45544ca7c5
|
@ -10,14 +10,12 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
// import the 'vitess' sql driver
|
||||
_ "github.com/youtube/vitess/go/vt/vitessdriver"
|
||||
"github.com/youtube/vitess/go/vt/vitessdriver"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -25,16 +23,14 @@ var (
|
|||
)
|
||||
|
||||
func main() {
|
||||
keyspace := "test_keyspace"
|
||||
timeout := (10 * time.Second).Nanoseconds()
|
||||
shard := "0"
|
||||
|
||||
flag.Parse()
|
||||
|
||||
keyspace := "test_keyspace"
|
||||
shard := "0"
|
||||
timeout := 10 * time.Second
|
||||
|
||||
// Connect to vtgate.
|
||||
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "keyspace": "%s", "shard": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`,
|
||||
*server, keyspace, shard, "master", false, timeout)
|
||||
db, err := sql.Open("vitess", connStr)
|
||||
db, err := vitessdriver.OpenShard(*server, keyspace, shard, "master", timeout)
|
||||
if err != nil {
|
||||
fmt.Printf("client error: %v\n", err)
|
||||
os.Exit(1)
|
||||
|
@ -82,9 +78,7 @@ func main() {
|
|||
// Note that this may be behind master due to replication lag.
|
||||
fmt.Println("Reading from replica...")
|
||||
|
||||
connStr = fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "keyspace": "%s", "shard": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`,
|
||||
*server, keyspace, shard, "replica", false, timeout)
|
||||
dbr, err := sql.Open("vitess", connStr)
|
||||
dbr, err := vitessdriver.OpenShard(*server, keyspace, shard, "replica", timeout)
|
||||
if err != nil {
|
||||
fmt.Printf("client error: %v\n", err)
|
||||
os.Exit(1)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
@ -17,8 +16,7 @@ import (
|
|||
"github.com/youtube/vitess/go/exit"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
|
||||
// import the 'vitess' sql driver
|
||||
_ "github.com/youtube/vitess/go/vt/vitessdriver"
|
||||
"github.com/youtube/vitess/go/vt/vitessdriver"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -106,8 +104,15 @@ func main() {
|
|||
exit.Return(1)
|
||||
}
|
||||
|
||||
connStr := fmt.Sprintf(`{"address": "%s", "keyspace": "%s", "shard": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`, *server, *keyspace, *shard, *tabletType, *streaming, int64(30*(*timeout)))
|
||||
db, err := sql.Open("vitess", connStr)
|
||||
c := vitessdriver.Configuration{
|
||||
Address: *server,
|
||||
Keyspace: *keyspace,
|
||||
Shard: *shard,
|
||||
TabletType: *tabletType,
|
||||
Timeout: *timeout,
|
||||
Streaming: *streaming,
|
||||
}
|
||||
db, err := vitessdriver.OpenWithConfiguration(c)
|
||||
if err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
|
|
|
@ -25,32 +25,91 @@ func init() {
|
|||
sql.Register("vitess", drv{})
|
||||
}
|
||||
|
||||
// TODO(mberlin): Add helper methods.
|
||||
// Open is a Vitess helper function for sql.Open().
|
||||
//
|
||||
// It opens a database connection to vtgate running at "address".
|
||||
//
|
||||
// Note that this is the vtgate v3 mode and requires a loaded VSchema.
|
||||
func Open(address, tabletType string, timeout time.Duration) (*sql.DB, error) {
|
||||
return OpenShard(address, "" /* keyspace */, "" /* shard */, tabletType, timeout)
|
||||
}
|
||||
|
||||
// OpenShard connects to vtgate running at "address".
|
||||
//
|
||||
// Unlike Open(), all queries will target a specific shard in a given keyspace
|
||||
// ("fallback" mode to vtgate v1).
|
||||
//
|
||||
// This mode is recommended when you want to try out Vitess initially because it
|
||||
// does not require defining a VSchema. Just replace the MySQL/MariaDB driver
|
||||
// invocation in your application with the Vitess driver.
|
||||
func OpenShard(address, keyspace, shard, tabletType string, timeout time.Duration) (*sql.DB, error) {
|
||||
c := newDefaultConfiguration()
|
||||
c.Address = address
|
||||
c.Keyspace = keyspace
|
||||
c.Shard = shard
|
||||
c.TabletType = tabletType
|
||||
c.Timeout = timeout
|
||||
return OpenWithConfiguration(c)
|
||||
}
|
||||
|
||||
// OpenForStreaming is the same as Open() but uses streaming RPCs to retrieve
|
||||
// the results.
|
||||
//
|
||||
// The streaming mode is recommended for large results.
|
||||
func OpenForStreaming(address, tabletType string, timeout time.Duration) (*sql.DB, error) {
|
||||
return OpenShardForStreaming(address, "" /* keyspace */, "" /* shard */, tabletType, timeout)
|
||||
}
|
||||
|
||||
// OpenShardForStreaming is the same as OpenShard() but uses streaming RPCs to
|
||||
// retrieve the results.
|
||||
//
|
||||
// The streaming mode is recommended for large results.
|
||||
func OpenShardForStreaming(address, keyspace, shard, tabletType string, timeout time.Duration) (*sql.DB, error) {
|
||||
c := newDefaultConfiguration()
|
||||
c.Address = address
|
||||
c.Keyspace = keyspace
|
||||
c.Shard = shard
|
||||
c.TabletType = tabletType
|
||||
c.Timeout = timeout
|
||||
c.Streaming = true
|
||||
return OpenWithConfiguration(c)
|
||||
}
|
||||
|
||||
// OpenWithConfiguration is the generic Vitess helper function for sql.Open().
|
||||
//
|
||||
// It allows to pass in a Configuration struct to control all possible
|
||||
// settings of the Vitess Go SQL driver.
|
||||
func OpenWithConfiguration(c Configuration) (*sql.DB, error) {
|
||||
jsonBytes, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.Open("vitess", string(jsonBytes))
|
||||
}
|
||||
|
||||
type drv struct {
|
||||
}
|
||||
|
||||
// Open must be called with a JSON string that looks like this:
|
||||
// Open implements the database/sql/driver.Driver interface.
|
||||
//
|
||||
// For "name", the Vitess driver requires that a JSON object is passed in.
|
||||
//
|
||||
// Instead of using this call and passing in a hand-crafted JSON string, it's
|
||||
// recommended to use the public Vitess helper functions like
|
||||
// Open(), OpenShard() or OpenWithConfiguration() instead. These will generate
|
||||
// the required JSON string behind the scenes for you.
|
||||
//
|
||||
// Example for a JSON string:
|
||||
//
|
||||
// {"protocol": "gorpc", "address": "localhost:1111", "tablet_type": "master", "timeout": 1000000000}
|
||||
//
|
||||
// protocol specifies the rpc protocol to use.
|
||||
// address specifies the address for the VTGate to connect to.
|
||||
// tablet_type represents the consistency level of your operations.
|
||||
// For example "replica" means eventually consistent reads, while
|
||||
// "master" supports transactions and gives you read-after-write consistency.
|
||||
// timeout is specified in nanoseconds. It applies for all operations.
|
||||
//
|
||||
// If you want to execute queries which are not supported by vtgate v3, you can
|
||||
// run queries against a specific keyspace and shard.
|
||||
// Therefore, add the fields "keyspace" and "shard" to the JSON string. Example:
|
||||
//
|
||||
// {"protocol": "gorpc", "address": "localhost:1111", "keyspace": "ks1", "shard": "0", "tablet_type": "master", "timeout": 1000000000}
|
||||
// For a description of the available fields, see the Configuration struct.
|
||||
// Note: In the JSON string, timeout has to be specified in nanoseconds.
|
||||
//
|
||||
// Note that this function will always create a connection to vtgate i.e. there
|
||||
// is no need to call DB.Ping() to verify the connection.
|
||||
func (d drv) Open(name string) (driver.Conn, error) {
|
||||
c := &conn{TabletType: "master"}
|
||||
c := &conn{Configuration: newDefaultConfiguration()}
|
||||
err := json.Unmarshal([]byte(name), c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -61,7 +120,7 @@ func (d drv) Open(name string) (driver.Conn, error) {
|
|||
if c.useExecuteShards() {
|
||||
log.Infof("Sending queries only to keyspace/shard: %v/%v", c.Keyspace, c.Shard)
|
||||
}
|
||||
c.tabletType, err = topoproto.ParseTabletType(c.TabletType)
|
||||
c.tabletTypeProto, err = topoproto.ParseTabletType(c.TabletType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -72,23 +131,63 @@ func (d drv) Open(name string) (driver.Conn, error) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
// Configuration holds all Vitess driver settings.
|
||||
//
|
||||
// Fields with documented default values do not have to be set explicitly.
|
||||
type Configuration struct {
|
||||
// Protocol is the name of the vtgate RPC client implementation.
|
||||
// Note: In open-source "grpc" is the recommended implementation.
|
||||
//
|
||||
// Default: "grpc"
|
||||
Protocol string
|
||||
Address string
|
||||
// Keyspace of a specific keyspace/shard to target. Disables vtgate v3.
|
||||
// If Keyspace and Shard are not empty, vtgate v2 instead of v3 will be used
|
||||
|
||||
// Address must point to a vtgate instance.
|
||||
//
|
||||
// Format: hostname:port
|
||||
Address string
|
||||
|
||||
// Keyspace of a specific keyspace and shard to target. Disables vtgate v3.
|
||||
//
|
||||
// If Keyspace and Shard are not empty, vtgate v1 instead of v3 will be used
|
||||
// and all requests will be sent only to that particular shard.
|
||||
// This functionality is meant for initial migrations from MySQL/MariaDB to Vitess.
|
||||
Keyspace string
|
||||
// Shard of a specific keyspace/shard to target. Disables vtgate v3.
|
||||
Shard string
|
||||
TabletType string `json:"tablet_type"`
|
||||
Streaming bool
|
||||
Timeout time.Duration
|
||||
// Shard of a specific keyspace and shard to target. Disables vtgate v3.
|
||||
Shard string
|
||||
|
||||
tabletType topodatapb.TabletType
|
||||
vtgateConn *vtgateconn.VTGateConn
|
||||
tx *vtgateconn.VTGateTx
|
||||
// TabletType is the type of tablet you want to access and affects the
|
||||
// freshness of read data.
|
||||
//
|
||||
// For example, "replica" means eventually consistent reads, while
|
||||
// "master" supports transactions and gives you read-after-write consistency.
|
||||
//
|
||||
// Default: "master"
|
||||
// Allowed values: "master", "replica", "rdonly"
|
||||
TabletType string `json:"tablet_type"`
|
||||
|
||||
// Streaming is true when streaming RPCs are used.
|
||||
// Recommended for large results.
|
||||
// Default: false
|
||||
Streaming bool
|
||||
|
||||
// Timeout after which a pending query will be aborted.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func newDefaultConfiguration() Configuration {
|
||||
return Configuration{
|
||||
Protocol: "grpc",
|
||||
TabletType: "master",
|
||||
Streaming: false,
|
||||
}
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
Configuration
|
||||
// tabletTypeProto is the protobof enum value of the string Configuration.TabletType.
|
||||
tabletTypeProto topodatapb.TabletType
|
||||
vtgateConn *vtgateconn.VTGateConn
|
||||
tx *vtgateconn.VTGateTx
|
||||
}
|
||||
|
||||
func (c *conn) dial() error {
|
||||
|
@ -162,6 +261,7 @@ func (s *stmt) Close() error {
|
|||
}
|
||||
|
||||
func (s *stmt) NumInput() int {
|
||||
// -1 = Golang sql won't sanity check argument counts before Exec or Query.
|
||||
return -1
|
||||
}
|
||||
|
||||
|
@ -188,9 +288,9 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|||
var errFunc vtgateconn.ErrFunc
|
||||
var err error
|
||||
if s.c.useExecuteShards() {
|
||||
qrc, errFunc, err = s.c.vtgateConn.StreamExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletType)
|
||||
qrc, errFunc, err = s.c.vtgateConn.StreamExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletTypeProto)
|
||||
} else {
|
||||
qrc, errFunc, err = s.c.vtgateConn.StreamExecute(ctx, s.query, makeBindVars(args), s.c.tabletType)
|
||||
qrc, errFunc, err = s.c.vtgateConn.StreamExecute(ctx, s.query, makeBindVars(args), s.c.tabletTypeProto)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -210,16 +310,16 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|||
func (s *stmt) executeVitess(ctx context.Context, args []driver.Value) (*sqltypes.Result, error) {
|
||||
if s.c.tx != nil {
|
||||
if s.c.useExecuteShards() {
|
||||
return s.c.tx.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletType, false /* notInTransaction */)
|
||||
return s.c.tx.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletTypeProto, false /* notInTransaction */)
|
||||
}
|
||||
return s.c.tx.Execute(ctx, s.query, makeBindVars(args), s.c.tabletType, false /* notInTransaction */)
|
||||
return s.c.tx.Execute(ctx, s.query, makeBindVars(args), s.c.tabletTypeProto, false /* notInTransaction */)
|
||||
}
|
||||
|
||||
// Non-transactional case.
|
||||
if s.c.useExecuteShards() {
|
||||
return s.c.vtgateConn.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletType)
|
||||
return s.c.vtgateConn.ExecuteShards(ctx, s.query, s.c.Keyspace, []string{s.c.Shard}, makeBindVars(args), s.c.tabletTypeProto)
|
||||
}
|
||||
return s.c.vtgateConn.Execute(ctx, s.query, makeBindVars(args), s.c.tabletType)
|
||||
return s.c.vtgateConn.Execute(ctx, s.query, makeBindVars(args), s.c.tabletTypeProto)
|
||||
}
|
||||
|
||||
func makeBindVars(args []driver.Value) map[string]interface{} {
|
||||
|
|
|
@ -5,10 +5,7 @@
|
|||
package vitessdriver
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -27,17 +24,21 @@ var (
|
|||
testAddress string
|
||||
)
|
||||
|
||||
// TestMain tests the Vitess Go SQL driver.
|
||||
//
|
||||
// Note that the queries used in the test are not valid SQL queries and don't
|
||||
// have to be. The main point here is to test the interactions against a
|
||||
// vtgate implementation (here: fakeVTGateService from fakeserver_test.go).
|
||||
func TestMain(m *testing.M) {
|
||||
// fake service
|
||||
service := CreateFakeServer()
|
||||
|
||||
// listen on a random port
|
||||
// listen on a random port.
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Cannot listen: %v", err))
|
||||
}
|
||||
|
||||
// Create a gRPC server and listen on the port
|
||||
// Create a gRPC server and listen on the port.
|
||||
server := grpc.NewServer()
|
||||
grpcvtgateservice.RegisterForTest(server, service)
|
||||
go server.Serve(listener)
|
||||
|
@ -46,38 +47,22 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestDriver(t *testing.T) {
|
||||
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`, testAddress, int64(30*time.Second))
|
||||
db, err := sql.Open("vitess", connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, err := db.Query("request1", int64(0))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
count := 0
|
||||
for r.Next() {
|
||||
count++
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("count: %d, want 2", count)
|
||||
}
|
||||
_ = db.Close()
|
||||
}
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "replica", "timeout": %d}`, testAddress, int64(30*time.Second))
|
||||
func TestOpen(t *testing.T) {
|
||||
connStr := fmt.Sprintf(`{"address": "%s", "tablet_type": "replica", "timeout": %d}`, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
wantc := &conn{
|
||||
Protocol: "grpc",
|
||||
TabletType: "replica",
|
||||
Streaming: false,
|
||||
Timeout: 30 * time.Second,
|
||||
tabletType: topodatapb.TabletType_REPLICA,
|
||||
Configuration: Configuration{
|
||||
Protocol: "grpc",
|
||||
TabletType: "replica",
|
||||
Streaming: false,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
tabletTypeProto: topodatapb.TabletType_REPLICA,
|
||||
}
|
||||
newc := *(c.(*conn))
|
||||
newc.Address = ""
|
||||
|
@ -85,16 +70,67 @@ func TestDial(t *testing.T) {
|
|||
if !reflect.DeepEqual(&newc, wantc) {
|
||||
t.Errorf("conn: %+v, want %+v", &newc, wantc)
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
_, err = drv{}.Open(`{"protocol": "none"}`)
|
||||
func TestOpenShard(t *testing.T) {
|
||||
connStr := fmt.Sprintf(`{"address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "replica", "timeout": %d}`, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
wantc := &conn{
|
||||
Configuration: Configuration{
|
||||
Protocol: "grpc",
|
||||
Keyspace: "ks1",
|
||||
Shard: "0",
|
||||
TabletType: "replica",
|
||||
Streaming: false,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
tabletTypeProto: topodatapb.TabletType_REPLICA,
|
||||
}
|
||||
newc := *(c.(*conn))
|
||||
newc.Address = ""
|
||||
newc.vtgateConn = nil
|
||||
if !reflect.DeepEqual(&newc, wantc) {
|
||||
t.Errorf("conn: %+v, want %+v", &newc, wantc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen_UnregisteredProtocol(t *testing.T) {
|
||||
_, err := drv{}.Open(`{"protocol": "none"}`)
|
||||
want := "no dialer registered for VTGate protocol none"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, want %s", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = drv{}.Open(`{`)
|
||||
want = "unexpected end of JSON input"
|
||||
func TestOpen_InvalidJson(t *testing.T) {
|
||||
_, err := drv{}.Open(`{`)
|
||||
want := "unexpected end of JSON input"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, want %s", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen_KeyspaceAndShardBelongTogether(t *testing.T) {
|
||||
_, err := drv{}.Open(`{"keyspace": "ks1"}`)
|
||||
want := "Always set both keyspace and shard or leave both empty."
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, want %s", err, want)
|
||||
}
|
||||
|
||||
_, err = drv{}.Open(`{"shard": "0"}`)
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, want %s", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen_ValidTabletTypeRequired(t *testing.T) {
|
||||
_, err := drv{}.Open(`{"tablet_type": "foobar"}`)
|
||||
want := "unknown TabletType foobar"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, want %s", err, want)
|
||||
}
|
||||
|
@ -102,59 +138,85 @@ func TestDial(t *testing.T) {
|
|||
|
||||
func TestExec(t *testing.T) {
|
||||
var testcases = []struct {
|
||||
dataSourceName string
|
||||
requestName string
|
||||
desc string
|
||||
config Configuration
|
||||
requestName string
|
||||
}{
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`,
|
||||
requestName: "request1",
|
||||
desc: "vtgate v3",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
TabletType: "rdonly",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
requestName: "request1",
|
||||
},
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "rdonly", "timeout": %d}`,
|
||||
requestName: "request1SpecificShard",
|
||||
desc: "vtgate v1",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
Keyspace: "ks1",
|
||||
Shard: "0",
|
||||
TabletType: "rdonly",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
requestName: "request1SpecificShard",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
db, err := OpenWithConfiguration(tc.config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
s, _ := c.Prepare(tc.requestName)
|
||||
if ni := s.NumInput(); ni != -1 {
|
||||
t.Errorf("got %d, want -1", ni)
|
||||
}
|
||||
r, err := s.Exec([]driver.Value{int64(0)})
|
||||
defer db.Close()
|
||||
|
||||
s, err := db.Prepare(tc.requestName)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
r, err := s.Exec(int64(0))
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
if v, _ := r.LastInsertId(); v != 72 {
|
||||
t.Errorf("insert id: %d, want 72", v)
|
||||
t.Fatalf("%v: insert id: %d, want 72", tc.desc, v)
|
||||
}
|
||||
if v, _ := r.RowsAffected(); v != 123 {
|
||||
t.Errorf("rows affected: %d, want 123", v)
|
||||
t.Fatalf("%v: rows affected: %d, want 123", tc.desc, v)
|
||||
}
|
||||
_ = s.Close()
|
||||
|
||||
s, _ = c.Prepare("none")
|
||||
_, err = s.Exec(nil)
|
||||
s2, err := db.Prepare("none")
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
defer s2.Close()
|
||||
|
||||
_, err = s2.Exec(nil)
|
||||
want := "no match for: none"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, does not contain %s", err, want)
|
||||
t.Errorf("%v: err: %v, does not contain %s", tc.desc, err, want)
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecStreamingNotAllowed(t *testing.T) {
|
||||
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "streaming": true, "timeout": %d}`, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
db, err := OpenForStreaming(testAddress, "rdonly", 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s, _ := c.Prepare("request1")
|
||||
_, err = s.Exec([]driver.Value{int64(0)})
|
||||
|
||||
s, err := db.Prepare("request1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
_, err = s.Exec(int64(0))
|
||||
want := "Exec not allowed for streaming connections"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, does not contain %s", err, want)
|
||||
|
@ -163,195 +225,244 @@ func TestExecStreamingNotAllowed(t *testing.T) {
|
|||
|
||||
func TestQuery(t *testing.T) {
|
||||
var testcases = []struct {
|
||||
dataSourceName string
|
||||
requestName string
|
||||
desc string
|
||||
config Configuration
|
||||
requestName string
|
||||
}{
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`,
|
||||
requestName: "request1",
|
||||
desc: "non-streaming, vtgate v3",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
TabletType: "rdonly",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
requestName: "request1",
|
||||
},
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "rdonly", "timeout": %d}`,
|
||||
requestName: "request1SpecificShard",
|
||||
desc: "non-streaming, vtgate v1",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
Keyspace: "ks1",
|
||||
Shard: "0",
|
||||
TabletType: "rdonly",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
requestName: "request1SpecificShard",
|
||||
},
|
||||
{
|
||||
desc: "streaming, vtgate v3",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
TabletType: "rdonly",
|
||||
Timeout: 30 * time.Second,
|
||||
Streaming: true,
|
||||
},
|
||||
requestName: "request1",
|
||||
},
|
||||
{
|
||||
desc: "streaming, vtgate v1",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
Keyspace: "ks1",
|
||||
Shard: "0",
|
||||
TabletType: "rdonly",
|
||||
Timeout: 30 * time.Second,
|
||||
Streaming: true,
|
||||
},
|
||||
requestName: "request1SpecificShard",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
db, err := OpenWithConfiguration(tc.config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
s, _ := c.Prepare(tc.requestName)
|
||||
r, err := s.Query([]driver.Value{int64(0)})
|
||||
defer db.Close()
|
||||
|
||||
s, err := db.Prepare(tc.requestName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
r, err := s.Query(int64(0))
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
cols, err := r.Columns()
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
cols := r.Columns()
|
||||
wantCols := []string{
|
||||
"field1",
|
||||
"field2",
|
||||
}
|
||||
if !reflect.DeepEqual(cols, wantCols) {
|
||||
t.Fatalf("cols: %v, want %v", cols, wantCols)
|
||||
t.Fatalf("%v: cols: %v, want %v", tc.desc, cols, wantCols)
|
||||
}
|
||||
row := make([]driver.Value, 2)
|
||||
count := 0
|
||||
for {
|
||||
err = r.Next(row)
|
||||
wantValues := []struct {
|
||||
field1 int16
|
||||
field2 string
|
||||
}{{1, "value1"}, {2, "value2"}}
|
||||
for r.Next() {
|
||||
var field1 int16
|
||||
var field2 string
|
||||
err := r.Scan(&field1, &field2)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
t.Error(err)
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
if want := wantValues[count].field1; field1 != want {
|
||||
t.Fatalf("%v: wrong value for field1: got: %v want: %v", tc.desc, field1, want)
|
||||
}
|
||||
if want := wantValues[count].field2; field2 != want {
|
||||
t.Fatalf("%v: wrong value for field2: got: %v want: %v", tc.desc, field2, want)
|
||||
}
|
||||
count++
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("count: %d, want 2", count)
|
||||
if count != len(wantValues) {
|
||||
t.Errorf("%v: count: %d, want %d", tc.desc, count, len(wantValues))
|
||||
}
|
||||
_ = s.Close()
|
||||
|
||||
s, _ = c.Prepare("none")
|
||||
_, err = s.Query(nil)
|
||||
s2, err := db.Prepare("none")
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", tc.desc, err)
|
||||
}
|
||||
defer s2.Close()
|
||||
|
||||
rows, err := s2.Query(nil)
|
||||
want := "no match for: none"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, does not contain %s", err, want)
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryStreaming(t *testing.T) {
|
||||
var testcases = []struct {
|
||||
dataSourceName string
|
||||
requestName string
|
||||
}{
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "timeout": %d}`,
|
||||
requestName: "request1",
|
||||
},
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "rdonly", "timeout": %d}`,
|
||||
requestName: "request1SpecificShard",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s, _ := c.Prepare(tc.requestName)
|
||||
r, err := s.Query([]driver.Value{int64(0)})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cols := r.Columns()
|
||||
wantCols := []string{
|
||||
"field1",
|
||||
"field2",
|
||||
}
|
||||
if !reflect.DeepEqual(cols, wantCols) {
|
||||
t.Fatalf("cols: %v, want %v", cols, wantCols)
|
||||
}
|
||||
row := make([]driver.Value, 2)
|
||||
count := 0
|
||||
for {
|
||||
err = r.Next(row)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
t.Fatal(err)
|
||||
if tc.config.Streaming && err == nil {
|
||||
// gRPC requires to consume the stream first before the error becomes visible.
|
||||
if rows.Next() {
|
||||
t.Fatalf("%v: query should not have returned anything but did.", tc.desc)
|
||||
}
|
||||
count++
|
||||
err = rows.Err()
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("count: %d, want 2", count)
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Fatalf("%v: err: %v, does not contain %s", tc.desc, err, want)
|
||||
}
|
||||
_ = s.Close()
|
||||
_ = c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestTx(t *testing.T) {
|
||||
var testcases = []struct {
|
||||
dataSourceName string
|
||||
requestName string
|
||||
desc string
|
||||
config Configuration
|
||||
requestName string
|
||||
}{
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "tablet_type": "master", "timeout": %d}`,
|
||||
requestName: "txRequest",
|
||||
desc: "vtgate v3",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
TabletType: "master",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
requestName: "txRequest",
|
||||
},
|
||||
{
|
||||
dataSourceName: `{"protocol": "grpc", "address": "%s", "keyspace": "ks1", "shard": "0", "tablet_type": "master", "timeout": %d}`,
|
||||
requestName: "txRequestSpecificShard",
|
||||
desc: "vtgate v1",
|
||||
config: Configuration{
|
||||
Protocol: "grpc",
|
||||
Address: testAddress,
|
||||
Keyspace: "ks1",
|
||||
Shard: "0",
|
||||
TabletType: "master",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
requestName: "txRequestSpecificShard",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
connStr := fmt.Sprintf(tc.dataSourceName, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
tx, err := c.Begin()
|
||||
if err != nil {
|
||||
t.Errorf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
s, _ := c.Prepare(tc.requestName)
|
||||
_, err = s.Exec([]driver.Value{int64(0)})
|
||||
if err != nil {
|
||||
t.Errorf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Errorf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
err = tx.Commit()
|
||||
want := "commit: not in transaction"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("case: %v err: %v, does not contain %s", tc.requestName, err, want)
|
||||
}
|
||||
_ = c.Close()
|
||||
testTxCommit(t, tc.config, tc.desc, tc.requestName)
|
||||
|
||||
c, err = drv{}.Open(connStr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
tx, err = c.Begin()
|
||||
if err != nil {
|
||||
t.Errorf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
s, _ = c.Prepare(tc.requestName)
|
||||
_, err = s.Query([]driver.Value{int64(0)})
|
||||
if err != nil {
|
||||
t.Errorf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Errorf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Errorf("%v: %v", tc.requestName, err)
|
||||
}
|
||||
_ = c.Close()
|
||||
testTxRollback(t, tc.config, tc.desc, tc.requestName)
|
||||
}
|
||||
}
|
||||
|
||||
func testTxCommit(t *testing.T, c Configuration, desc, requestName string) {
|
||||
db, err := OpenWithConfiguration(c)
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
|
||||
s, err := tx.Prepare(requestName)
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
|
||||
_, err = s.Exec(int64(0))
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
// Commit on committed transaction is caught by Golang sql package.
|
||||
// We actually don't have to cover this in our code.
|
||||
err = tx.Commit()
|
||||
want := "sql: Transaction has already been committed or rolled back"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Fatalf("%v: err: %v, does not contain %s", desc, err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func testTxRollback(t *testing.T, c Configuration, desc, requestName string) {
|
||||
db, err := OpenWithConfiguration(c)
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
s, err := tx.Prepare(requestName)
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
_, err = s.Query(int64(0))
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatalf("%v: %v", desc, err)
|
||||
}
|
||||
// Rollback on rolled back transaction is caught by Golang sql package.
|
||||
// We actually don't have to cover this in our code.
|
||||
err = tx.Rollback()
|
||||
want := "sql: Transaction has already been committed or rolled back"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Fatalf("%v: err: %v, does not contain %s", desc, err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxExecStreamingNotAllowed(t *testing.T) {
|
||||
connStr := fmt.Sprintf(`{"protocol": "grpc", "address": "%s", "tablet_type": "rdonly", "streaming": true, "timeout": %d}`, testAddress, int64(30*time.Second))
|
||||
c, err := drv{}.Open(connStr)
|
||||
db, err := OpenForStreaming(testAddress, "rdonly", 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = c.Begin()
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Begin()
|
||||
want := "transaction not allowed for streaming connection"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, does not contain %s", err, want)
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
|
|
@ -143,6 +143,38 @@ func (f *fakeVTGateService) StreamExecute(ctx context.Context, sql string, bindV
|
|||
|
||||
// StreamExecuteShards is part of the VTGateService interface
|
||||
func (f *fakeVTGateService) StreamExecuteShards(ctx context.Context, sql string, bindVariables map[string]interface{}, keyspace string, shards []string, tabletType topodatapb.TabletType, sendReply func(*sqltypes.Result) error) error {
|
||||
execCase, ok := execSpecificShardMap[sql]
|
||||
if !ok {
|
||||
return fmt.Errorf("no match for: %s", sql)
|
||||
}
|
||||
query := &queryExecuteSpecificShard{
|
||||
queryExecute: queryExecute{
|
||||
SQL: sql,
|
||||
BindVariables: bindVariables,
|
||||
TabletType: tabletType,
|
||||
},
|
||||
Keyspace: keyspace,
|
||||
Shard: shards[0],
|
||||
}
|
||||
if !reflect.DeepEqual(query, execCase.execQuery) {
|
||||
return fmt.Errorf("request mismatch: got %+v, want %+v", query, execCase.execQuery)
|
||||
}
|
||||
if execCase.result != nil {
|
||||
result := &sqltypes.Result{
|
||||
Fields: execCase.result.Fields,
|
||||
}
|
||||
if err := sendReply(result); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, row := range execCase.result.Rows {
|
||||
result := &sqltypes.Result{
|
||||
Rows: [][]sqltypes.Value{row},
|
||||
}
|
||||
if err := sendReply(result); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -228,7 +228,8 @@ def tearDownModule():
|
|||
if utils.options.skip_teardown:
|
||||
return
|
||||
logging.debug('Tearing down the servers and setup')
|
||||
keyspace_env.teardown()
|
||||
if keyspace_env:
|
||||
keyspace_env.teardown()
|
||||
|
||||
environment.topo_server().teardown()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче