зеркало из https://github.com/github/vitess-gh.git
Merge branch 'master' into resharding
This commit is contained in:
Коммит
364907137a
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"table1": {
|
||||
1:2
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"test_table": {
|
||||
"READER": "vt",
|
||||
"WRITER": "vt"
|
||||
}
|
||||
}
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/mysqlctl"
|
||||
"github.com/youtube/vitess/go/vt/servenv"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver"
|
||||
|
||||
// import mysql to register mysql connection function
|
||||
|
@ -76,6 +77,7 @@ func main() {
|
|||
log.Infof("schemaOverrides: %s\n", data)
|
||||
|
||||
if *tableAclConfig != "" {
|
||||
tableacl.Register("simpleacl", &simpleacl.Factory{})
|
||||
tableacl.Init(*tableAclConfig)
|
||||
}
|
||||
qsc := tabletserver.NewQueryServiceControl()
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/mysqlctl"
|
||||
"github.com/youtube/vitess/go/vt/servenv"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
|
||||
"github.com/youtube/vitess/go/vt/tabletmanager"
|
||||
"github.com/youtube/vitess/go/vt/tabletmanager/actionnode"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver"
|
||||
|
@ -85,6 +86,7 @@ func main() {
|
|||
dbcfgs.App.EnableRowcache = *enableRowcache
|
||||
|
||||
if *tableAclConfig != "" {
|
||||
tableacl.Register("simpleacl", &simpleacl.Factory{})
|
||||
tableacl.Init(*tableAclConfig)
|
||||
}
|
||||
|
||||
|
|
|
@ -38,3 +38,14 @@ func FromContext(ctx context.Context) (CallInfo, bool) {
|
|||
ci, ok := ctx.Value(callInfoKey).(CallInfo)
|
||||
return ci, ok
|
||||
}
|
||||
|
||||
// HTMLFromContext returns that value of HTML() from the context, or "" if we're
|
||||
// not able to recover one
|
||||
func HTMLFromContext(ctx context.Context) template.HTML {
|
||||
var h template.HTML
|
||||
ci, ok := FromContext(ctx)
|
||||
if ok {
|
||||
return ci.HTML()
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
|
|
@ -6,9 +6,7 @@ import (
|
|||
log "github.com/golang/glog"
|
||||
)
|
||||
|
||||
// ConsoleLogger is a Logger that uses glog directly to log.
|
||||
// We can't specify the depth of the stack trace,
|
||||
// So we just find it and add it to the message.
|
||||
// ConsoleLogger is a Logger that uses glog directly to log, at the right level.
|
||||
type ConsoleLogger struct{}
|
||||
|
||||
// NewConsoleLogger returns a simple ConsoleLogger
|
||||
|
@ -18,26 +16,17 @@ func NewConsoleLogger() ConsoleLogger {
|
|||
|
||||
// Infof is part of the Logger interface
|
||||
func (cl ConsoleLogger) Infof(format string, v ...interface{}) {
|
||||
file, line := fileAndLine(3)
|
||||
vals := []interface{}{file, line}
|
||||
vals = append(vals, v...)
|
||||
log.Infof("%v:%v] "+format, vals...)
|
||||
log.InfoDepth(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// Warningf is part of the Logger interface
|
||||
func (cl ConsoleLogger) Warningf(format string, v ...interface{}) {
|
||||
file, line := fileAndLine(3)
|
||||
vals := []interface{}{file, line}
|
||||
vals = append(vals, v...)
|
||||
log.Warningf("%v:%v] "+format, vals...)
|
||||
log.WarningDepth(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// Errorf is part of the Logger interface
|
||||
func (cl ConsoleLogger) Errorf(format string, v ...interface{}) {
|
||||
file, line := fileAndLine(3)
|
||||
vals := []interface{}{file, line}
|
||||
vals = append(vals, v...)
|
||||
log.Errorf("%v:%v] "+format, vals...)
|
||||
log.ErrorDepth(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// Printf is part of the Logger interface
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package logutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -29,12 +30,12 @@ func NewThrottledLogger(name string, maxInterval time.Duration) *ThrottledLogger
|
|||
}
|
||||
}
|
||||
|
||||
type logFunc func(string, ...interface{})
|
||||
type logFunc func(int, ...interface{})
|
||||
|
||||
var (
|
||||
infof = log.Infof
|
||||
warningf = log.Warningf
|
||||
errorf = log.Errorf
|
||||
infoDepth = log.InfoDepth
|
||||
warningDepth = log.WarningDepth
|
||||
errorDepth = log.ErrorDepth
|
||||
)
|
||||
|
||||
func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
|
||||
|
@ -45,7 +46,7 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
|
|||
logWaitTime := tl.maxInterval - (now.Sub(tl.lastlogTime))
|
||||
if logWaitTime < 0 {
|
||||
tl.lastlogTime = now
|
||||
logF(tl.name+":"+format, v...)
|
||||
logF(2, fmt.Sprintf(tl.name+":"+format, v...))
|
||||
return
|
||||
}
|
||||
// If this is the first message to be skipped, start a goroutine
|
||||
|
@ -55,7 +56,9 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
|
|||
time.Sleep(d)
|
||||
tl.mu.Lock()
|
||||
defer tl.mu.Unlock()
|
||||
logF("%v: skipped %v log messages", tl.name, tl.skippedCount)
|
||||
// Because of the go func(), we lose the stack trace,
|
||||
// so we just use the current line for this.
|
||||
logF(0, fmt.Sprintf("%v: skipped %v log messages", tl.name, tl.skippedCount))
|
||||
tl.skippedCount = 0
|
||||
}(logWaitTime)
|
||||
}
|
||||
|
@ -64,15 +67,15 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) {
|
|||
|
||||
// Infof logs an info if not throttled.
|
||||
func (tl *ThrottledLogger) Infof(format string, v ...interface{}) {
|
||||
tl.log(infof, format, v...)
|
||||
tl.log(infoDepth, format, v...)
|
||||
}
|
||||
|
||||
// Warningf logs a warning if not throttled.
|
||||
func (tl *ThrottledLogger) Warningf(format string, v ...interface{}) {
|
||||
tl.log(warningf, format, v...)
|
||||
tl.log(warningDepth, format, v...)
|
||||
}
|
||||
|
||||
// Errorf logs an error if not throttled.
|
||||
func (tl *ThrottledLogger) Errorf(format string, v ...interface{}) {
|
||||
tl.log(errorf, format, v...)
|
||||
tl.log(errorDepth, format, v...)
|
||||
}
|
||||
|
|
|
@ -9,8 +9,8 @@ import (
|
|||
func TestThrottledLogger(t *testing.T) {
|
||||
// Install a fake log func for testing.
|
||||
log := make(chan string)
|
||||
infof = func(format string, args ...interface{}) {
|
||||
log <- fmt.Sprintf(format, args...)
|
||||
infoDepth = func(depth int, args ...interface{}) {
|
||||
log <- fmt.Sprint(args...)
|
||||
}
|
||||
interval := 100 * time.Millisecond
|
||||
tl := NewThrottledLogger("name", interval)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package acl
|
||||
|
||||
// ACL is an interface for Access Control List.
|
||||
type ACL interface {
|
||||
// IsMember checks the membership of a principal in this ACL.
|
||||
IsMember(principal string) bool
|
||||
}
|
||||
|
||||
// Factory is responsible to create new ACL instance.
|
||||
type Factory interface {
|
||||
// New creates a new ACL instance.
|
||||
New(entries []string) (ACL, error)
|
||||
// All returns an ACL instance that contains all users.
|
||||
All() ACL
|
||||
// AllString returns a string representation of all users.
|
||||
AllString() string
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tableacl
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRoleName(t *testing.T) {
|
||||
if READER.Name() != roleNames[READER] {
|
||||
t.Fatalf("role READER does not return expected name, expected: %s, actual: %s", roleNames[READER], READER.Name())
|
||||
}
|
||||
|
||||
if WRITER.Name() != roleNames[WRITER] {
|
||||
t.Fatalf("role WRITER does not return expected name, expected: %s, actual: %s", roleNames[WRITER], WRITER.Name())
|
||||
}
|
||||
|
||||
if ADMIN.Name() != roleNames[ADMIN] {
|
||||
t.Fatalf("role ADMIN does not return expected name, expected: %s, actual: %s", roleNames[ADMIN], ADMIN.Name())
|
||||
}
|
||||
|
||||
unknownRole := Role(-1)
|
||||
if unknownRole.Name() != "" {
|
||||
t.Fatalf("role is not defined, expected to get an empty string but got: %s", unknownRole.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleByName(t *testing.T) {
|
||||
if role, ok := RoleByName("unknown"); ok {
|
||||
t.Fatalf("should not find a valid role for invalid name, but got role: %s", role.Name())
|
||||
}
|
||||
|
||||
role, ok := RoleByName("READER")
|
||||
if !ok || role != READER {
|
||||
t.Fatalf("string READER should return role READER")
|
||||
}
|
||||
|
||||
role, ok = RoleByName("WRITER")
|
||||
if !ok || role != WRITER {
|
||||
t.Fatalf("string WRITER should return role WRITER")
|
||||
}
|
||||
|
||||
role, ok = RoleByName("ADMIN")
|
||||
if !ok || role != ADMIN {
|
||||
t.Fatalf("string ADMIN should return role ADMIN")
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package tableacl
|
||||
|
||||
var allAcl simpleACL
|
||||
|
||||
const (
|
||||
ALL = "*"
|
||||
)
|
||||
|
||||
// NewACL returns an ACL with the specified entries
|
||||
func NewACL(entries []string) (ACL, error) {
|
||||
a := simpleACL(map[string]bool{})
|
||||
for _, e := range entries {
|
||||
a[e] = true
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// simpleACL keeps all entries in a unique in-memory list
|
||||
type simpleACL map[string]bool
|
||||
|
||||
// IsMember checks the membership of a principal in this ACL
|
||||
func (a simpleACL) IsMember(principal string) bool {
|
||||
return a[principal] || a[ALL]
|
||||
}
|
||||
|
||||
func all() ACL {
|
||||
if allAcl == nil {
|
||||
allAcl = simpleACL(map[string]bool{ALL: true})
|
||||
}
|
||||
return allAcl
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package simpleacl
|
||||
|
||||
import "github.com/youtube/vitess/go/vt/tableacl/acl"
|
||||
|
||||
var allAcl SimpleAcl
|
||||
|
||||
const all = "*"
|
||||
|
||||
// SimpleAcl keeps all entries in a unique in-memory list
|
||||
type SimpleAcl map[string]bool
|
||||
|
||||
// IsMember checks the membership of a principal in this ACL
|
||||
func (sacl SimpleAcl) IsMember(principal string) bool {
|
||||
return sacl[principal] || sacl[all]
|
||||
}
|
||||
|
||||
// Factory is responsible to create new ACL instance.
|
||||
type Factory struct{}
|
||||
|
||||
// New creates a new ACL instance.
|
||||
func (factory *Factory) New(entries []string) (acl.ACL, error) {
|
||||
acl := SimpleAcl(map[string]bool{})
|
||||
for _, e := range entries {
|
||||
acl[e] = true
|
||||
}
|
||||
return acl, nil
|
||||
}
|
||||
|
||||
// All returns an ACL instance that contains all users.
|
||||
func (factory *Factory) All() acl.ACL {
|
||||
if allAcl == nil {
|
||||
allAcl = SimpleAcl(map[string]bool{all: true})
|
||||
}
|
||||
return allAcl
|
||||
}
|
||||
|
||||
// AllString returns a string representation of all users.
|
||||
func (factory *Factory) AllString() string {
|
||||
return all
|
||||
}
|
||||
|
||||
// make sure SimpleAcl implements interface acl.ACL
|
||||
var _ (acl.ACL) = (*SimpleAcl)(nil)
|
||||
|
||||
// make sure Factory implements interface acl.AclFactory
|
||||
var _ (acl.Factory) = (*Factory)(nil)
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package simpleacl
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/tableacl/testlib"
|
||||
)
|
||||
|
||||
func TestSimpleAcl(t *testing.T) {
|
||||
testlib.TestSuite(t, &Factory{})
|
||||
}
|
|
@ -1,31 +1,39 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tableacl
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/vt/tableacl/acl"
|
||||
)
|
||||
|
||||
// ACL is an interface for Access Control List.
|
||||
type ACL interface {
|
||||
// IsMember checks the membership of a principal in this ACL
|
||||
IsMember(principal string) bool
|
||||
}
|
||||
var mu sync.Mutex
|
||||
var tableAcl map[*regexp.Regexp]map[Role]acl.ACL
|
||||
var acls = make(map[string]acl.Factory)
|
||||
|
||||
var tableAcl map[*regexp.Regexp]map[Role]ACL
|
||||
// defaultACL tells the default ACL implementation to use.
|
||||
var defaultACL string
|
||||
|
||||
// Init initiates table ACLs.
|
||||
func Init(configFile string) {
|
||||
config, err := ioutil.ReadFile(configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to read tableACL config file: %v", err)
|
||||
log.Errorf("unable to read tableACL config file: %v", err)
|
||||
panic(fmt.Errorf("unable to read tableACL config file: %v", err))
|
||||
}
|
||||
tableAcl, err = load(config)
|
||||
if err != nil {
|
||||
log.Fatalf("tableACL initialization error: %v", err)
|
||||
log.Errorf("tableACL initialization error: %v", err)
|
||||
panic(fmt.Errorf("tableACL initialization error: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,19 +50,19 @@ func InitFromBytes(config []byte) (err error) {
|
|||
// <tableRegexPattern1>: {"READER": "*", "WRITER": "<u2>,<u4>...","ADMIN": "<u5>"},
|
||||
// <tableRegexPattern2>: {"ADMIN": "<u5>"}
|
||||
//}`)
|
||||
func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
|
||||
func load(config []byte) (map[*regexp.Regexp]map[Role]acl.ACL, error) {
|
||||
var contents map[string]map[string]string
|
||||
err := json.Unmarshal(config, &contents)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tableAcl := make(map[*regexp.Regexp]map[Role]ACL)
|
||||
tableAcl := make(map[*regexp.Regexp]map[Role]acl.ACL)
|
||||
for tblPattern, accessMap := range contents {
|
||||
re, err := regexp.Compile(tblPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regexp compile error %v: %v", tblPattern, err)
|
||||
}
|
||||
tableAcl[re] = make(map[Role]ACL)
|
||||
tableAcl[re] = make(map[Role]acl.ACL)
|
||||
|
||||
entriesByRole := make(map[Role][]string)
|
||||
for i := READER; i < NumRoles; i++ {
|
||||
|
@ -71,7 +79,7 @@ func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
|
|||
}
|
||||
}
|
||||
for r, entries := range entriesByRole {
|
||||
a, err := NewACL(entries)
|
||||
a, err := newACL(entries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -83,8 +91,8 @@ func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
|
|||
}
|
||||
|
||||
// Authorized returns the list of entities who have at least the
|
||||
// minimum specified Role on a table.
|
||||
func Authorized(table string, minRole Role) ACL {
|
||||
// minimum specified Role on a tablel.
|
||||
func Authorized(table string, minRole Role) acl.ACL {
|
||||
// If table ACL is disabled, return nil
|
||||
if tableAcl == nil {
|
||||
return nil
|
||||
|
@ -98,3 +106,48 @@ func Authorized(table string, minRole Role) ACL {
|
|||
// No matching patterns for table, allow all access
|
||||
return all()
|
||||
}
|
||||
|
||||
// Register registers a AclFactory.
|
||||
func Register(name string, factory acl.Factory) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if _, ok := acls[name]; ok {
|
||||
panic(fmt.Sprintf("register a registered key: %s", name))
|
||||
}
|
||||
acls[name] = factory
|
||||
}
|
||||
|
||||
// SetDefaultACL sets the default ACL implementation.
|
||||
func SetDefaultACL(name string) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
defaultACL = name
|
||||
}
|
||||
|
||||
// GetCurrentAclFactory returns current table acl implementation.
|
||||
func GetCurrentAclFactory() acl.Factory {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if defaultACL == "" {
|
||||
if len(acls) == 1 {
|
||||
for _, aclFactory := range acls {
|
||||
return aclFactory
|
||||
}
|
||||
}
|
||||
panic("there are more than one AclFactory " +
|
||||
"registered but no default has been given.")
|
||||
}
|
||||
aclFactory, ok := acls[defaultACL]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("aclFactory for given default: %s is not found.", defaultACL))
|
||||
}
|
||||
return aclFactory
|
||||
}
|
||||
|
||||
func newACL(entries []string) (acl.ACL, error) {
|
||||
return GetCurrentAclFactory().New(entries)
|
||||
}
|
||||
|
||||
func all() acl.ACL {
|
||||
return GetCurrentAclFactory().All()
|
||||
}
|
||||
|
|
|
@ -1,93 +1,172 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tableacl
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
func currentUser() string {
|
||||
return "DummyUser"
|
||||
"github.com/youtube/vitess/go/testfiles"
|
||||
"github.com/youtube/vitess/go/vt/tableacl/acl"
|
||||
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
|
||||
)
|
||||
|
||||
type fakeAclFactory struct{}
|
||||
|
||||
func (factory *fakeAclFactory) New(entries []string) (acl.ACL, error) {
|
||||
return nil, fmt.Errorf("unable to create a new ACL")
|
||||
}
|
||||
|
||||
func TestParseInvalidJSON(t *testing.T) {
|
||||
checkLoad([]byte(`{1:2}`), false, t)
|
||||
checkLoad([]byte(`{"1":"2"}`), false, t)
|
||||
checkLoad([]byte(`{"table1":{1:2}}`), false, t)
|
||||
func (factory *fakeAclFactory) All() acl.ACL {
|
||||
return &fakeACL{}
|
||||
}
|
||||
|
||||
func TestInvalidRoleName(t *testing.T) {
|
||||
checkLoad([]byte(`{"table1":{"SOMEROLE":"u1"}}`), false, t)
|
||||
// AllString returns a string representation of all users.
|
||||
func (factory *fakeAclFactory) AllString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestInvalidRegex(t *testing.T) {
|
||||
checkLoad([]byte(`{"table(1":{"READER":"u1"}}`), false, t)
|
||||
type fakeACL struct{}
|
||||
|
||||
func (acl *fakeACL) IsMember(principal string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestValidConfigs(t *testing.T) {
|
||||
checkLoad([]byte(`{"table1":{"READER":"u1"}}`), true, t)
|
||||
checkLoad([]byte(`{"table1":{"READER":"u1,u2", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,u2", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{
|
||||
"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"},
|
||||
"tbl[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3", "ADMIN":"u4"}
|
||||
}`), true, t)
|
||||
func TestInitWithInvalidFilePath(t *testing.T) {
|
||||
setUpTableACL(&simpleacl.Factory{})
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
t.Fatalf("init should fail for an invalid config file path")
|
||||
}
|
||||
}()
|
||||
Init("/invalid_file_path")
|
||||
}
|
||||
|
||||
func TestDenyReaderInsert(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", WRITER, t, false)
|
||||
func TestInitWithInvalidConfigFile(t *testing.T) {
|
||||
setUpTableACL(&simpleacl.Factory{})
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
t.Fatalf("init should fail for an invalid config file")
|
||||
}
|
||||
}()
|
||||
Init(testfiles.Locate("tableacl/invalid_tableacl_config.json"))
|
||||
}
|
||||
|
||||
func TestAllowReaderSelect(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", READER, t, true)
|
||||
func TestInitWithValidConfig(t *testing.T) {
|
||||
setUpTableACL(&simpleacl.Factory{})
|
||||
Init(testfiles.Locate("tableacl/test_table_tableacl_config.json"))
|
||||
}
|
||||
|
||||
func TestDenyReaderDDL(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", ADMIN, t, false)
|
||||
func TestInitFromBytes(t *testing.T) {
|
||||
aclFactory := &simpleacl.Factory{}
|
||||
setUpTableACL(aclFactory)
|
||||
acl := Authorized("test_table", READER)
|
||||
if acl != nil {
|
||||
t.Fatalf("tableacl has not been initialized, should get nil ACL")
|
||||
}
|
||||
err := InitFromBytes([]byte(`{"test_table":{"Reader": "vt"}}`))
|
||||
if err != nil {
|
||||
t.Fatalf("tableacl init should succeed, but got error: %v", err)
|
||||
}
|
||||
|
||||
acl = Authorized("unknown_table", READER)
|
||||
if !reflect.DeepEqual(aclFactory.All(), acl) {
|
||||
t.Fatalf("there is no config for unknown_table, should grand all permission")
|
||||
}
|
||||
|
||||
acl = Authorized("test_table", READER)
|
||||
if !acl.IsMember("vt") {
|
||||
t.Fatalf("user: vt should have reader permission to table: test_table")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowUnmatchedTable(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "UNMATCHED_TABLE", ADMIN, t, true)
|
||||
func TestInvalidTableRegex(t *testing.T) {
|
||||
setUpTableACL(&simpleacl.Factory{})
|
||||
err := InitFromBytes([]byte(`{"table(":{"Reader": "vt", "WRITER":"vt"}}`))
|
||||
if err == nil {
|
||||
t.Fatalf("tableacl init should fail because config file has an invalid table regex")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllUserReadAccess(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + ALL + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", READER, t, true)
|
||||
func TestInvalidRole(t *testing.T) {
|
||||
setUpTableACL(&simpleacl.Factory{})
|
||||
err := InitFromBytes([]byte(`{"test_table":{"InvalidRole": "vt", "Reader": "vt", "WRITER":"vt"}}`))
|
||||
if err == nil {
|
||||
t.Fatalf("tableacl init should fail because config file has an invalid role")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllUserWriteAccess(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"` + ALL + `"}}`)
|
||||
checkAccess(configData, "table1", WRITER, t, true)
|
||||
func TestFailedToCreateACL(t *testing.T) {
|
||||
setUpTableACL(&fakeAclFactory{})
|
||||
err := InitFromBytes([]byte(`{"test_table":{"Reader": "vt", "WRITER":"vt"}}`))
|
||||
if err == nil {
|
||||
t.Fatalf("tableacl init should fail because fake ACL returns an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabled(t *testing.T) {
|
||||
func TestDoubleRegisterTheSameKey(t *testing.T) {
|
||||
acls = make(map[string]acl.Factory)
|
||||
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
|
||||
Register(name, &simpleacl.Factory{})
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
t.Fatalf("the second tableacl register should fail")
|
||||
}
|
||||
}()
|
||||
Register(name, &simpleacl.Factory{})
|
||||
}
|
||||
|
||||
func TestGetAclFactory(t *testing.T) {
|
||||
acls = make(map[string]acl.Factory)
|
||||
defaultACL = ""
|
||||
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
|
||||
aclFactory := &simpleacl.Factory{}
|
||||
Register(name, aclFactory)
|
||||
if !reflect.DeepEqual(aclFactory, GetCurrentAclFactory()) {
|
||||
t.Fatalf("should return registered acl factory even if default acl is not set.")
|
||||
}
|
||||
Register(name+"2", aclFactory)
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
t.Fatalf("there are more than one acl factories, but the default is not set")
|
||||
}
|
||||
}()
|
||||
GetCurrentAclFactory()
|
||||
}
|
||||
|
||||
func TestGetAclFactoryWithWrongDefault(t *testing.T) {
|
||||
acls = make(map[string]acl.Factory)
|
||||
defaultACL = ""
|
||||
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
|
||||
aclFactory := &simpleacl.Factory{}
|
||||
Register(name, aclFactory)
|
||||
Register(name+"2", aclFactory)
|
||||
SetDefaultACL("wrong_name")
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
t.Fatalf("there are more than one acl factories, but the default given does not match any of these.")
|
||||
}
|
||||
}()
|
||||
GetCurrentAclFactory()
|
||||
}
|
||||
|
||||
func setUpTableACL(factory acl.Factory) {
|
||||
tableAcl = nil
|
||||
got := Authorized("table1", READER)
|
||||
if got != nil {
|
||||
t.Errorf("table acl disabled, got: %v, want: nil", got)
|
||||
}
|
||||
name := fmt.Sprintf("tableacl-name-%d", rand.Int63())
|
||||
Register(name, factory)
|
||||
SetDefaultACL(name)
|
||||
}
|
||||
|
||||
func checkLoad(configData []byte, valid bool, t *testing.T) {
|
||||
err := InitFromBytes(configData)
|
||||
if !valid && err == nil {
|
||||
t.Errorf("expecting parse error none returned")
|
||||
}
|
||||
|
||||
if valid && err != nil {
|
||||
t.Errorf("unexpected load error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkAccess(configData []byte, tableName string, role Role, t *testing.T, want bool) {
|
||||
checkLoad(configData, true, t)
|
||||
got := Authorized(tableName, role).IsMember(currentUser())
|
||||
if want != got {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testlib
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tableacl/acl"
|
||||
)
|
||||
|
||||
// TestSuite tests a concrete acl.Factory implementation.
|
||||
func TestSuite(t *testing.T, factory acl.Factory) {
|
||||
name := fmt.Sprintf("tableacl-test-%d", rand.Int63())
|
||||
tableacl.Register(name, factory)
|
||||
tableacl.SetDefaultACL(name)
|
||||
|
||||
testParseInvalidJSON(t)
|
||||
testInvalidRoleName(t)
|
||||
testInvalidRegex(t)
|
||||
testValidConfigs(t)
|
||||
testDenyReaderInsert(t)
|
||||
testAllowReaderSelect(t)
|
||||
testDenyReaderDDL(t)
|
||||
testAllowUnmatchedTable(t)
|
||||
testAllUserReadAccess(t)
|
||||
testAllUserWriteAccess(t)
|
||||
}
|
||||
|
||||
func currentUser() string {
|
||||
return "DummyUser"
|
||||
}
|
||||
|
||||
func testParseInvalidJSON(t *testing.T) {
|
||||
checkLoad([]byte(`{1:2}`), false, t)
|
||||
checkLoad([]byte(`{"1":"2"}`), false, t)
|
||||
checkLoad([]byte(`{"table1":{1:2}}`), false, t)
|
||||
}
|
||||
|
||||
func testInvalidRoleName(t *testing.T) {
|
||||
checkLoad([]byte(`{"table1":{"SOMEROLE":"u1"}}`), false, t)
|
||||
}
|
||||
|
||||
func testInvalidRegex(t *testing.T) {
|
||||
checkLoad([]byte(`{"table(1":{"READER":"u1"}}`), false, t)
|
||||
}
|
||||
|
||||
func testValidConfigs(t *testing.T) {
|
||||
checkLoad([]byte(`{"table1":{"READER":"u1"}}`), true, t)
|
||||
checkLoad([]byte(`{"table1":{"READER":"u1,u2", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,u2", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,`+allString()+`", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{
|
||||
"table[0-9]+":{"Reader":"u1,`+allString()+`", "WRITER":"u3"},
|
||||
"tbl[0-9]+":{"Reader":"u1,`+allString()+`", "WRITER":"u3", "ADMIN":"u4"}
|
||||
}`), true, t)
|
||||
}
|
||||
|
||||
func testDenyReaderInsert(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", tableacl.WRITER, t, false)
|
||||
}
|
||||
|
||||
func testAllowReaderSelect(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", tableacl.READER, t, true)
|
||||
}
|
||||
|
||||
func testDenyReaderDDL(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", tableacl.ADMIN, t, false)
|
||||
}
|
||||
|
||||
func testAllowUnmatchedTable(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "UNMATCHED_TABLE", tableacl.ADMIN, t, true)
|
||||
}
|
||||
|
||||
func testAllUserReadAccess(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + allString() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", tableacl.READER, t, true)
|
||||
}
|
||||
|
||||
func testAllUserWriteAccess(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"` + allString() + `"}}`)
|
||||
checkAccess(configData, "table1", tableacl.WRITER, t, true)
|
||||
}
|
||||
|
||||
func checkLoad(configData []byte, valid bool, t *testing.T) {
|
||||
err := tableacl.InitFromBytes(configData)
|
||||
if !valid && err == nil {
|
||||
t.Errorf("expecting parse error none returned")
|
||||
}
|
||||
|
||||
if valid && err != nil {
|
||||
t.Errorf("unexpected load error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkAccess(configData []byte, tableName string, role tableacl.Role, t *testing.T, want bool) {
|
||||
checkLoad(configData, true, t)
|
||||
got := tableacl.Authorized(tableName, role).IsMember(currentUser())
|
||||
if want != got {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func allString() string {
|
||||
return tableacl.GetCurrentAclFactory().AllString()
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
|
@ -91,6 +91,9 @@ type ActionAgent struct {
|
|||
// the reason we're not healthy.
|
||||
_healthy error
|
||||
|
||||
// this is the last time health check ran
|
||||
_healthyTime time.Time
|
||||
|
||||
// replication delay the last time we got it
|
||||
_replicationDelay time.Duration
|
||||
|
||||
|
@ -247,10 +250,21 @@ func (agent *ActionAgent) Tablet() *topo.TabletInfo {
|
|||
}
|
||||
|
||||
// Healthy reads the result of the latest healthcheck, protected by mutex.
|
||||
// If that status is too old, it means healthcheck hasn't run for a while,
|
||||
// and is probably stuck, this is not good, we're not healthy.
|
||||
func (agent *ActionAgent) Healthy() (time.Duration, error) {
|
||||
agent.mutex.Lock()
|
||||
defer agent.mutex.Unlock()
|
||||
return agent._replicationDelay, agent._healthy
|
||||
|
||||
healthy := agent._healthy
|
||||
if healthy == nil {
|
||||
timeSinceLastCheck := time.Now().Sub(agent._healthyTime)
|
||||
if timeSinceLastCheck > *healthCheckInterval*3 {
|
||||
healthy = fmt.Errorf("last health check is too old: %s > %s", timeSinceLastCheck, *healthCheckInterval*3)
|
||||
}
|
||||
}
|
||||
|
||||
return agent._replicationDelay, healthy
|
||||
}
|
||||
|
||||
// BlacklistedTables reads the list of blacklisted tables from the TabletControl
|
||||
|
|
|
@ -228,6 +228,7 @@ func (agent *ActionAgent) runHealthCheck(targetTabletType topo.TabletType) {
|
|||
// remember our health status
|
||||
agent.mutex.Lock()
|
||||
agent._healthy = err
|
||||
agent._healthyTime = time.Now()
|
||||
agent._replicationDelay = replicationDelay
|
||||
agent.mutex.Unlock()
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -153,6 +154,7 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
|
|||
|
||||
// first health check, should change us to replica, and update the
|
||||
// mysql port to 3306
|
||||
before := time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err := agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -167,9 +169,13 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
|
|||
if !agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
|
||||
// now make the tablet unhealthy
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportError = fmt.Errorf("tablet is unhealthy")
|
||||
before = time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err = agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -181,6 +187,9 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
|
|||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryServiceNotStarting verifies that if a tablet cannot start the
|
||||
|
@ -190,6 +199,7 @@ func TestQueryServiceNotStarting(t *testing.T) {
|
|||
targetTabletType := topo.TYPE_REPLICA
|
||||
agent.QueryServiceControl.(*tabletserver.TestQueryServiceControl).AllowQueriesError = fmt.Errorf("test cannot start query service")
|
||||
|
||||
before := time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err := agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -201,6 +211,9 @@ func TestQueryServiceNotStarting(t *testing.T) {
|
|||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryServiceStopped verifies that if a healthy tablet's query
|
||||
|
@ -210,6 +223,7 @@ func TestQueryServiceStopped(t *testing.T) {
|
|||
targetTabletType := topo.TYPE_REPLICA
|
||||
|
||||
// first health check, should change us to replica
|
||||
before := time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err := agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -221,12 +235,16 @@ func TestQueryServiceStopped(t *testing.T) {
|
|||
if !agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
|
||||
// shut down query service and prevent it from starting again
|
||||
agent.QueryServiceControl.DisallowQueries()
|
||||
agent.QueryServiceControl.(*tabletserver.TestQueryServiceControl).AllowQueriesError = fmt.Errorf("test cannot start query service")
|
||||
|
||||
// health check should now fail
|
||||
before = time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err = agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -238,6 +256,9 @@ func TestQueryServiceStopped(t *testing.T) {
|
|||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTabletControl verifies the shard's TabletControl record can disable
|
||||
|
@ -247,6 +268,7 @@ func TestTabletControl(t *testing.T) {
|
|||
targetTabletType := topo.TYPE_REPLICA
|
||||
|
||||
// first health check, should change us to replica
|
||||
before := time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err := agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -258,6 +280,9 @@ func TestTabletControl(t *testing.T) {
|
|||
if !agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
|
||||
// now update the shard
|
||||
si, err := agent.TopoServer.GetShard(keyspace, shard)
|
||||
|
@ -286,6 +311,7 @@ func TestTabletControl(t *testing.T) {
|
|||
}
|
||||
|
||||
// check running a health check will not start it again
|
||||
before = time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err = agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -297,9 +323,13 @@ func TestTabletControl(t *testing.T) {
|
|||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
|
||||
// go unhealthy, check we go to spare and QS is not running
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportError = fmt.Errorf("tablet is unhealthy")
|
||||
before = time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err = agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -311,9 +341,13 @@ func TestTabletControl(t *testing.T) {
|
|||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
|
||||
// go back healthy, check QS is still not running
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportError = nil
|
||||
before = time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err = agent.TopoServer.GetTablet(tabletAlias)
|
||||
if err != nil {
|
||||
|
@ -325,4 +359,33 @@ func TestTabletControl(t *testing.T) {
|
|||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOldHealthCheck verifies that a healthcheck that is too old will
|
||||
// return an error
|
||||
func TestOldHealthCheck(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
*healthCheckInterval = 20 * time.Second
|
||||
agent._healthy = nil
|
||||
|
||||
// last health check time is now, we're good
|
||||
agent._healthyTime = time.Now()
|
||||
if _, healthy := agent.Healthy(); healthy != nil {
|
||||
t.Errorf("Healthy returned unexpected error: %v", healthy)
|
||||
}
|
||||
|
||||
// last health check time is 2x interval ago, we're good
|
||||
agent._healthyTime = time.Now().Add(-2 * *healthCheckInterval)
|
||||
if _, healthy := agent.Healthy(); healthy != nil {
|
||||
t.Errorf("Healthy returned unexpected error: %v", healthy)
|
||||
}
|
||||
|
||||
// last health check time is 4x interval ago, we're not good
|
||||
agent._healthyTime = time.Now().Add(-4 * *healthCheckInterval)
|
||||
if _, healthy := agent.Healthy(); healthy == nil || !strings.Contains(healthy.Error(), "last health check is too old") {
|
||||
t.Errorf("Healthy returned wrong error: %v", healthy)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -252,7 +252,7 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re
|
|||
qre.logStats.CacheAbsent = absent
|
||||
qre.logStats.CacheMisses = misses
|
||||
|
||||
qre.logStats.QuerySources |= QUERY_SOURCE_ROWCACHE
|
||||
qre.logStats.QuerySources |= QuerySourceRowcache
|
||||
|
||||
tableInfo.hits.Add(hits)
|
||||
tableInfo.absent.Add(absent)
|
||||
|
@ -538,7 +538,7 @@ func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser
|
|||
q.Result, q.Err = qre.execSQLNoPanic(conn, sql, false)
|
||||
}
|
||||
} else {
|
||||
logStats.QuerySources |= QUERY_SOURCE_CONSOLIDATOR
|
||||
logStats.QuerySources |= QuerySourceConsolidator
|
||||
startTime := time.Now()
|
||||
q.Wait()
|
||||
waitStats.Record("Consolidations", startTime)
|
||||
|
|
|
@ -6,7 +6,6 @@ package tabletserver
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
@ -15,6 +14,8 @@ import (
|
|||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/vt/callinfo"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tableacl/simpleacl"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/fakecacheservice"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/fakesqldb"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
|
@ -22,27 +23,6 @@ import (
|
|||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type fakeCallInfo struct {
|
||||
remoteAddr string
|
||||
username string
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) RemoteAddr() string {
|
||||
return fci.remoteAddr
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) Username() string {
|
||||
return fci.username
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) Text() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) HTML() template.HTML {
|
||||
return template.HTML("")
|
||||
}
|
||||
|
||||
func TestQueryExecutorPlanDDL(t *testing.T) {
|
||||
db := setUpQueryExecutorTest()
|
||||
query := "alter table test_table add zipcode int"
|
||||
|
@ -567,53 +547,57 @@ func TestQueryExecutorPlanOther(t *testing.T) {
|
|||
checkEqual(t, expected, qre.Execute())
|
||||
}
|
||||
|
||||
//func TestQueryExecutorTableAcl(t *testing.T) {
|
||||
// db := setUpQueryExecutorTest()
|
||||
// query := "select * from test_table limit 1000"
|
||||
// expected := &mproto.QueryResult{
|
||||
// Fields: getTestTableFields(),
|
||||
// RowsAffected: 0,
|
||||
// Rows: [][]sqltypes.Value{},
|
||||
// }
|
||||
// db.AddQuery(query, expected)
|
||||
// db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{
|
||||
// Fields: getTestTableFields(),
|
||||
// })
|
||||
//
|
||||
// username := "u2"
|
||||
// callInfo := &fakeCallInfo{
|
||||
// remoteAddr: "1.2.3.4",
|
||||
// username: username,
|
||||
// }
|
||||
// ctx := callinfo.NewContext(context.Background(), callInfo)
|
||||
// if err := tableacl.InitFromBytes(
|
||||
// []byte(fmt.Sprintf(`{"test_table":{"READER":"%s"}}`, username))); err != nil {
|
||||
// t.Fatalf("unable to load tableacl config, error: %v", err)
|
||||
// }
|
||||
//
|
||||
// qre, sqlQuery := newTestQueryExecutor(
|
||||
// query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
|
||||
// checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
|
||||
// checkEqual(t, expected, qre.Execute())
|
||||
// sqlQuery.disallowQueries()
|
||||
//
|
||||
// if err := tableacl.InitFromBytes([]byte(`{"test_table":{"READER":"superuser"}}`)); err != nil {
|
||||
// t.Fatalf("unable to load tableacl config, error: %v", err)
|
||||
// }
|
||||
// // without enabling Config.StrictTableAcl
|
||||
// qre, sqlQuery = newTestQueryExecutor(
|
||||
// query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
|
||||
// checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
|
||||
// qre.Execute()
|
||||
// sqlQuery.disallowQueries()
|
||||
// // enable Config.StrictTableAcl
|
||||
// qre, sqlQuery = newTestQueryExecutor(
|
||||
// query, ctx, enableRowCache|enableSchemaOverrides|enableStrict|enableStrictTableAcl)
|
||||
// defer sqlQuery.disallowQueries()
|
||||
// checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
|
||||
// defer handleAndVerifyTabletError(t, "query should fail because current user do not have read permissions", ErrFail)
|
||||
// qre.Execute()
|
||||
//}
|
||||
func TestQueryExecutorTableAcl(t *testing.T) {
|
||||
aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63())
|
||||
tableacl.Register(aclName, &simpleacl.Factory{})
|
||||
tableacl.SetDefaultACL(aclName)
|
||||
|
||||
db := setUpQueryExecutorTest()
|
||||
query := "select * from test_table limit 1000"
|
||||
expected := &mproto.QueryResult{
|
||||
Fields: getTestTableFields(),
|
||||
RowsAffected: 0,
|
||||
Rows: [][]sqltypes.Value{},
|
||||
}
|
||||
db.AddQuery(query, expected)
|
||||
db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{
|
||||
Fields: getTestTableFields(),
|
||||
})
|
||||
|
||||
username := "u2"
|
||||
callInfo := &fakeCallInfo{
|
||||
remoteAddr: "1.2.3.4",
|
||||
username: username,
|
||||
}
|
||||
ctx := callinfo.NewContext(context.Background(), callInfo)
|
||||
if err := tableacl.InitFromBytes(
|
||||
[]byte(fmt.Sprintf(`{"test_table":{"READER":"%s"}}`, username))); err != nil {
|
||||
t.Fatalf("unable to load tableacl config, error: %v", err)
|
||||
}
|
||||
|
||||
qre, sqlQuery := newTestQueryExecutor(
|
||||
query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
|
||||
checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
|
||||
checkEqual(t, expected, qre.Execute())
|
||||
sqlQuery.disallowQueries()
|
||||
|
||||
if err := tableacl.InitFromBytes([]byte(`{"test_table":{"READER":"superuser"}}`)); err != nil {
|
||||
t.Fatalf("unable to load tableacl config, error: %v", err)
|
||||
}
|
||||
// without enabling Config.StrictTableAcl
|
||||
qre, sqlQuery = newTestQueryExecutor(
|
||||
query, ctx, enableRowCache|enableSchemaOverrides|enableStrict)
|
||||
checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
|
||||
qre.Execute()
|
||||
sqlQuery.disallowQueries()
|
||||
// enable Config.StrictTableAcl
|
||||
qre, sqlQuery = newTestQueryExecutor(
|
||||
query, ctx, enableRowCache|enableSchemaOverrides|enableStrict|enableStrictTableAcl)
|
||||
defer sqlQuery.disallowQueries()
|
||||
checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId)
|
||||
defer handleAndVerifyTabletError(t, "query should fail because current user do not have read permissions", ErrFail)
|
||||
qre.Execute()
|
||||
}
|
||||
|
||||
func TestQueryExecutorBlacklistQRFail(t *testing.T) {
|
||||
db := setUpQueryExecutorTest()
|
||||
|
|
|
@ -98,14 +98,9 @@ func (ql *QueryList) GetQueryzRows() []QueryDetailzRow {
|
|||
ql.mu.Lock()
|
||||
rows := []QueryDetailzRow{}
|
||||
for _, qd := range ql.queryDetails {
|
||||
var h template.HTML
|
||||
ci, ok := callinfo.FromContext(qd.context)
|
||||
if ok {
|
||||
h = ci.HTML()
|
||||
}
|
||||
row := QueryDetailzRow{
|
||||
Query: qd.conn.Current(),
|
||||
ContextHTML: h,
|
||||
ContextHTML: callinfo.HTMLFromContext(qd.context),
|
||||
Start: qd.start,
|
||||
Duration: time.Now().Sub(qd.start),
|
||||
ConnID: qd.connID,
|
||||
|
|
|
@ -7,20 +7,25 @@ import (
|
|||
)
|
||||
|
||||
type testConn struct {
|
||||
id int64
|
||||
query string
|
||||
id int64
|
||||
query string
|
||||
killed bool
|
||||
}
|
||||
|
||||
func (tc testConn) Current() string { return tc.query }
|
||||
func (tc *testConn) Current() string { return tc.query }
|
||||
|
||||
func (tc testConn) ID() int64 { return tc.id }
|
||||
func (tc *testConn) ID() int64 { return tc.id }
|
||||
|
||||
func (tc testConn) Kill() {}
|
||||
func (tc *testConn) Kill() { tc.killed = true }
|
||||
|
||||
func (tc *testConn) IsKilled() bool {
|
||||
return tc.killed
|
||||
}
|
||||
|
||||
func TestQueryList(t *testing.T) {
|
||||
ql := NewQueryList()
|
||||
connID := int64(1)
|
||||
qd := NewQueryDetail(context.Background(), testConn{id: connID})
|
||||
qd := NewQueryDetail(context.Background(), &testConn{id: connID})
|
||||
ql.Add(qd)
|
||||
|
||||
if qd1, ok := ql.queryDetails[connID]; !ok || qd1.connID != connID {
|
||||
|
@ -28,7 +33,7 @@ func TestQueryList(t *testing.T) {
|
|||
}
|
||||
|
||||
conn2ID := int64(2)
|
||||
qd2 := NewQueryDetail(context.Background(), testConn{id: conn2ID})
|
||||
qd2 := NewQueryDetail(context.Background(), &testConn{id: conn2ID})
|
||||
ql.Add(qd2)
|
||||
|
||||
rows := ql.GetQueryzRows()
|
||||
|
|
|
@ -382,6 +382,15 @@ func (rqsc *realQueryServiceControl) registerDebugHealthHandler() {
|
|||
})
|
||||
}
|
||||
|
||||
func (rqsc *realQueryServiceControl) registerStreamQueryzHandlers() {
|
||||
http.HandleFunc("/streamqueryz", func(w http.ResponseWriter, r *http.Request) {
|
||||
streamQueryzHandler(rqsc.sqlQueryRPCService.qe.streamQList, w, r)
|
||||
})
|
||||
http.HandleFunc("/streamqueryz/terminate", func(w http.ResponseWriter, r *http.Request) {
|
||||
streamQueryzTerminateHandler(rqsc.sqlQueryRPCService.qe.streamQList, w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func buildFmter(logger *streamlog.StreamLogger) func(url.Values, interface{}) string {
|
||||
type formatter interface {
|
||||
Format(url.Values) string
|
||||
|
|
|
@ -113,7 +113,9 @@ func querylogzHandler(w http.ResponseWriter, r *http.Request) {
|
|||
*SQLQueryStats
|
||||
ColorLevel string
|
||||
}{stats, level}
|
||||
querylogzTmpl.Execute(w, tmplData)
|
||||
if err := querylogzTmpl.Execute(w, tmplData); err != nil {
|
||||
log.Errorf("querylogz: couldn't execute template: %v", err)
|
||||
}
|
||||
case <-tmr.C:
|
||||
return
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/acl"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
)
|
||||
|
@ -150,7 +151,9 @@ func (rqsc *realQueryServiceControl) registerQueryzHandler() {
|
|||
}
|
||||
sort.Sort(&sorter)
|
||||
for _, Value := range sorter.rows {
|
||||
queryzTmpl.Execute(w, Value)
|
||||
if err := queryzTmpl.Execute(w, Value); err != nil {
|
||||
log.Errorf("queryz: couldn't execute template: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/youtube/vitess/go/timer"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
tacl "github.com/youtube/vitess/go/vt/tableacl/acl"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
@ -44,7 +45,7 @@ type ExecPlan struct {
|
|||
TableInfo *TableInfo
|
||||
Fields []mproto.Field
|
||||
Rules *QueryRules
|
||||
Authorized tableacl.ACL
|
||||
Authorized tacl.ACL
|
||||
|
||||
mu sync.Mutex
|
||||
QueryCount int64
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http"
|
||||
"sort"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/acl"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
)
|
||||
|
@ -77,7 +78,9 @@ func (rqsc *realQueryServiceControl) registerSchemazHandler() {
|
|||
}
|
||||
for _, Value := range sorter.rows {
|
||||
envelope.Table = Value
|
||||
schemazTmpl.Execute(w, envelope)
|
||||
if err := schemazTmpl.Execute(w, envelope); err != nil {
|
||||
log.Errorf("schemaz: couldn't execute template: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ package tabletserver
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -22,9 +23,12 @@ import (
|
|||
var SqlQueryLogger = streamlog.New("SqlQuery", 50)
|
||||
|
||||
const (
|
||||
QUERY_SOURCE_ROWCACHE = 1 << iota
|
||||
QUERY_SOURCE_CONSOLIDATOR
|
||||
QUERY_SOURCE_MYSQL
|
||||
// QuerySourceRowcache means query result is found in rowcache.
|
||||
QuerySourceRowcache = 1 << iota
|
||||
// QuerySourceConsolidator means query result is found in consolidator.
|
||||
QuerySourceConsolidator
|
||||
// QuerySourceMySQL means query result is returned from MySQL.
|
||||
QuerySourceMySQL
|
||||
)
|
||||
|
||||
// SQLQueryStats records the stats for a single query
|
||||
|
@ -67,7 +71,7 @@ func (stats *SQLQueryStats) Send() {
|
|||
|
||||
// AddRewrittenSql adds a single sql statement to the rewritten list
|
||||
func (stats *SQLQueryStats) AddRewrittenSql(sql string, start time.Time) {
|
||||
stats.QuerySources |= QUERY_SOURCE_MYSQL
|
||||
stats.QuerySources |= QuerySourceMySQL
|
||||
stats.NumberOfQueries++
|
||||
stats.rewrittenSqls = append(stats.rewrittenSqls, sql)
|
||||
stats.MysqlResponseTime += time.Now().Sub(start)
|
||||
|
@ -139,21 +143,28 @@ func (stats *SQLQueryStats) FmtQuerySources() string {
|
|||
}
|
||||
sources := make([]string, 3)
|
||||
n := 0
|
||||
if stats.QuerySources&QUERY_SOURCE_MYSQL != 0 {
|
||||
if stats.QuerySources&QuerySourceMySQL != 0 {
|
||||
sources[n] = "mysql"
|
||||
n++
|
||||
}
|
||||
if stats.QuerySources&QUERY_SOURCE_ROWCACHE != 0 {
|
||||
if stats.QuerySources&QuerySourceRowcache != 0 {
|
||||
sources[n] = "rowcache"
|
||||
n++
|
||||
}
|
||||
if stats.QuerySources&QUERY_SOURCE_CONSOLIDATOR != 0 {
|
||||
if stats.QuerySources&QuerySourceConsolidator != 0 {
|
||||
sources[n] = "consolidator"
|
||||
n++
|
||||
}
|
||||
return strings.Join(sources[:n], ",")
|
||||
}
|
||||
|
||||
// ContextHTML returns the HTML version of the context that was used, or "".
|
||||
// This is a method on SQLQueryStats instead of a field so that it doesn't need
|
||||
// to be passed by value everywhere.
|
||||
func (stats *SQLQueryStats) ContextHTML() template.HTML {
|
||||
return callinfo.HTMLFromContext(stats.context)
|
||||
}
|
||||
|
||||
// ErrorStr returns the error string or ""
|
||||
func (stats *SQLQueryStats) ErrorStr() string {
|
||||
if stats.Error != nil {
|
|
@ -0,0 +1,142 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tabletserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/vt/callinfo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestSqlQueryStats(t *testing.T) {
|
||||
logStats := newSqlQueryStats("test", context.Background())
|
||||
logStats.AddRewrittenSql("sql1", time.Now())
|
||||
|
||||
if !strings.Contains(logStats.RewrittenSql(), "sql1") {
|
||||
t.Fatalf("RewrittenSql should contains sql: sql1")
|
||||
}
|
||||
|
||||
if logStats.SizeOfResponse() != 0 {
|
||||
t.Fatalf("there is no rows in log stats, estimated size should be 0 bytes")
|
||||
}
|
||||
|
||||
logStats.Rows = [][]sqltypes.Value{[]sqltypes.Value{sqltypes.MakeString([]byte("a"))}}
|
||||
if logStats.SizeOfResponse() <= 0 {
|
||||
t.Fatalf("log stats has some rows, should have positive response size")
|
||||
}
|
||||
|
||||
params := map[string][]string{"full": []string{}}
|
||||
|
||||
logStats.Format(url.Values(params))
|
||||
}
|
||||
|
||||
func TestSqlQueryStatsFormatBindVariables(t *testing.T) {
|
||||
logStats := newSqlQueryStats("test", context.Background())
|
||||
logStats.BindVariables = make(map[string]interface{})
|
||||
logStats.BindVariables["key_1"] = "val_1"
|
||||
logStats.BindVariables["key_2"] = 789
|
||||
|
||||
formattedStr := logStats.FmtBindVariables(true)
|
||||
if !strings.Contains(formattedStr, "key_1") ||
|
||||
!strings.Contains(formattedStr, "val_1") {
|
||||
t.Fatalf("bind variable 'key_1': 'val_1' is not formatted")
|
||||
}
|
||||
if !strings.Contains(formattedStr, "key_2") ||
|
||||
!strings.Contains(formattedStr, "789") {
|
||||
t.Fatalf("bind variable 'key_2': '789' is not formatted")
|
||||
}
|
||||
|
||||
logStats.BindVariables["key_3"] = []byte("val_3")
|
||||
formattedStr = logStats.FmtBindVariables(false)
|
||||
if !strings.Contains(formattedStr, "key_1") {
|
||||
t.Fatalf("bind variable 'key_1' is not formatted")
|
||||
}
|
||||
if !strings.Contains(formattedStr, "key_2") ||
|
||||
!strings.Contains(formattedStr, "789") {
|
||||
t.Fatalf("bind variable 'key_2': '789' is not formatted")
|
||||
}
|
||||
if !strings.Contains(formattedStr, "key_3") {
|
||||
t.Fatalf("bind variable 'key_3' is not formatted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlQueryStatsFormatQuerySources(t *testing.T) {
|
||||
logStats := newSqlQueryStats("test", context.Background())
|
||||
if logStats.FmtQuerySources() != "none" {
|
||||
t.Fatalf("should return none since log stats does not have any query source, but got: %s", logStats.FmtQuerySources())
|
||||
}
|
||||
|
||||
logStats.QuerySources |= QuerySourceMySQL
|
||||
if !strings.Contains(logStats.FmtQuerySources(), "mysql") {
|
||||
t.Fatalf("'mysql' should be in formated query sources")
|
||||
}
|
||||
|
||||
logStats.QuerySources |= QuerySourceRowcache
|
||||
if !strings.Contains(logStats.FmtQuerySources(), "rowcache") {
|
||||
t.Fatalf("'rowcache' should be in formated query sources")
|
||||
}
|
||||
|
||||
logStats.QuerySources |= QuerySourceConsolidator
|
||||
if !strings.Contains(logStats.FmtQuerySources(), "consolidator") {
|
||||
t.Fatalf("'consolidator' should be in formated query sources")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlQueryStatsContextHTML(t *testing.T) {
|
||||
html := "HtmlContext"
|
||||
callInfo := &fakeCallInfo{
|
||||
html: html,
|
||||
}
|
||||
ctx := callinfo.NewContext(context.Background(), callInfo)
|
||||
logStats := newSqlQueryStats("test", ctx)
|
||||
if string(logStats.ContextHTML()) != html {
|
||||
t.Fatalf("expect to get html: %s, but got: %s", html, string(logStats.ContextHTML()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlQueryStatsErrorStr(t *testing.T) {
|
||||
logStats := newSqlQueryStats("test", context.Background())
|
||||
if logStats.ErrorStr() != "" {
|
||||
t.Fatalf("should not get error in stats, but got: %s", logStats.ErrorStr())
|
||||
}
|
||||
errStr := "unknown error"
|
||||
logStats.Error = fmt.Errorf(errStr)
|
||||
if logStats.ErrorStr() != errStr {
|
||||
t.Fatalf("expect to get error string: %s, but got: %s", errStr, logStats.ErrorStr())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlQueryStatsRemoteAddrUsername(t *testing.T) {
|
||||
logStats := newSqlQueryStats("test", context.Background())
|
||||
addr, user := logStats.RemoteAddrUsername()
|
||||
if addr != "" {
|
||||
t.Fatalf("remote addr should be empty")
|
||||
}
|
||||
if user != "" {
|
||||
t.Fatalf("username should be empty")
|
||||
}
|
||||
|
||||
remoteAddr := "1.2.3.4"
|
||||
username := "vt"
|
||||
callInfo := &fakeCallInfo{
|
||||
remoteAddr: remoteAddr,
|
||||
username: username,
|
||||
}
|
||||
ctx := callinfo.NewContext(context.Background(), callInfo)
|
||||
logStats = newSqlQueryStats("test", ctx)
|
||||
addr, user = logStats.RemoteAddrUsername()
|
||||
if addr != remoteAddr {
|
||||
t.Fatalf("expected to get remote addr: %s, but got: %s", remoteAddr, addr)
|
||||
}
|
||||
if user != username {
|
||||
t.Fatalf("expected to get username: %s, but got: %s", username, user)
|
||||
}
|
||||
}
|
|
@ -1,3 +1,7 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tabletserver
|
||||
|
||||
import (
|
||||
|
@ -7,6 +11,7 @@ import (
|
|||
"strconv"
|
||||
"text/template"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/acl"
|
||||
)
|
||||
|
||||
|
@ -23,7 +28,7 @@ var (
|
|||
</thead>
|
||||
`)
|
||||
streamqueryzTmpl = template.Must(template.New("example").Parse(`
|
||||
<tr>
|
||||
<tr>
|
||||
<td>{{.Query}}</td>
|
||||
<td>{{.ContextHTML}}</td>
|
||||
<td>{{.Duration}}</td>
|
||||
|
@ -34,56 +39,55 @@ var (
|
|||
`))
|
||||
)
|
||||
|
||||
func (rqsc *realQueryServiceControl) registerStreamQueryzHandlers() {
|
||||
streamqueryzHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
|
||||
acl.SendError(w, err)
|
||||
func streamQueryzHandler(queryList *QueryList, w http.ResponseWriter, r *http.Request) {
|
||||
if err := acl.CheckAccessHTTP(r, acl.DEBUGGING); err != nil {
|
||||
acl.SendError(w, err)
|
||||
return
|
||||
}
|
||||
rows := queryList.GetQueryzRows()
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
format := r.FormValue("format")
|
||||
if format == "json" {
|
||||
js, err := json.Marshal(rows)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
rows := rqsc.sqlQueryRPCService.qe.streamQList.GetQueryzRows()
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
format := r.FormValue("format")
|
||||
if format == "json" {
|
||||
js, err := json.Marshal(rows)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(js)
|
||||
return
|
||||
}
|
||||
startHTMLTable(w)
|
||||
defer endHTMLTable(w)
|
||||
w.Write(streamqueryzHeader)
|
||||
for i := range rows {
|
||||
streamqueryzTmpl.Execute(w, rows[i])
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(js)
|
||||
return
|
||||
}
|
||||
startHTMLTable(w)
|
||||
defer endHTMLTable(w)
|
||||
w.Write(streamqueryzHeader)
|
||||
for i := range rows {
|
||||
if err := streamqueryzTmpl.Execute(w, rows[i]); err != nil {
|
||||
log.Errorf("streamlogz: couldn't execute template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
http.HandleFunc("/streamqueryz", streamqueryzHandler)
|
||||
http.HandleFunc("/streamqueryz/terminate", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := acl.CheckAccessHTTP(r, acl.ADMIN); err != nil {
|
||||
acl.SendError(w, err)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
connID := r.FormValue("connID")
|
||||
c, err := strconv.Atoi(connID)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid connID", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err = rqsc.sqlQueryRPCService.qe.streamQList.Terminate(int64(c)); err != nil {
|
||||
http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
streamqueryzHandler(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func streamQueryzTerminateHandler(queryList *QueryList, w http.ResponseWriter, r *http.Request) {
|
||||
if err := acl.CheckAccessHTTP(r, acl.ADMIN); err != nil {
|
||||
acl.SendError(w, err)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, fmt.Sprintf("cannot parse form: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
connID := r.FormValue("connID")
|
||||
c, err := strconv.Atoi(connID)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid connID", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err = queryList.Terminate(int64(c)); err != nil {
|
||||
http.Error(w, fmt.Sprintf("error: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
streamQueryzHandler(queryList, w, r)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tabletserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestStreamQueryzHandlerJSON(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/streamqueryz?format=json", nil)
|
||||
|
||||
queryList := NewQueryList()
|
||||
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
|
||||
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
|
||||
|
||||
streamQueryzHandler(queryList, resp, req)
|
||||
}
|
||||
|
||||
func TestStreamQueryzHandlerHTTP(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/streamqueryz", nil)
|
||||
|
||||
queryList := NewQueryList()
|
||||
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 1}))
|
||||
queryList.Add(NewQueryDetail(context.Background(), &testConn{id: 2}))
|
||||
|
||||
streamQueryzHandler(queryList, resp, req)
|
||||
}
|
||||
|
||||
func TestStreamQueryzHandlerHTTPFailedInvalidForm(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/streamqueryz", nil)
|
||||
|
||||
streamQueryzHandler(NewQueryList(), resp, req)
|
||||
if resp.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("http call should fail and return code: %d, but got: %d",
|
||||
http.StatusInternalServerError, resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamQueryzHandlerTerminateConn(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/streamqueryz/terminate?connID=1", nil)
|
||||
|
||||
queryList := NewQueryList()
|
||||
testConn := &testConn{id: 1}
|
||||
queryList.Add(NewQueryDetail(context.Background(), testConn))
|
||||
if testConn.IsKilled() {
|
||||
t.Fatalf("conn should still be alive")
|
||||
}
|
||||
streamQueryzTerminateHandler(queryList, resp, req)
|
||||
if !testConn.IsKilled() {
|
||||
t.Fatalf("conn should be killed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamQueryzHandlerTerminateFailedInvalidConnID(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/streamqueryz/terminate?connID=invalid", nil)
|
||||
|
||||
streamQueryzTerminateHandler(NewQueryList(), resp, req)
|
||||
if resp.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("http call should fail and return code: %d, but got: %d",
|
||||
http.StatusInternalServerError, resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamQueryzHandlerTerminateFailedKnownConnID(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/streamqueryz/terminate?connID=10", nil)
|
||||
|
||||
streamQueryzTerminateHandler(NewQueryList(), resp, req)
|
||||
if resp.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("http call should fail and return code: %d, but got: %d",
|
||||
http.StatusInternalServerError, resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamQueryzHandlerTerminateFailedInvalidForm(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/streamqueryz/terminate?inva+lid=2", nil)
|
||||
|
||||
streamQueryzTerminateHandler(NewQueryList(), resp, req)
|
||||
if resp.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("http call should fail and return code: %d, but got: %d",
|
||||
http.StatusInternalServerError, resp.Code)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tabletserver
|
||||
|
||||
import "html/template"
|
||||
|
||||
type fakeCallInfo struct {
|
||||
remoteAddr string
|
||||
username string
|
||||
text string
|
||||
html string
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) RemoteAddr() string {
|
||||
return fci.remoteAddr
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) Username() string {
|
||||
return fci.username
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) Text() string {
|
||||
return fci.text
|
||||
}
|
||||
|
||||
func (fci *fakeCallInfo) HTML() template.HTML {
|
||||
return template.HTML(fci.html)
|
||||
}
|
|
@ -122,7 +122,9 @@ func txlogzHandler(w http.ResponseWriter, req *http.Request) {
|
|||
Duration float64
|
||||
ColorLevel string
|
||||
}{txc, duration, level}
|
||||
txlogzTmpl.Execute(w, tmplData)
|
||||
if err := txlogzTmpl.Execute(w, tmplData); err != nil {
|
||||
log.Errorf("txlogz: couldn't execute template: %v", err)
|
||||
}
|
||||
case <-tmr.C:
|
||||
return
|
||||
}
|
||||
|
|
|
@ -46,6 +46,7 @@ func (batchQueryShard *BatchQueryShard) MarshalBson(buf *bytes2.ChunkedWriter, k
|
|||
} else {
|
||||
(*batchQueryShard.Session).MarshalBson(buf, "Session")
|
||||
}
|
||||
bson.EncodeBool(buf, "NotInTransaction", batchQueryShard.NotInTransaction)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
@ -104,6 +105,8 @@ func (batchQueryShard *BatchQueryShard) UnmarshalBson(buf *bytes.Buffer, kind by
|
|||
batchQueryShard.Session = new(Session)
|
||||
(*batchQueryShard.Session).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "NotInTransaction":
|
||||
batchQueryShard.NotInTransaction = bson.DecodeBool(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ func (entityIdsQuery *EntityIdsQuery) MarshalBson(buf *bytes2.ChunkedWriter, key
|
|||
} else {
|
||||
(*entityIdsQuery.Session).MarshalBson(buf, "Session")
|
||||
}
|
||||
bson.EncodeBool(buf, "NotInTransaction", entityIdsQuery.NotInTransaction)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
@ -109,6 +110,8 @@ func (entityIdsQuery *EntityIdsQuery) UnmarshalBson(buf *bytes.Buffer, kind byte
|
|||
entityIdsQuery.Session = new(Session)
|
||||
(*entityIdsQuery.Session).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "NotInTransaction":
|
||||
entityIdsQuery.NotInTransaction = bson.DecodeBool(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
kproto "github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
)
|
||||
|
||||
// MarshalBson bson-encodes KeyRangeQuery.
|
||||
|
@ -31,7 +31,7 @@ func (keyRangeQuery *KeyRangeQuery) MarshalBson(buf *bytes2.ChunkedWriter, key s
|
|||
lenWriter.Close()
|
||||
}
|
||||
bson.EncodeString(buf, "Keyspace", keyRangeQuery.Keyspace)
|
||||
// []kproto.KeyRange
|
||||
// []key.KeyRange
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "KeyRanges")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
@ -47,6 +47,7 @@ func (keyRangeQuery *KeyRangeQuery) MarshalBson(buf *bytes2.ChunkedWriter, key s
|
|||
} else {
|
||||
(*keyRangeQuery.Session).MarshalBson(buf, "Session")
|
||||
}
|
||||
bson.EncodeBool(buf, "NotInTransaction", keyRangeQuery.NotInTransaction)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
@ -85,16 +86,16 @@ func (keyRangeQuery *KeyRangeQuery) UnmarshalBson(buf *bytes.Buffer, kind byte)
|
|||
case "Keyspace":
|
||||
keyRangeQuery.Keyspace = bson.DecodeString(buf, kind)
|
||||
case "KeyRanges":
|
||||
// []kproto.KeyRange
|
||||
// []key.KeyRange
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for keyRangeQuery.KeyRanges", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
keyRangeQuery.KeyRanges = make([]kproto.KeyRange, 0, 8)
|
||||
keyRangeQuery.KeyRanges = make([]key.KeyRange, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v2 kproto.KeyRange
|
||||
var _v2 key.KeyRange
|
||||
_v2.UnmarshalBson(buf, kind)
|
||||
keyRangeQuery.KeyRanges = append(keyRangeQuery.KeyRanges, _v2)
|
||||
}
|
||||
|
@ -107,6 +108,8 @@ func (keyRangeQuery *KeyRangeQuery) UnmarshalBson(buf *bytes.Buffer, kind byte)
|
|||
keyRangeQuery.Session = new(Session)
|
||||
(*keyRangeQuery.Session).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "NotInTransaction":
|
||||
keyRangeQuery.NotInTransaction = bson.DecodeBool(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
kproto "github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
|
||||
)
|
||||
|
||||
|
@ -31,7 +31,7 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) MarshalBson(buf *bytes2.Chunke
|
|||
lenWriter.Close()
|
||||
}
|
||||
bson.EncodeString(buf, "Keyspace", keyspaceIdBatchQuery.Keyspace)
|
||||
// []kproto.KeyspaceId
|
||||
// []key.KeyspaceId
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "KeyspaceIds")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
@ -47,6 +47,7 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) MarshalBson(buf *bytes2.Chunke
|
|||
} else {
|
||||
(*keyspaceIdBatchQuery.Session).MarshalBson(buf, "Session")
|
||||
}
|
||||
bson.EncodeBool(buf, "NotInTransaction", keyspaceIdBatchQuery.NotInTransaction)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
@ -83,16 +84,16 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) UnmarshalBson(buf *bytes.Buffe
|
|||
case "Keyspace":
|
||||
keyspaceIdBatchQuery.Keyspace = bson.DecodeString(buf, kind)
|
||||
case "KeyspaceIds":
|
||||
// []kproto.KeyspaceId
|
||||
// []key.KeyspaceId
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for keyspaceIdBatchQuery.KeyspaceIds", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
keyspaceIdBatchQuery.KeyspaceIds = make([]kproto.KeyspaceId, 0, 8)
|
||||
keyspaceIdBatchQuery.KeyspaceIds = make([]key.KeyspaceId, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v2 kproto.KeyspaceId
|
||||
var _v2 key.KeyspaceId
|
||||
_v2.UnmarshalBson(buf, kind)
|
||||
keyspaceIdBatchQuery.KeyspaceIds = append(keyspaceIdBatchQuery.KeyspaceIds, _v2)
|
||||
}
|
||||
|
@ -105,6 +106,8 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) UnmarshalBson(buf *bytes.Buffe
|
|||
keyspaceIdBatchQuery.Session = new(Session)
|
||||
(*keyspaceIdBatchQuery.Session).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "NotInTransaction":
|
||||
keyspaceIdBatchQuery.NotInTransaction = bson.DecodeBool(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
kproto "github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
)
|
||||
|
||||
// MarshalBson bson-encodes KeyspaceIdQuery.
|
||||
|
@ -31,7 +31,7 @@ func (keyspaceIdQuery *KeyspaceIdQuery) MarshalBson(buf *bytes2.ChunkedWriter, k
|
|||
lenWriter.Close()
|
||||
}
|
||||
bson.EncodeString(buf, "Keyspace", keyspaceIdQuery.Keyspace)
|
||||
// []kproto.KeyspaceId
|
||||
// []key.KeyspaceId
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "KeyspaceIds")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
@ -47,6 +47,7 @@ func (keyspaceIdQuery *KeyspaceIdQuery) MarshalBson(buf *bytes2.ChunkedWriter, k
|
|||
} else {
|
||||
(*keyspaceIdQuery.Session).MarshalBson(buf, "Session")
|
||||
}
|
||||
bson.EncodeBool(buf, "NotInTransaction", keyspaceIdQuery.NotInTransaction)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
@ -85,16 +86,16 @@ func (keyspaceIdQuery *KeyspaceIdQuery) UnmarshalBson(buf *bytes.Buffer, kind by
|
|||
case "Keyspace":
|
||||
keyspaceIdQuery.Keyspace = bson.DecodeString(buf, kind)
|
||||
case "KeyspaceIds":
|
||||
// []kproto.KeyspaceId
|
||||
// []key.KeyspaceId
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for keyspaceIdQuery.KeyspaceIds", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
keyspaceIdQuery.KeyspaceIds = make([]kproto.KeyspaceId, 0, 8)
|
||||
keyspaceIdQuery.KeyspaceIds = make([]key.KeyspaceId, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v2 kproto.KeyspaceId
|
||||
var _v2 key.KeyspaceId
|
||||
_v2.UnmarshalBson(buf, kind)
|
||||
keyspaceIdQuery.KeyspaceIds = append(keyspaceIdQuery.KeyspaceIds, _v2)
|
||||
}
|
||||
|
@ -107,6 +108,8 @@ func (keyspaceIdQuery *KeyspaceIdQuery) UnmarshalBson(buf *bytes.Buffer, kind by
|
|||
keyspaceIdQuery.Session = new(Session)
|
||||
(*keyspaceIdQuery.Session).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "NotInTransaction":
|
||||
keyspaceIdQuery.NotInTransaction = bson.DecodeBool(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ func (query *Query) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
|||
} else {
|
||||
(*query.Session).MarshalBson(buf, "Session")
|
||||
}
|
||||
bson.EncodeBool(buf, "NotInTransaction", query.NotInTransaction)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
@ -79,6 +80,8 @@ func (query *Query) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
|||
query.Session = new(Session)
|
||||
(*query.Session).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "NotInTransaction":
|
||||
query.NotInTransaction = bson.DecodeBool(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
|
|
|
@ -46,6 +46,7 @@ func (queryShard *QueryShard) MarshalBson(buf *bytes2.ChunkedWriter, key string)
|
|||
} else {
|
||||
(*queryShard.Session).MarshalBson(buf, "Session")
|
||||
}
|
||||
bson.EncodeBool(buf, "NotInTransaction", queryShard.NotInTransaction)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
@ -106,6 +107,8 @@ func (queryShard *QueryShard) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
|||
queryShard.Session = new(Session)
|
||||
(*queryShard.Session).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "NotInTransaction":
|
||||
queryShard.NotInTransaction = bson.DecodeBool(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
|
|
|
@ -43,10 +43,11 @@ func (shardSession *ShardSession) String() string {
|
|||
|
||||
// Query represents a keyspace agnostic query request.
|
||||
type Query struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type Query -o query_bson.go
|
||||
|
@ -54,12 +55,13 @@ type Query struct {
|
|||
// QueryShard represents a query request for the
|
||||
// specified list of shards.
|
||||
type QueryShard struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type QueryShard -o query_shard_bson.go
|
||||
|
@ -67,12 +69,13 @@ type QueryShard struct {
|
|||
// KeyspaceIdQuery represents a query request for the
|
||||
// specified list of keyspace IDs.
|
||||
type KeyspaceIdQuery struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyspaceIds []key.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyspaceIds []key.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type KeyspaceIdQuery -o keyspace_id_query_bson.go
|
||||
|
@ -80,12 +83,13 @@ type KeyspaceIdQuery struct {
|
|||
// KeyRangeQuery represents a query request for the
|
||||
// specified list of keyranges.
|
||||
type KeyRangeQuery struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyRanges []key.KeyRange
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyRanges []key.KeyRange
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type KeyRangeQuery -o key_range_query_bson.go
|
||||
|
@ -107,6 +111,7 @@ type EntityIdsQuery struct {
|
|||
EntityKeyspaceIDs []EntityId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type EntityIdsQuery -o entity_ids_query_bson.go
|
||||
|
@ -123,11 +128,12 @@ type QueryResult struct {
|
|||
// BatchQueryShard represents a batch query request
|
||||
// for the specified shards.
|
||||
type BatchQueryShard struct {
|
||||
Queries []tproto.BoundQuery
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Queries []tproto.BoundQuery
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type BatchQueryShard -o batch_query_shard_bson.go
|
||||
|
@ -135,11 +141,12 @@ type BatchQueryShard struct {
|
|||
// KeyspaceIdBatchQuery represents a batch query request
|
||||
// for the specified keyspace IDs.
|
||||
type KeyspaceIdBatchQuery struct {
|
||||
Queries []tproto.BoundQuery
|
||||
Keyspace string
|
||||
KeyspaceIds []key.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Queries []tproto.BoundQuery
|
||||
Keyspace string
|
||||
KeyspaceIds []key.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type KeyspaceIdBatchQuery -o keyspace_id_batch_query_bson.go
|
||||
|
|
|
@ -92,22 +92,24 @@ func TestSession(t *testing.T) {
|
|||
}
|
||||
|
||||
type reflectQueryShard struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
type extraQueryShard struct {
|
||||
Extra int
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Extra int
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
func TestQueryShard(t *testing.T) {
|
||||
|
@ -231,20 +233,22 @@ type reflectBoundQuery struct {
|
|||
}
|
||||
|
||||
type reflectBatchQueryShard struct {
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
type extraBatchQueryShard struct {
|
||||
Extra int
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Extra int
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
func TestBatchQueryShard(t *testing.T) {
|
||||
|
@ -312,11 +316,12 @@ func TestBatchQueryShard(t *testing.T) {
|
|||
}
|
||||
|
||||
type badTypeBatchQueryShard struct {
|
||||
Queries string
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Queries string
|
||||
Keyspace string
|
||||
Shards []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
func TestBatchQueryShardBadType(t *testing.T) {
|
||||
|
@ -404,22 +409,24 @@ func TestQueryResultList(t *testing.T) {
|
|||
}
|
||||
|
||||
type reflectKeyspaceIdQuery struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyspaceIds kproto.KeyspaceIdArray
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyspaceIds kproto.KeyspaceIdArray
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
type extraKeyspaceIdQuery struct {
|
||||
Extra int
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyspaceIds []kproto.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Extra int
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyspaceIds []kproto.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
func TestKeyspaceIdQuery(t *testing.T) {
|
||||
|
@ -474,22 +481,24 @@ func TestKeyspaceIdQuery(t *testing.T) {
|
|||
}
|
||||
|
||||
type reflectKeyRangeQuery struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyRanges kproto.KeyRangeArray
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyRanges kproto.KeyRangeArray
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
type extraKeyRangeQuery struct {
|
||||
Extra int
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyRanges []kproto.KeyRange
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Extra int
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
Keyspace string
|
||||
KeyRanges []kproto.KeyRange
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
func TestKeyRangeQuery(t *testing.T) {
|
||||
|
@ -544,20 +553,22 @@ func TestKeyRangeQuery(t *testing.T) {
|
|||
}
|
||||
|
||||
type reflectKeyspaceIdBatchQuery struct {
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
KeyspaceIds []kproto.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
KeyspaceIds []kproto.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
type extraKeyspaceIdBatchQuery struct {
|
||||
Extra int
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
KeyspaceIds []kproto.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Extra int
|
||||
Queries []reflectBoundQuery
|
||||
Keyspace string
|
||||
KeyspaceIds []kproto.KeyspaceId
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
func TestKeyspaceIdBatchQuery(t *testing.T) {
|
||||
|
@ -625,11 +636,12 @@ func TestKeyspaceIdBatchQuery(t *testing.T) {
|
|||
}
|
||||
|
||||
type badTypeKeyspaceIdsBatchQuery struct {
|
||||
Queries string
|
||||
Keyspace string
|
||||
KeyspaceIds []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
Queries string
|
||||
Keyspace string
|
||||
KeyspaceIds []string
|
||||
TabletType topo.TabletType
|
||||
Session *Session
|
||||
NotInTransaction bool
|
||||
}
|
||||
|
||||
func TestKeyspaceIdsBatchQueryBadType(t *testing.T) {
|
||||
|
|
|
@ -73,7 +73,7 @@ func (res *Resolver) ExecuteKeyspaceIds(ctx context.Context, query *proto.Keyspa
|
|||
query.TabletType,
|
||||
query.KeyspaceIds)
|
||||
}
|
||||
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards)
|
||||
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, query.NotInTransaction)
|
||||
}
|
||||
|
||||
// ExecuteKeyRanges executes a non-streaming query based on KeyRanges.
|
||||
|
@ -88,7 +88,7 @@ func (res *Resolver) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRange
|
|||
query.TabletType,
|
||||
query.KeyRanges)
|
||||
}
|
||||
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards)
|
||||
return res.Execute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, query.NotInTransaction)
|
||||
}
|
||||
|
||||
// Execute executes a non-streaming query based on shards resolved by given func.
|
||||
|
@ -101,6 +101,7 @@ func (res *Resolver) Execute(
|
|||
tabletType topo.TabletType,
|
||||
session *proto.Session,
|
||||
mapToShards func(string) (string, []string, error),
|
||||
notInTransaction bool,
|
||||
) (*mproto.QueryResult, error) {
|
||||
keyspace, shards, err := mapToShards(keyspace)
|
||||
if err != nil {
|
||||
|
@ -114,7 +115,8 @@ func (res *Resolver) Execute(
|
|||
keyspace,
|
||||
shards,
|
||||
tabletType,
|
||||
NewSafeSession(session))
|
||||
NewSafeSession(session),
|
||||
notInTransaction)
|
||||
if connError, ok := err.(*ShardConnError); ok && connError.Code == tabletconn.ERR_RETRY {
|
||||
resharding := false
|
||||
newKeyspace, newShards, err := mapToShards(keyspace)
|
||||
|
@ -169,7 +171,8 @@ func (res *Resolver) ExecuteEntityIds(
|
|||
bindVars,
|
||||
query.Keyspace,
|
||||
query.TabletType,
|
||||
NewSafeSession(query.Session))
|
||||
NewSafeSession(query.Session),
|
||||
query.NotInTransaction)
|
||||
if connError, ok := err.(*ShardConnError); ok && connError.Code == tabletconn.ERR_RETRY {
|
||||
resharding := false
|
||||
newKeyspace, newShardIDMap, err := mapEntityIdsToShards(
|
||||
|
@ -219,7 +222,7 @@ func (res *Resolver) ExecuteBatchKeyspaceIds(ctx context.Context, query *proto.K
|
|||
query.TabletType,
|
||||
query.KeyspaceIds)
|
||||
}
|
||||
return res.ExecuteBatch(ctx, query.Queries, query.Keyspace, query.TabletType, query.Session, mapToShards)
|
||||
return res.ExecuteBatch(ctx, query.Queries, query.Keyspace, query.TabletType, query.Session, mapToShards, query.NotInTransaction)
|
||||
}
|
||||
|
||||
// ExecuteBatch executes a group of queries based on shards resolved by given func.
|
||||
|
@ -231,6 +234,7 @@ func (res *Resolver) ExecuteBatch(
|
|||
tabletType topo.TabletType,
|
||||
session *proto.Session,
|
||||
mapToShards func(string) (string, []string, error),
|
||||
notInTransaction bool,
|
||||
) (*tproto.QueryResultList, error) {
|
||||
keyspace, shards, err := mapToShards(keyspace)
|
||||
if err != nil {
|
||||
|
@ -243,7 +247,8 @@ func (res *Resolver) ExecuteBatch(
|
|||
keyspace,
|
||||
shards,
|
||||
tabletType,
|
||||
NewSafeSession(session))
|
||||
NewSafeSession(session),
|
||||
notInTransaction)
|
||||
if connError, ok := err.(*ShardConnError); ok && connError.Code == tabletconn.ERR_RETRY {
|
||||
resharding := false
|
||||
newKeyspace, newShards, err := mapToShards(keyspace)
|
||||
|
@ -288,7 +293,7 @@ func (res *Resolver) StreamExecuteKeyspaceIds(ctx context.Context, query *proto.
|
|||
query.TabletType,
|
||||
query.KeyspaceIds)
|
||||
}
|
||||
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply)
|
||||
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply, query.NotInTransaction)
|
||||
}
|
||||
|
||||
// StreamExecuteKeyRanges executes a streaming query on the specified KeyRanges.
|
||||
|
@ -307,7 +312,7 @@ func (res *Resolver) StreamExecuteKeyRanges(ctx context.Context, query *proto.Ke
|
|||
query.TabletType,
|
||||
query.KeyRanges)
|
||||
}
|
||||
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply)
|
||||
return res.StreamExecute(ctx, query.Sql, query.BindVariables, query.Keyspace, query.TabletType, query.Session, mapToShards, sendReply, query.NotInTransaction)
|
||||
}
|
||||
|
||||
// StreamExecute executes a streaming query on shards resolved by given func.
|
||||
|
@ -323,6 +328,7 @@ func (res *Resolver) StreamExecute(
|
|||
session *proto.Session,
|
||||
mapToShards func(string) (string, []string, error),
|
||||
sendReply func(*mproto.QueryResult) error,
|
||||
notInTransaction bool,
|
||||
) error {
|
||||
keyspace, shards, err := mapToShards(keyspace)
|
||||
if err != nil {
|
||||
|
@ -336,7 +342,8 @@ func (res *Resolver) StreamExecute(
|
|||
shards,
|
||||
tabletType,
|
||||
NewSafeSession(session),
|
||||
sendReply)
|
||||
sendReply,
|
||||
notInTransaction)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -102,6 +102,7 @@ func (rtr *Router) Execute(ctx context.Context, query *proto.Query) (*mproto.Que
|
|||
params.shardVars,
|
||||
query.TabletType,
|
||||
NewSafeSession(vcursor.query.Session),
|
||||
query.NotInTransaction,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -140,6 +141,7 @@ func (rtr *Router) StreamExecute(ctx context.Context, query *proto.Query, sendRe
|
|||
query.TabletType,
|
||||
NewSafeSession(vcursor.query.Session),
|
||||
sendReply,
|
||||
query.NotInTransaction,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -250,7 +252,8 @@ func (rtr *Router) execUpdateEqual(vcursor *requestContext, plan *planbuilder.Pl
|
|||
ks,
|
||||
[]string{shard},
|
||||
vcursor.query.TabletType,
|
||||
NewSafeSession(vcursor.query.Session))
|
||||
NewSafeSession(vcursor.query.Session),
|
||||
vcursor.query.NotInTransaction)
|
||||
}
|
||||
|
||||
func (rtr *Router) execDeleteEqual(vcursor *requestContext, plan *planbuilder.Plan) (*mproto.QueryResult, error) {
|
||||
|
@ -280,7 +283,8 @@ func (rtr *Router) execDeleteEqual(vcursor *requestContext, plan *planbuilder.Pl
|
|||
ks,
|
||||
[]string{shard},
|
||||
vcursor.query.TabletType,
|
||||
NewSafeSession(vcursor.query.Session))
|
||||
NewSafeSession(vcursor.query.Session),
|
||||
vcursor.query.NotInTransaction)
|
||||
}
|
||||
|
||||
func (rtr *Router) execInsertSharded(vcursor *requestContext, plan *planbuilder.Plan) (*mproto.QueryResult, error) {
|
||||
|
@ -318,7 +322,8 @@ func (rtr *Router) execInsertSharded(vcursor *requestContext, plan *planbuilder.
|
|||
ks,
|
||||
[]string{shard},
|
||||
vcursor.query.TabletType,
|
||||
NewSafeSession(vcursor.query.Session))
|
||||
NewSafeSession(vcursor.query.Session),
|
||||
vcursor.query.NotInTransaction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execInsertSharded: %v", err)
|
||||
}
|
||||
|
@ -421,7 +426,8 @@ func (rtr *Router) deleteVindexEntries(vcursor *requestContext, plan *planbuilde
|
|||
ks,
|
||||
[]string{shard},
|
||||
vcursor.query.TabletType,
|
||||
NewSafeSession(vcursor.query.Session))
|
||||
NewSafeSession(vcursor.query.Session),
|
||||
vcursor.query.NotInTransaction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -127,6 +127,7 @@ func (stc *ScatterConn) Execute(
|
|||
shards []string,
|
||||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
notInTransaction bool,
|
||||
) (*mproto.QueryResult, error) {
|
||||
results, allErrors := stc.multiGo(
|
||||
context,
|
||||
|
@ -135,6 +136,7 @@ func (stc *ScatterConn) Execute(
|
|||
shards,
|
||||
tabletType,
|
||||
session,
|
||||
notInTransaction,
|
||||
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
|
||||
innerqr, err := sdc.Execute(context, query, bindVars, transactionId)
|
||||
if err != nil {
|
||||
|
@ -165,6 +167,7 @@ func (stc *ScatterConn) ExecuteMulti(
|
|||
shardVars map[string]map[string]interface{},
|
||||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
notInTransaction bool,
|
||||
) (*mproto.QueryResult, error) {
|
||||
results, allErrors := stc.multiGo(
|
||||
context,
|
||||
|
@ -173,6 +176,7 @@ func (stc *ScatterConn) ExecuteMulti(
|
|||
getShards(shardVars),
|
||||
tabletType,
|
||||
session,
|
||||
notInTransaction,
|
||||
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
|
||||
innerqr, err := sdc.Execute(context, query, shardVars[sdc.shard], transactionId)
|
||||
if err != nil {
|
||||
|
@ -202,6 +206,7 @@ func (stc *ScatterConn) ExecuteEntityIds(
|
|||
keyspace string,
|
||||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
notInTransaction bool,
|
||||
) (*mproto.QueryResult, error) {
|
||||
results, allErrors := stc.multiGo(
|
||||
context,
|
||||
|
@ -210,6 +215,7 @@ func (stc *ScatterConn) ExecuteEntityIds(
|
|||
shards,
|
||||
tabletType,
|
||||
session,
|
||||
notInTransaction,
|
||||
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
|
||||
shard := sdc.shard
|
||||
sql := sqls[shard]
|
||||
|
@ -241,6 +247,7 @@ func (stc *ScatterConn) ExecuteBatch(
|
|||
shards []string,
|
||||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
notInTransaction bool,
|
||||
) (qrs *tproto.QueryResultList, err error) {
|
||||
results, allErrors := stc.multiGo(
|
||||
context,
|
||||
|
@ -249,6 +256,7 @@ func (stc *ScatterConn) ExecuteBatch(
|
|||
shards,
|
||||
tabletType,
|
||||
session,
|
||||
notInTransaction,
|
||||
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
|
||||
innerqrs, err := sdc.ExecuteBatch(context, queries, transactionId)
|
||||
if err != nil {
|
||||
|
@ -282,6 +290,7 @@ func (stc *ScatterConn) StreamExecute(
|
|||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
sendReply func(reply *mproto.QueryResult) error,
|
||||
notInTransaction bool,
|
||||
) error {
|
||||
results, allErrors := stc.multiGo(
|
||||
context,
|
||||
|
@ -290,6 +299,7 @@ func (stc *ScatterConn) StreamExecute(
|
|||
shards,
|
||||
tabletType,
|
||||
session,
|
||||
notInTransaction,
|
||||
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
|
||||
sr, errFunc := sdc.StreamExecute(context, query, bindVars, transactionId)
|
||||
if sr != nil {
|
||||
|
@ -333,6 +343,7 @@ func (stc *ScatterConn) StreamExecuteMulti(
|
|||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
sendReply func(reply *mproto.QueryResult) error,
|
||||
notInTransaction bool,
|
||||
) error {
|
||||
results, allErrors := stc.multiGo(
|
||||
context,
|
||||
|
@ -341,6 +352,7 @@ func (stc *ScatterConn) StreamExecuteMulti(
|
|||
getShards(shardVars),
|
||||
tabletType,
|
||||
session,
|
||||
notInTransaction,
|
||||
func(sdc *ShardConn, transactionId int64, sResults chan<- interface{}) error {
|
||||
sr, errFunc := sdc.StreamExecute(context, query, shardVars[sdc.shard], transactionId)
|
||||
if sr != nil {
|
||||
|
@ -446,7 +458,7 @@ func (stc *ScatterConn) SplitQuery(ctx context.Context, query tproto.BoundQuery,
|
|||
for shard := range keyRangeByShard {
|
||||
shards = append(shards, shard)
|
||||
}
|
||||
allSplits, allErrors := stc.multiGo(ctx, "SplitQuery", keyspace, shards, topo.TYPE_RDONLY, NewSafeSession(&proto.Session{}), actionFunc)
|
||||
allSplits, allErrors := stc.multiGo(ctx, "SplitQuery", keyspace, shards, topo.TYPE_RDONLY, NewSafeSession(&proto.Session{}), false, actionFunc)
|
||||
splits := []proto.SplitQueryPart{}
|
||||
for s := range allSplits {
|
||||
splits = append(splits, s.([]proto.SplitQueryPart)...)
|
||||
|
@ -512,6 +524,7 @@ func (stc *ScatterConn) multiGo(
|
|||
shards []string,
|
||||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
notInTransaction bool,
|
||||
action shardActionFunc,
|
||||
) (rResults <-chan interface{}, allErrors *concurrency.AllErrorRecorder) {
|
||||
allErrors = new(concurrency.AllErrorRecorder)
|
||||
|
@ -526,7 +539,7 @@ func (stc *ScatterConn) multiGo(
|
|||
defer stc.timings.Record(statsKey, startTime)
|
||||
|
||||
sdc := stc.getConnection(context, keyspace, shard, tabletType)
|
||||
transactionID, err := stc.updateSession(context, sdc, keyspace, shard, tabletType, session)
|
||||
transactionID, err := stc.updateSession(context, sdc, keyspace, shard, tabletType, session, notInTransaction)
|
||||
if err != nil {
|
||||
allErrors.RecordError(err)
|
||||
stc.tabletCallErrorCount.Add(statsKey, 1)
|
||||
|
@ -577,6 +590,7 @@ func (stc *ScatterConn) updateSession(
|
|||
keyspace, shard string,
|
||||
tabletType topo.TabletType,
|
||||
session *SafeSession,
|
||||
notInTransaction bool,
|
||||
) (transactionID int64, err error) {
|
||||
if !session.InTransaction() {
|
||||
return 0, nil
|
||||
|
@ -589,6 +603,12 @@ func (stc *ScatterConn) updateSession(
|
|||
if transactionID != 0 {
|
||||
return transactionID, nil
|
||||
}
|
||||
// We are in a transaction at higher level,
|
||||
// but client requires not to start a transaction for this query.
|
||||
// If a transaction was started on this conn, we will use it (as above).
|
||||
if notInTransaction {
|
||||
return 0, nil
|
||||
}
|
||||
transactionID, err = sdc.Begin(context)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
func TestScatterConnExecute(t *testing.T) {
|
||||
testScatterConnGeneric(t, "TestScatterConnExecute", func(shards []string) (*mproto.QueryResult, error) {
|
||||
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
return stc.Execute(context.Background(), "query", nil, "TestScatterConnExecute", shards, "", nil)
|
||||
return stc.Execute(context.Background(), "query", nil, "TestScatterConnExecute", shards, "", nil, false)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,7 @@ func TestScatterConnExecuteMulti(t *testing.T) {
|
|||
for _, shard := range shards {
|
||||
shardVars[shard] = nil
|
||||
}
|
||||
return stc.ExecuteMulti(context.Background(), "query", "TestScatterConnExecute", shardVars, "", nil)
|
||||
return stc.ExecuteMulti(context.Background(), "query", "TestScatterConnExecute", shardVars, "", nil, false)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ func TestScatterConnExecuteBatch(t *testing.T) {
|
|||
testScatterConnGeneric(t, "TestScatterConnExecuteBatch", func(shards []string) (*mproto.QueryResult, error) {
|
||||
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
queries := []tproto.BoundQuery{{"query", nil}}
|
||||
qrs, err := stc.ExecuteBatch(context.Background(), queries, "TestScatterConnExecuteBatch", shards, "", nil)
|
||||
qrs, err := stc.ExecuteBatch(context.Background(), queries, "TestScatterConnExecuteBatch", shards, "", nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ func TestScatterConnStreamExecute(t *testing.T) {
|
|||
err := stc.StreamExecute(context.Background(), "query", nil, "TestScatterConnStreamExecute", shards, "", nil, func(r *mproto.QueryResult) error {
|
||||
appendResult(qr, r)
|
||||
return nil
|
||||
})
|
||||
}, false)
|
||||
return qr, err
|
||||
})
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func TestScatterConnStreamExecuteMulti(t *testing.T) {
|
|||
err := stc.StreamExecuteMulti(context.Background(), "query", "TestScatterConnStreamExecute", shardVars, "", nil, func(r *mproto.QueryResult) error {
|
||||
appendResult(qr, r)
|
||||
return nil
|
||||
})
|
||||
}, false)
|
||||
return qr, err
|
||||
})
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ func TestMultiExecs(t *testing.T) {
|
|||
"bv1": 1,
|
||||
},
|
||||
}
|
||||
_, _ = stc.ExecuteMulti(context.Background(), "query", "TestMultiExecs", shardVars, "", nil)
|
||||
_, _ = stc.ExecuteMulti(context.Background(), "query", "TestMultiExecs", shardVars, "", nil, false)
|
||||
if !reflect.DeepEqual(sbc0.Queries[0].BindVariables, shardVars["0"]) {
|
||||
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, shardVars["0"])
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ func TestMultiExecs(t *testing.T) {
|
|||
sbc1.Queries = nil
|
||||
_ = stc.StreamExecuteMulti(context.Background(), "query", "TestMultiExecs", shardVars, "", nil, func(*mproto.QueryResult) error {
|
||||
return nil
|
||||
})
|
||||
}, false)
|
||||
if !reflect.DeepEqual(sbc0.Queries[0].BindVariables, shardVars["0"]) {
|
||||
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, shardVars["0"])
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ func TestScatterConnStreamExecuteSendError(t *testing.T) {
|
|||
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
err := stc.StreamExecute(context.Background(), "query", nil, "TestScatterConnStreamExecuteSendError", []string{"0"}, "", nil, func(*mproto.QueryResult) error {
|
||||
return fmt.Errorf("send error")
|
||||
})
|
||||
}, false)
|
||||
want := "send error"
|
||||
// Ensure that we handle send errors.
|
||||
if err == nil || err.Error() != want {
|
||||
|
@ -241,7 +241,7 @@ func TestScatterConnCommitSuccess(t *testing.T) {
|
|||
|
||||
// Sequence the executes to ensure commit order
|
||||
session := NewSafeSession(&proto.Session{InTransaction: true})
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0"}, "", session)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0"}, "", session, false)
|
||||
wantSession := proto.Session{
|
||||
InTransaction: true,
|
||||
ShardSessions: []*proto.ShardSession{{
|
||||
|
@ -254,7 +254,7 @@ func TestScatterConnCommitSuccess(t *testing.T) {
|
|||
if !reflect.DeepEqual(wantSession, *session.Session) {
|
||||
t.Errorf("want\n%+v, got\n%+v", wantSession, *session.Session)
|
||||
}
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0", "1"}, "", session)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnCommitSuccess", []string{"0", "1"}, "", session, false)
|
||||
wantSession = proto.Session{
|
||||
InTransaction: true,
|
||||
ShardSessions: []*proto.ShardSession{{
|
||||
|
@ -299,8 +299,8 @@ func TestScatterConnRollback(t *testing.T) {
|
|||
|
||||
// Sequence the executes to ensure commit order
|
||||
session := NewSafeSession(&proto.Session{InTransaction: true})
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0"}, "", session)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0", "1"}, "", session)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0"}, "", session, false)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnRollback", []string{"0", "1"}, "", session, false)
|
||||
err := stc.Rollback(context.Background(), session)
|
||||
if err != nil {
|
||||
t.Errorf("want nil, got %v", err)
|
||||
|
@ -322,7 +322,7 @@ func TestScatterConnClose(t *testing.T) {
|
|||
sbc := &sandboxConn{}
|
||||
s.MapTestConn("0", sbc)
|
||||
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnClose", []string{"0"}, "", nil)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnClose", []string{"0"}, "", nil, false)
|
||||
stc.Close()
|
||||
time.Sleep(1)
|
||||
if sbc.CloseCount != 1 {
|
||||
|
@ -330,6 +330,111 @@ func TestScatterConnClose(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestScatterConnQueryNotInTransaction(t *testing.T) {
|
||||
s := createSandbox("TestScatterConnQueryNotInTransaction")
|
||||
|
||||
// case 1: read query (not in transaction) followed by write query, not in the same shard.
|
||||
sbc0 := &sandboxConn{}
|
||||
s.MapTestConn("0", sbc0)
|
||||
sbc1 := &sandboxConn{}
|
||||
s.MapTestConn("1", sbc1)
|
||||
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
session := NewSafeSession(&proto.Session{InTransaction: true})
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0"}, "", session, true)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"1"}, "", session, false)
|
||||
|
||||
wantSession := proto.Session{
|
||||
InTransaction: true,
|
||||
ShardSessions: []*proto.ShardSession{{
|
||||
Keyspace: "TestScatterConnQueryNotInTransaction",
|
||||
Shard: "1",
|
||||
TabletType: "",
|
||||
TransactionId: 1,
|
||||
}},
|
||||
}
|
||||
if !reflect.DeepEqual(wantSession, *session.Session) {
|
||||
t.Errorf("want\n%+v\ngot\n%+v", wantSession, *session.Session)
|
||||
}
|
||||
stc.Commit(context.Background(), session)
|
||||
if sbc0.ExecCount != 1 || sbc1.ExecCount != 3 {
|
||||
t.Errorf("want 1/3, got %d/%d", sbc0.ExecCount, sbc1.ExecCount)
|
||||
}
|
||||
if sbc0.CommitCount != 0 {
|
||||
t.Errorf("want 0, got %d", sbc0.CommitCount)
|
||||
}
|
||||
if sbc1.CommitCount != 1 {
|
||||
t.Errorf("want 1, got %d", sbc1.CommitCount)
|
||||
}
|
||||
|
||||
// case 2: write query followed by read query (not in transaction), not in the same shard.
|
||||
s.Reset()
|
||||
sbc0 = &sandboxConn{}
|
||||
s.MapTestConn("0", sbc0)
|
||||
sbc1 = &sandboxConn{}
|
||||
s.MapTestConn("1", sbc1)
|
||||
stc = NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
session = NewSafeSession(&proto.Session{InTransaction: true})
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0"}, "", session, false)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"1"}, "", session, true)
|
||||
|
||||
wantSession = proto.Session{
|
||||
InTransaction: true,
|
||||
ShardSessions: []*proto.ShardSession{{
|
||||
Keyspace: "TestScatterConnQueryNotInTransaction",
|
||||
Shard: "0",
|
||||
TabletType: "",
|
||||
TransactionId: 1,
|
||||
}},
|
||||
}
|
||||
if !reflect.DeepEqual(wantSession, *session.Session) {
|
||||
t.Errorf("want\n%+v\ngot\n%+v", wantSession, *session.Session)
|
||||
}
|
||||
stc.Commit(context.Background(), session)
|
||||
if sbc0.ExecCount != 3 || sbc1.ExecCount != 1 {
|
||||
t.Errorf("want 3/1, got %d/%d", sbc0.ExecCount, sbc1.ExecCount)
|
||||
}
|
||||
if sbc0.CommitCount != 1 {
|
||||
t.Errorf("want 1, got %d", sbc0.CommitCount)
|
||||
}
|
||||
if sbc1.CommitCount != 0 {
|
||||
t.Errorf("want 0, got %d", sbc1.CommitCount)
|
||||
}
|
||||
|
||||
// case 3: write query followed by read query, in the same shard.
|
||||
s.Reset()
|
||||
sbc0 = &sandboxConn{}
|
||||
s.MapTestConn("0", sbc0)
|
||||
sbc1 = &sandboxConn{}
|
||||
s.MapTestConn("1", sbc1)
|
||||
stc = NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
session = NewSafeSession(&proto.Session{InTransaction: true})
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0"}, "", session, false)
|
||||
stc.Execute(context.Background(), "query1", nil, "TestScatterConnQueryNotInTransaction", []string{"0", "1"}, "", session, true)
|
||||
|
||||
wantSession = proto.Session{
|
||||
InTransaction: true,
|
||||
ShardSessions: []*proto.ShardSession{{
|
||||
Keyspace: "TestScatterConnQueryNotInTransaction",
|
||||
Shard: "0",
|
||||
TabletType: "",
|
||||
TransactionId: 1,
|
||||
}},
|
||||
}
|
||||
if !reflect.DeepEqual(wantSession, *session.Session) {
|
||||
t.Errorf("want\n%+v\ngot\n%+v", wantSession, *session.Session)
|
||||
}
|
||||
stc.Commit(context.Background(), session)
|
||||
if sbc0.ExecCount != 4 || sbc1.ExecCount != 1 {
|
||||
t.Errorf("want 4/1, got %d/%d", sbc0.ExecCount, sbc1.ExecCount)
|
||||
}
|
||||
if sbc0.CommitCount != 1 {
|
||||
t.Errorf("want 1, got %d", sbc0.CommitCount)
|
||||
}
|
||||
if sbc1.CommitCount != 0 {
|
||||
t.Errorf("want 0, got %d", sbc1.CommitCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendResult(t *testing.T) {
|
||||
qr := new(mproto.QueryResult)
|
||||
innerqr1 := &mproto.QueryResult{
|
||||
|
|
|
@ -57,7 +57,10 @@ func NewShardConn(ctx context.Context, serv SrvTopoServer, cell, keyspace, shard
|
|||
return endpoints, nil
|
||||
}
|
||||
blc := NewBalancer(getAddresses, retryDelay)
|
||||
ticker := timer.NewRandTicker(connLife, connLife/2)
|
||||
var ticker *timer.RandTicker
|
||||
if tabletType != topo.TYPE_MASTER {
|
||||
ticker = timer.NewRandTicker(connLife, connLife/2)
|
||||
}
|
||||
sdc := &ShardConn{
|
||||
keyspace: keyspace,
|
||||
shard: shard,
|
||||
|
@ -72,11 +75,13 @@ func NewShardConn(ctx context.Context, serv SrvTopoServer, cell, keyspace, shard
|
|||
consolidator: sync2.NewConsolidator(),
|
||||
connectTimings: tabletConnectTimings,
|
||||
}
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
sdc.closeCurrent()
|
||||
}
|
||||
}()
|
||||
if ticker != nil {
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
sdc.closeCurrent()
|
||||
}
|
||||
}()
|
||||
}
|
||||
return sdc
|
||||
}
|
||||
|
||||
|
@ -180,7 +185,9 @@ func (sdc *ShardConn) SplitQuery(ctx context.Context, query tproto.BoundQuery, s
|
|||
|
||||
// Close closes the underlying TabletConn.
|
||||
func (sdc *ShardConn) Close() {
|
||||
sdc.ticker.Stop()
|
||||
if sdc.ticker != nil {
|
||||
sdc.ticker.Stop()
|
||||
}
|
||||
sdc.closeCurrent()
|
||||
}
|
||||
|
||||
|
|
|
@ -750,13 +750,14 @@ func TestShardConnReconnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestShardConnLife(t *testing.T) {
|
||||
func TestReplicaShardConnLife(t *testing.T) {
|
||||
// auto-reconnect for non-master
|
||||
retryDelay := 10 * time.Millisecond
|
||||
retryCount := 5
|
||||
s := createSandbox("TestShardConnReconnect")
|
||||
s := createSandbox("TestReplicaShardConnLife")
|
||||
sbc := &sandboxConn{}
|
||||
s.MapTestConn("0", sbc)
|
||||
sdc := NewShardConn(context.Background(), new(sandboxTopo), "aa", "TestShardConnReconnect", "0", "", retryDelay, retryCount, connTimeoutTotal, connTimeoutPerConn, 10*time.Millisecond, connectTimings)
|
||||
sdc := NewShardConn(context.Background(), new(sandboxTopo), "aa", "TestReplicaShardConnLife", "0", topo.TYPE_REPLICA, retryDelay, retryCount, connTimeoutTotal, connTimeoutPerConn, 10*time.Millisecond, connectTimings)
|
||||
sdc.Execute(context.Background(), "query", nil, 0)
|
||||
if s.DialCounter != 1 {
|
||||
t.Errorf("DialCounter: %d, want 1", s.DialCounter)
|
||||
|
@ -766,4 +767,25 @@ func TestShardConnLife(t *testing.T) {
|
|||
if s.DialCounter != 2 {
|
||||
t.Errorf("DialCounter: %d, want 2", s.DialCounter)
|
||||
}
|
||||
sdc.Close()
|
||||
}
|
||||
|
||||
func TestMasterShardConnLife(t *testing.T) {
|
||||
// Do not auto-reconnect for master
|
||||
retryDelay := 10 * time.Millisecond
|
||||
retryCount := 5
|
||||
s := createSandbox("TestMasterShardConnLife")
|
||||
sbc := &sandboxConn{}
|
||||
s.MapTestConn("0", sbc)
|
||||
sdc := NewShardConn(context.Background(), new(sandboxTopo), "aa", "TestMasterShardConnLife", "0", topo.TYPE_MASTER, retryDelay, retryCount, connTimeoutTotal, connTimeoutPerConn, 10*time.Millisecond, connectTimings)
|
||||
sdc.Execute(context.Background(), "query", nil, 0)
|
||||
if s.DialCounter != 1 {
|
||||
t.Errorf("DialCounter: %d, want 1", s.DialCounter)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
sdc.Execute(context.Background(), "query", nil, 0)
|
||||
if s.DialCounter != 1 {
|
||||
t.Errorf("DialCounter: %d, want 1", s.DialCounter)
|
||||
}
|
||||
sdc.Close()
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ import (
|
|||
func TestExecuteKeyspaceAlias(t *testing.T) {
|
||||
testVerticalSplitGeneric(t, false, func(shards []string) (*mproto.QueryResult, error) {
|
||||
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
return stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil)
|
||||
return stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil, false)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ func TestBatchExecuteKeyspaceAlias(t *testing.T) {
|
|||
testVerticalSplitGeneric(t, false, func(shards []string) (*mproto.QueryResult, error) {
|
||||
stc := NewScatterConn(new(sandboxTopo), "", "aa", 1*time.Millisecond, 3, 2*time.Millisecond, 1*time.Millisecond, 24*time.Hour)
|
||||
queries := []tproto.BoundQuery{{"query", nil}}
|
||||
qrs, err := stc.ExecuteBatch(context.Background(), queries, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil)
|
||||
qrs, err := stc.ExecuteBatch(context.Background(), queries, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func TestStreamExecuteKeyspaceAlias(t *testing.T) {
|
|||
err := stc.StreamExecute(context.Background(), "query", nil, KsTestUnshardedServedFrom, shards, topo.TYPE_RDONLY, nil, func(r *mproto.QueryResult) error {
|
||||
appendResult(qr, r)
|
||||
return nil
|
||||
})
|
||||
}, false)
|
||||
return qr, err
|
||||
})
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func TestInTransactionKeyspaceAlias(t *testing.T) {
|
|||
TransactionId: 1,
|
||||
}},
|
||||
})
|
||||
_, err := stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, []string{"0"}, topo.TYPE_MASTER, session)
|
||||
_, err := stc.Execute(context.Background(), "query", nil, KsTestUnshardedServedFrom, []string{"0"}, topo.TYPE_MASTER, session, false)
|
||||
want := "shard, host: TestUnshardedServedFrom.0.master, {Uid:0 Host:0 NamedPortMap:map[vt:1] Health:map[]}, retry: err"
|
||||
if err == nil || err.Error() != want {
|
||||
t.Errorf("want '%v', got '%v'", want, err)
|
||||
|
|
|
@ -188,6 +188,7 @@ func (vtg *VTGate) ExecuteShard(ctx context.Context, query *proto.QueryShard, re
|
|||
func(keyspace string) (string, []string, error) {
|
||||
return query.Keyspace, query.Shards, nil
|
||||
},
|
||||
query.NotInTransaction,
|
||||
)
|
||||
if err == nil {
|
||||
reply.Result = qr
|
||||
|
@ -289,6 +290,7 @@ func (vtg *VTGate) ExecuteBatchShard(ctx context.Context, batchQuery *proto.Batc
|
|||
func(keyspace string) (string, []string, error) {
|
||||
return batchQuery.Keyspace, batchQuery.Shards, nil
|
||||
},
|
||||
batchQuery.NotInTransaction,
|
||||
)
|
||||
if err == nil {
|
||||
reply.List = qrs.List
|
||||
|
@ -484,7 +486,8 @@ func (vtg *VTGate) StreamExecuteShard(ctx context.Context, query *proto.QuerySha
|
|||
// Note we don't populate reply.Session here,
|
||||
// as it may change incrementaly as responses are sent.
|
||||
return sendReply(reply)
|
||||
})
|
||||
},
|
||||
query.NotInTransaction)
|
||||
vtg.rowsReturned.Add(statsKey, rowCount)
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -168,7 +168,7 @@ func (wr *Wrangler) MigrateServedTypes(ctx context.Context, keyspace, shard stri
|
|||
|
||||
// rebuild the keyspace serving graph if there was no error
|
||||
if !rec.HasErrors() {
|
||||
rec.RecordError(wr.RebuildKeyspaceGraph(ctx, keyspace, nil))
|
||||
rec.RecordError(wr.RebuildKeyspaceGraph(ctx, keyspace, cells))
|
||||
}
|
||||
|
||||
// Send a refresh to the tablets we just disabled, iff:
|
||||
|
|
|
@ -10,8 +10,12 @@ class BindVarsProxy(object):
|
|||
self.accessed_keys = set()
|
||||
|
||||
def __getitem__(self, name):
|
||||
var = self.bind_vars[name]
|
||||
self.bind_vars[name]
|
||||
self.accessed_keys.add(name)
|
||||
if isinstance(var, (list, set, tuple)):
|
||||
return '::%s' % name
|
||||
|
||||
return ':%s' % name
|
||||
|
||||
def export_bind_vars(self):
|
||||
|
|
|
@ -93,7 +93,11 @@ def convert_bind_vars(bind_variables):
|
|||
new_vars[key] = times.DateTimeToString(val)
|
||||
elif isinstance(val, datetime.date):
|
||||
new_vars[key] = times.DateToString(val)
|
||||
elif isinstance(val, (int, long, float, str, List, NoneType)):
|
||||
elif isinstance(val, set):
|
||||
new_vars[key] = sorted(val)
|
||||
elif isinstance(val, tuple):
|
||||
new_vars[key] = list(val)
|
||||
elif isinstance(val, (int, long, float, str, list, NoneType)):
|
||||
new_vars[key] = val
|
||||
else:
|
||||
# NOTE(msolomon) begrudgingly I allow this - we just have too much code
|
||||
|
|
|
@ -85,7 +85,8 @@ class VTGateCursor(object):
|
|||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges)
|
||||
keyranges=self.keyranges,
|
||||
not_in_transaction=(not self.is_writable()))
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
@ -106,7 +107,8 @@ class VTGateCursor(object):
|
|||
self.keyspace,
|
||||
self.tablet_type,
|
||||
entity_keyspace_id_map,
|
||||
entity_column_name)
|
||||
entity_column_name,
|
||||
not_in_transaction=(not self.is_writable()))
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
@ -207,7 +209,8 @@ class BatchVTGateCursor(VTGateCursor):
|
|||
self.bind_vars_list,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
self.keyspace_ids)
|
||||
self.keyspace_ids,
|
||||
not_in_transaction=(not self.is_writable()))
|
||||
self.query_list = []
|
||||
self.bind_vars_list = []
|
||||
|
||||
|
@ -236,7 +239,8 @@ class StreamVTGateCursor(VTGateCursor):
|
|||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges)
|
||||
keyranges=self.keyranges,
|
||||
not_in_transaction=(not self.is_writable()))
|
||||
self.index = 0
|
||||
return 0
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ def convert_exception(exc, *args, **kwargs):
|
|||
return new_exc
|
||||
|
||||
|
||||
def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspace_ids):
|
||||
def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspace_ids, not_in_transaction):
|
||||
# keyspace_ids are Keyspace Ids packed to byte[]
|
||||
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
|
||||
new_binds = field_types.convert_bind_vars(new_binds)
|
||||
|
@ -83,11 +83,12 @@ def _create_req_with_keyspace_ids(sql, new_binds, keyspace, tablet_type, keyspac
|
|||
'Keyspace': keyspace,
|
||||
'TabletType': tablet_type,
|
||||
'KeyspaceIds': keyspace_ids,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
return req
|
||||
|
||||
|
||||
def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges):
|
||||
def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges, not_in_transaction):
|
||||
# keyranges are keyspace.KeyRange objects with start/end packed to byte[]
|
||||
sql, new_binds = dbapi.prepare_query_bind_vars(sql, new_binds)
|
||||
new_binds = field_types.convert_bind_vars(new_binds)
|
||||
|
@ -97,6 +98,7 @@ def _create_req_with_keyranges(sql, new_binds, keyspace, tablet_type, keyranges)
|
|||
'Keyspace': keyspace,
|
||||
'TabletType': tablet_type,
|
||||
'KeyRanges': keyranges,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
return req
|
||||
|
||||
|
@ -180,14 +182,14 @@ class VTGateConnection(object):
|
|||
self.session = response.reply['Session']
|
||||
|
||||
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
|
||||
def _execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None):
|
||||
def _execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None, not_in_transaction=False):
|
||||
exec_method = None
|
||||
req = None
|
||||
if keyspace_ids is not None:
|
||||
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids)
|
||||
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
|
||||
exec_method = 'VTGate.ExecuteKeyspaceIds'
|
||||
elif keyranges is not None:
|
||||
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges)
|
||||
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
|
||||
exec_method = 'VTGate.ExecuteKeyRanges'
|
||||
else:
|
||||
raise dbexceptions.ProgrammingError('_execute called without specifying keyspace_ids or keyranges')
|
||||
|
@ -227,7 +229,7 @@ class VTGateConnection(object):
|
|||
return results, rowcount, lastrowid, fields
|
||||
|
||||
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
|
||||
def _execute_entity_ids(self, sql, bind_variables, keyspace, tablet_type, entity_keyspace_id_map, entity_column_name):
|
||||
def _execute_entity_ids(self, sql, bind_variables, keyspace, tablet_type, entity_keyspace_id_map, entity_column_name, not_in_transaction=False):
|
||||
sql, new_binds = dbapi.prepare_query_bind_vars(sql, bind_variables)
|
||||
new_binds = field_types.convert_bind_vars(new_binds)
|
||||
req = {
|
||||
|
@ -239,6 +241,7 @@ class VTGateConnection(object):
|
|||
{'ExternalID': xid, 'KeyspaceID': kid}
|
||||
for xid, kid in entity_keyspace_id_map.iteritems()],
|
||||
'EntityColumnName': entity_column_name,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
|
||||
self._add_session(req)
|
||||
|
@ -275,7 +278,7 @@ class VTGateConnection(object):
|
|||
|
||||
|
||||
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
|
||||
def _execute_batch(self, sql_list, bind_variables_list, keyspace, tablet_type, keyspace_ids):
|
||||
def _execute_batch(self, sql_list, bind_variables_list, keyspace, tablet_type, keyspace_ids, not_in_transaction=False):
|
||||
query_list = []
|
||||
for sql, bind_vars in zip(sql_list, bind_variables_list):
|
||||
sql, bind_vars = dbapi.prepare_query_bind_vars(sql, bind_vars)
|
||||
|
@ -292,6 +295,7 @@ class VTGateConnection(object):
|
|||
'Keyspace': keyspace,
|
||||
'TabletType': tablet_type,
|
||||
'KeyspaceIds': keyspace_ids,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
self._add_session(req)
|
||||
response = self.client.call('VTGate.ExecuteBatchKeyspaceIds', req)
|
||||
|
@ -327,14 +331,14 @@ class VTGateConnection(object):
|
|||
# the conversions will need to be passed back to _stream_next
|
||||
# (that way we avoid using a member variable here for such a corner case)
|
||||
@vtgate_utils.exponential_backoff_retry((dbexceptions.RequestBacklog))
|
||||
def _stream_execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None):
|
||||
def _stream_execute(self, sql, bind_variables, keyspace, tablet_type, keyspace_ids=None, keyranges=None, not_in_transaction=False):
|
||||
exec_method = None
|
||||
req = None
|
||||
if keyspace_ids is not None:
|
||||
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids)
|
||||
req = _create_req_with_keyspace_ids(sql, bind_variables, keyspace, tablet_type, keyspace_ids, not_in_transaction)
|
||||
exec_method = 'VTGate.StreamExecuteKeyspaceIds'
|
||||
elif keyranges is not None:
|
||||
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges)
|
||||
req = _create_req_with_keyranges(sql, bind_variables, keyspace, tablet_type, keyranges, not_in_transaction)
|
||||
exec_method = 'VTGate.StreamExecuteKeyRanges'
|
||||
else:
|
||||
raise dbexceptions.ProgrammingError('_stream_execute called without specifying keyspace_ids or keyranges')
|
||||
|
|
|
@ -76,12 +76,13 @@ def convert_exception(exc, *args):
|
|||
return exc
|
||||
|
||||
|
||||
def _create_req(sql, new_binds, tablet_type):
|
||||
def _create_req(sql, new_binds, tablet_type, not_in_transaction):
|
||||
new_binds = field_types.convert_bind_vars(new_binds)
|
||||
req = {
|
||||
'Sql': sql,
|
||||
'BindVariables': new_binds,
|
||||
'TabletType': tablet_type,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
return req
|
||||
|
||||
|
@ -160,8 +161,8 @@ class VTGateConnection(object):
|
|||
if 'Session' in response.reply and response.reply['Session']:
|
||||
self.session = response.reply['Session']
|
||||
|
||||
def _execute(self, sql, bind_variables, tablet_type):
|
||||
req = _create_req(sql, bind_variables, tablet_type)
|
||||
def _execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
|
||||
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
|
||||
self._add_session(req)
|
||||
|
||||
fields = []
|
||||
|
@ -196,7 +197,7 @@ class VTGateConnection(object):
|
|||
return results, rowcount, lastrowid, fields
|
||||
|
||||
|
||||
def _execute_batch(self, sql_list, bind_variables_list, tablet_type):
|
||||
def _execute_batch(self, sql_list, bind_variables_list, tablet_type, not_in_transaction=False):
|
||||
query_list = []
|
||||
for sql, bind_vars in zip(sql_list, bind_variables_list):
|
||||
query = {}
|
||||
|
@ -210,6 +211,7 @@ class VTGateConnection(object):
|
|||
req = {
|
||||
'Queries': query_list,
|
||||
'TabletType': tablet_type,
|
||||
'NotInTransaction': not_in_transaction,
|
||||
}
|
||||
self._add_session(req)
|
||||
response = self.client.call('VTGate.ExecuteBatch', req)
|
||||
|
@ -243,8 +245,8 @@ class VTGateConnection(object):
|
|||
# we return the fields for the response, and the column conversions
|
||||
# the conversions will need to be passed back to _stream_next
|
||||
# (that way we avoid using a member variable here for such a corner case)
|
||||
def _stream_execute(self, sql, bind_variables, tablet_type):
|
||||
req = _create_req(sql, bind_variables, tablet_type)
|
||||
def _stream_execute(self, sql, bind_variables, tablet_type, not_in_transaction=False):
|
||||
req = _create_req(sql, bind_variables, tablet_type, not_in_transaction)
|
||||
self._add_session(req)
|
||||
|
||||
self._stream_fields = []
|
||||
|
|
|
@ -5,6 +5,7 @@ import hmac
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import pprint
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
|
@ -69,10 +70,30 @@ keyspace_id bigint(20) unsigned NOT NULL,
|
|||
primary key(eid, id)
|
||||
) Engine=InnoDB'''
|
||||
|
||||
create_tables = [create_vt_insert_test, create_vt_a]
|
||||
create_vt_field_types = '''create table vt_field_types (
|
||||
id bigint(20) auto_increment,
|
||||
uint_val bigint(20) unsigned,
|
||||
str_val varchar(64),
|
||||
unicode_val varchar(64),
|
||||
float_val float(5, 1),
|
||||
keyspace_id bigint(20) unsigned NOT NULL,
|
||||
primary key(id)
|
||||
) Engine=InnoDB'''
|
||||
|
||||
|
||||
create_tables = [create_vt_insert_test, create_vt_a, create_vt_field_types]
|
||||
pack_kid = struct.Struct('!Q').pack
|
||||
|
||||
|
||||
class DBRow(object):
|
||||
|
||||
def __init__(self, column_names, row_tuple):
|
||||
self.__dict__ = dict(zip(column_names, row_tuple))
|
||||
|
||||
def __repr__(self):
|
||||
return pprint.pformat(self.__dict__, 4)
|
||||
|
||||
|
||||
def setUpModule():
|
||||
logging.debug("in setUpModule")
|
||||
try:
|
||||
|
@ -571,6 +592,91 @@ class TestVTGateFunctions(unittest.TestCase):
|
|||
except Exception, e:
|
||||
self.fail("Failed with error %s %s" % (str(e), traceback.print_exc()))
|
||||
|
||||
def test_field_types(self):
|
||||
try:
|
||||
vtgate_conn = get_connection()
|
||||
_delete_all(self.shard_index, 'vt_field_types')
|
||||
count = 10
|
||||
base_uint = int('8'+'0'*15, base=16)
|
||||
kid_list = shard_kid_map[shard_names[self.shard_index]]
|
||||
for x in xrange(1, count):
|
||||
keyspace_id = kid_list[count%len(kid_list)]
|
||||
cursor = vtgate_conn.cursor(KEYSPACE_NAME, 'master',
|
||||
keyspace_ids=[pack_kid(keyspace_id)],
|
||||
writable=True)
|
||||
cursor.begin()
|
||||
cursor.execute(
|
||||
"insert into vt_field_types "
|
||||
"(uint_val, str_val, unicode_val, float_val, keyspace_id) "
|
||||
"values (%(uint_val)s, %(str_val)s, %(unicode_val)s, "
|
||||
"%(float_val)s, %(keyspace_id)s)",
|
||||
{'uint_val': base_uint + x, 'str_val': 'str_%d' % x,
|
||||
'unicode_val': unicode('str_%d' % x), 'float_val': x*1.2,
|
||||
'keyspace_id': keyspace_id})
|
||||
cursor.commit()
|
||||
cursor = vtgate_conn.cursor(KEYSPACE_NAME, 'master',
|
||||
keyranges=[self.keyrange])
|
||||
rowcount = cursor.execute("select * from vt_field_types", {})
|
||||
field_names = [f[0] for f in cursor.description]
|
||||
self.assertEqual(rowcount, count -1, "rowcount doesn't match")
|
||||
id_list = []
|
||||
uint_val_list = []
|
||||
str_val_list = []
|
||||
unicode_val_list = []
|
||||
float_val_list = []
|
||||
for r in cursor.results:
|
||||
row = DBRow(field_names, r)
|
||||
id_list.append(row.id)
|
||||
uint_val_list.append(row.uint_val)
|
||||
str_val_list.append(row.str_val)
|
||||
unicode_val_list.append(row.unicode_val)
|
||||
float_val_list.append(row.float_val)
|
||||
|
||||
# iterable type checks - list, tuple, set are supported.
|
||||
query = "select * from vt_field_types where id in %(id_1)s"
|
||||
rowcount = cursor.execute(query, {'id_1': id_list})
|
||||
self.assertEqual(rowcount, len(id_list), "rowcount doesn't match")
|
||||
rowcount = cursor.execute(query, {'id_1': tuple(id_list)})
|
||||
self.assertEqual(rowcount, len(id_list), "rowcount doesn't match")
|
||||
rowcount = cursor.execute(query, {'id_1': set(id_list)})
|
||||
self.assertEqual(rowcount, len(id_list), "rowcount doesn't match")
|
||||
for i, r in enumerate(cursor.results):
|
||||
row = DBRow(field_names, r)
|
||||
self.assertIsInstance(row.id, (int, long))
|
||||
|
||||
# received field types same as input.
|
||||
# uint
|
||||
query = "select * from vt_field_types where uint_val in %(uint_val_1)s"
|
||||
rowcount = cursor.execute(query, {'uint_val_1': uint_val_list})
|
||||
self.assertEqual(rowcount, len(uint_val_list), "rowcount doesn't match")
|
||||
for i, r in enumerate(cursor.results):
|
||||
row = DBRow(field_names, r)
|
||||
self.assertIsInstance(row.uint_val, long)
|
||||
self.assertGreaterEqual(row.uint_val, base_uint, "uint value not in correct range")
|
||||
|
||||
# str
|
||||
query = "select * from vt_field_types where str_val in %(str_val_1)s"
|
||||
rowcount = cursor.execute(query, {'str_val_1': str_val_list})
|
||||
self.assertEqual(rowcount, len(str_val_list), "rowcount doesn't match")
|
||||
for i, r in enumerate(cursor.results):
|
||||
row = DBRow(field_names, r)
|
||||
self.assertIsInstance(row.str_val, str)
|
||||
|
||||
# unicode str
|
||||
query = "select * from vt_field_types where unicode_val in %(unicode_val_1)s"
|
||||
rowcount = cursor.execute(query, {'unicode_val_1': unicode_val_list})
|
||||
self.assertEqual(rowcount, len(unicode_val_list), "rowcount doesn't match")
|
||||
for i, r in enumerate(cursor.results):
|
||||
row = DBRow(field_names, r)
|
||||
self.assertIsInstance(row.unicode_val, (str,unicode))
|
||||
|
||||
# deliberately eliminating the float test since it is flaky due
|
||||
# to mysql float precision handling.
|
||||
except Exception, e:
|
||||
logging.debug("Write failed with error %s" % str(e))
|
||||
raise
|
||||
|
||||
|
||||
def _query_lots(self,
|
||||
conn,
|
||||
query,
|
||||
|
@ -607,6 +713,16 @@ class TestFailures(unittest.TestCase):
|
|||
return tablet.start_vttablet(lameduck_period=lameduck_period)
|
||||
# target_tablet_type=tablet_type)
|
||||
|
||||
def test_status_with_error(self):
|
||||
"""Tests that the status page loads correctly after a VTGate error."""
|
||||
vtgate_conn = get_connection()
|
||||
cursor = vtgate_conn.cursor('INVALID_KEYSPACE', 'replica', keyspace_ids=['0'])
|
||||
# We expect to see a DatabaseError due to an invalid keyspace
|
||||
with self.assertRaises(dbexceptions.DatabaseError):
|
||||
cursor.execute('select * from vt_insert_test', {})
|
||||
# Page should have loaded successfully
|
||||
self.assertIn('</html>', utils.get_status(vtgate_port))
|
||||
|
||||
def test_tablet_restart_read(self):
|
||||
try:
|
||||
vtgate_conn = get_connection()
|
||||
|
|
Загрузка…
Ссылка в новой задаче