Merge branch 'master' into resharding

This commit is contained in:
Alain Jobart 2015-04-17 13:13:34 -07:00
Родитель 300174be96 1a3f9122ec
Коммит 364907137a
58 изменённых файлов: 1542 добавлений и 463 удалений

Просмотреть файл

@ -0,0 +1,5 @@
{
"table1": {
1:2
}
}

Просмотреть файл

@ -0,0 +1,6 @@
{
"test_table": {
"READER": "vt",
"WRITER": "vt"
}
}

Просмотреть файл

@ -17,6 +17,7 @@ import (
"github.com/youtube/vitess/go/vt/mysqlctl"
"github.com/youtube/vitess/go/vt/servenv"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
"github.com/youtube/vitess/go/vt/tabletserver"
// import mysql to register mysql connection function
@ -76,6 +77,7 @@ func main() {
log.Infof("schemaOverrides: %s\n", data)
if *tableAclConfig != "" {
tableacl.Register("simpleacl", &simpleacl.Factory{})
tableacl.Init(*tableAclConfig)
}
qsc := tabletserver.NewQueryServiceControl()

Просмотреть файл

@ -15,6 +15,7 @@ import (
"github.com/youtube/vitess/go/vt/mysqlctl"
"github.com/youtube/vitess/go/vt/servenv"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
"github.com/youtube/vitess/go/vt/tabletmanager"
"github.com/youtube/vitess/go/vt/tabletmanager/actionnode"
"github.com/youtube/vitess/go/vt/tabletserver"
@ -85,6 +86,7 @@ func main() {
dbcfgs.App.EnableRowcache = *enableRowcache
if *tableAclConfig != "" {
tableacl.Register("simpleacl", &simpleacl.Factory{})
tableacl.Init(*tableAclConfig)
}

Просмотреть файл

@ -38,3 +38,14 @@ func FromContext(ctx context.Context) (CallInfo, bool) {
ci, ok := ctx.Value(callInfoKey).(CallInfo)
return ci, ok
}
// HTMLFromContext returns that value of HTML() from the context, or "" if we're
// not able to recover one
func HTMLFromContext(ctx context.Context) template.HTML {
var h template.HTML
ci, ok := FromContext(ctx)
if ok {
return ci.HTML()
}
return h
}

Просмотреть файл

@ -6,9 +6,7 @@ import (
log "github.com/golang/glog"
)
// ConsoleLogger is a Logger that uses glog directly to log.
// We can't specify the depth of the stack trace,
// So we just find it and add it to the message.
// ConsoleLogger is a Logger that uses glog directly to log, at the right level.
type ConsoleLogger struct{}
// NewConsoleLogger returns a simple ConsoleLogger
@ -18,26 +16,17 @@ func NewConsoleLogger() ConsoleLogger {
// Infof is part of the Logger interface
func (cl ConsoleLogger) Infof(format string, v ...interface{}) {
file, line := fileAndLine(3)
vals := []interface{}{file, line}
vals = append(vals, v...)
log.Infof("%v:%v] "+format, vals...)
log.InfoDepth(2, fmt.Sprintf(format, v...))
}
// Warningf is part of the Logger interface
func (cl ConsoleLogger) Warningf(format string, v ...interface{}) {
file, line := fileAndLine(3)
vals := []interface{}{file, line}
vals = append(vals, v...)
log.Warningf("%v:%v] "+format, vals...)
log.WarningDepth(2, fmt.Sprintf(format, v...))
}
// Errorf is part of the Logger interface
func (cl ConsoleLogger) Errorf(format string, v ...interface{}) {
file, line := fileAndLine(3)
vals := []interface{}{file, line}
vals = append(vals, v...)
log.Errorf("%v:%v] "+format, vals...)
log.ErrorDepth(2, fmt.Sprintf(format, v...))
}
// Printf is part of the Logger interface

Просмотреть файл

@ -1,6 +1,7 @@
package logutil
import (
"fmt"
"sync"
"time"
@ -29,12 +30,12 @@ func NewThrottledLogger(name string, maxInterval time.Duration) *ThrottledLogger
}
}
type logFunc func(string, ...interface{})
type logFunc func(int, ...interface{})
var (
infof = log.Infof
warningf = log.Warningf
errorf = log.Errorf
infoDepth = log.InfoDepth
warningDepth = log.WarningDepth
errorDepth = log.ErrorDepth
)
func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
@ -45,7 +46,7 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
logWaitTime := tl.maxInterval - (now.Sub(tl.lastlogTime))
if logWaitTime < 0 {
tl.lastlogTime = now
logF(tl.name+":"+format, v...)
logF(2, fmt.Sprintf(tl.name+":"+format, v...))
return
}
// If this is the first message to be skipped, start a goroutine
@ -55,7 +56,9 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
time.Sleep(d)
tl.mu.Lock()
defer tl.mu.Unlock()
logF("%v: skipped %v log messages", tl.name, tl.skippedCount)
// Because of the go func(), we lose the stack trace,
// so we just use the current line for this.
logF(0, fmt.Sprintf("%v: skipped %v log messages", tl.name, tl.skippedCount))
tl.skippedCount = 0
}(logWaitTime)
}
@ -64,15 +67,15 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
// Infof logs an info if not throttled.
func (tl *ThrottledLogger) Infof(format string, v ...interface{}) {
tl.log(infof, format, v...)
tl.log(infoDepth, format, v...)
}
// Warningf logs a warning if not throttled.
func (tl *ThrottledLogger) Warningf(format string, v ...interface{}) {
tl.log(warningf, format, v...)
tl.log(warningDepth, format, v...)
}
// Errorf logs an error if not throttled.
func (tl *ThrottledLogger) Errorf(format string, v ...interface{}) {
tl.log(errorf, format, v...)
tl.log(errorDepth, format, v...)
}

Просмотреть файл

@ -9,8 +9,8 @@ import (
func TestThrottledLogger(t *testing.T) {
// Install a fake log func for testing.
log := make(chan string)
infof = func(format string, args ...interface{}) {
log <- fmt.Sprintf(format, args...)
infoDepth = func(depth int, args ...interface{}) {
log <- fmt.Sprint(args...)
}
interval := 100 * time.Millisecond
tl := NewThrottledLogger("name", interval)

21
go/vt/tableacl/acl/acl.go Normal file
Просмотреть файл

@ -0,0 +1,21 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acl
// ACL is an interface for Access Control List.
type ACL interface {
// IsMember checks the membership of a principal in this ACL.
IsMember(principal string) bool
}
// Factory is responsible to create new ACL instance.
type Factory interface {
// New creates a new ACL instance.
New(entries []string) (ACL, error)
// All returns an ACL instance that contains all users.
All() ACL
// AllString returns a string representation of all users.
AllString() string
}

Просмотреть файл

@ -0,0 +1,47 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tableacl
import "testing"
func TestRoleName(t *testing.T) {
if READER.Name() != roleNames[READER] {
t.Fatalf("role READER does not return expected name, expected: %s, actual: %s", roleNames[READER], READER.Name())
}
if WRITER.Name() != roleNames[WRITER] {
t.Fatalf("role WRITER does not return expected name, expected: %s, actual: %s", roleNames[WRITER], WRITER.Name())
}
if ADMIN.Name() != roleNames[ADMIN] {
t.Fatalf("role ADMIN does not return expected name, expected: %s, actual: %s", roleNames[ADMIN], ADMIN.Name())
}
unknownRole := Role(-1)
if unknownRole.Name() != "" {
t.Fatalf("role is not defined, expected to get an empty string but got: %s", unknownRole.Name())
}
}
func TestRoleByName(t *testing.T) {
if role, ok := RoleByName("unknown"); ok {
t.Fatalf("should not find a valid role for invalid name, but got role: %s", role.Name())
}
role, ok := RoleByName("READER")
if !ok || role != READER {
t.Fatalf("string READER should return role READER")
}
role, ok = RoleByName("WRITER")
if !ok || role != WRITER {
t.Fatalf("string WRITER should return role WRITER")
}
role, ok = RoleByName("ADMIN")
if !ok || role != ADMIN {
t.Fatalf("string ADMIN should return role ADMIN")
}
}

Просмотреть файл

@ -1,31 +0,0 @@
package tableacl
var allAcl simpleACL
const (
ALL = "*"
)
// NewACL returns an ACL with the specified entries
func NewACL(entries []string) (ACL, error) {
a := simpleACL(map[string]bool{})
for _, e := range entries {
a[e] = true
}
return a, nil
}
// simpleACL keeps all entries in a unique in-memory list
type simpleACL map[string]bool
// IsMember checks the membership of a principal in this ACL
func (a simpleACL) IsMember(principal string) bool {
return a[principal] || a[ALL]
}
func all() ACL {
if allAcl == nil {
allAcl = simpleACL(map[string]bool{ALL: true})
}
return allAcl
}

Просмотреть файл

@ -0,0 +1,50 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package simpleacl
import "github.com/youtube/vitess/go/vt/tableacl/acl"
var allAcl SimpleAcl
const all = "*"
// SimpleAcl keeps all entries in a unique in-memory list
type SimpleAcl map[string]bool
// IsMember checks the membership of a principal in this ACL
func (sacl SimpleAcl) IsMember(principal string) bool {
return sacl[principal] || sacl[all]
}
// Factory is responsible to create new ACL instance.
type Factory struct{}
// New creates a new ACL instance.
func (factory *Factory) New(entries []string) (acl.ACL, error) {
acl := SimpleAcl(map[string]bool{})
for _, e := range entries {
acl[e] = true
}
return acl, nil
}
// All returns an ACL instance that contains all users.
func (factory *Factory) All() acl.ACL {
if allAcl == nil {
allAcl = SimpleAcl(map[string]bool{all: true})
}
return allAcl
}
// AllString returns a string representation of all users.
func (factory *Factory) AllString() string {
return all
}
// make sure SimpleAcl implements interface acl.ACL
var _ (acl.ACL) = (*SimpleAcl)(nil)
// make sure Factory implements interface acl.AclFactory
var _ (acl.Factory) = (*Factory)(nil)

Просмотреть файл

@ -0,0 +1,15 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package simpleacl
import (
"testing"
"github.com/youtube/vitess/go/vt/tableacl/testlib"
)
func TestSimpleAcl(t *testing.T) {
testlib.TestSuite(t, &Factory{})
}

Просмотреть файл

@ -1,31 +1,39 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tableacl
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"regexp"
"strings"
"sync"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/vt/tableacl/acl"
)
// ACL is an interface for Access Control List.
type ACL interface {
// IsMember checks the membership of a principal in this ACL
IsMember(principal string) bool
}
var mu sync.Mutex
var tableAcl map[*regexp.Regexp]map[Role]acl.ACL
var acls = make(map[string]acl.Factory)
var tableAcl map[*regexp.Regexp]map[Role]ACL
// defaultACL tells the default ACL implementation to use.
var defaultACL string
// Init initiates table ACLs.
func Init(configFile string) {
config, err := ioutil.ReadFile(configFile)
if err != nil {
log.Fatalf("unable to read tableACL config file: %v", err)
log.Errorf("unable to read tableACL config file: %v", err)
panic(fmt.Errorf("unable to read tableACL config file: %v", err))
}
tableAcl, err = load(config)
if err != nil {
log.Fatalf("tableACL initialization error: %v", err)
log.Errorf("tableACL initialization error: %v", err)
panic(fmt.Errorf("tableACL initialization error: %v", err))
}
}
@ -42,19 +50,19 @@ func InitFromBytes(config []byte) (err error) {
// <tableRegexPattern1>: {"READER": "*", "WRITER": "<u2>,<u4>...","ADMIN": "<u5>"},
// <tableRegexPattern2>: {"ADMIN": "<u5>"}
//}`)
func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
func load(config []byte) (map[*regexp.Regexp]map[Role]acl.ACL, error) {
var contents map[string]map[string]string
err := json.Unmarshal(config, &contents)
if err != nil {
return nil, err
}
tableAcl := make(map[*regexp.Regexp]map[Role]ACL)
tableAcl := make(map[*regexp.Regexp]map[Role]acl.ACL)
for tblPattern, accessMap := range contents {
re, err := regexp.Compile(tblPattern)
if err != nil {
return nil, fmt.Errorf("regexp compile error %v: %v", tblPattern, err)
}
tableAcl[re] = make(map[Role]ACL)
tableAcl[re] = make(map[Role]acl.ACL)
entriesByRole := make(map[Role][]string)
for i := READER; i < NumRoles; i++ {
@ -71,7 +79,7 @@ func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
}
}
for r, entries := range entriesByRole {
a, err := NewACL(entries)
a, err := newACL(entries)
if err != nil {
return nil, err
}
@ -83,8 +91,8 @@ func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
}
// Authorized returns the list of entities who have at least the
// minimum specified Role on a table.
func Authorized(table string, minRole Role) ACL {
// minimum specified Role on a tablel.
func Authorized(table string, minRole Role) acl.ACL {
// If table ACL is disabled, return nil
if tableAcl == nil {
return nil
@ -98,3 +106,48 @@ func Authorized(table string, minRole Role) ACL {
// No matching patterns for table, allow all access
return all()
}
// Register registers a AclFactory.
func Register(name string, factory acl.Factory) {
mu.Lock()
defer mu.Unlock()
if _, ok := acls[name]; ok {
panic(fmt.Sprintf("register a registered key: %s", name))
}
acls[name] = factory
}
// SetDefaultACL sets the default ACL implementation.
func SetDefaultACL(name string) {
mu.Lock()
defer mu.Unlock()
defaultACL = name
}
// GetCurrentAclFactory returns current table acl implementation.
func GetCurrentAclFactory() acl.Factory {
mu.Lock()
defer mu.Unlock()
if defaultACL == "" {
if len(acls) == 1 {
for _, aclFactory := range acls {
return aclFactory
}
}
panic("there are more than one AclFactory " +
"registered but no default has been given.")
}
aclFactory, ok := acls[defaultACL]
if !ok {
panic(fmt.Sprintf("aclFactory for given default: %s is not found.", defaultACL))
}
return aclFactory
}
func newACL(entries []string) (acl.ACL, error) {
return GetCurrentAclFactory().New(entries)
}
func all() acl.ACL {
return GetCurrentAclFactory().All()
}

Просмотреть файл

@ -1,93 +1,172 @@
// Copyright 2012, Google Inc. All rights reserved.
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tableacl
import "testing"
import (
"fmt"
"math/rand"
"reflect"
"testing"
"time"
func currentUser() string {
return "DummyUser"
"github.com/youtube/vitess/go/testfiles"
"github.com/youtube/vitess/go/vt/tableacl/acl"
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
)
type fakeAclFactory struct{}
func (factory *fakeAclFactory) New(entries []string) (acl.ACL, error) {
return nil, fmt.Errorf("unable to create a new ACL")
}
func TestParseInvalidJSON(t *testing.T) {
checkLoad([]byte(`{1:2}`), false, t)
checkLoad([]byte(`{"1":"2"}`), false, t)
checkLoad([]byte(`{"table1":{1:2}}`), false, t)
func (factory *fakeAclFactory) All() acl.ACL {
return &fakeACL{}
}
func TestInvalidRoleName(t *testing.T) {
checkLoad([]byte(`{"table1":{"SOMEROLE":"u1"}}`), false, t)
// AllString returns a string representation of all users.
func (factory *fakeAclFactory) AllString() string {
return ""
}
func TestInvalidRegex(t *testing.T) {
checkLoad([]byte(`{"table(1":{"READER":"u1"}}`), false, t)
type fakeACL struct{}
func (acl *fakeACL) IsMember(principal string) bool {
return false
}
func TestValidConfigs(t *testing.T) {
checkLoad([]byte(`{"table1":{"READER":"u1"}}`), true, t)
checkLoad([]byte(`{"table1":{"READER":"u1,u2", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,u2", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{
"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"},
"tbl[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3", "ADMIN":"u4"}
}`), true, t)
func TestInitWithInvalidFilePath(t *testing.T) {
setUpTableACL(&simpleacl.Factory{})
defer func() {
err := recover()
if err == nil {
t.Fatalf("init should fail for an invalid config file path")
}
}()
Init("/invalid_file_path")
}
func TestDenyReaderInsert(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", WRITER, t, false)
func TestInitWithInvalidConfigFile(t *testing.T) {
setUpTableACL(&simpleacl.Factory{})
defer func() {
err := recover()
if err == nil {
t.Fatalf("init should fail for an invalid config file")
}
}()
Init(testfiles.Locate("tableacl/invalid_tableacl_config.json"))
}
func TestAllowReaderSelect(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", READER, t, true)
func TestInitWithValidConfig(t *testing.T) {
setUpTableACL(&simpleacl.Factory{})
Init(testfiles.Locate("tableacl/test_table_tableacl_config.json"))
}
func TestDenyReaderDDL(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", ADMIN, t, false)
func TestInitFromBytes(t *testing.T) {
aclFactory := &simpleacl.Factory{}
setUpTableACL(aclFactory)
acl := Authorized("test_table", READER)
if acl != nil {
t.Fatalf("tableacl has not been initialized, should get nil ACL")
}
err := InitFromBytes([]byte(`{"test_table":{"Reader": "vt"}}`))
if err != nil {
t.Fatalf("tableacl init should succeed, but got error: %v", err)
}
acl = Authorized("unknown_table", READER)
if !reflect.DeepEqual(aclFactory.All(), acl) {
t.Fatalf("there is no config for unknown_table, should grand all permission")
}
acl = Authorized("test_table", READER)
if !acl.IsMember("vt") {
t.Fatalf("user: vt should have reader permission to table: test_table")
}
}
func TestAllowUnmatchedTable(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "UNMATCHED_TABLE", ADMIN, t, true)
func TestInvalidTableRegex(t *testing.T) {
setUpTableACL(&simpleacl.Factory{})
err := InitFromBytes([]byte(`{"table(":{"Reader": "vt", "WRITER":"vt"}}`))
if err == nil {
t.Fatalf("tableacl init should fail because config file has an invalid table regex")
}
}
func TestAllUserReadAccess(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + ALL + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", READER, t, true)
func TestInvalidRole(t *testing.T) {
setUpTableACL(&simpleacl.Factory{})
err := InitFromBytes([]byte(`{"test_table":{"InvalidRole": "vt", "Reader": "vt", "WRITER":"vt"}}`))
if err == nil {
t.Fatalf("tableacl init should fail because config file has an invalid role")
}
}
func TestAllUserWriteAccess(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"` + ALL + `"}}`)
checkAccess(configData, "table1", WRITER, t, true)
func TestFailedToCreateACL(t *testing.T) {
setUpTableACL(&fakeAclFactory{})
err := InitFromBytes([]byte(`{"test_table":{"Reader": "vt", "WRITER":"vt"}}`))
if err == nil {
t.Fatalf("tableacl init should fail because fake ACL returns an error")
}
}
func TestDisabled(t *testing.T) {
func TestDoubleRegisterTheSameKey(t *testing.T) {
acls = make(map[string]acl.Factory)
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
Register(name, &simpleacl.Factory{})
defer func() {
err := recover()
if err == nil {
t.Fatalf("the second tableacl register should fail")
}
}()
Register(name, &simpleacl.Factory{})
}
func TestGetAclFactory(t *testing.T) {
acls = make(map[string]acl.Factory)
defaultACL = ""
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
aclFactory := &simpleacl.Factory{}
Register(name, aclFactory)
if !reflect.DeepEqual(aclFactory, GetCurrentAclFactory()) {
t.Fatalf("should return registered acl factory even if default acl is not set.")
}
Register(name+"2", aclFactory)
defer func() {
err := recover()
if err == nil {
t.Fatalf("there are more than one acl factories, but the default is not set")
}
}()
GetCurrentAclFactory()
}
func TestGetAclFactoryWithWrongDefault(t *testing.T) {
acls = make(map[string]acl.Factory)
defaultACL = ""
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
aclFactory := &simpleacl.Factory{}
Register(name, aclFactory)
Register(name+"2", aclFactory)
SetDefaultACL("wrong_name")
defer func() {
err := recover()
if err == nil {
t.Fatalf("there are more than one acl factories, but the default given does not match any of these.")
}
}()
GetCurrentAclFactory()
}
func setUpTableACL(factory acl.Factory) {
tableAcl = nil
got := Authorized("table1", READER)
if got != nil {
t.Errorf("table acl disabled, got: %v, want: nil", got)
}
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
Register(name, factory)
SetDefaultACL(name)
}
func checkLoad(configData []byte, valid bool, t *testing.T) {
err := InitFromBytes(configData)
if !valid && err == nil {
t.Errorf("expecting parse error none returned")
}
if valid && err != nil {
t.Errorf("unexpected load error: %v", err)
}
}
func checkAccess(configData []byte, tableName string, role Role, t *testing.T, want bool) {
checkLoad(configData, true, t)
got := Authorized(tableName, role).IsMember(currentUser())
if want != got {
t.Errorf("got %v, want %v", got, want)
}
func init() {
rand.Seed(time.Now().UnixNano())
}

Просмотреть файл

@ -0,0 +1,119 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testlib
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tableacl/acl"
)
// TestSuite tests a concrete acl.Factory implementation.
func TestSuite(t *testing.T, factory acl.Factory) {
name := fmt.Sprintf("tableacl-test-%d", rand.Int63())
tableacl.Register(name, factory)
tableacl.SetDefaultACL(name)
testParseInvalidJSON(t)
testInvalidRoleName(t)
testInvalidRegex(t)
testValidConfigs(t)
testDenyReaderInsert(t)
testAllowReaderSelect(t)
testDenyReaderDDL(t)
testAllowUnmatchedTable(t)
testAllUserReadAccess(t)
testAllUserWriteAccess(t)
}
func currentUser() string {
return "DummyUser"
}
func testParseInvalidJSON(t *testing.T) {
checkLoad([]byte(`{1:2}`), false, t)
checkLoad([]byte(`{"1":"2"}`), false, t)
checkLoad([]byte(`{"table1":{1:2}}`), false, t)
}
func testInvalidRoleName(t *testing.T) {
checkLoad([]byte(`{"table1":{"SOMEROLE":"u1"}}`), false, t)
}
func testInvalidRegex(t *testing.T) {
checkLoad([]byte(`{"table(1":{"READER":"u1"}}`), false, t)
}
func testValidConfigs(t *testing.T) {
checkLoad([]byte(`{"table1":{"READER":"u1"}}`), true, t)
checkLoad([]byte(`{"table1":{"READER":"u1,u2", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,u2", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,`+allString()+`", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{
"table[0-9]+":{"Reader":"u1,`+allString()+`", "WRITER":"u3"},
"tbl[0-9]+":{"Reader":"u1,`+allString()+`", "WRITER":"u3", "ADMIN":"u4"}
}`), true, t)
}
func testDenyReaderInsert(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", tableacl.WRITER, t, false)
}
func testAllowReaderSelect(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", tableacl.READER, t, true)
}
func testDenyReaderDDL(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", tableacl.ADMIN, t, false)
}
func testAllowUnmatchedTable(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "UNMATCHED_TABLE", tableacl.ADMIN, t, true)
}
func testAllUserReadAccess(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + allString() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", tableacl.READER, t, true)
}
func testAllUserWriteAccess(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"` + allString() + `"}}`)
checkAccess(configData, "table1", tableacl.WRITER, t, true)
}
func checkLoad(configData []byte, valid bool, t *testing.T) {
err := tableacl.InitFromBytes(configData)
if !valid && err == nil {
t.Errorf("expecting parse error none returned")
}
if valid && err != nil {
t.Errorf("unexpected load error: %v", err)
}
}
func checkAccess(configData []byte, tableName string, role tableacl.Role, t *testing.T, want bool) {
checkLoad(configData, true, t)
got := tableacl.Authorized(tableName, role).IsMember(currentUser())
if want != got {
t.Errorf("got %v, want %v", got, want)
}
}
func allString() string {
return tableacl.GetCurrentAclFactory().AllString()
}
func init() {
rand.Seed(time.Now().UnixNano())
}

Просмотреть файл

@ -91,6 +91,9 @@ type ActionAgent struct {
// the reason we're not healthy.
_healthy error
// this is the last time health check ran
_healthyTime time.Time
// replication delay the last time we got it
_replicationDelay time.Duration
@ -247,10 +250,21 @@ func (agent *ActionAgent) Tablet() *topo.TabletInfo {
}
// Healthy reads the result of the latest healthcheck, protected by mutex.
// If that status is too old, it means healthcheck hasn't run for a while,
// and is probably stuck, this is not good, we're not healthy.
func (agent *ActionAgent) Healthy() (time.Duration, error) {
agent.mutex.Lock()
defer agent.mutex.Unlock()
return agent._replicationDelay, agent._healthy
healthy := agent._healthy
if healthy == nil {
timeSinceLastCheck := time.Now().Sub(agent._healthyTime)
if timeSinceLastCheck > *healthCheckInterval*3 {
healthy = fmt.Errorf("last health check is too old: %s > %s", timeSinceLastCheck, *healthCheckInterval*3)
}
}
return agent._replicationDelay, healthy
}
// BlacklistedTables reads the list of blacklisted tables from the TabletControl

Просмотреть файл

@ -228,6 +228,7 @@ func (agent *ActionAgent) runHealthCheck(targetTabletType topo.TabletType) {
// remember our health status
agent.mutex.Lock()
agent._healthy = err
agent._healthyTime = time.Now()
agent._replicationDelay = replicationDelay
agent.mutex.Unlock()

Просмотреть файл

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"html/template"
"strings"
"testing"
"time"
@ -153,6 +154,7 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
// first health check, should change us to replica, and update the
// mysql port to 3306
before := time.Now()
agent.runHealthCheck(targetTabletType)
ti, err := agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -167,9 +169,13 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
if !agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
// now make the tablet unhealthy
agent.HealthReporter.(*fakeHealthCheck).reportError = fmt.Errorf("tablet is unhealthy")
before = time.Now()
agent.runHealthCheck(targetTabletType)
ti, err = agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -181,6 +187,9 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
if agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should not be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
}
// TestQueryServiceNotStarting verifies that if a tablet cannot start the
@ -190,6 +199,7 @@ func TestQueryServiceNotStarting(t *testing.T) {
targetTabletType := topo.TYPE_REPLICA
agent.QueryServiceControl.(*tabletserver.TestQueryServiceControl).AllowQueriesError = fmt.Errorf("test cannot start query service")
before := time.Now()
agent.runHealthCheck(targetTabletType)
ti, err := agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -201,6 +211,9 @@ func TestQueryServiceNotStarting(t *testing.T) {
if agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should not be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
}
// TestQueryServiceStopped verifies that if a healthy tablet's query
@ -210,6 +223,7 @@ func TestQueryServiceStopped(t *testing.T) {
targetTabletType := topo.TYPE_REPLICA
// first health check, should change us to replica
before := time.Now()
agent.runHealthCheck(targetTabletType)
ti, err := agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -221,12 +235,16 @@ func TestQueryServiceStopped(t *testing.T) {
if !agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
// shut down query service and prevent it from starting again
agent.QueryServiceControl.DisallowQueries()
agent.QueryServiceControl.(*tabletserver.TestQueryServiceControl).AllowQueriesError = fmt.Errorf("test cannot start query service")
// health check should now fail
before = time.Now()
agent.runHealthCheck(targetTabletType)
ti, err = agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -238,6 +256,9 @@ func TestQueryServiceStopped(t *testing.T) {
if agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should not be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
}
// TestTabletControl verifies the shard's TabletControl record can disable
@ -247,6 +268,7 @@ func TestTabletControl(t *testing.T) {
targetTabletType := topo.TYPE_REPLICA
// first health check, should change us to replica
before := time.Now()
agent.runHealthCheck(targetTabletType)
ti, err := agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -258,6 +280,9 @@ func TestTabletControl(t *testing.T) {
if !agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
// now update the shard
si, err := agent.TopoServer.GetShard(keyspace, shard)
@ -286,6 +311,7 @@ func TestTabletControl(t *testing.T) {
}
// check running a health check will not start it again
before = time.Now()
agent.runHealthCheck(targetTabletType)
ti, err = agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -297,9 +323,13 @@ func TestTabletControl(t *testing.T) {
if agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should not be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
// go unhealthy, check we go to spare and QS is not running
agent.HealthReporter.(*fakeHealthCheck).reportError = fmt.Errorf("tablet is unhealthy")
before = time.Now()
agent.runHealthCheck(targetTabletType)
ti, err = agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -311,9 +341,13 @@ func TestTabletControl(t *testing.T) {
if agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should not be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
// go back healthy, check QS is still not running
agent.HealthReporter.(*fakeHealthCheck).reportError = nil
before = time.Now()
agent.runHealthCheck(targetTabletType)
ti, err = agent.TopoServer.GetTablet(tabletAlias)
if err != nil {
@ -325,4 +359,33 @@ func TestTabletControl(t *testing.T) {
if agent.QueryServiceControl.IsServing() {
t.Errorf("Query service should not be running")
}
if agent._healthyTime.Sub(before) < 0 {
t.Errorf("runHealthCheck did not update agent._healthyTime")
}
}
// TestOldHealthCheck verifies that a healthcheck that is too old will
// return an error
func TestOldHealthCheck(t *testing.T) {
agent := createTestAgent(t)
*healthCheckInterval = 20 * time.Second
agent._healthy = nil
// last health check time is now, we're good
agent._healthyTime = time.Now()
if _, healthy := agent.Healthy(); healthy != nil {
t.Errorf("Healthy returned unexpected error: %v", healthy)
}
// last health check time is 2x interval ago, we're good
agent._healthyTime = time.Now().Add(-2 * *healthCheckInterval)
if _, healthy := agent.Healthy(); healthy != nil {
t.Errorf("Healthy returned unexpected error: %v", healthy)
}
// last health check time is 4x interval ago, we're not good
agent._healthyTime = time.Now().Add(-4 * *healthCheckInterval)
if _, healthy := agent.Healthy(); healthy == nil || !strings.Contains(healthy.Error(), "last health check is too old") {
t.Errorf("Healthy returned wrong error: %v", healthy)
}
}

Просмотреть файл

@ -252,7 +252,7 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re
qre.logStats.CacheAbsent = absent
qre.logStats.CacheMisses = misses
qre.logStats.QuerySources |= QUERY_SOURCE_ROWCACHE
qre.logStats.QuerySources |= QuerySourceRowcache
tableInfo.hits.Add(hits)
tableInfo.absent.Add(absent)
@ -538,7 +538,7 @@ func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser
q.Result, q.Err = qre.execSQLNoPanic(conn, sql, false)
}
} else {
logStats.QuerySources |= QUERY_SOURCE_CONSOLIDATOR
logStats.QuerySources |= QuerySourceConsolidator
startTime := time.Now()
q.Wait()
waitStats.Record("Consolidations", startTime)

Просмотреть файл

@ -6,7 +6,6 @@ package tabletserver
import (
"fmt"
"html/template"
"math/rand"
"reflect"
"testing"
@ -15,6 +14,8 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/vt/callinfo"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
"github.com/youtube/vitess/go/vt/tabletserver/fakecacheservice"
"github.com/youtube/vitess/go/vt/tabletserver/fakesqldb"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
@ -22,27 +23,6 @@ import (
"golang.org/x/net/context"
)
type fakeCallInfo struct {
remoteAddr string
username string
}
func (fci *fakeCallInfo) RemoteAddr() string {
return fci.remoteAddr
}
func (fci *fakeCallInfo) Username() string {
return fci.username
}
func (fci *fakeCallInfo) Text() string {
return ""
}
func (fci *fakeCallInfo) HTML() template.HTML {
return template.HTML("")
}
func TestQueryExecutorPlanDDL(t *testing.T) {
db := setUpQueryExecutorTest()
query := "alter table test_table add zipcode int"
@ -567,53 +547,57 @@ func TestQueryExecutorPlanOther(t *testing.T) {
checkEqual(t, expected, qre.Execute())
}
//func TestQueryExecutorTableAcl(t *testing.T) {
// db := setUpQueryExecutorTest()
// query := "select * from test_table limit 1000"
// expected := &mproto.QueryResult{
// Fields: getTestTableFields(),
// RowsAffected: 0,
// Rows: [][]sqltypes.Value{},
// }
// db.AddQuery(query, expected)
// db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{
// Fields: getTestTableFields(),
// })
//
// username := "u2"
// callInfo := &fakeCallInfo{
// remoteAddr: "1.2.3.4",
// username: username,
// }
// ctx := callinfo.NewContext(context.Background(), callInfo)
// if err := tableacl.InitFromBytes(
// []byte(fmt.Sprintf(`{"test_table":{"READER":"%s"}}`, username))); err != nil {
// t.Fatalf("unable to load tableacl config, error: %v", err)
// }
//
// qre, sqlQuery := newTestQueryExecutor(
// query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
// checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
// checkEqual(t, expected, qre.Execute())
// sqlQuery.disallowQueries()
//
// if err := tableacl.InitFromBytes([]byte(`{"test_table":{"READER":"superuser"}}`)); err != nil {
// t.Fatalf("unable to load tableacl config, error: %v", err)
// }
// // without enabling Config.StrictTableAcl
// qre, sqlQuery = newTestQueryExecutor(
// query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
// checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
// qre.Execute()
// sqlQuery.disallowQueries()
// // enable Config.StrictTableAcl
// qre, sqlQuery = newTestQueryExecutor(
// query, ctx, enableRowCache|enableSchemaOverrides|enableStrict|enableStrictTableAcl)
// defer sqlQuery.disallowQueries()
// checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
// defer handleAndVerifyTabletError(t, "query should fail because current user do not have read permissions", ErrFail)
// qre.Execute()
//}
func TestQueryExecutorTableAcl(t *testing.T) {
aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63())
tableacl.Register(aclName, &simpleacl.Factory{})
tableacl.SetDefaultACL(aclName)
db := setUpQueryExecutorTest()
query := "select * from test_table limit 1000"
expected := &mproto.QueryResult{
Fields: getTestTableFields(),
RowsAffected: 0,
Rows: [][]sqltypes.Value{},
}
db.AddQuery(query, expected)
db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{
Fields: getTestTableFields(),
})
username := "u2"
callInfo := &fakeCallInfo{
remoteAddr: "1.2.3.4",
username: username,
}
ctx := callinfo.NewContext(context.Background(), callInfo)
if err := tableacl.InitFromBytes(
[]byte(fmt.Sprintf(`{"test_table":{"READER":"%s"}}`, username))); err != nil {
t.Fatalf("unable to load tableacl config, error: %v", err)
}
qre, sqlQuery := newTestQueryExecutor(
query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
checkEqual(t, expected, qre.Execute())
sqlQuery.disallowQueries()
if err := tableacl.InitFromBytes([]byte(`{"test_table":{"READER":"superuser"}}`)); err != nil {
t.Fatalf("unable to load tableacl config, error: %v", err)
}
// without enabling Config.StrictTableAcl
qre, sqlQuery = newTestQueryExecutor(
query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
qre.Execute()
sqlQuery.disallowQueries()
// enable Config.StrictTableAcl
qre, sqlQuery = newTestQueryExecutor(
query, ctx, enableRowCache|enableSchemaOverrides|enableStrict|enableStrictTableAcl)
defer sqlQuery.disallowQueries()
checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
defer handleAndVerifyTabletError(t, "query should fail because current user do not have read permissions", ErrFail)
qre.Execute()
}
func TestQueryExecutorBlacklistQRFail(t *testing.T) {
db := setUpQueryExecutorTest()

Просмотреть файл

@ -98,14 +98,9 @@ func (ql *QueryList) GetQueryzRows() []QueryDetailzRow {
ql.mu.Lock()
rows := []QueryDetailzRow{}
for _, qd := range ql.queryDetails {
var h template.HTML
ci, ok := callinfo.FromContext(qd.context)
if ok {
h = ci.HTML()
}
row := QueryDetailzRow{
Query: qd.conn.Current(),
ContextHTML: h,
ContextHTML: callinfo.HTMLFromContext(qd.context),
Start: qd.start,
Duration: time.Now().Sub(qd.start),
ConnID: qd.connID,

Просмотреть файл

@ -7,20 +7,25 @@ import (
)
type testConn struct {
id int64
query string
id int64
query string
killed bool
}
func (tc testConn) Current() string { return tc.query }
func (tc *testConn) Current() string { return tc.query }
func (tc testConn) ID() int64 { return tc.id }
func (tc *testConn) ID() int64 { return tc.id }
func (tc testConn) Kill() {}
func (tc *testConn) Kill() { tc.killed = true }
func (tc *testConn) IsKilled() bool {
return tc.killed
}
func TestQueryList(t *testing.T) {
ql := NewQueryList()
connID := int64(1)
qd := NewQueryDetail(context.Background(), testConn{id: connID})
qd := NewQueryDetail(context.Background(), &testConn{id: connID})
ql.Add(qd)
if qd1, ok := ql.queryDetails[connID]; !ok || qd1.connID != connID {
@ -28,7 +33,7 @@ func TestQueryList(t *testing.T) {
}
conn2ID := int64(2)
qd2 := NewQueryDetail(context.Background(), testConn{id: conn2ID})
qd2 := NewQueryDetail(context.Background(), &testConn{id: conn2ID})
ql.Add(qd2)
rows := ql.GetQueryzRows()

Просмотреть файл

@ -382,6 +382,15 @@ func (rqsc *realQueryServiceControl) registerDebugHealthHandler() {
})
}
func (rqsc *realQueryServiceControl) registerStreamQueryzHandlers() {
http.HandleFunc("/streamqueryz", func(w http.ResponseWriter, r *http.Request) {
streamQueryzHandler(rqsc.sqlQueryRPCService.qe.streamQList, w, r)
})
http.HandleFunc("/streamqueryz/terminate", func(w http.ResponseWriter, r *http.Request) {
streamQueryzTerminateHandler(rqsc.sqlQueryRPCService.qe.streamQList, w, r)
})
}
func buildFmter(logger *streamlog.StreamLogger) func(url.Values, interface{}) string {
type formatter interface {
Format(url.Values) string

Просмотреть файл

@ -113,7 +113,9 @@ func querylogzHandler(w http.ResponseWriter, r *http.Request) {
*SQLQueryStats
ColorLevel string
}{stats, level}
querylogzTmpl.Execute(w, tmplData)
if err := querylogzTmpl.Execute(w, tmplData); err != nil {
log.Errorf("querylogz: couldn't execute template: %v", err)
}
case <-tmr.C:
return
}

Просмотреть файл

@ -11,6 +11,7 @@ import (
"sort"
"time"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/acl"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
)
@ -150,7 +151,9 @@ func (rqsc *realQueryServiceControl) registerQueryzHandler() {
}
sort.Sort(&sorter)
for _, Value := range sorter.rows {
queryzTmpl.Execute(w, Value)
if err := queryzTmpl.Execute(w, Value); err != nil {
log.Errorf("queryz: couldn't execute template: %v", err)
}
}
})
}

Просмотреть файл

@ -22,6 +22,7 @@ import (
"github.com/youtube/vitess/go/timer"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/tableacl"
tacl "github.com/youtube/vitess/go/vt/tableacl/acl"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
"golang.org/x/net/context"
)
@ -44,7 +45,7 @@ type ExecPlan struct {
TableInfo *TableInfo
Fields []mproto.Field
Rules *QueryRules
Authorized tableacl.ACL
Authorized tacl.ACL
mu sync.Mutex
QueryCount int64

Просмотреть файл

@ -9,6 +9,7 @@ import (
"net/http"
"sort"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/acl"
"github.com/youtube/vitess/go/vt/schema"
)
@ -77,7 +78,9 @@ func (rqsc *realQueryServiceControl) registerSchemazHandler() {
}
for _, Value := range sorter.rows {
envelope.Table = Value
schemazTmpl.Execute(w, envelope)
if err := schemazTmpl.Execute(w, envelope); err != nil {
log.Errorf("schemaz: couldn't execute template: %v", err)
}
}
})
}

Просмотреть файл

@ -7,6 +7,7 @@ package tabletserver
import (
"encoding/json"
"fmt"
"html/template"
"net/url"
"strings"
"time"
@ -22,9 +23,12 @@ import (
var SqlQueryLogger = streamlog.New("SqlQuery", 50)
const (
QUERY_SOURCE_ROWCACHE = 1 << iota
QUERY_SOURCE_CONSOLIDATOR
QUERY_SOURCE_MYSQL
// QuerySourceRowcache means query result is found in rowcache.
QuerySourceRowcache = 1 << iota
// QuerySourceConsolidator means query result is found in consolidator.
QuerySourceConsolidator
// QuerySourceMySQL means query result is returned from MySQL.
QuerySourceMySQL
)
// SQLQueryStats records the stats for a single query
@ -67,7 +71,7 @@ func (stats *SQLQueryStats) Send() {
// AddRewrittenSql adds a single sql statement to the rewritten list
func (stats *SQLQueryStats) AddRewrittenSql(sql string, start time.Time) {
stats.QuerySources |= QUERY_SOURCE_MYSQL
stats.QuerySources |= QuerySourceMySQL
stats.NumberOfQueries++
stats.rewrittenSqls = append(stats.rewrittenSqls, sql)
stats.MysqlResponseTime += time.Now().Sub(start)
@ -139,21 +143,28 @@ func (stats *SQLQueryStats) FmtQuerySources() string {
}
sources := make([]string, 3)
n := 0
if stats.QuerySources&QUERY_SOURCE_MYSQL != 0 {
if stats.QuerySources&QuerySourceMySQL != 0 {
sources[n] = "mysql"
n++
}
if stats.QuerySources&QUERY_SOURCE_ROWCACHE != 0 {
if stats.QuerySources&QuerySourceRowcache != 0 {
sources[n] = "rowcache"
n++
}
if stats.QuerySources&QUERY_SOURCE_CONSOLIDATOR != 0 {
if stats.QuerySources&QuerySourceConsolidator != 0 {
sources[n] = "consolidator"
n++
}
return strings.Join(sources[:n], ",")
}
// ContextHTML returns the HTML version of the context that was used, or "".
// This is a method on SQLQueryStats instead of a field so that it doesn't need
// to be passed by value everywhere.
func (stats *SQLQueryStats) ContextHTML() template.HTML {
return callinfo.HTMLFromContext(stats.context)
}
// ErrorStr returns the error string or ""
func (stats *SQLQueryStats) ErrorStr() string {
if stats.Error != nil {

Просмотреть файл

@ -0,0 +1,142 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tabletserver
import (
"fmt"
"net/url"
"strings"
"testing"
"time"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/vt/callinfo"
"golang.org/x/net/context"
)
func TestSqlQueryStats(t *testing.T) {
logStats := newSqlQueryStats("test", context.Background())
logStats.AddRewrittenSql("sql1", time.Now())
if !strings.Contains(logStats.RewrittenSql(), "sql1") {
t.Fatalf("RewrittenSql should contains sql: sql1")
}
if logStats.SizeOfResponse() != 0 {
t.Fatalf("there is no rows in log stats, estimated size should be 0 bytes")
}
logStats.Rows = [][]sqltypes.Value{[]sqltypes.Value{sqltypes.MakeString([]byte("a"))}}
if logStats.SizeOfResponse() <= 0 {
t.Fatalf("log stats has some rows, should have positive response size")
}
params := map[string][]string{"full": []string{}}
logStats.Format(url.Values(params))
}
func TestSqlQueryStatsFormatBindVariables(t *testing.T) {
logStats := newSqlQueryStats("test", context.Background())
logStats.BindVariables = make(map[string]interface{})
logStats.BindVariables["key_1"] = "val_1"
logStats.BindVariables["key_2"] = 789
formattedStr := logStats.FmtBindVariables(true)
if !strings.Contains(formattedStr, "key_1") ||
!strings.Contains(formattedStr, "val_1") {
t.Fatalf("bind variable 'key_1': 'val_1' is not formatted")
}
if !strings.Contains(formattedStr, "key_2") ||
!strings.Contains(formattedStr, "789") {
t.Fatalf("bind variable 'key_2': '789' is not formatted")
}
logStats.BindVariables["key_3"] = []byte("val_3")
formattedStr = logStats.FmtBindVariables(false)
if !strings.Contains(formattedStr, "key_1") {
t.Fatalf("bind variable 'key_1' is not formatted")
}
if !strings.Contains(formattedStr, "key_2") ||
!strings.Contains(formattedStr, "789") {
t.Fatalf("bind variable 'key_2': '789' is not formatted")
}
if !strings.Contains(formattedStr, "key_3") {
t.Fatalf("bind variable 'key_3' is not formatted")
}
}
func TestSqlQueryStatsFormatQuerySources(t *testing.T) {
logStats := newSqlQueryStats("test", context.Background())
if logStats.FmtQuerySources() != "none" {
t.Fatalf("should return none since log stats does not have any query source, but got: %s", logStats.FmtQuerySources())
}
logStats.QuerySources |= QuerySourceMySQL
if !strings.Contains(logStats.FmtQuerySources(), "mysql") {
t.Fatalf("'mysql' should be in formated query sources")
}
logStats.QuerySources |= QuerySourceRowcache
if !strings.Contains(logStats.FmtQuerySources(), "rowcache") {
t.Fatalf("'rowcache' should be in formated query sources")
}
logStats.QuerySources |= QuerySourceConsolidator
if !strings.Contains(logStats.FmtQuerySources(), "consolidator") {
t.Fatalf("'consolidator' should be in formated query sources")
}
}
func TestSqlQueryStatsContextHTML(t *testing.T) {
html := "HtmlContext"
callInfo := &fakeCallInfo{
html: html,
}
ctx := callinfo.NewContext(context.Background(), callInfo)
logStats := newSqlQueryStats("test", ctx)
if string(logStats.ContextHTML()) != html {
t.Fatalf("expect to get html: %s, but got: %s", html, string(logStats.ContextHTML()))
}
}
func TestSqlQueryStatsErrorStr(t *testing.T) {
logStats := newSqlQueryStats("test", context.Background())
if logStats.ErrorStr() != "" {
t.Fatalf("should not get error in stats, but got: %s", logStats.ErrorStr())
}
errStr := "unknown error"
logStats.Error = fmt.Errorf(errStr)
if logStats.ErrorStr() != errStr {
t.Fatalf("expect to get error string: %s, but got: %s", errStr, logStats.ErrorStr())
}
}
func TestSqlQueryStatsRemoteAddrUsername(t *testing.T) {
logStats := newSqlQueryStats("test", context.Background())
addr, user := logStats.RemoteAddrUsername()
if addr != "" {
t.Fatalf("remote addr should be empty")
}
if user != "" {
t.Fatalf("username should be empty")
}
remoteAddr := "1.2.3.4"
username := "vt"
callInfo := &fakeCallInfo{
remoteAddr: remoteAddr,
username: username,
}
ctx := callinfo.NewContext(context.Background(), callInfo)
logStats = newSqlQueryStats("test", ctx)
addr, user = logStats.RemoteAddrUsername()
if addr != remoteAddr {
t.Fatalf("expected to get remote addr: %s, but got: %s", remoteAddr, addr)
}
if user != username {
t.Fatalf("expected to get username: %s, but got: %s", username, user)
}
}

Просмотреть файл

@ -1,3 +1,7 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tabletserver
import (
@ -7,6 +11,7 @@ import (
"strconv"
"text/template"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/acl"
)
@ -23,7 +28,7 @@ var (
</thead>
`)
streamqueryzTmpl = template.Must(template.New("example").Parse(`
<tr>
<tr>
<td>{{.Query}}</td>
<td>{{.ContextHTML}}</td>
<td>{{.Duration}}</td>
@ -34,56 +39,55 @@ var (
`))
)
func (rqsc *realQueryServiceControl) registerStreamQueryzHandlers() {
streamqueryzHandler := func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
acl.SendError(w, err)
func streamQueryzHandler(queryList *QueryList, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
acl.SendError(w, err)
return
}
rows := queryList.GetQueryzRows()
if err := r.ParseForm(); err != nil {
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
return
}
format := r.FormValue("format")
if format == "json" {
js, err := json.Marshal(rows)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
rows := rqsc.sqlQueryRPCService.qe.streamQList.GetQueryzRows()
if err := r.ParseForm(); err != nil {
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
return
}
format := r.FormValue("format")
if format == "json" {
js, err := json.Marshal(rows)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(js)
return
}
startHTMLTable(w)
defer endHTMLTable(w)
w.Write(streamqueryzHeader)
for i := range rows {
streamqueryzTmpl.Execute(w, rows[i])
w.Header().Set("Content-Type", "application/json")
w.Write(js)
return
}
startHTMLTable(w)
defer endHTMLTable(w)
w.Write(streamqueryzHeader)
for i := range rows {
if err := streamqueryzTmpl.Execute(w, rows[i]); err != nil {
log.Errorf("streamlogz: couldn't execute template: %v", err)
}
}
http.HandleFunc("/streamqueryz", streamqueryzHandler)
http.HandleFunc("/streamqueryz/terminate", func(w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.ADMIN); err != nil {
acl.SendError(w, err)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
return
}
connID := r.FormValue("connID")
c, err := strconv.Atoi(connID)
if err != nil {
http.Error(w, "invalid connID", http.StatusInternalServerError)
return
}
if err = rqsc.sqlQueryRPCService.qe.streamQList.Terminate(int64(c)); err != nil {
http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
return
}
streamqueryzHandler(w, r)
})
}
func streamQueryzTerminateHandler(queryList *QueryList, w http.ResponseWriter, r *http.Request) {
if err := acl.CheckAccessHTTP(r, acl.ADMIN); err != nil {
acl.SendError(w, err)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
return
}
connID := r.FormValue("connID")
c, err := strconv.Atoi(connID)
if err != nil {
http.Error(w, "invalid connID", http.StatusInternalServerError)
return
}
if err = queryList.Terminate(int64(c)); err != nil {
http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
return
}
streamQueryzHandler(queryList, w, r)
}

Просмотреть файл

@ -0,0 +1,95 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tabletserver
import (
"net/http"
"net/http/httptest"
"testing"
"golang.org/x/net/context"
)
func TestStreamQueryzHandlerJSON(t *testing.T) {
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/streamqueryz?format=json", nil)
queryList := NewQueryList()
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
streamQueryzHandler(queryList, resp, req)
}
func TestStreamQueryzHandlerHTTP(t *testing.T) {
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/streamqueryz", nil)
queryList := NewQueryList()
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
streamQueryzHandler(queryList, resp, req)
}
func TestStreamQueryzHandlerHTTPFailedInvalidForm(t *testing.T) {
resp := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/streamqueryz", nil)
streamQueryzHandler(NewQueryList(), resp, req)
if resp.Code != http.StatusInternalServerError {
t.Fatalf("http call should fail and return code: %d, but got: %d",
http.StatusInternalServerError, resp.Code)
}
}
func TestStreamQueryzHandlerTerminateConn(t *testing.T) {
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/streamqueryz/terminate?connID=1", nil)
queryList := NewQueryList()
testConn := &testConn{id: 1}
queryList.Add(NewQueryDetail(context.Background(), testConn))
if testConn.IsKilled() {
t.Fatalf("conn should still be alive")
}
streamQueryzTerminateHandler(queryList, resp, req)
if !testConn.IsKilled() {
t.Fatalf("conn should be killed")
}
}
func TestStreamQueryzHandlerTerminateFailedInvalidConnID(t *testing.T) {
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/streamqueryz/terminate?connID=invalid", nil)
streamQueryzTerminateHandler(NewQueryList(), resp, req)
if resp.Code != http.StatusInternalServerError {
t.Fatalf("http call should fail and return code: %d, but got: %d",
http.StatusInternalServerError, resp.Code)
}
}
func TestStreamQueryzHandlerTerminateFailedKnownConnID(t *testing.T) {
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/streamqueryz/terminate?connID=10", nil)
streamQueryzTerminateHandler(NewQueryList(), resp, req)
if resp.Code != http.StatusInternalServerError {
t.Fatalf("http call should fail and return code: %d, but got: %d",
http.StatusInternalServerError, resp.Code)
}
}
func TestStreamQueryzHandlerTerminateFailedInvalidForm(t *testing.T) {
resp := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/streamqueryz/terminate?inva+lid=2", nil)
streamQueryzTerminateHandler(NewQueryList(), resp, req)
if resp.Code != http.StatusInternalServerError {
t.Fatalf("http call should fail and return code: %d, but got: %d",
http.StatusInternalServerError, resp.Code)
}
}

Просмотреть файл

@ -0,0 +1,30 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tabletserver
import "html/template"
type fakeCallInfo struct {
remoteAddr string
username string
text string
html string
}
func (fci *fakeCallInfo) RemoteAddr() string {
return fci.remoteAddr
}
func (fci *fakeCallInfo) Username() string {
return fci.username
}
func (fci *fakeCallInfo) Text() string {
return fci.text
}
func (fci *fakeCallInfo) HTML() template.HTML {
return template.HTML(fci.html)
}

Просмотреть файл

@ -122,7 +122,9 @@ func txlogzHandler(w http.ResponseWriter, req *http.Request) {
Duration float64
ColorLevel string
}{txc, duration, level}
txlogzTmpl.Execute(w, tmplData)
if err := txlogzTmpl.Execute(w, tmplData); err != nil {
log.Errorf("txlogz: couldn't execute template: %v", err)
}
case <-tmr.C:
return
}

Просмотреть файл

@ -46,6 +46,7 @@ func (batchQueryShard *BatchQueryShard) MarshalBson(buf *bytes2.ChunkedWriter, k
} else {
(*batchQueryShard.Session).MarshalBson(buf, "Session")
}
bson.EncodeBool(buf, "NotInTransaction", batchQueryShard.NotInTransaction)
lenWriter.Close()
}
@ -104,6 +105,8 @@ func (batchQueryShard *BatchQueryShard) UnmarshalBson(buf *bytes.Buffer, kind by
batchQueryShard.Session = new(Session)
(*batchQueryShard.Session).UnmarshalBson(buf, kind)
}
case "NotInTransaction":
batchQueryShard.NotInTransaction = bson.DecodeBool(buf, kind)
default:
bson.Skip(buf, kind)
}

Просмотреть файл

@ -47,6 +47,7 @@ func (entityIdsQuery *EntityIdsQuery) MarshalBson(buf *bytes2.ChunkedWriter, key
} else {
(*entityIdsQuery.Session).MarshalBson(buf, "Session")
}
bson.EncodeBool(buf, "NotInTransaction", entityIdsQuery.NotInTransaction)
lenWriter.Close()
}
@ -109,6 +110,8 @@ func (entityIdsQuery *EntityIdsQuery) UnmarshalBson(buf *bytes.Buffer, kind byte
entityIdsQuery.Session = new(Session)
(*entityIdsQuery.Session).UnmarshalBson(buf, kind)
}
case "NotInTransaction":
entityIdsQuery.NotInTransaction = bson.DecodeBool(buf, kind)
default:
bson.Skip(buf, kind)
}

Просмотреть файл

@ -12,7 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
kproto "github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/key"
)
// MarshalBson bson-encodes KeyRangeQuery.
@ -31,7 +31,7 @@ func (keyRangeQuery *KeyRangeQuery) MarshalBson(buf *bytes2.ChunkedWriter, key s
lenWriter.Close()
}
bson.EncodeString(buf, "Keyspace", keyRangeQuery.Keyspace)
// []kproto.KeyRange
// []key.KeyRange
{
bson.EncodePrefix(buf, bson.Array, "KeyRanges")
lenWriter := bson.NewLenWriter(buf)
@ -47,6 +47,7 @@ func (keyRangeQuery *KeyRangeQuery) MarshalBson(buf *bytes2.ChunkedWriter, key s
} else {
(*keyRangeQuery.Session).MarshalBson(buf, "Session")
}
bson.EncodeBool(buf, "NotInTransaction", keyRangeQuery.NotInTransaction)
lenWriter.Close()
}
@ -85,16 +86,16 @@ func (keyRangeQuery *KeyRangeQuery) UnmarshalBson(buf *bytes.Buffer, kind byte)
case "Keyspace":
keyRangeQuery.Keyspace = bson.DecodeString(buf, kind)
case "KeyRanges":
// []kproto.KeyRange
// []key.KeyRange
if kind != bson.Null {
if kind != bson.Array {
panic(bson.NewBsonError("unexpected kind %v for keyRangeQuery.KeyRanges", kind))
}
bson.Next(buf, 4)
keyRangeQuery.KeyRanges = make([]kproto.KeyRange, 0, 8)
keyRangeQuery.KeyRanges = make([]key.KeyRange, 0, 8)
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
bson.SkipIndex(buf)
var _v2 kproto.KeyRange
var _v2 key.KeyRange
_v2.UnmarshalBson(buf, kind)
keyRangeQuery.KeyRanges = append(keyRangeQuery.KeyRanges, _v2)
}
@ -107,6 +108,8 @@ func (keyRangeQuery *KeyRangeQuery) UnmarshalBson(buf *bytes.Buffer, kind byte)
keyRangeQuery.Session = new(Session)
(*keyRangeQuery.Session).UnmarshalBson(buf, kind)
}
case "NotInTransaction":
keyRangeQuery.NotInTransaction = bson.DecodeBool(buf, kind)
default:
bson.Skip(buf, kind)
}

Просмотреть файл

@ -12,7 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
kproto "github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/key"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
@ -31,7 +31,7 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) MarshalBson(buf *bytes2.Chunke
lenWriter.Close()
}
bson.EncodeString(buf, "Keyspace", keyspaceIdBatchQuery.Keyspace)
// []kproto.KeyspaceId
// []key.KeyspaceId
{
bson.EncodePrefix(buf, bson.Array, "KeyspaceIds")
lenWriter := bson.NewLenWriter(buf)
@ -47,6 +47,7 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) MarshalBson(buf *bytes2.Chunke
} else {
(*keyspaceIdBatchQuery.Session).MarshalBson(buf, "Session")
}
bson.EncodeBool(buf, "NotInTransaction", keyspaceIdBatchQuery.NotInTransaction)
lenWriter.Close()
}
@ -83,16 +84,16 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) UnmarshalBson(buf *bytes.Buffe
case "Keyspace":
keyspaceIdBatchQuery.Keyspace = bson.DecodeString(buf, kind)
case "KeyspaceIds":
// []kproto.KeyspaceId
// []key.KeyspaceId
if kind != bson.Null {
if kind != bson.Array {
panic(bson.NewBsonError("unexpected kind %v for keyspaceIdBatchQuery.KeyspaceIds", kind))
}
bson.Next(buf, 4)
keyspaceIdBatchQuery.KeyspaceIds = make([]kproto.KeyspaceId, 0, 8)
keyspaceIdBatchQuery.KeyspaceIds = make([]key.KeyspaceId, 0, 8)
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
bson.SkipIndex(buf)
var _v2 kproto.KeyspaceId
var _v2 key.KeyspaceId
_v2.UnmarshalBson(buf, kind)
keyspaceIdBatchQuery.KeyspaceIds = append(keyspaceIdBatchQuery.KeyspaceIds, _v2)
}
@ -105,6 +106,8 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) UnmarshalBson(buf *bytes.Buffe
keyspaceIdBatchQuery.Session = new(Session)
(*keyspaceIdBatchQuery.Session).UnmarshalBson(buf, kind)
}
case "NotInTransaction":
keyspaceIdBatchQuery.NotInTransaction = bson.DecodeBool(buf, kind)
default:
bson.Skip(buf, kind)
}

Просмотреть файл

@ -12,7 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
kproto "github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/key"
)
// MarshalBson bson-encodes KeyspaceIdQuery.
@ -31,7 +31,7 @@ func (keyspaceIdQuery *KeyspaceIdQuery) MarshalBson(buf *bytes2.ChunkedWriter, k
lenWriter.Close()
}
bson.EncodeString(buf, "Keyspace", keyspaceIdQuery.Keyspace)
// []kproto.KeyspaceId
// []key.KeyspaceId
{
bson.EncodePrefix(buf, bson.Array, "KeyspaceIds")
lenWriter := bson.NewLenWriter(buf)
@ -47,6 +47,7 @@ func (keyspaceIdQuery *KeyspaceIdQuery) MarshalBson(buf *bytes2.ChunkedWriter, k
} else {
(*keyspaceIdQuery.Session).MarshalBson(buf, "Session")
}
bson.EncodeBool(buf, "NotInTransaction", keyspaceIdQuery.NotInTransaction)
lenWriter.Close()
}
@ -85,16 +86,16 @@ func (keyspaceIdQuery *KeyspaceIdQuery) UnmarshalBson(buf *bytes.Buffer, kind by
case "Keyspace":
keyspaceIdQuery.Keyspace = bson.DecodeString(buf, kind)
case "KeyspaceIds":
// []kproto.KeyspaceId
// []key.KeyspaceId
if kind != bson.Null {
if kind != bson.Array {
panic(bson.NewBsonError("unexpected kind %v for keyspaceIdQuery.KeyspaceIds", kind))
}
bson.Next(buf, 4)
keyspaceIdQuery.KeyspaceIds = make([]kproto.KeyspaceId, 0, 8)
keyspaceIdQuery.KeyspaceIds = make([]key.KeyspaceId, 0, 8)
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
bson.SkipIndex(buf)
var _v2 kproto.KeyspaceId
var _v2 key.KeyspaceId
_v2.UnmarshalBson(buf, kind)
keyspaceIdQuery.KeyspaceIds = append(keyspaceIdQuery.KeyspaceIds, _v2)
}
@ -107,6 +108,8 @@ func (keyspaceIdQuery *KeyspaceIdQuery) UnmarshalBson(buf *bytes.Buffer, kind by
keyspaceIdQuery.Session = new(Session)
(*keyspaceIdQuery.Session).UnmarshalBson(buf, kind)
}
case "NotInTransaction":
keyspaceIdQuery.NotInTransaction = bson.DecodeBool(buf, kind)
default:
bson.Skip(buf, kind)
}

Просмотреть файл

@ -36,6 +36,7 @@ func (query *Query) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
} else {
(*query.Session).MarshalBson(buf, "Session")
}
bson.EncodeBool(buf, "NotInTransaction", query.NotInTransaction)
lenWriter.Close()
}
@ -79,6 +80,8 @@ func (query *Query) UnmarshalBson(buf *bytes.Buffer, kind byte) {
query.Session = new(Session)
(*query.Session).UnmarshalBson(buf, kind)
}
case "NotInTransaction":
query.NotInTransaction = bson.DecodeBool(buf, kind)
default:
bson.Skip(buf, kind)
}

Просмотреть файл

@ -46,6 +46,7 @@ func (queryShard *QueryShard) MarshalBson(buf *bytes2.ChunkedWriter, key string)
} else {
(*queryShard.Session).MarshalBson(buf, "Session")
}
bson.EncodeBool(buf, "NotInTransaction", queryShard.NotInTransaction)
lenWriter.Close()
}
@ -106,6 +107,8 @@ func (queryShard *QueryShard) UnmarshalBson(buf *bytes.Buffer, kind byte) {
queryShard.Session = new(Session)
(*queryShard.Session).UnmarshalBson(buf, kind)
}
case "NotInTransaction":
queryShard.NotInTransaction = bson.DecodeBool(buf, kind)
default:
bson.Skip(buf, kind)
}

Просмотреть файл

@ -43,10 +43,11 @@ func (shardSession *ShardSession) String() string {
// Query represents a keyspace agnostic query request.
type Query struct {
Sql string
BindVariables map[string]interface{}
TabletType topo.TabletType
Session *Session
Sql string
BindVariables map[string]interface{}
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
//go:generate bsongen -file $GOFILE -type Query -o query_bson.go
@ -54,12 +55,13 @@ type Query struct {
// QueryShard represents a query request for the
// specified list of shards.
type QueryShard struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
Sql string
BindVariables map[string]interface{}
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
//go:generate bsongen -file $GOFILE -type QueryShard -o query_shard_bson.go
@ -67,12 +69,13 @@ type QueryShard struct {
// KeyspaceIdQuery represents a query request for the
// specified list of keyspace IDs.
type KeyspaceIdQuery struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyspaceIds []key.KeyspaceId
TabletType topo.TabletType
Session *Session
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyspaceIds []key.KeyspaceId
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
//go:generate bsongen -file $GOFILE -type KeyspaceIdQuery -o keyspace_id_query_bson.go
@ -80,12 +83,13 @@ type KeyspaceIdQuery struct {
// KeyRangeQuery represents a query request for the
// specified list of keyranges.
type KeyRangeQuery struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRanges []key.KeyRange
TabletType topo.TabletType
Session *Session
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRanges []key.KeyRange
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
//go:generate bsongen -file $GOFILE -type KeyRangeQuery -o key_range_query_bson.go
@ -107,6 +111,7 @@ type EntityIdsQuery struct {
EntityKeyspaceIDs []EntityId
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
//go:generate bsongen -file $GOFILE -type EntityIdsQuery -o entity_ids_query_bson.go
@ -123,11 +128,12 @@ type QueryResult struct {
// BatchQueryShard represents a batch query request
// for the specified shards.
type BatchQueryShard struct {
Queries []tproto.BoundQuery
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
Queries []tproto.BoundQuery
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
//go:generate bsongen -file $GOFILE -type BatchQueryShard -o batch_query_shard_bson.go
@ -135,11 +141,12 @@ type BatchQueryShard struct {
// KeyspaceIdBatchQuery represents a batch query request
// for the specified keyspace IDs.
type KeyspaceIdBatchQuery struct {
Queries []tproto.BoundQuery
Keyspace string
KeyspaceIds []key.KeyspaceId
TabletType topo.TabletType
Session *Session
Queries []tproto.BoundQuery
Keyspace string
KeyspaceIds []key.KeyspaceId
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
//go:generate bsongen -file $GOFILE -type KeyspaceIdBatchQuery -o keyspace_id_batch_query_bson.go

Просмотреть файл

@ -92,22 +92,24 @@ func TestSession(t *testing.T) {
}
type reflectQueryShard struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
Sql string
BindVariables map[string]interface{}
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
type extraQueryShard struct {
Extra int
Sql string
BindVariables map[string]interface{}
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
Extra int
Sql string
BindVariables map[string]interface{}
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
func TestQueryShard(t *testing.T) {
@ -231,20 +233,22 @@ type reflectBoundQuery struct {
}
type reflectBatchQueryShard struct {
Queries []reflectBoundQuery
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
Queries []reflectBoundQuery
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
type extraBatchQueryShard struct {
Extra int
Queries []reflectBoundQuery
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
Extra int
Queries []reflectBoundQuery
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
func TestBatchQueryShard(t *testing.T) {
@ -312,11 +316,12 @@ func TestBatchQueryShard(t *testing.T) {
}
type badTypeBatchQueryShard struct {
Queries string
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
Queries string
Keyspace string
Shards []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
func TestBatchQueryShardBadType(t *testing.T) {
@ -404,22 +409,24 @@ func TestQueryResultList(t *testing.T) {
}
type reflectKeyspaceIdQuery struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyspaceIds kproto.KeyspaceIdArray
TabletType topo.TabletType
Session *Session
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyspaceIds kproto.KeyspaceIdArray
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
type extraKeyspaceIdQuery struct {
Extra int
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyspaceIds []kproto.KeyspaceId
TabletType topo.TabletType
Session *Session
Extra int
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyspaceIds []kproto.KeyspaceId
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
func TestKeyspaceIdQuery(t *testing.T) {
@ -474,22 +481,24 @@ func TestKeyspaceIdQuery(t *testing.T) {
}
type reflectKeyRangeQuery struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRanges kproto.KeyRangeArray
TabletType topo.TabletType
Session *Session
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRanges kproto.KeyRangeArray
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
type extraKeyRangeQuery struct {
Extra int
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRanges []kproto.KeyRange
TabletType topo.TabletType
Session *Session
Extra int
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRanges []kproto.KeyRange
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
func TestKeyRangeQuery(t *testing.T) {
@ -544,20 +553,22 @@ func TestKeyRangeQuery(t *testing.T) {
}
type reflectKeyspaceIdBatchQuery struct {
Queries []reflectBoundQuery
Keyspace string
KeyspaceIds []kproto.KeyspaceId
TabletType topo.TabletType
Session *Session
Queries []reflectBoundQuery
Keyspace string
KeyspaceIds []kproto.KeyspaceId
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
type extraKeyspaceIdBatchQuery struct {
Extra int
Queries []reflectBoundQuery
Keyspace string
KeyspaceIds []kproto.KeyspaceId
TabletType topo.TabletType
Session *Session
Extra int
Queries []reflectBoundQuery
Keyspace string
KeyspaceIds []kproto.KeyspaceId
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
func TestKeyspaceIdBatchQuery(t *testing.T) {
@ -625,11 +636,12 @@ func TestKeyspaceIdBatchQuery(t *testing.T) {
}
type badTypeKeyspaceIdsBatchQuery struct {
Queries string
Keyspace string
KeyspaceIds []string
TabletType topo.TabletType
Session *Session
Queries string
Keyspace string
KeyspaceIds []string
TabletType topo.TabletType
Session *Session
NotInTransaction bool
}
func TestKeyspaceIdsBatchQueryBadType(t *testing.T) {

Просмотреть файл

@ -73,7 +73,7 @@ func (res *Resolver) ExecuteKeyspaceIds(ctx context.Context, query *proto.Keyspa
query.TabletType,
query.KeyspaceIds)
}
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards)
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, query.NotInTransaction)
}
// ExecuteKeyRanges executes a non-streaming query based on KeyRanges.
@ -88,7 +88,7 @@ func (res *Resolver) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRange
query.TabletType,
query.KeyRanges)
}
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards)
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, query.NotInTransaction)
}
// Execute executes a non-streaming query based on shards resolved by given func.
@ -101,6 +101,7 @@ func (res *Resolver) Execute(
tabletType topo.TabletType,
session *proto.Session,
mapToShards func(string) (string, []string, error),
notInTransaction bool,
) (*mproto.QueryResult, error) {
keyspace, shards, err := mapToShards(keyspace)
if err != nil {
@ -114,7 +115,8 @@ func (res *Resolver) Execute(
keyspace,
shards,
tabletType,
NewSafeSession(session))
NewSafeSession(session),
notInTransaction)
if connError, ok := err.(*ShardConnError); ok && connError.Code == tabletconn.ERR_RETRY {
resharding := false
newKeyspace, newShards, err := mapToShards(keyspace)
@ -169,7 +171,8 @@ func (res *Resolver) ExecuteEntityIds(
bindVars,
query.Keyspace,
query.TabletType,
NewSafeSession(query.Session))
NewSafeSession(query.Session),
query.NotInTransaction)
if connError, ok := err.(*ShardConnError); ok && connError.Code == tabletconn.ERR_RETRY {
resharding := false
newKeyspace, newShardIDMap, err := mapEntityIdsToShards(
@ -219,7 +222,7 @@ func (res *Resolver) ExecuteBatchKeyspaceIds(ctx context.Context, query *proto.K
query.TabletType,
query.KeyspaceIds)
}
return res.ExecuteBatch(ctx, query.Queries, query.Keyspace, query.TabletType, query.Session, mapToShards)
return res.ExecuteBatch(ctx, query.Queries, query.Keyspace, query.TabletType, query.Session, mapToShards, query.NotInTransaction)
}
// ExecuteBatch executes a group of queries based on shards resolved by given func.
@ -231,6 +234,7 @@ func (res *Resolver) ExecuteBatch(
tabletType topo.TabletType,
session *proto.Session,
mapToShards func(string) (string, []string, error),
notInTransaction bool,
) (*tproto.QueryResultList, error) {
keyspace, shards, err := mapToShards(keyspace)
if err != nil {
@ -243,7 +247,8 @@ func (res *Resolver) ExecuteBatch(
keyspace,
shards,
tabletType,
NewSafeSession(session))
NewSafeSession(session),
notInTransaction)
if connError, ok := err.(*ShardConnError); ok && connError.Code == tabletconn.ERR_RETRY {
resharding := false
newKeyspace, newShards, err := mapToShards(keyspace)
@ -288,7 +293,7 @@ func (res *Resolver) StreamExecuteKeyspaceIds(ctx context.Context, query *proto.
query.TabletType,
query.KeyspaceIds)
}
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply)
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply, query.NotInTransaction)
}
// StreamExecuteKeyRanges executes a streaming query on the specified KeyRanges.
@ -307,7 +312,7 @@ func (res *Resolver) StreamExecuteKeyRanges(ctx context.Context, query *proto.Ke
query.TabletType,
query.KeyRanges)
}
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply)
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply, query.NotInTransaction)
}
// StreamExecute executes a streaming query on shards resolved by given func.
@ -323,6 +328,7 @@ func (res *Resolver) StreamExecute(
session *proto.Session,
mapToShards func(string) (string, []string, error),
sendReply func(*mproto.QueryResult) error,
notInTransaction bool,
) error {
keyspace, shards, err := mapToShards(keyspace)
if err != nil {
@ -336,7 +342,8 @@ func (res *Resolver) StreamExecute(
shards,
tabletType,
NewSafeSession(session),
sendReply)
sendReply,
notInTransaction)
return err
}

Просмотреть файл

@ -102,6 +102,7 @@ func (rtr *Router) Execute(ctx context.Context, query *proto.Query) (*mproto.Que
params.shardVars,
query.TabletType,
NewSafeSession(vcursor.query.Session),
query.NotInTransaction,
)
}
@ -140,6 +141,7 @@ func (rtr *Router) StreamExecute(ctx context.Context, query *proto.Query, sendRe
query.TabletType,
NewSafeSession(vcursor.query.Session),
sendReply,
query.NotInTransaction,
)
}
@ -250,7 +252,8 @@ func (rtr *Router) execUpdateEqual(vcursor *requestContext, plan *planbuilder.Pl
ks,
[]string{shard},
vcursor.query.TabletType,
NewSafeSession(vcursor.query.Session))
NewSafeSession(vcursor.query.Session),
vcursor.query.NotInTransaction)
}
func (rtr *Router) execDeleteEqual(vcursor *requestContext, plan *planbuilder.Plan) (*mproto.QueryResult, error) {
@ -280,7 +283,8 @@ func (rtr *Router) execDeleteEqual(vcursor *requestContext, plan *planbuilder.Pl
ks,
[]string{shard},
vcursor.query.TabletType,
NewSafeSession(vcursor.query.Session))
NewSafeSession(vcursor.query.Session),
vcursor.query.NotInTransaction)
}
func (rtr *Router) execInsertSharded(vcursor *requestContext, plan *planbuilder.Plan) (*mproto.QueryResult, error) {
@ -318,7 +322,8 @@ func (rtr *Router) execInsertSharded(vcursor *requestContext, plan *planbuilder.
ks,
[]string{shard},
vcursor.query.TabletType,
NewSafeSession(vcursor.query.Session))
NewSafeSession(vcursor.query.Session),
vcursor.query.NotInTransaction)
if err != nil {
return nil, fmt.Errorf("execInsertSharded: %v", err)
}
@ -421,7 +426,8 @@ func (rtr *Router) deleteVindexEntries(vcursor *requestContext, plan *planbuilde
ks,
[]string{shard},
vcursor.query.TabletType,
NewSafeSession(vcursor.query.Session))
NewSafeSession(vcursor.query.Session),
vcursor.query.NotInTransaction)
if err != nil {
return err
}

Просмотреть файл

@ -127,6 +127,7 @@ func (stc *ScatterConn) Execute(
shards []string,
tabletType topo.TabletType,
session *SafeSession,
notInTransaction bool,
) (*mproto.QueryResult, error) {
results, allErrors := stc.multiGo(
context,
@ -135,6 +136,7 @@ func (stc *ScatterConn) Execute(
shards,
tabletType,
session,
notInTransaction,
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
innerqr, err := sdc.Execute(context, query, bindVars, transactionId)
if err != nil {
@ -165,6 +167,7 @@ func (stc *ScatterConn) ExecuteMulti(
shardVars map[string]map[string]interface{},
tabletType topo.TabletType,
session *SafeSession,
notInTransaction bool,
) (*mproto.QueryResult, error) {
results, allErrors := stc.multiGo(
context,
@ -173,6 +176,7 @@ func (stc *ScatterConn) ExecuteMulti(
getShards(shardVars),
tabletType,
session,
notInTransaction,
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
innerqr, err := sdc.Execute(context, query, shardVars[sdc.shard], transactionId)
if err != nil {
@ -202,6 +206,7 @@ func (stc *ScatterConn) ExecuteEntityIds(
keyspace string,
tabletType topo.TabletType,
session *SafeSession,
notInTransaction bool,
) (*mproto.QueryResult, error) {
results, allErrors := stc.multiGo(
context,
@ -210,6 +215,7 @@ func (stc *ScatterConn) ExecuteEntityIds(
shards,
tabletType,
session,
notInTransaction,
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
shard := sdc.shard
sql := sqls[shard]
@ -241,6 +247,7 @@ func (stc *ScatterConn) ExecuteBatch(
shards []string,
tabletType topo.TabletType,
session *SafeSession,
notInTransaction bool,
) (qrs *tproto.QueryResultList, err error) {
results, allErrors := stc.multiGo(
context,
@ -249,6 +256,7 @@ func (stc *ScatterConn) ExecuteBatch(
shards,
tabletType,
session,
notInTransaction,
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
innerqrs, err := sdc.ExecuteBatch(context, queries, transactionId)
if err != nil {
@ -282,6 +290,7 @@ func (stc *ScatterConn) StreamExecute(
tabletType topo.TabletType,
session *SafeSession,
sendReply func(reply *mproto.QueryResult) error,
notInTransaction bool,
) error {
results, allErrors := stc.multiGo(
context,
@ -290,6 +299,7 @@ func (stc *ScatterConn) StreamExecute(
shards,
tabletType,
session,
notInTransaction,
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
sr, errFunc := sdc.StreamExecute(context, query, bindVars, transactionId)
if sr != nil {
@ -333,6 +343,7 @@ func (stc *ScatterConn) StreamExecuteMulti(
tabletType topo.TabletType,
session *SafeSession,
sendReply func(reply *mproto.QueryResult) error,
notInTransaction bool,
) error {
results, allErrors := stc.multiGo(
context,
@ -341,6 +352,7 @@ func (stc *ScatterConn) StreamExecuteMulti(
getShards(shardVars),
tabletType,
session,
notInTransaction,
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
sr, errFunc := sdc.StreamExecute(context, query, shardVars[sdc.shard], transactionId)
if sr != nil {
@ -446,7 +458,7 @@ func (stc *ScatterConn) SplitQuery(ctx context.Context, query tproto.BoundQuery,
for shard := range keyRangeByShard {
shards = append(shards, shard)
}
allSplits, allErrors := stc.multiGo(ctx, "SplitQuery", keyspace, shards, topo.TYPE_RDONLY, NewSafeSession(&proto.Session{}), actionFunc)
allSplits, allErrors := stc.multiGo(ctx, "SplitQuery", keyspace, shards, topo.TYPE_RDONLY, NewSafeSession(&proto.Session{}), false, actionFunc)
splits := []proto.SplitQueryPart{}
for s := range allSplits {
splits = append(splits, s.([]proto.SplitQueryPart)...)
@ -512,6 +524,7 @@ func (stc *ScatterConn) multiGo(
shards []string,
tabletType topo.TabletType,
session *SafeSession,
notInTransaction bool,
action shardActionFunc,
) (rResults <-chan interface{}, allErrors *concurrency.AllErrorRecorder) {
allErrors = new(concurrency.AllErrorRecorder)
@ -526,7 +539,7 @@ func (stc *ScatterConn) multiGo(
defer stc.timings.Record(statsKey, startTime)
sdc := stc.getConnection(context, keyspace, shard, tabletType)
transactionID, err := stc.updateSession(context, sdc, keyspace, shard, tabletType, session)
transactionID, err := stc.updateSession(context, sdc, keyspace, shard, tabletType, session, notInTransaction)
if err != nil {
allErrors.RecordError(err)
stc.tabletCallErrorCount.Add(statsKey, 1)
@ -577,6 +590,7 @@ func (stc *ScatterConn) updateSession(
keyspace, shard string,
tabletType topo.TabletType,
session *SafeSession,
notInTransaction bool,
) (transactionID int64, err error) {
if !session.InTransaction() {
return 0, nil
@ -589,6 +603,12 @@ func (stc *ScatterConn) updateSession(
if transactionID != 0 {
return transactionID, nil
}
// We are in a transaction at higher level,
// but client requires not to start a transaction for this query.
// If a transaction was started on this conn, we will use it (as above).
if notInTransaction {
return 0, nil
}
transactionID, err = sdc.Begin(context)
if err != nil {
return 0, err

Просмотреть файл

@ -22,7 +22,7 @@ import (
func TestScatterConnExecute(t *testing.T) {
testScatterConnGeneric(t, "TestScatterConnExecute", func(shards []string) (*mproto.QueryResult, error) {
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
return stc.Execute(context.Background(), "query", nil, "TestScatterConnExecute", shards, "", nil)
return stc.Execute(context.Background(), "query", nil, "TestScatterConnExecute", shards, "", nil, false)
})
}
@ -33,7 +33,7 @@ func TestScatterConnExecuteMulti(t *testing.T) {
for _, shard := range shards {
shardVars[shard] = nil
}
return stc.ExecuteMulti(context.Background(), "query", "TestScatterConnExecute", shardVars, "", nil)
return stc.ExecuteMulti(context.Background(), "query", "TestScatterConnExecute", shardVars, "", nil, false)
})
}
@ -41,7 +41,7 @@ func TestScatterConnExecuteBatch(t *testing.T) {
testScatterConnGeneric(t, "TestScatterConnExecuteBatch", func(shards []string) (*mproto.QueryResult, error) {
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
queries := []tproto.BoundQuery{{"query", nil}}
qrs, err := stc.ExecuteBatch(context.Background(), queries, "TestScatterConnExecuteBatch", shards, "", nil)
qrs, err := stc.ExecuteBatch(context.Background(), queries, "TestScatterConnExecuteBatch", shards, "", nil, false)
if err != nil {
return nil, err
}
@ -56,7 +56,7 @@ func TestScatterConnStreamExecute(t *testing.T) {
err := stc.StreamExecute(context.Background(), "query", nil, "TestScatterConnStreamExecute", shards, "", nil, func(r *mproto.QueryResult) error {
appendResult(qr, r)
return nil
})
}, false)
return qr, err
})
}
@ -72,7 +72,7 @@ func TestScatterConnStreamExecuteMulti(t *testing.T) {
err := stc.StreamExecuteMulti(context.Background(), "query", "TestScatterConnStreamExecute", shardVars, "", nil, func(r *mproto.QueryResult) error {
appendResult(qr, r)
return nil
})
}, false)
return qr, err
})
}
@ -173,7 +173,7 @@ func TestMultiExecs(t *testing.T) {
"bv1": 1,
},
}
_, _ = stc.ExecuteMulti(context.Background(), "query", "TestMultiExecs", shardVars, "", nil)
_, _ = stc.ExecuteMulti(context.Background(), "query", "TestMultiExecs", shardVars, "", nil, false)
if !reflect.DeepEqual(sbc0.Queries[0].BindVariables, shardVars["0"]) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, shardVars["0"])
}
@ -184,7 +184,7 @@ func TestMultiExecs(t *testing.T) {
sbc1.Queries = nil
_ = stc.StreamExecuteMulti(context.Background(), "query", "TestMultiExecs", shardVars, "", nil, func(*mproto.QueryResult) error {
return nil
})
}, false)
if !reflect.DeepEqual(sbc0.Queries[0].BindVariables, shardVars["0"]) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, shardVars["0"])
}
@ -200,7 +200,7 @@ func TestScatterConnStreamExecuteSendError(t *testing.T) {
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
err := stc.StreamExecute(context.Background(), "query", nil, "TestScatterConnStreamExecuteSendError", []string{"0"}, "", nil, func(*mproto.QueryResult) error {
return fmt.Errorf("send error")
})
}, false)
want := "send error"
// Ensure that we handle send errors.
if err == nil || err.Error() != want {
@ -241,7 +241,7 @@ func TestScatterConnCommitSuccess(t *testing.T) {
// Sequence the executes to ensure commit order
session := NewSafeSession(&proto.Session{InTransaction: true})
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0"}, "", session)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0"}, "", session, false)
wantSession := proto.Session{
InTransaction: true,
ShardSessions: []*proto.ShardSession{{
@ -254,7 +254,7 @@ func TestScatterConnCommitSuccess(t *testing.T) {
if !reflect.DeepEqual(wantSession, *session.Session) {
t.Errorf("want\n%+v, got\n%+v", wantSession, *session.Session)
}
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0", "1"}, "", session)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0", "1"}, "", session, false)
wantSession = proto.Session{
InTransaction: true,
ShardSessions: []*proto.ShardSession{{
@ -299,8 +299,8 @@ func TestScatterConnRollback(t *testing.T) {
// Sequence the executes to ensure commit order
session := NewSafeSession(&proto.Session{InTransaction: true})
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0"}, "", session)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0", "1"}, "", session)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0"}, "", session, false)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0", "1"}, "", session, false)
err := stc.Rollback(context.Background(), session)
if err != nil {
t.Errorf("want nil, got %v", err)
@ -322,7 +322,7 @@ func TestScatterConnClose(t *testing.T) {
sbc := &sandboxConn{}
s.MapTestConn("0", sbc)
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnClose", []string{"0"}, "", nil)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnClose", []string{"0"}, "", nil, false)
stc.Close()
time.Sleep(1)
if sbc.CloseCount != 1 {
@ -330,6 +330,111 @@ func TestScatterConnClose(t *testing.T) {
}
}
func TestScatterConnQueryNotInTransaction(t *testing.T) {
s := createSandbox("TestScatterConnQueryNotInTransaction")
// case 1: read query (not in transaction) followed by write query, not in the same shard.
sbc0 := &sandboxConn{}
s.MapTestConn("0", sbc0)
sbc1 := &sandboxConn{}
s.MapTestConn("1", sbc1)
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
session := NewSafeSession(&proto.Session{InTransaction: true})
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0"}, "", session, true)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"1"}, "", session, false)
wantSession := proto.Session{
InTransaction: true,
ShardSessions: []*proto.ShardSession{{
Keyspace: "TestScatterConnQueryNotInTransaction",
Shard: "1",
TabletType: "",
TransactionId: 1,
}},
}
if !reflect.DeepEqual(wantSession, *session.Session) {
t.Errorf("want\n%+v\ngot\n%+v", wantSession, *session.Session)
}
stc.Commit(context.Background(), session)
if sbc0.ExecCount != 1 || sbc1.ExecCount != 3 {
t.Errorf("want 1/3, got %d/%d", sbc0.ExecCount, sbc1.ExecCount)
}
if sbc0.CommitCount != 0 {
t.Errorf("want 0, got %d", sbc0.CommitCount)
}
if sbc1.CommitCount != 1 {
t.Errorf("want 1, got %d", sbc1.CommitCount)
}
// case 2: write query followed by read query (not in transaction), not in the same shard.
s.Reset()
sbc0 = &sandboxConn{}
s.MapTestConn("0", sbc0)
sbc1 = &sandboxConn{}
s.MapTestConn("1", sbc1)
stc = NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
session = NewSafeSession(&proto.Session{InTransaction: true})
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0"}, "", session, false)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"1"}, "", session, true)
wantSession = proto.Session{
InTransaction: true,
ShardSessions: []*proto.ShardSession{{
Keyspace: "TestScatterConnQueryNotInTransaction",
Shard: "0",
TabletType: "",
TransactionId: 1,
}},
}
if !reflect.DeepEqual(wantSession, *session.Session) {
t.Errorf("want\n%+v\ngot\n%+v", wantSession, *session.Session)
}
stc.Commit(context.Background(), session)
if sbc0.ExecCount != 3 || sbc1.ExecCount != 1 {
t.Errorf("want 3/1, got %d/%d", sbc0.ExecCount, sbc1.ExecCount)
}
if sbc0.CommitCount != 1 {
t.Errorf("want 1, got %d", sbc0.CommitCount)
}
if sbc1.CommitCount != 0 {
t.Errorf("want 0, got %d", sbc1.CommitCount)
}
// case 3: write query followed by read query, in the same shard.
s.Reset()
sbc0 = &sandboxConn{}
s.MapTestConn("0", sbc0)
sbc1 = &sandboxConn{}
s.MapTestConn("1", sbc1)
stc = NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
session = NewSafeSession(&proto.Session{InTransaction: true})
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0"}, "", session, false)
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0", "1"}, "", session, true)
wantSession = proto.Session{
InTransaction: true,
ShardSessions: []*proto.ShardSession{{
Keyspace: "TestScatterConnQueryNotInTransaction",
Shard: "0",
TabletType: "",
TransactionId: 1,
}},
}
if !reflect.DeepEqual(wantSession, *session.Session) {
t.Errorf("want\n%+v\ngot\n%+v", wantSession, *session.Session)
}
stc.Commit(context.Background(), session)
if sbc0.ExecCount != 4 || sbc1.ExecCount != 1 {
t.Errorf("want 4/1, got %d/%d", sbc0.ExecCount, sbc1.ExecCount)
}
if sbc0.CommitCount != 1 {
t.Errorf("want 1, got %d", sbc0.CommitCount)
}
if sbc1.CommitCount != 0 {
t.Errorf("want 0, got %d", sbc1.CommitCount)
}
}
func TestAppendResult(t *testing.T) {
qr := new(mproto.QueryResult)
innerqr1 := &mproto.QueryResult{

Просмотреть файл

@ -57,7 +57,10 @@ func NewShardConn(ctx context.Context, serv SrvTopoServer, cell, keyspace, shard
return endpoints, nil
}
blc := NewBalancer(getAddresses, retryDelay)
ticker := timer.NewRandTicker(connLife, connLife/2)
var ticker *timer.RandTicker
if tabletType != topo.TYPE_MASTER {
ticker = timer.NewRandTicker(connLife, connLife/2)
}
sdc := &ShardConn{
keyspace: keyspace,
shard: shard,
@ -72,11 +75,13 @@ func NewShardConn(ctx context.Context, serv SrvTopoServer, cell, keyspace, shard
consolidator: sync2.NewConsolidator(),
connectTimings: tabletConnectTimings,
}
go func() {
for range ticker.C {
sdc.closeCurrent()
}
}()
if ticker != nil {
go func() {
for range ticker.C {
sdc.closeCurrent()
}
}()
}
return sdc
}
@ -180,7 +185,9 @@ func (sdc *ShardConn) SplitQuery(ctx context.Context, query tproto.BoundQuery, s
// Close closes the underlying TabletConn.
func (sdc *ShardConn) Close() {
sdc.ticker.Stop()
if sdc.ticker != nil {
sdc.ticker.Stop()
}
sdc.closeCurrent()
}

Просмотреть файл

@ -750,13 +750,14 @@ func TestShardConnReconnect(t *testing.T) {
}
}
func TestShardConnLife(t *testing.T) {
func TestReplicaShardConnLife(t *testing.T) {
// auto-reconnect for non-master
retryDelay := 10 * time.Millisecond
retryCount := 5
s := createSandbox("TestShardConnReconnect")
s := createSandbox("TestReplicaShardConnLife")
sbc := &sandboxConn{}
s.MapTestConn("0", sbc)
sdc := NewShardConn(context.Background(), new(sandboxTopo), "aa", "TestShardConnReconnect", "0", "", retryDelay, retryCount, connTimeoutTotal, connTimeoutPerConn, 10*time.Millisecond, connectTimings)
sdc := NewShardConn(context.Background(), new(sandboxTopo), "aa", "TestReplicaShardConnLife", "0", topo.TYPE_REPLICA, retryDelay, retryCount, connTimeoutTotal, connTimeoutPerConn, 10*time.Millisecond, connectTimings)
sdc.Execute(context.Background(), "query", nil, 0)
if s.DialCounter != 1 {
t.Errorf("DialCounter: %d, want 1", s.DialCounter)
@ -766,4 +767,25 @@ func TestShardConnLife(t *testing.T) {
if s.DialCounter != 2 {
t.Errorf("DialCounter: %d, want 2", s.DialCounter)
}
sdc.Close()
}
func TestMasterShardConnLife(t *testing.T) {
// Do not auto-reconnect for master
retryDelay := 10 * time.Millisecond
retryCount := 5
s := createSandbox("TestMasterShardConnLife")
sbc := &sandboxConn{}
s.MapTestConn("0", sbc)
sdc := NewShardConn(context.Background(), new(sandboxTopo), "aa", "TestMasterShardConnLife", "0", topo.TYPE_MASTER, retryDelay, retryCount, connTimeoutTotal, connTimeoutPerConn, 10*time.Millisecond, connectTimings)
sdc.Execute(context.Background(), "query", nil, 0)
if s.DialCounter != 1 {
t.Errorf("DialCounter: %d, want 1", s.DialCounter)
}
time.Sleep(20 * time.Millisecond)
sdc.Execute(context.Background(), "query", nil, 0)
if s.DialCounter != 1 {
t.Errorf("DialCounter: %d, want 1", s.DialCounter)
}
sdc.Close()
}

Просмотреть файл

@ -20,7 +20,7 @@ import (
func TestExecuteKeyspaceAlias(t *testing.T) {
testVerticalSplitGeneric(t, false, func(shards []string) (*mproto.QueryResult, error) {
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
return stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil)
return stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil, false)
})
}
@ -28,7 +28,7 @@ func TestBatchExecuteKeyspaceAlias(t *testing.T) {
testVerticalSplitGeneric(t, false, func(shards []string) (*mproto.QueryResult, error) {
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
queries := []tproto.BoundQuery{{"query", nil}}
qrs, err := stc.ExecuteBatch(context.Background(), queries, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil)
qrs, err := stc.ExecuteBatch(context.Background(), queries, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil, false)
if err != nil {
return nil, err
}
@ -43,7 +43,7 @@ func TestStreamExecuteKeyspaceAlias(t *testing.T) {
err := stc.StreamExecute(context.Background(), "query", nil, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil, func(r *mproto.QueryResult) error {
appendResult(qr, r)
return nil
})
}, false)
return qr, err
})
}
@ -63,7 +63,7 @@ func TestInTransactionKeyspaceAlias(t *testing.T) {
TransactionId: 1,
}},
})
_, err := stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, []string{"0"}, topo.TYPE_MASTER, session)
_, err := stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, []string{"0"}, topo.TYPE_MASTER, session, false)
want := "shard, host: TestUnshardedServedFrom.0.master, {Uid:0 Host:0 NamedPortMap:map[vt:1] Health:map[]}, retry: err"
if err == nil || err.Error() != want {
t.Errorf("want '%v', got '%v'", want, err)

Просмотреть файл

@ -188,6 +188,7 @@ func (vtg *VTGate) ExecuteShard(ctx context.Context, query *proto.QueryShard, re
func(keyspace string) (string, []string, error) {
return query.Keyspace, query.Shards, nil
},
query.NotInTransaction,
)
if err == nil {
reply.Result = qr
@ -289,6 +290,7 @@ func (vtg *VTGate) ExecuteBatchShard(ctx context.Context, batchQuery *proto.Batc
func(keyspace string) (string, []string, error) {
return batchQuery.Keyspace, batchQuery.Shards, nil
},
batchQuery.NotInTransaction,
)
if err == nil {
reply.List = qrs.List
@ -484,7 +486,8 @@ func (vtg *VTGate) StreamExecuteShard(ctx context.Context, query *proto.QuerySha
// Note we don't populate reply.Session here,
// as it may change incrementaly as responses are sent.
return sendReply(reply)
})
},
query.NotInTransaction)
vtg.rowsReturned.Add(statsKey, rowCount)
if err != nil {

Просмотреть файл

@ -168,7 +168,7 @@ func (wr *Wrangler) MigrateServedTypes(ctx context.Context, keyspace, shard stri
// rebuild the keyspace serving graph if there was no error
if !rec.HasErrors() {
rec.RecordError(wr.RebuildKeyspaceGraph(ctx, keyspace, nil))
rec.RecordError(wr.RebuildKeyspaceGraph(ctx, keyspace, cells))
}
// Send a refresh to the tablets we just disabled, iff:

Просмотреть файл

@ -10,8 +10,12 @@ class BindVarsProxy(object):
self.accessed_keys = set()
def __getitem__(self, name):
var = self.bind_vars[name]
self.bind_vars[name]
self.accessed_keys.add(name)
if isinstance(var, (list, set, tuple)):
return '::%s' % name
return ':%s' % name
def export_bind_vars(self):

Просмотреть файл

@ -93,7 +93,11 @@ def convert_bind_vars(bind_variables):
new_vars[key] = times.DateTimeToString(val)
elif isinstance(val, datetime.date):
new_vars[key] = times.DateToString(val)
elif isinstance(val, (int, long, float, str, List, NoneType)):
elif isinstance(val, set):
new_vars[key] = sorted(val)
elif isinstance(val, tuple):
new_vars[key] = list(val)
elif isinstance(val, (int, long, float, str, list, NoneType)):
new_vars[key] = val
else:
# NOTE(msolomon) begrudgingly I allow this - we just have too much code

Просмотреть файл

@ -85,7 +85,8 @@ class VTGateCursor(object):
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges)
keyranges=self.keyranges,
not_in_transaction=(not self.is_writable()))
self.index = 0
return self.rowcount
@ -106,7 +107,8 @@ class VTGateCursor(object):
self.keyspace,
self.tablet_type,
entity_keyspace_id_map,
entity_column_name)
entity_column_name,
not_in_transaction=(not self.is_writable()))
self.index = 0
return self.rowcount
@ -207,7 +209,8 @@ class BatchVTGateCursor(VTGateCursor):
self.bind_vars_list,
self.keyspace,
self.tablet_type,
self.keyspace_ids)
self.keyspace_ids,
not_in_transaction=(not self.is_writable()))
self.query_list = []
self.bind_vars_list = []
@ -236,7 +239,8 @@ class StreamVTGateCursor(VTGateCursor):
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges)
keyranges=self.keyranges,
not_in_transaction=(not self.is_writable()))
self.index = 0
return 0

Просмотреть файл

@ -73,7 +73,7 @@ def convert_exception(exc, *args, **kwargs):
return new_exc
def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspace_ids):
def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspace_ids, not_in_transaction):
# keyspace_ids are Keyspace Ids packed to byte[]
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
new_binds = field_types.convert_bind_vars(new_binds)
@ -83,11 +83,12 @@ def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspac
'Keyspace': keyspace,
'TabletType': tablet_type,
'KeyspaceIds': keyspace_ids,
'NotInTransaction': not_in_transaction,
}
return req
def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges):
def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges, not_in_transaction):
# keyranges are keyspace.KeyRange objects with start/end packed to byte[]
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
new_binds = field_types.convert_bind_vars(new_binds)
@ -97,6 +98,7 @@ def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges)
'Keyspace': keyspace,
'TabletType': tablet_type,
'KeyRanges': keyranges,
'NotInTransaction': not_in_transaction,
}
return req
@ -180,14 +182,14 @@ class VTGateConnection(object):
self.session = response.reply['Session']
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
def _execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None):
def _execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None, not_in_transaction=False):
exec_method = None
req = None
if keyspace_ids is not None:
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids)
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
exec_method = 'VTGate.ExecuteKeyspaceIds'
elif keyranges is not None:
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges)
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
exec_method = 'VTGate.ExecuteKeyRanges'
else:
raise dbexceptions.ProgrammingError('_execute called without specifying keyspace_ids or keyranges')
@ -227,7 +229,7 @@ class VTGateConnection(object):
return results, rowcount, lastrowid, fields
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
def _execute_entity_ids(self, sql, bind_variables, keyspace, tablet_type, entity_keyspace_id_map, entity_column_name):
def _execute_entity_ids(self, sql, bind_variables, keyspace, tablet_type, entity_keyspace_id_map, entity_column_name, not_in_transaction=False):
sql, new_binds = dbapi.prepare_query_bind_vars(sql, bind_variables)
new_binds = field_types.convert_bind_vars(new_binds)
req = {
@ -239,6 +241,7 @@ class VTGateConnection(object):
{'ExternalID': xid, 'KeyspaceID': kid}
for xid, kid in entity_keyspace_id_map.iteritems()],
'EntityColumnName': entity_column_name,
'NotInTransaction': not_in_transaction,
}
self._add_session(req)
@ -275,7 +278,7 @@ class VTGateConnection(object):
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
def _execute_batch(self, sql_list, bind_variables_list, keyspace, tablet_type, keyspace_ids):
def _execute_batch(self, sql_list, bind_variables_list, keyspace, tablet_type, keyspace_ids, not_in_transaction=False):
query_list = []
for sql, bind_vars in zip(sql_list, bind_variables_list):
sql, bind_vars = dbapi.prepare_query_bind_vars(sql, bind_vars)
@ -292,6 +295,7 @@ class VTGateConnection(object):
'Keyspace': keyspace,
'TabletType': tablet_type,
'KeyspaceIds': keyspace_ids,
'NotInTransaction': not_in_transaction,
}
self._add_session(req)
response = self.client.call('VTGate.ExecuteBatchKeyspaceIds', req)
@ -327,14 +331,14 @@ class VTGateConnection(object):
# the conversions will need to be passed back to _stream_next
# (that way we avoid using a member variable here for such a corner case)
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
def _stream_execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None):
def _stream_execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None, not_in_transaction=False):
exec_method = None
req = None
if keyspace_ids is not None:
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids)
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
exec_method = 'VTGate.StreamExecuteKeyspaceIds'
elif keyranges is not None:
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges)
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
exec_method = 'VTGate.StreamExecuteKeyRanges'
else:
raise dbexceptions.ProgrammingError('_stream_execute called without specifying keyspace_ids or keyranges')

Просмотреть файл

@ -76,12 +76,13 @@ def convert_exception(exc, *args):
return exc
def _create_req(sql, new_binds, tablet_type):
def _create_req(sql, new_binds, tablet_type, not_in_transaction):
new_binds = field_types.convert_bind_vars(new_binds)
req = {
'Sql': sql,
'BindVariables': new_binds,
'TabletType': tablet_type,
'NotInTransaction': not_in_transaction,
}
return req
@ -160,8 +161,8 @@ class VTGateConnection(object):
if 'Session' in response.reply and response.reply['Session']:
self.session = response.reply['Session']
def _execute(self, sql, bind_variables, tablet_type):
req = _create_req(sql, bind_variables, tablet_type)
def _execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
self._add_session(req)
fields = []
@ -196,7 +197,7 @@ class VTGateConnection(object):
return results, rowcount, lastrowid, fields
def _execute_batch(self, sql_list, bind_variables_list, tablet_type):
def _execute_batch(self, sql_list, bind_variables_list, tablet_type, not_in_transaction=False):
query_list = []
for sql, bind_vars in zip(sql_list, bind_variables_list):
query = {}
@ -210,6 +211,7 @@ class VTGateConnection(object):
req = {
'Queries': query_list,
'TabletType': tablet_type,
'NotInTransaction': not_in_transaction,
}
self._add_session(req)
response = self.client.call('VTGate.ExecuteBatch', req)
@ -243,8 +245,8 @@ class VTGateConnection(object):
# we return the fields for the response, and the column conversions
# the conversions will need to be passed back to _stream_next
# (that way we avoid using a member variable here for such a corner case)
def _stream_execute(self, sql, bind_variables, tablet_type):
req = _create_req(sql, bind_variables, tablet_type)
def _stream_execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
self._add_session(req)
self._stream_fields = []

Просмотреть файл

@ -5,6 +5,7 @@ import hmac
import json
import logging
import os
import pprint
import struct
import threading
import time
@ -69,10 +70,30 @@ keyspace_id bigint(20) unsigned NOT NULL,
primary key(eid, id)
) Engine=InnoDB'''
create_tables = [create_vt_insert_test, create_vt_a]
create_vt_field_types = '''create table vt_field_types (
id bigint(20) auto_increment,
uint_val bigint(20) unsigned,
str_val varchar(64),
unicode_val varchar(64),
float_val float(5, 1),
keyspace_id bigint(20) unsigned NOT NULL,
primary key(id)
) Engine=InnoDB'''
create_tables = [create_vt_insert_test, create_vt_a, create_vt_field_types]
pack_kid = struct.Struct('!Q').pack
class DBRow(object):
def __init__(self, column_names, row_tuple):
self.__dict__ = dict(zip(column_names, row_tuple))
def __repr__(self):
return pprint.pformat(self.__dict__, 4)
def setUpModule():
logging.debug("in setUpModule")
try:
@ -571,6 +592,91 @@ class TestVTGateFunctions(unittest.TestCase):
except Exception, e:
self.fail("Failed with error %s %s" % (str(e), traceback.print_exc()))
def test_field_types(self):
try:
vtgate_conn = get_connection()
_delete_all(self.shard_index, 'vt_field_types')
count = 10
base_uint = int('8'+'0'*15, base=16)
kid_list = shard_kid_map[shard_names[self.shard_index]]
for x in xrange(1, count):
keyspace_id = kid_list[count%len(kid_list)]
cursor = vtgate_conn.cursor(KEYSPACE_NAME, 'master',
keyspace_ids=[pack_kid(keyspace_id)],
writable=True)
cursor.begin()
cursor.execute(
"insert into vt_field_types "
"(uint_val, str_val, unicode_val, float_val, keyspace_id) "
"values (%(uint_val)s, %(str_val)s, %(unicode_val)s, "
"%(float_val)s, %(keyspace_id)s)",
{'uint_val': base_uint + x, 'str_val': 'str_%d' % x,
'unicode_val': unicode('str_%d' % x), 'float_val': x*1.2,
'keyspace_id': keyspace_id})
cursor.commit()
cursor = vtgate_conn.cursor(KEYSPACE_NAME, 'master',
keyranges=[self.keyrange])
rowcount = cursor.execute("select * from vt_field_types", {})
field_names = [f[0] for f in cursor.description]
self.assertEqual(rowcount, count -1, "rowcount doesn't match")
id_list = []
uint_val_list = []
str_val_list = []
unicode_val_list = []
float_val_list = []
for r in cursor.results:
row = DBRow(field_names, r)
id_list.append(row.id)
uint_val_list.append(row.uint_val)
str_val_list.append(row.str_val)
unicode_val_list.append(row.unicode_val)
float_val_list.append(row.float_val)
# iterable type checks - list, tuple, set are supported.
query = "select * from vt_field_types where id in %(id_1)s"
rowcount = cursor.execute(query, {'id_1': id_list})
self.assertEqual(rowcount, len(id_list), "rowcount doesn't match")
rowcount = cursor.execute(query, {'id_1': tuple(id_list)})
self.assertEqual(rowcount, len(id_list), "rowcount doesn't match")
rowcount = cursor.execute(query, {'id_1': set(id_list)})
self.assertEqual(rowcount, len(id_list), "rowcount doesn't match")
for i, r in enumerate(cursor.results):
row = DBRow(field_names, r)
self.assertIsInstance(row.id, (int, long))
# received field types same as input.
# uint
query = "select * from vt_field_types where uint_val in %(uint_val_1)s"
rowcount = cursor.execute(query, {'uint_val_1': uint_val_list})
self.assertEqual(rowcount, len(uint_val_list), "rowcount doesn't match")
for i, r in enumerate(cursor.results):
row = DBRow(field_names, r)
self.assertIsInstance(row.uint_val, long)
self.assertGreaterEqual(row.uint_val, base_uint, "uint value not in correct range")
# str
query = "select * from vt_field_types where str_val in %(str_val_1)s"
rowcount = cursor.execute(query, {'str_val_1': str_val_list})
self.assertEqual(rowcount, len(str_val_list), "rowcount doesn't match")
for i, r in enumerate(cursor.results):
row = DBRow(field_names, r)
self.assertIsInstance(row.str_val, str)
# unicode str
query = "select * from vt_field_types where unicode_val in %(unicode_val_1)s"
rowcount = cursor.execute(query, {'unicode_val_1': unicode_val_list})
self.assertEqual(rowcount, len(unicode_val_list), "rowcount doesn't match")
for i, r in enumerate(cursor.results):
row = DBRow(field_names, r)
self.assertIsInstance(row.unicode_val, (str,unicode))
# deliberately eliminating the float test since it is flaky due
# to mysql float precision handling.
except Exception, e:
logging.debug("Write failed with error %s" % str(e))
raise
def _query_lots(self,
conn,
query,
@ -607,6 +713,16 @@ class TestFailures(unittest.TestCase):
return tablet.start_vttablet(lameduck_period=lameduck_period)
# target_tablet_type=tablet_type)
def test_status_with_error(self):
"""Tests that the status page loads correctly after a VTGate error."""
vtgate_conn = get_connection()
cursor = vtgate_conn.cursor('INVALID_KEYSPACE', 'replica', keyspace_ids=['0'])
# We expect to see a DatabaseError due to an invalid keyspace
with self.assertRaises(dbexceptions.DatabaseError):
cursor.execute('select * from vt_insert_test', {})
# Page should have loaded successfully
self.assertIn('</html>', utils.get_status(vtgate_port))
def test_tablet_restart_read(self):
try:
vtgate_conn = get_connection()