Merge branch 'master' into replication

This commit is contained in:
Alain Jobart 2015-06-01 06:52:32 -07:00
Родитель 9043de7f67 b5f5970651
Коммит 5dc0717d4c
289 изменённых файлов: 10875 добавлений и 9607 удалений

Просмотреть файл

@ -92,7 +92,7 @@ small_integration_test_files = \
sharded.py \
secure.py \
binlog.py \
clone.py \
backup.py \
update_stream.py
medium_integration_test_files = \
@ -100,7 +100,8 @@ medium_integration_test_files = \
reparent.py \
vtdb_test.py \
vtgate_utils_test.py \
rowcache_invalidator.py
rowcache_invalidator.py \
worker.py
large_integration_test_files = \
vtgatev2_test.py \
@ -122,7 +123,8 @@ worker_integration_test_files = \
vertical_split.py \
vertical_split_vtgate.py \
initial_sharding.py \
initial_sharding_bytes.py
initial_sharding_bytes.py \
worker.py
.ONESHELL:
SHELL = /bin/bash
@ -183,6 +185,7 @@ proto:
cd go/vt/proto/queryservice && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/queryservice.proto --go_out=plugins=grpc:.
cd go/vt/proto/vtctl && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/vtctl.proto --go_out=plugins=grpc:.
cd go/vt/proto/tabletmanager && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/tabletmanager.proto --go_out=plugins=grpc:.
cd go/vt/proto/automation && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/automation.proto --go_out=plugins=grpc:.
find go/vt/proto -name "*.pb.go" | xargs sed --in-place -r -e 's,"([a-z0-9_]+).pb","github.com/youtube/vitess/go/vt/proto/\1",g'
cd py/vtctl && $$VTROOT/dist/protobuf/bin/protoc -I../../proto ../../proto/vtctl.proto --python_out=. --grpc_out=. --plugin=protoc-gen-grpc=$$VTROOT/dist/grpc/bin/grpc_python_plugin

Просмотреть файл

@ -24,10 +24,9 @@ and a more [detailed presentation from @Scale '14](http://youtu.be/5yDO-tmIoXY).
### Using Vitess
* [Getting Started](http://vitess.io/getting-started/):
running Vitess on Kubernetes.
* [Building](http://vitess.io/doc/GettingStarted):
how to manually build Vitess.
* Getting Started
* [On Kubernetes](http://vitess.io/getting-started/).
* [From the ground up](http://vitess.io/doc/GettingStarted).
* [Tools](http://vitess.io/doc/Tools):
all Vitess tools and servers.
* [vttablet/vtocc](http://vitess.io/doc/vtocc):

54
doc/BackupAndRestore.md Normal file
Просмотреть файл

@ -0,0 +1,54 @@
# Backup and Restore
This document describes Vitess Backup and Restore strategy.
### Overview
Backups are used in Vitess for two purposes: to provide a point-in-time backup for the data, and to bootstrap new instances.
### Backup Storage
Backups are stored on a Backup Storage service. Vitess core software contains an implementation that uses a local filesystem to store the files. Any network-mounted drive can then be used as the repository for backups.
We have plans to implement a version of the Backup Storage service for Google Cloud Storage (contact us if you are interested).
(The interface definition for the Backup Storage service is in [interface.go](https://github.com/youtube/vitess/blob/master/go/vt/mysqlctl/backupstorage/interface.go), see comments there for more details).
Concretely, the following command line flags are used for Backup Storage:
* -backup\_storage\_implementation: which implementation of the Backup Storage interface to use.
* -file\_backup\_storage\_root: the root of the backups if 'file' is used as a Backup Storage.
### Taking a Backup
To take a backup is very straightforward: just run the 'vtctl Backup <tablet-alias>' command. The designated tablet will take itself out of the healthy serving tablets, shutdown its mysqld process, copy the necessary files to the Backup Storage, restart mysql, restart replication, and join the cluster back.
With health-check enabled (the recommended default), the tablet goes back to spare state. Once it catches up on replication, it will go back to a serving state.
Note for this to work correctly, the tablet must be started with the right parameters to point it to the Backup Storage system (see previous section).
### Life of a Shard
To illustrate how backups are used in Vitess to bootstrap new instances, let's go through the creation and life of a Shard:
* A shard is initially brought up with no existing backup. All instances are started as replicas. With health-check enabled (the recommended default), each instance will realize replication is not running, and just stay unhealthy as spare.
* Once a few replicas are up, InitShardMaster is run, one host becomes the master, the others replicas. Master becomes healthy, replicas are not as no database exists.
* Initial schema can then be applied to the Master. Either use the usual Schema change tools, or use CopySchemaShard for shards created as targets for resharding.
* After replicating the schema creation, all replicas become healthy. At this point, we have a working and functionnal shard.
* The initial backup is taken (that stores the data and the current replication position), and backup data is copied to a network storage.
* When a replica comes up (either a new replica, or one whose instance was just restarted), it restores the latest backup, resets its master to the current shard master, and starts replicating.
* A Cronjob to backup the data on a regular basis should then be run. The frequency of the backups should be high enough (compared to MySQL binlog retention), so we can always have a backup to fall back upon.
Restoring a backup is enabled by the --restore\_from\_backup command line option in vttablet. It can be enabled all the time for all the tablets in a shard, as it doesn't prevent vttablet from starting if no backup can be found.
### Backup Management
Two vtctl commands exist to manage the backups:
* 'vtctl ListBackups <keyspace/shard>' will display the existing backups for a keyspace/shard in the order they were taken (oldest first).
* 'vtctl RemoveBackup <keyspace/shard> <backup name>' will remove a backup from Backup Storage.
### Details
Both Backup and Restore copy and compress / decompress multiple files simultaneously to increase throughput. The concurrency can be controlled by command-line flags (-concurrency for 'vtctl Backup', and -restore\_concurrency for vttablet). If the network link is fast enough, the concurrency will match the CPU usage of the process during backup / restore.

Просмотреть файл

@ -9,7 +9,7 @@ Lets assume that youve already got a keyspace up and running, with a singl
The first thing that we need to do is add a column to the soon-to-be-sharded keyspace which will be used as the "sharding key". This column will tell Vitess which shard a particular row of data should go to. You can add the column by running an alter on the unsharded keyspace - probably by running something like:
`vtctl ApplySchemaKeyspace -simple -sql="alter table <table name> add keyspace_id" test_keyspace`
`vtctl ApplySchema -sql="alter table <table name> add keyspace_id" test_keyspace`
for each table in the keyspace. Once the column is added everywhere, each row needs to be backfilled with the appropriate keyspace ID.

Просмотреть файл

@ -61,89 +61,14 @@ type SchemaChangeResult struct {
}
```
The ApplySchema action applies a schema change. It is described by the following structure (also returns a SchemaChangeResult):
The ApplySchema action applies a schema change to a specified keyspace, the performed steps are:
```go
type SchemaChange struct {
Sql string
Force bool
AllowReplication bool
BeforeSchema *SchemaDefinition
AfterSchema *SchemaDefinition
}
```
And the associated ApplySchema remote action for a tablet. Then the performed steps are:
* The database to use is either derived from the tablet dbName if UseVt is false, or is the _vt database. A use dbname is prepended to the Sql.
* (if BeforeSchema is not nil) read the schema, make sure it is equal to BeforeSchema. If not equal: if Force is not set, we will abort, if Force is set, well issue a warning and keep going.
* if AllowReplication is false, well disable replication (adding SET sql_log_bin=0 before the Sql).
* We will then apply the Sql command.
* (if AfterSchema is not nil) read the schema again, make sure it is equal to AfterSchema. If not equal: if Force is not set, we will issue an error, if Force is set, well issue a warning.
We will return the following information:
* whether it worked or not (doh!)
* BeforeSchema
* AfterSchema
### Use case 1: Single tablet update:
* we first do a Preflight (to know what BeforeSchema and AfterSchema will be). This can be disabled, but is not recommended.
* we then do the schema upgrade. We will check BeforeSchema before the upgrade, and AfterSchema after the upgrade.
### Use case 2: Single Shard update:
* need to figure out (or be told) if its a simple or complex schema update (does it require the shell game?). For now we'll use a command line flag.
* in any case, do a Preflight on the master, to get the BeforeSchema and AfterSchema values.
* in any case, gather the schema on all databases, to see which ones have been upgraded already or not. This guarantees we can interrupt and restart a schema change. Also, this makes sure no action is currently running on the databases we're about to change.
* if simple:
* nobody has it: apply to master, very similar to a single tablet update.
* some tablets have it but not others: error out
* if complex: do the shell game while disabling replication. Skip the tablets that already have it. Have an option to re-parent at the end.
* Note the Backup, and Lag servers won't apply a complex schema change. Only the servers actively in the replication graph will.
* the process can be interrupted at any time, restarting it as a complex schema upgrade should just work.
### Use case 3: Keyspace update:
* Similar to Single Shard, but the BeforeSchema and AfterSchema values are taken from the first shard, and used in all shards after that.
* We don't know the new masters to use on each shard, so just skip re-parenting all together.
This translates into the following vtctl commands:
* It first finds shards belong to this keyspace, including newly added shards in the presence of [resharding event](Resharding.md).
* Validate the sql syntax and reject the schema change if the sql 1) Alter more then 100,000 rows, or 2) The targed table has more then 2,000,000 rows. The rational behind this is that ApplySchema simply applies schema changes to the masters; therefore, a big schema change that takes too much time slows down the replication and may reduce the availability of the overall system.
* Create a temporary database that has the same schema as the targeted table. Apply the sql to it and makes sure it changes table structure.
* Apply the Sql command to the database.
* Read the schema again, make sure it is equal to AfterSchema.
```
PreflightSchema {-sql=<sql> || -sql_file=<filename>} <tablet alias>
ApplySchema {-sql=<sql> || -sql_file=<filename>} <keyspace>
```
apply the schema change to a temporary database to gather before and after schema and validate the change. The sql can be inlined or read from a file.
This will create a temporary database, copy the existing keyspace schema into it, apply the schema change, and re-read the resulting schema.
```
$ echo "create table test_table(id int);" > change.sql
$ vtctl PreflightSchema -sql_file=change.sql nyc-0002009001
```
```
ApplySchema {-sql=<sql> || -sql_file=<filename>} [-skip_preflight] [-stop_replication] <tablet alias>
```
apply the schema change to the specific tablet (allowing replication by default). The sql can be inlined or read from a file.
a PreflightSchema operation will first be used to make sure the schema is OK (unless skip_preflight is specified).
```
ApplySchemaShard {-sql=<sql> || -sql_file=<filename>} [-simple] [-new_parent=<tablet alias>] <keyspace/shard>
```
apply the schema change to the specific shard. If simple is specified, we just apply on the live master. Otherwise, we do the shell game and will optionally re-parent.
if new_parent is set, we will also reparent (otherwise the master won't be touched at all). Using the force flag will cause a bunch of checks to be ignored, use with care.
```
$ vtctl ApplySchemaShard --sql-file=change.sql -simple vtx/0
$ vtctl ApplySchemaShard --sql-file=change.sql -new_parent=nyc-0002009002 vtx/0
```
```
ApplySchemaKeyspace {-sql=<sql> || -sql_file=<filename>} [-simple] <keyspace>
```
apply the schema change to the specified shard. If simple is specified, we just apply on the live master. Otherwise we will need to do the shell game. So we will apply the schema change to every single slave.

Двоичные данные
doc/slides/Percona2015.pptx Normal file

Двоичный файл не отображается.

Просмотреть файл

Просмотреть файл

@ -18,8 +18,9 @@ if [[ ! -f bootstrap.sh ]]; then
exit 1
fi
# To avoid AUFS permission issues, files must allow access by "other"
chmod -R o=g *
# To avoid AUFS permission issues, files must allow access by "other" (permissions rX required).
# Mirror permissions to "other" from the owning group (for which we assume it has at least rX permissions).
chmod -R o=g .
args="$args --rm -e USER=vitess -v /dev/log:/dev/log"
args="$args -v $PWD:/tmp/src"

Просмотреть файл

@ -27,10 +27,11 @@ $ go get github.com/youtube/vitess/go/cmd/vtctlclient
### Set the path to kubectl
If you're running in Container Engine, set the `KUBECTL` environment variable
to point to the `gcloud` command:
to point to the `kubectl` command provided by the Google Cloud SDK (if you've
already added gcloud to your PATH, you likely have kubectl):
```
$ export KUBECTL='gcloud alpha container kubectl'
$ export KUBECTL='kubectl'
```
If you're running Kubernetes manually, set the `KUBECTL` environment variable

Просмотреть файл

@ -89,7 +89,7 @@ if [ -z "$GOPATH" ]; then
exit -1
fi
export KUBECTL='gcloud alpha container kubectl'
export KUBECTL='kubectl'
go get github.com/youtube/vitess/go/cmd/vtctlclient
gcloud config set compute/zone $GKE_ZONE
project_id=`gcloud config list project | sed -n 2p | cut -d " " -f 3`

Просмотреть файл

@ -42,26 +42,6 @@ func initCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) err
return nil
}
func restoreCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
dontWaitForSlaveStart := subFlags.Bool("dont_wait_for_slave_start", false, "won't wait for replication to start (useful when restoring from master server)")
fetchConcurrency := subFlags.Int("fetch_concurrency", 3, "how many files to fetch simultaneously")
fetchRetryCount := subFlags.Int("fetch_retry_count", 3, "how many times to retry a failed transfer")
subFlags.Parse(args)
if subFlags.NArg() != 1 {
return fmt.Errorf("Command restore requires <snapshot manifest file>")
}
rs, err := mysqlctl.ReadSnapshotManifest(subFlags.Arg(0))
if err != nil {
return fmt.Errorf("restore failed: ReadSnapshotManifest: %v", err)
}
err = mysqld.RestoreFromSnapshot(logutil.NewConsoleLogger(), rs, *fetchConcurrency, *fetchRetryCount, *dontWaitForSlaveStart, nil)
if err != nil {
return fmt.Errorf("restore failed: RestoreFromSnapshot: %v", err)
}
return nil
}
func shutdownCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
waitTime := subFlags.Duration("wait_time", mysqlctl.MysqlWaitTime, "how long to wait for shutdown")
subFlags.Parse(args)
@ -72,50 +52,6 @@ func shutdownCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string)
return nil
}
func snapshotCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
concurrency := subFlags.Int("concurrency", 4, "how many compression jobs to run simultaneously")
subFlags.Parse(args)
if subFlags.NArg() != 1 {
return fmt.Errorf("Command snapshot requires <db name>")
}
filename, _, _, err := mysqld.CreateSnapshot(logutil.NewConsoleLogger(), subFlags.Arg(0), tabletAddr, false, *concurrency, false, nil)
if err != nil {
return fmt.Errorf("snapshot failed: %v", err)
}
log.Infof("manifest location: %v", filename)
return nil
}
func snapshotSourceStartCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
concurrency := subFlags.Int("concurrency", 4, "how many checksum jobs to run simultaneously")
subFlags.Parse(args)
if subFlags.NArg() != 1 {
return fmt.Errorf("Command snapshotsourcestart requires <db name>")
}
filename, slaveStartRequired, readOnly, err := mysqld.CreateSnapshot(logutil.NewConsoleLogger(), subFlags.Arg(0), tabletAddr, false, *concurrency, true, nil)
if err != nil {
return fmt.Errorf("snapshot failed: %v", err)
}
log.Infof("manifest location: %v", filename)
log.Infof("slave start required: %v", slaveStartRequired)
log.Infof("read only: %v", readOnly)
return nil
}
func snapshotSourceEndCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
slaveStartRequired := subFlags.Bool("slave_start", false, "will restart replication")
readWrite := subFlags.Bool("read_write", false, "will make the server read-write")
subFlags.Parse(args)
err := mysqld.SnapshotSourceEnd(*slaveStartRequired, !(*readWrite), true, map[string]string{})
if err != nil {
return fmt.Errorf("snapshotsourceend failed: %v", err)
}
return nil
}
func startCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
waitTime := subFlags.Duration("wait_time", mysqlctl.MysqlWaitTime, "how long to wait for startup")
subFlags.Parse(args)
@ -188,19 +124,6 @@ var commands = []command{
command{"shutdown", shutdownCmd, "[-wait_time=20s]",
"Shuts down mysqld, does not remove any file"},
command{"snapshot", snapshotCmd,
"[-concurrency=4] <db name>",
"Takes a full snapshot, copying the innodb data files"},
command{"snapshotsourcestart", snapshotSourceStartCmd,
"[-concurrency=4] <db name>",
"Enters snapshot server mode (mysqld stopped, serving innodb data files)"},
command{"snapshotsourceend", snapshotSourceEndCmd,
"[-slave_start] [-read_write]",
"Gets out of snapshot server mode"},
command{"restore", restoreCmd,
"[-fetch_concurrency=3] [-fetch_retry_count=3] [-dont_wait_for_slave_start] <snapshot manifest file>",
"Restores a full snapshot"},
command{"position", positionCmd,
"<operation> <pos1> <pos2 | gtid>",
"Compute operations on replication positions"},

Просмотреть файл

@ -13,6 +13,7 @@ import (
"github.com/youtube/vitess/go/vt/logutil"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/topo/helpers"
"golang.org/x/net/context"
)
var fromTopo = flag.String("from", "", "topology to copy data from")
@ -41,19 +42,20 @@ func main() {
exit.Return(1)
}
ctx := context.Background()
fromTS := topo.GetServerByName(*fromTopo)
toTS := topo.GetServerByName(*toTopo)
if *doKeyspaces {
helpers.CopyKeyspaces(fromTS, toTS)
helpers.CopyKeyspaces(ctx, fromTS, toTS)
}
if *doShards {
helpers.CopyShards(fromTS, toTS, *deleteKeyspaceShards)
helpers.CopyShards(ctx, fromTS, toTS, *deleteKeyspaceShards)
}
if *doShardReplications {
helpers.CopyShardReplications(fromTS, toTS)
helpers.CopyShardReplications(ctx, fromTS, toTS)
}
if *doTablets {
helpers.CopyTablets(fromTS, toTS)
helpers.CopyTablets(ctx, fromTS, toTS)
}
}

Просмотреть файл

@ -0,0 +1,11 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
// Imports and register the gorpc vtgateconn client
import (
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
)

189
go/cmd/vtclient/vtclient.go Normal file
Просмотреть файл

@ -0,0 +1,189 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"database/sql"
"encoding/json"
"flag"
"fmt"
"os"
"strings"
"time"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/exit"
"github.com/youtube/vitess/go/vt/logutil"
// import the 'vitess' sql driver
_ "github.com/youtube/vitess/go/vt/client"
)
var (
usage = `
vtclient connects to a vtgate server using the standard go driver API.
Version 3 of the API is used, we do not send any hint to the server.
For query bound variables, we assume place-holders in the query string
in the form of :v1, :v2, etc.
`
server = flag.String("server", "", "vtgate server to connect to")
tabletType = flag.String("tablet_type", "rdonly", "tablet type to direct queries to")
timeout = flag.Duration("timeout", 30*time.Second, "timeout for queries")
streaming = flag.Bool("streaming", false, "use a streaming query")
bindVariables = newBindvars("bind_variables", "bind variables as a json list")
)
func init() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, usage)
}
}
type bindvars []interface{}
func (bv *bindvars) String() string {
b, err := json.Marshal(bv)
if err != nil {
return err.Error()
}
return string(b)
}
func (bv *bindvars) Set(s string) (err error) {
err = json.Unmarshal([]byte(s), &bv)
if err != nil {
return err
}
// json reads all numbers as float64
// So, we just ditch floats for bindvars
for i, v := range *bv {
if f, ok := v.(float64); ok {
if f > 0 {
(*bv)[i] = uint64(f)
} else {
(*bv)[i] = int64(f)
}
}
}
return nil
}
// For internal flag compatibility
func (bv *bindvars) Get() interface{} {
return bv
}
func newBindvars(name, usage string) *bindvars {
var bv bindvars
flag.Var(&bv, name, usage)
return &bv
}
// FIXME(alainjobart) this is a cheap trick. Should probably use the
// query parser if we needed this to be 100% reliable.
func isDml(sql string) bool {
lower := strings.TrimSpace(strings.ToLower(sql))
return strings.HasPrefix(lower, "insert") || strings.HasPrefix(lower, "update") || strings.HasPrefix(lower, "delete")
}
func main() {
defer exit.Recover()
defer logutil.Flush()
flag.Parse()
args := flag.Args()
if len(args) == 0 {
flag.Usage()
exit.Return(1)
}
connStr := fmt.Sprintf(`{"address": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`, *server, *tabletType, *streaming, int64(30*(*timeout)))
db, err := sql.Open("vitess", connStr)
if err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)
}
log.Infof("Sending the query...")
now := time.Now()
// handle dml
if isDml(args[0]) {
tx, err := db.Begin()
if err != nil {
log.Errorf("begin failed: %v", err)
exit.Return(1)
}
result, err := db.Exec(args[0], []interface{}(*bindVariables)...)
if err != nil {
log.Errorf("exec failed: %v", err)
exit.Return(1)
}
err = tx.Commit()
if err != nil {
log.Errorf("commit failed: %v", err)
exit.Return(1)
}
rowsAffected, err := result.RowsAffected()
lastInsertId, err := result.LastInsertId()
log.Infof("Total time: %v / Row affected: %v / Last Insert Id: %v", time.Now().Sub(now), rowsAffected, lastInsertId)
} else {
// launch the query
rows, err := db.Query(args[0], []interface{}(*bindVariables)...)
if err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)
}
defer rows.Close()
// print the headers
cols, err := rows.Columns()
if err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)
}
line := "Index"
for _, field := range cols {
line += "\t" + field
}
fmt.Printf("%s\n", line)
// get the rows
rowIndex := 0
for rows.Next() {
row := make([]interface{}, len(cols))
for i := range row {
var col string
row[i] = &col
}
if err := rows.Scan(row...); err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)
}
// print the line
line := fmt.Sprintf("%d", rowIndex)
for _, value := range row {
line += fmt.Sprintf("\t%v", *(value.(*string)))
}
fmt.Printf("%s\n", line)
rowIndex++
}
if err := rows.Err(); err != nil {
log.Errorf("Error %v\n", err)
exit.Return(1)
}
log.Infof("Total time: %v / Row count: %v", time.Now().Sub(now), rowIndex)
}
}

Просмотреть файл

@ -1,11 +0,0 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
// This plugin imports etcdtopo to register the etcd implementation of TopoServer.
import (
_ "github.com/youtube/vitess/go/vt/etcdtopo"
)

Просмотреть файл

@ -1,11 +0,0 @@
// Copyright 2013, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
// Imports and register the gorpc tabletconn client
import (
_ "github.com/youtube/vitess/go/vt/tabletserver/gorpctabletconn"
)

Просмотреть файл

@ -1,11 +0,0 @@
// Copyright 2013, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
// Imports and register the Zookeeper TopologyServer
import (
_ "github.com/youtube/vitess/go/vt/zktopo"
)

Просмотреть файл

@ -1,193 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"encoding/json"
"flag"
"fmt"
"os"
"strings"
"time"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/db"
"github.com/youtube/vitess/go/exit"
"github.com/youtube/vitess/go/vt/client2"
_ "github.com/youtube/vitess/go/vt/client2/tablet"
"github.com/youtube/vitess/go/vt/logutil"
)
var usage = `
The parameters are first the SQL command, then the bound variables.
For query arguments, we assume place-holders in the query string
in the form of :v0, :v1, etc.
`
var count = flag.Int("count", 1, "how many times to run the query")
var bindvars = FlagMap("bindvars", "bind vars as a json dictionary")
var server = flag.String("server", "localhost:6603/test_keyspace/0", "vtocc server as [user:password@]hostname:port/keyspace/shard[#keyrangestart-keyrangeend]")
var driver = flag.String("driver", "vttablet", "which driver to use (one of vttablet, vttablet-streaming, vtdb, vtdb-streaming)")
var verbose = flag.Bool("verbose", false, "show results")
func init() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, usage)
}
}
//----------------------------------
type Map map[string]interface{}
func (m *Map) String() string {
b, err := json.Marshal(*m)
if err != nil {
return err.Error()
}
return string(b)
}
func (m *Map) Set(s string) (err error) {
err = json.Unmarshal([]byte(s), m)
if err != nil {
return err
}
// json reads all numbers as float64
// So, we just ditch floats for bindvars
for k, v := range *m {
f, ok := v.(float64)
if ok {
if f > 0 {
(*m)[k] = uint64(f)
} else {
(*m)[k] = int64(f)
}
}
}
return nil
}
// For internal flag compatibility
func (m *Map) Get() interface{} {
return m
}
func FlagMap(name, usage string) (m map[string]interface{}) {
m = make(map[string]interface{})
mm := Map(m)
flag.Var(&mm, name, usage)
return m
}
// FIXME(alainjobart) this is a cheap trick. Should probably use the
// query parser if we needed this to be 100% reliable.
func isDml(sql string) bool {
lower := strings.TrimSpace(strings.ToLower(sql))
return strings.HasPrefix(lower, "insert") || strings.HasPrefix(lower, "update")
}
func main() {
defer exit.Recover()
defer logutil.Flush()
flag.Parse()
args := flag.Args()
if len(args) == 0 {
flag.Usage()
exit.Return(1)
}
client2.RegisterShardedDrivers()
conn, err := db.Open(*driver, *server)
if err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)
}
log.Infof("Sending the query...")
now := time.Now()
// handle dml
if isDml(args[0]) {
t, err := conn.Begin()
if err != nil {
log.Errorf("begin failed: %v", err)
exit.Return(1)
}
r, err := conn.Exec(args[0], bindvars)
if err != nil {
log.Errorf("exec failed: %v", err)
exit.Return(1)
}
err = t.Commit()
if err != nil {
log.Errorf("commit failed: %v", err)
exit.Return(1)
}
n, err := r.RowsAffected()
log.Infof("Total time: %v / Row affected: %v", time.Now().Sub(now), n)
} else {
// launch the query
r, err := conn.Exec(args[0], bindvars)
if err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)
}
// get the headers
cols := r.Columns()
if err != nil {
log.Errorf("client error: %v", err)
exit.Return(1)
}
// print the header
if *verbose {
line := "Index"
for _, field := range cols {
line += "\t" + field
}
log.Infof(line)
}
// get the rows
rowIndex := 0
for row := r.Next(); row != nil; row = r.Next() {
// print the line if needed
if *verbose {
line := fmt.Sprintf("%d", rowIndex)
for _, value := range row {
if value != nil {
switch value.(type) {
case []byte:
line += fmt.Sprintf("\t%s", value)
default:
line += fmt.Sprintf("\t%v", value)
}
} else {
line += "\t"
}
}
log.Infof(line)
}
rowIndex++
}
if err := r.Err(); err != nil {
log.Errorf("Error %v\n", err)
exit.Return(1)
}
log.Infof("Total time: %v / Row count: %v", time.Now().Sub(now), rowIndex)
}
}

Просмотреть файл

@ -0,0 +1,9 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
_ "github.com/youtube/vitess/go/vt/mysqlctl/filebackupstorage"
)

Просмотреть файл

@ -10,7 +10,6 @@ import (
"log/syslog"
"os"
"os/signal"
"sort"
"strings"
"syscall"
"time"
@ -18,7 +17,6 @@ import (
log "github.com/golang/glog"
"github.com/youtube/vitess/go/exit"
"github.com/youtube/vitess/go/vt/logutil"
myproto "github.com/youtube/vitess/go/vt/mysqlctl/proto"
"github.com/youtube/vitess/go/vt/tabletmanager/tmclient"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtctl"
@ -54,6 +52,12 @@ func installSignalHandlers(cancel func()) {
}()
}
// hooks to register plug-ins after flag init
type initFunc func()
var initFuncs []initFunc
func main() {
defer exit.RecoverAll()
defer logutil.Flush()
@ -81,6 +85,10 @@ func main() {
wr := wrangler.New(logutil.NewConsoleLogger(), topoServer, tmclient.NewTabletManagerClient(), *lockWaitTimeout)
installSignalHandlers(cancel)
for _, f := range initFuncs {
f()
}
err := vtctl.RunCommand(ctx, wr, args)
cancel()
switch err {
@ -94,52 +102,3 @@ func main() {
exit.Return(255)
}
}
type rTablet struct {
*topo.TabletInfo
*myproto.ReplicationStatus
}
type rTablets []*rTablet
func (rts rTablets) Len() int { return len(rts) }
func (rts rTablets) Swap(i, j int) { rts[i], rts[j] = rts[j], rts[i] }
// Sort for tablet replication.
// master first, then i/o position, then sql position
func (rts rTablets) Less(i, j int) bool {
// NOTE: Swap order of unpack to reverse sort
l, r := rts[j], rts[i]
// l or r ReplicationPosition would be nil if we failed to get
// the position (put them at the beginning of the list)
if l.ReplicationStatus == nil {
return r.ReplicationStatus != nil
}
if r.ReplicationStatus == nil {
return false
}
var lTypeMaster, rTypeMaster int
if l.Type == topo.TYPE_MASTER {
lTypeMaster = 1
}
if r.Type == topo.TYPE_MASTER {
rTypeMaster = 1
}
if lTypeMaster < rTypeMaster {
return true
}
if lTypeMaster == rTypeMaster {
return !l.Position.AtLeast(r.Position)
}
return false
}
func sortReplicatingTablets(tablets []*topo.TabletInfo, stats []*myproto.ReplicationStatus) []*rTablet {
rtablets := make([]*rTablet, len(tablets))
for i, status := range stats {
rtablets[i] = &rTablet{tablets[i], status}
}
sort.Sort(rTablets(rtablets))
return rtablets
}

Просмотреть файл

@ -48,12 +48,12 @@ func newTabletHealth(thc *tabletHealthCache, tabletAlias topo.TabletAlias) (*Tab
func (th *TabletHealth) update(thc *tabletHealthCache, tabletAlias topo.TabletAlias) {
defer thc.delete(tabletAlias)
ti, err := thc.ts.GetTablet(tabletAlias)
ctx := context.Background()
ti, err := thc.ts.GetTablet(ctx, tabletAlias)
if err != nil {
return
}
ctx := context.Background()
c, errFunc, err := thc.tmc.HealthStream(ctx, ti)
if err != nil {
return

Просмотреть файл

@ -235,7 +235,7 @@ func (loader *TemplateLoader) ServeTemplate(templateName string, data interface{
var (
modifyDbTopology func(context.Context, topo.Server, *topotools.Topology) error
modifyDbServingGraph func(topo.Server, *topotools.ServingGraph)
modifyDbServingGraph func(context.Context, topo.Server, *topotools.ServingGraph)
)
// SetDbTopologyPostprocessor installs a hook that can modify
@ -249,7 +249,7 @@ func SetDbTopologyPostprocessor(f func(context.Context, topo.Server, *topotools.
// SetDbServingGraphPostprocessor installs a hook that can modify
// topotools.ServingGraph struct before it's displayed.
func SetDbServingGraphPostprocessor(f func(topo.Server, *topotools.ServingGraph)) {
func SetDbServingGraphPostprocessor(f func(context.Context, topo.Server, *topotools.ServingGraph)) {
if modifyDbServingGraph != nil {
panic("Cannot set multiple DbServingGraph postprocessors")
}

Просмотреть файл

@ -9,6 +9,7 @@ import (
"time"
"github.com/youtube/vitess/go/vt/topo"
"golang.org/x/net/context"
)
// This file includes the support for serving topo data to an ajax-based
@ -47,7 +48,7 @@ func (bvo *BaseVersionedObject) SetVersion(version int) {
// GetVersionedObjectFunc is the function the cache will call to get
// the object itself.
type GetVersionedObjectFunc func() (VersionedObject, error)
type GetVersionedObjectFunc func(ctx context.Context) (VersionedObject, error)
// VersionedObjectCache is the main cache object. Just needs a method to get
// the content.
@ -68,7 +69,7 @@ func NewVersionedObjectCache(getObject GetVersionedObjectFunc) *VersionedObjectC
}
// Get returns the versioned value from the cache.
func (voc *VersionedObjectCache) Get() ([]byte, error) {
func (voc *VersionedObjectCache) Get(ctx context.Context) ([]byte, error) {
voc.mu.Lock()
defer voc.mu.Unlock()
@ -77,7 +78,7 @@ func (voc *VersionedObjectCache) Get() ([]byte, error) {
return voc.result, nil
}
newObject, err := voc.getObject()
newObject, err := voc.getObject(ctx)
if err != nil {
return nil, err
}
@ -142,7 +143,7 @@ func NewVersionedObjectCacheMap(factory VersionedObjectCacheFactory) *VersionedO
}
// Get finds the right VersionedObjectCache and returns its value
func (vocm *VersionedObjectCacheMap) Get(key string) ([]byte, error) {
func (vocm *VersionedObjectCacheMap) Get(ctx context.Context, key string) ([]byte, error) {
vocm.mu.Lock()
voc, ok := vocm.cacheMap[key]
if !ok {
@ -151,7 +152,7 @@ func (vocm *VersionedObjectCacheMap) Get(key string) ([]byte, error) {
}
vocm.mu.Unlock()
return voc.Get()
return voc.Get(ctx)
}
// Flush will flush the entire cache
@ -177,8 +178,8 @@ func (kc *KnownCells) Reset() {
}
func newKnownCellsCache(ts topo.Server) *VersionedObjectCache {
return NewVersionedObjectCache(func() (VersionedObject, error) {
cells, err := ts.GetKnownCells()
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
cells, err := ts.GetKnownCells(ctx)
if err != nil {
return nil, err
}
@ -202,8 +203,8 @@ func (k *Keyspaces) Reset() {
}
func newKeyspacesCache(ts topo.Server) *VersionedObjectCache {
return NewVersionedObjectCache(func() (VersionedObject, error) {
keyspaces, err := ts.GetKeyspaces()
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
keyspaces, err := ts.GetKeyspaces(ctx)
if err != nil {
return nil, err
}
@ -232,8 +233,8 @@ func (k *Keyspace) Reset() {
func newKeyspaceCache(ts topo.Server) *VersionedObjectCacheMap {
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
return NewVersionedObjectCache(func() (VersionedObject, error) {
k, err := ts.GetKeyspace(key)
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
k, err := ts.GetKeyspace(ctx, key)
if err != nil {
return nil, err
}
@ -264,8 +265,8 @@ func (s *ShardNames) Reset() {
func newShardNamesCache(ts topo.Server) *VersionedObjectCacheMap {
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
return NewVersionedObjectCache(func() (VersionedObject, error) {
sn, err := ts.GetShardNames(key)
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
sn, err := ts.GetShardNames(ctx, key)
if err != nil {
return nil, err
}
@ -301,13 +302,13 @@ func (s *Shard) Reset() {
func newShardCache(ts topo.Server) *VersionedObjectCacheMap {
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
return NewVersionedObjectCache(func() (VersionedObject, error) {
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
keyspace, shard, err := topo.ParseKeyspaceShardString(key)
if err != nil {
return nil, err
}
s, err := ts.GetShard(keyspace, shard)
s, err := ts.GetShard(ctx, keyspace, shard)
if err != nil {
return nil, err
}
@ -349,12 +350,12 @@ func (cst *CellShardTablets) Reset() {
func newCellShardTabletsCache(ts topo.Server) *VersionedObjectCacheMap {
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
return NewVersionedObjectCache(func() (VersionedObject, error) {
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
parts := strings.Split(key, "/")
if len(parts) != 3 {
return nil, fmt.Errorf("Invalid shard tablets path: %v", key)
}
sr, err := ts.GetShardReplication(parts[0], parts[1], parts[2])
sr, err := ts.GetShardReplication(ctx, parts[0], parts[1], parts[2])
if err != nil {
return nil, err
}

Просмотреть файл

@ -8,10 +8,12 @@ import (
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/zktopo"
"golang.org/x/net/context"
)
func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo VersionedObject, expectedVO VersionedObject) {
result, err := voc.Get()
ctx := context.Background()
result, err := voc.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -26,7 +28,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
t.Fatalf("Got bad result: %#v expected: %#v", vo, expectedVO)
}
result2, err := voc.Get()
result2, err := voc.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -36,7 +38,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
// force a re-get with same content, version shouldn't change
voc.timestamp = time.Time{}
result2, err = voc.Get()
result2, err = voc.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -47,7 +49,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
// force a reget with different content, version should change
voc.timestamp = time.Time{}
voc.versionedObject.Reset() // poking inside the object here
result, err = voc.Get()
result, err = voc.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -64,7 +66,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
// force a flush and see the version increase again
voc.Flush()
result, err = voc.Get()
result, err = voc.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -81,7 +83,8 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
}
func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, key string, vo VersionedObject, expectedVO VersionedObject) {
result, err := vocm.Get(key)
ctx := context.Background()
result, err := vocm.Get(ctx, key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -96,7 +99,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
t.Fatalf("Got bad result: %#v expected: %#v", vo, expectedVO)
}
result2, err := vocm.Get(key)
result2, err := vocm.Get(ctx, key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -106,7 +109,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
// force a re-get with same content, version shouldn't change
vocm.cacheMap[key].timestamp = time.Time{}
result2, err = vocm.Get(key)
result2, err = vocm.Get(ctx, key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -117,7 +120,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
// force a reget with different content, version should change
vocm.cacheMap[key].timestamp = time.Time{}
vocm.cacheMap[key].versionedObject.Reset() // poking inside the object here
result, err = vocm.Get(key)
result, err = vocm.Get(ctx, key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -134,7 +137,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
// force a flush and see the version increase again
vocm.Flush()
result, err = vocm.Get(key)
result, err = vocm.Get(ctx, key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -162,11 +165,12 @@ func TestKnownCellsCache(t *testing.T) {
}
func TestKeyspacesCache(t *testing.T) {
ctx := context.Background()
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{}); err != nil {
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{}); err != nil {
t.Fatalf("CreateKeyspace failed: %v", err)
}
if err := ts.CreateKeyspace("ks2", &topo.Keyspace{}); err != nil {
if err := ts.CreateKeyspace(ctx, "ks2", &topo.Keyspace{}); err != nil {
t.Fatalf("CreateKeyspace failed: %v", err)
}
kc := newKeyspacesCache(ts)
@ -179,13 +183,14 @@ func TestKeyspacesCache(t *testing.T) {
}
func TestKeyspaceCache(t *testing.T) {
ctx := context.Background()
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{
ShardingColumnName: "sharding_key",
}); err != nil {
t.Fatalf("CreateKeyspace failed: %v", err)
}
if err := ts.CreateKeyspace("ks2", &topo.Keyspace{
if err := ts.CreateKeyspace(ctx, "ks2", &topo.Keyspace{
SplitShardCount: 10,
}); err != nil {
t.Fatalf("CreateKeyspace failed: %v", err)
@ -211,18 +216,19 @@ func TestKeyspaceCache(t *testing.T) {
}
func TestShardNamesCache(t *testing.T) {
ctx := context.Background()
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{
ShardingColumnName: "sharding_key",
}); err != nil {
t.Fatalf("CreateKeyspace failed: %v", err)
}
if err := ts.CreateShard("ks1", "s1", &topo.Shard{
if err := ts.CreateShard(ctx, "ks1", "s1", &topo.Shard{
Cells: []string{"cell1", "cell2"},
}); err != nil {
t.Fatalf("CreateShard failed: %v", err)
}
if err := ts.CreateShard("ks1", "s2", &topo.Shard{
if err := ts.CreateShard(ctx, "ks1", "s2", &topo.Shard{
MasterAlias: topo.TabletAlias{
Cell: "cell1",
Uid: 12,
@ -241,18 +247,19 @@ func TestShardNamesCache(t *testing.T) {
}
func TestShardCache(t *testing.T) {
ctx := context.Background()
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{
ShardingColumnName: "sharding_key",
}); err != nil {
t.Fatalf("CreateKeyspace failed: %v", err)
}
if err := ts.CreateShard("ks1", "s1", &topo.Shard{
if err := ts.CreateShard(ctx, "ks1", "s1", &topo.Shard{
Cells: []string{"cell1", "cell2"},
}); err != nil {
t.Fatalf("CreateShard failed: %v", err)
}
if err := ts.CreateShard("ks1", "s2", &topo.Shard{
if err := ts.CreateShard(ctx, "ks1", "s2", &topo.Shard{
MasterAlias: topo.TabletAlias{
Cell: "cell1",
Uid: 12,
@ -286,8 +293,9 @@ func TestShardCache(t *testing.T) {
}
func TestCellShardTabletsCache(t *testing.T) {
ctx := context.Background()
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
if err := ts.UpdateShardReplicationFields("cell1", "ks1", "s1", func(sr *topo.ShardReplication) error {
if err := ts.UpdateShardReplicationFields(ctx, "cell1", "ks1", "s1", func(sr *topo.ShardReplication) error {
sr.ReplicationLinks = []topo.ReplicationLink{
topo.ReplicationLink{
TabletAlias: topo.TabletAlias{

Просмотреть файл

@ -11,20 +11,22 @@ import (
log "github.com/golang/glog"
"github.com/youtube/vitess/go/acl"
schmgr "github.com/youtube/vitess/go/vt/schemamanager"
"github.com/youtube/vitess/go/vt/schemamanager/uihandler"
"github.com/youtube/vitess/go/timer"
"github.com/youtube/vitess/go/vt/schemamanager"
"github.com/youtube/vitess/go/vt/servenv"
"github.com/youtube/vitess/go/vt/tabletmanager/tmclient"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/topotools"
"github.com/youtube/vitess/go/vt/wrangler"
// register gorpc vtgate client
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
)
var (
templateDir = flag.String("templates", "", "directory containing templates")
debug = flag.Bool("debug", false, "recompile templates for every request")
schemaChangeDir = flag.String("schema-change-dir", "", "directory contains schema changes for all keyspaces. Each keyspace has its own directory and schema changes are expected to live in '$KEYSPACE/input' dir. e.g. test_keyspace/input/*sql, each sql file represents a schema change")
schemaChangeController = flag.String("schema-change-controller", "", "schema change controller is responsible for finding schema changes and responsing schema change events")
schemaChangeCheckInterval = flag.Int("schema-change-check-interval", 60, "this value decides how often we check schema change dir, in seconds")
schemaChangeUser = flag.String("schema-change-user", "", "The user who submits this schema change.")
)
func init() {
@ -114,7 +116,7 @@ func main() {
// tablet actions
actionRepo.RegisterTabletAction("Ping", "",
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
ti, err := wr.TopoServer().GetTablet(tabletAlias)
ti, err := wr.TopoServer().GetTablet(ctx, tabletAlias)
if err != nil {
return "", err
}
@ -124,7 +126,7 @@ func main() {
actionRepo.RegisterTabletAction("ScrapTablet", acl.ADMIN,
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
// refuse to scrap tablets that are not spare
ti, err := wr.TopoServer().GetTablet(tabletAlias)
ti, err := wr.TopoServer().GetTablet(ctx, tabletAlias)
if err != nil {
return "", err
}
@ -137,7 +139,7 @@ func main() {
actionRepo.RegisterTabletAction("ScrapTabletForce", acl.ADMIN,
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
// refuse to scrap tablets that are not spare
ti, err := wr.TopoServer().GetTablet(tabletAlias)
ti, err := wr.TopoServer().GetTablet(ctx, tabletAlias)
if err != nil {
return "", err
}
@ -149,7 +151,7 @@ func main() {
actionRepo.RegisterTabletAction("DeleteTablet", acl.ADMIN,
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
return "", wr.DeleteTablet(tabletAlias)
return "", wr.DeleteTablet(ctx, tabletAlias)
})
// keyspace actions
@ -254,7 +256,8 @@ func main() {
cell := parts[len(parts)-1]
if cell == "" {
cells, err := ts.GetKnownCells()
ctx := context.Background()
cells, err := ts.GetKnownCells(ctx)
if err != nil {
httpError(w, "cannot get known cells: %v", err)
return
@ -263,9 +266,10 @@ func main() {
return
}
servingGraph := topotools.DbServingGraph(ts, cell)
ctx := context.Background()
servingGraph := topotools.DbServingGraph(ctx, ts, cell)
if modifyDbServingGraph != nil {
modifyDbServingGraph(ts, servingGraph)
modifyDbServingGraph(ctx, ts, servingGraph)
}
templateLoader.ServeTemplate("serving_graph.html", servingGraph, w, r)
})
@ -292,12 +296,13 @@ func main() {
Error error
Input, Output string
}
ctx := context.Background()
switch r.Method {
case "POST":
data.Input = r.FormValue("vschema")
data.Error = schemafier.SaveVSchema(data.Input)
data.Error = schemafier.SaveVSchema(ctx, data.Input)
}
vschema, err := schemafier.GetVSchema()
vschema, err := schemafier.GetVSchema(ctx)
if err != nil {
if data.Error == nil {
data.Error = fmt.Errorf("Error fetching schema: %s", err)
@ -330,7 +335,8 @@ func main() {
// serve some data
knownCellsCache := newKnownCellsCache(ts)
http.HandleFunc("/json/KnownCells", func(w http.ResponseWriter, r *http.Request) {
result, err := knownCellsCache.Get()
ctx := context.Background()
result, err := knownCellsCache.Get(ctx)
if err != nil {
httpError(w, "error getting known cells: %v", err)
return
@ -340,7 +346,8 @@ func main() {
keyspacesCache := newKeyspacesCache(ts)
http.HandleFunc("/json/Keyspaces", func(w http.ResponseWriter, r *http.Request) {
result, err := keyspacesCache.Get()
ctx := context.Background()
result, err := keyspacesCache.Get(ctx)
if err != nil {
httpError(w, "error getting keyspaces: %v", err)
return
@ -359,7 +366,8 @@ func main() {
http.Error(w, "no keyspace provided", http.StatusBadRequest)
return
}
result, err := keyspaceCache.Get(keyspace)
ctx := context.Background()
result, err := keyspaceCache.Get(ctx, keyspace)
if err != nil {
httpError(w, "error getting keyspace: %v", err)
return
@ -378,7 +386,8 @@ func main() {
http.Error(w, "no keyspace provided", http.StatusBadRequest)
return
}
result, err := shardNamesCache.Get(keyspace)
ctx := context.Background()
result, err := shardNamesCache.Get(ctx, keyspace)
if err != nil {
httpError(w, "error getting shardNames: %v", err)
return
@ -402,7 +411,8 @@ func main() {
http.Error(w, "no shard provided", http.StatusBadRequest)
return
}
result, err := shardCache.Get(keyspace + "/" + shard)
ctx := context.Background()
result, err := shardCache.Get(ctx, keyspace+"/"+shard)
if err != nil {
httpError(w, "error getting shard: %v", err)
return
@ -431,7 +441,8 @@ func main() {
http.Error(w, "no shard provided", http.StatusBadRequest)
return
}
result, err := cellShardTabletsCache.Get(cell + "/" + keyspace + "/" + shard)
ctx := context.Background()
result, err := cellShardTabletsCache.Get(ctx, cell+"/"+keyspace+"/"+shard)
if err != nil {
httpError(w, "error getting shard: %v", err)
return
@ -485,16 +496,47 @@ func main() {
}
sqlStr := r.FormValue("data")
keyspace := r.FormValue("keyspace")
shards, err := ts.GetShardNames(keyspace)
if err != nil {
httpError(w, "error getting shards for keyspace: <"+keyspace+">, error: %v", err)
}
schmgr.Run(
schmgr.NewSimepleDataSourcer(sqlStr),
schmgr.NewVtGateExecutor(
keyspace, nil, 1*time.Second),
uihandler.NewUIEventHandler(w),
shards)
executor := schemamanager.NewTabletExecutor(
tmclient.NewTabletManagerClient(),
ts)
schemamanager.Run(
context.Background(),
schemamanager.NewUIController(sqlStr, keyspace, w),
executor,
)
})
if *schemaChangeDir != "" {
interval := 60
if *schemaChangeCheckInterval > 0 {
interval = *schemaChangeCheckInterval
}
timer := timer.NewTimer(time.Duration(interval) * time.Second)
controllerFactory, err :=
schemamanager.GetControllerFactory(*schemaChangeController)
if err != nil {
log.Fatalf("unable to get a controller factory, error: %v", err)
}
timer.Start(func() {
controller, err := controllerFactory(map[string]string{
schemamanager.SchemaChangeDirName: *schemaChangeDir,
schemamanager.SchemaChangeUser: *schemaChangeUser,
})
if err != nil {
log.Errorf("failed to get controller, error: %v", err)
return
}
err = schemamanager.Run(
context.Background(),
controller,
schemamanager.NewTabletExecutor(
tmclient.NewTabletManagerClient(), ts),
)
log.Errorf("Schema change failed, error: %v", err)
})
servenv.OnClose(func() { timer.Stop() })
}
servenv.RunDefault()
}

Просмотреть файл

@ -14,6 +14,7 @@ import (
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtgate"
"github.com/youtube/vitess/go/vt/vtgate/planbuilder"
"golang.org/x/net/context"
)
var (
@ -48,7 +49,6 @@ func main() {
defer topo.CloseServers()
var schema *planbuilder.Schema
log.Info(*cell, *schemaFile)
if *schemaFile != "" {
var err error
if schema, err = planbuilder.LoadFile(*schemaFile); err != nil {
@ -62,7 +62,8 @@ func main() {
log.Infof("Skipping v3 initialization: topo does not suppurt schemafier interface")
goto startServer
}
schemaJSON, err := schemafier.GetVSchema()
ctx := context.Background()
schemaJSON, err := schemafier.GetVSchema(ctx)
if err != nil {
log.Warningf("Skipping v3 initialization: GetVSchema failed: %v", err)
goto startServer

Просмотреть файл

@ -0,0 +1,9 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
_ "github.com/youtube/vitess/go/vt/mysqlctl/filebackupstorage"
)

Просмотреть файл

@ -108,7 +108,6 @@ func main() {
exit.Return(1)
}
tabletmanager.HttpHandleSnapshots(mycnf, tabletAlias.Uid)
servenv.OnRun(func() {
addStatusParts(qsc)
})

Просмотреть файл

@ -14,6 +14,7 @@ import (
log "github.com/golang/glog"
"github.com/youtube/vitess/go/vt/worker"
"github.com/youtube/vitess/go/vt/wrangler"
"golang.org/x/net/context"
)
var (
@ -23,7 +24,7 @@ var (
type command struct {
Name string
method func(wr *wrangler.Wrangler, subFlags *flag.FlagSet, args []string) (worker.Worker, error)
interactive func(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request)
interactive func(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request)
params string
Help string // if help is empty, won't list the command
}
@ -112,7 +113,11 @@ func runCommand(args []string) error {
case <-done:
log.Infof("Command is done:")
log.Info(wrk.StatusAsText())
if wrk.Error() != nil {
currentWorkerMutex.Lock()
err := lastRunError
currentWorkerMutex.Unlock()
if err != nil {
log.Errorf("Ended with an error: %v", err)
os.Exit(1)
}
os.Exit(0)

Просмотреть файл

@ -10,6 +10,7 @@ import (
"net/http"
log "github.com/golang/glog"
"golang.org/x/net/context"
)
const indexHTML = `
@ -83,7 +84,8 @@ func initInteractiveMode() {
// closure.
pc := c
http.HandleFunc("/"+cg.Name+"/"+c.Name, func(w http.ResponseWriter, r *http.Request) {
pc.interactive(wr, w, r)
ctx := context.Background()
pc.interactive(ctx, wr, w, r)
})
}
}

Просмотреть файл

@ -18,6 +18,7 @@ import (
"github.com/youtube/vitess/go/vt/topotools"
"github.com/youtube/vitess/go/vt/worker"
"github.com/youtube/vitess/go/vt/wrangler"
"golang.org/x/net/context"
)
const splitCloneHTML = `
@ -107,8 +108,8 @@ func commandSplitClone(wr *wrangler.Wrangler, subFlags *flag.FlagSet, args []str
return worker, nil
}
func keyspacesWithOverlappingShards(wr *wrangler.Wrangler) ([]map[string]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces()
func keyspacesWithOverlappingShards(ctx context.Context, wr *wrangler.Wrangler) ([]map[string]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
if err != nil {
return nil, err
}
@ -121,7 +122,7 @@ func keyspacesWithOverlappingShards(wr *wrangler.Wrangler) ([]map[string]string,
wg.Add(1)
go func(keyspace string) {
defer wg.Done()
osList, err := topotools.FindOverlappingShards(wr.TopoServer(), keyspace)
osList, err := topotools.FindOverlappingShards(ctx, wr.TopoServer(), keyspace)
if err != nil {
rec.RecordError(err)
return
@ -147,7 +148,7 @@ func keyspacesWithOverlappingShards(wr *wrangler.Wrangler) ([]map[string]string,
return result, nil
}
func interactiveSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
func interactiveSplitClone(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httpError(w, "cannot parse form: %s", err)
return
@ -159,7 +160,7 @@ func interactiveSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter, r *http
// display the list of possible splits to choose from
// (just find all the overlapping guys)
result := make(map[string]interface{})
choices, err := keyspacesWithOverlappingShards(wr)
choices, err := keyspacesWithOverlappingShards(ctx, wr)
if err != nil {
result["Error"] = err.Error()
} else {

Просмотреть файл

@ -16,6 +16,7 @@ import (
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/worker"
"github.com/youtube/vitess/go/vt/wrangler"
"golang.org/x/net/context"
)
const splitDiffHTML = `
@ -76,8 +77,8 @@ func commandSplitDiff(wr *wrangler.Wrangler, subFlags *flag.FlagSet, args []stri
// shardsWithSources returns all the shards that have SourceShards set
// with no Tables list.
func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces()
func shardsWithSources(ctx context.Context, wr *wrangler.Wrangler) ([]map[string]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
if err != nil {
return nil, err
}
@ -90,7 +91,7 @@ func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
wg.Add(1)
go func(keyspace string) {
defer wg.Done()
shards, err := wr.TopoServer().GetShardNames(keyspace)
shards, err := wr.TopoServer().GetShardNames(ctx, keyspace)
if err != nil {
rec.RecordError(err)
return
@ -99,7 +100,7 @@ func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
wg.Add(1)
go func(keyspace, shard string) {
defer wg.Done()
si, err := wr.TopoServer().GetShard(keyspace, shard)
si, err := wr.TopoServer().GetShard(ctx, keyspace, shard)
if err != nil {
rec.RecordError(err)
return
@ -128,7 +129,7 @@ func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
return result, nil
}
func interactiveSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
func interactiveSplitDiff(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httpError(w, "cannot parse form: %s", err)
return
@ -139,7 +140,7 @@ func interactiveSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.
if keyspace == "" || shard == "" {
// display the list of possible shards to chose from
result := make(map[string]interface{})
shards, err := shardsWithSources(wr)
shards, err := shardsWithSources(ctx, wr)
if err != nil {
result["Error"] = err.Error()
} else {

Просмотреть файл

@ -5,6 +5,7 @@
package main
import (
"fmt"
"html/template"
"net/http"
"strings"
@ -68,17 +69,20 @@ func initStatusHandling() {
currentWorkerMutex.Lock()
wrk := currentWorker
logger := currentMemoryLogger
done := currentDone
ctx := currentContext
err := lastRunError
currentWorkerMutex.Unlock()
data := make(map[string]interface{})
if wrk != nil {
data["Status"] = wrk.StatusAsHTML()
select {
case <-done:
status := template.HTML("Current worker:<br>\n") + wrk.StatusAsHTML()
if ctx == nil {
data["Done"] = true
default:
if err != nil {
status += template.HTML(fmt.Sprintf("<br>\nEnded with an error: %v<br>\n", err))
}
}
data["Status"] = status
if logger != nil {
data["Logs"] = template.HTML(strings.Replace(logger.String(), "\n", "</br>\n", -1))
} else {
@ -99,29 +103,27 @@ func initStatusHandling() {
acl.SendError(w, err)
return
}
currentWorkerMutex.Lock()
wrk := currentWorker
done := currentDone
currentWorkerMutex.Unlock()
// no worker, we go to the menu
if wrk == nil {
if currentWorker == nil {
currentWorkerMutex.Unlock()
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
return
}
// check the worker is really done
select {
case <-done:
currentWorkerMutex.Lock()
if currentContext == nil {
currentWorker = nil
currentMemoryLogger = nil
currentDone = nil
currentWorkerMutex.Unlock()
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
default:
httpError(w, "worker still executing", nil)
return
}
currentWorkerMutex.Unlock()
httpError(w, "worker still executing", nil)
})
// cancel handler
@ -130,18 +132,20 @@ func initStatusHandling() {
acl.SendError(w, err)
return
}
currentWorkerMutex.Lock()
wrk := currentWorker
currentWorkerMutex.Unlock()
// no worker, we go to the menu
if wrk == nil {
currentWorkerMutex.Lock()
// no worker, or not running, we go to the menu
if currentWorker == nil || currentCancelFunc == nil {
currentWorkerMutex.Unlock()
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
return
}
// otherwise, cancel the running worker and go back to the status page
wrk.Cancel()
cancel := currentCancelFunc
currentWorkerMutex.Unlock()
cancel()
http.Redirect(w, r, servenv.StatusURLPath(), http.StatusTemporaryRedirect)
})

Просмотреть файл

@ -17,6 +17,7 @@ import (
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/worker"
"github.com/youtube/vitess/go/vt/wrangler"
"golang.org/x/net/context"
)
const (
@ -114,8 +115,8 @@ func commandVerticalSplitClone(wr *wrangler.Wrangler, subFlags *flag.FlagSet, ar
// keyspacesWithServedFrom returns all the keyspaces that have ServedFrom set
// to one value.
func keyspacesWithServedFrom(wr *wrangler.Wrangler) ([]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces()
func keyspacesWithServedFrom(ctx context.Context, wr *wrangler.Wrangler) ([]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
if err != nil {
return nil, err
}
@ -128,7 +129,7 @@ func keyspacesWithServedFrom(wr *wrangler.Wrangler) ([]string, error) {
wg.Add(1)
go func(keyspace string) {
defer wg.Done()
ki, err := wr.TopoServer().GetKeyspace(keyspace)
ki, err := wr.TopoServer().GetKeyspace(ctx, keyspace)
if err != nil {
rec.RecordError(err)
return
@ -151,7 +152,7 @@ func keyspacesWithServedFrom(wr *wrangler.Wrangler) ([]string, error) {
return result, nil
}
func interactiveVerticalSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
func interactiveVerticalSplitClone(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httpError(w, "cannot parse form: %s", err)
return
@ -161,7 +162,7 @@ func interactiveVerticalSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter,
if keyspace == "" {
// display the list of possible keyspaces to choose from
result := make(map[string]interface{})
keyspaces, err := keyspacesWithServedFrom(wr)
keyspaces, err := keyspacesWithServedFrom(ctx, wr)
if err != nil {
result["Error"] = err.Error()
} else {

Просмотреть файл

@ -16,6 +16,7 @@ import (
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/worker"
"github.com/youtube/vitess/go/vt/wrangler"
"golang.org/x/net/context"
)
const verticalSplitDiffHTML = `
@ -75,8 +76,8 @@ func commandVerticalSplitDiff(wr *wrangler.Wrangler, subFlags *flag.FlagSet, arg
// shardsWithTablesSources returns all the shards that have SourceShards set
// to one value, with an array of Tables.
func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces()
func shardsWithTablesSources(ctx context.Context, wr *wrangler.Wrangler) ([]map[string]string, error) {
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
if err != nil {
return nil, err
}
@ -89,7 +90,7 @@ func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error)
wg.Add(1)
go func(keyspace string) {
defer wg.Done()
shards, err := wr.TopoServer().GetShardNames(keyspace)
shards, err := wr.TopoServer().GetShardNames(ctx, keyspace)
if err != nil {
rec.RecordError(err)
return
@ -98,7 +99,7 @@ func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error)
wg.Add(1)
go func(keyspace, shard string) {
defer wg.Done()
si, err := wr.TopoServer().GetShard(keyspace, shard)
si, err := wr.TopoServer().GetShard(ctx, keyspace, shard)
if err != nil {
rec.RecordError(err)
return
@ -127,7 +128,7 @@ func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error)
return result, nil
}
func interactiveVerticalSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
func interactiveVerticalSplitDiff(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
httpError(w, "cannot parse form: %s", err)
return
@ -138,7 +139,7 @@ func interactiveVerticalSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter,
if keyspace == "" || shard == "" {
// display the list of possible shards to chose from
result := make(map[string]interface{})
shards, err := shardsWithTablesSources(wr)
shards, err := shardsWithTablesSources(ctx, wr)
if err != nil {
result["Error"] = err.Error()
} else {

Просмотреть файл

@ -29,6 +29,7 @@ import (
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/worker"
"github.com/youtube/vitess/go/vt/wrangler"
"golang.org/x/net/context"
)
var (
@ -44,10 +45,20 @@ var (
wr *wrangler.Wrangler
// mutex is protecting all the following variables
// 3 states here:
// - no job ever ran (or reset was run): currentWorker is nil,
// currentContext/currentCancelFunc is nil, lastRunError is nil
// - one worker running: currentWorker is set,
// currentContext/currentCancelFunc is set, lastRunError is nil
// - (at least) one worker already ran, none is running atm:
// currentWorker is set, currentContext is nil, lastRunError
// has the error returned by the worker.
currentWorkerMutex sync.Mutex
currentWorker worker.Worker
currentMemoryLogger *logutil.MemoryLogger
currentDone chan struct{}
currentContext context.Context
currentCancelFunc context.CancelFunc
lastRunError error
)
// signal handling, centralized here
@ -59,7 +70,9 @@ func installSignalHandlers() {
// we got a signal, notify our modules
currentWorkerMutex.Lock()
defer currentWorkerMutex.Unlock()
currentWorker.Cancel()
if currentCancelFunc != nil {
currentCancelFunc()
}
}()
}
@ -75,17 +88,27 @@ func setAndStartWorker(wrk worker.Worker) (chan struct{}, error) {
currentWorker = wrk
currentMemoryLogger = logutil.NewMemoryLogger()
currentDone = make(chan struct{})
currentContext, currentCancelFunc = context.WithCancel(context.Background())
lastRunError = nil
done := make(chan struct{})
wr.SetLogger(logutil.NewTeeLogger(currentMemoryLogger, logutil.NewConsoleLogger()))
// one go function runs the worker, closes 'done' when done
// one go function runs the worker, changes state when done
go func() {
// run will take a long time
log.Infof("Starting worker...")
wrk.Run()
close(currentDone)
err := wrk.Run(currentContext)
// it's done, let's save our state
currentWorkerMutex.Lock()
currentContext = nil
currentCancelFunc = nil
lastRunError = err
currentWorkerMutex.Unlock()
close(done)
}()
return currentDone, nil
return done, nil
}
func main() {

Просмотреть файл

@ -50,12 +50,12 @@ func connect() *rpcplus.Client {
return rpcClient
}
func getSrvKeyspaceNames(rpcClient *rpcplus.Client, cell string, verbose bool) {
func getSrvKeyspaceNames(ctx context.Context, rpcClient *rpcplus.Client, cell string, verbose bool) {
req := &topo.GetSrvKeyspaceNamesArgs{
Cell: cell,
}
reply := &topo.SrvKeyspaceNames{}
if err := rpcClient.Call(context.TODO(), "TopoReader.GetSrvKeyspaceNames", req, reply); err != nil {
if err := rpcClient.Call(ctx, "TopoReader.GetSrvKeyspaceNames", req, reply); err != nil {
log.Fatalf("TopoReader.GetSrvKeyspaceNames error: %v", err)
}
if verbose {
@ -65,13 +65,13 @@ func getSrvKeyspaceNames(rpcClient *rpcplus.Client, cell string, verbose bool) {
}
}
func getSrvKeyspace(rpcClient *rpcplus.Client, cell, keyspace string, verbose bool) {
func getSrvKeyspace(ctx context.Context, rpcClient *rpcplus.Client, cell, keyspace string, verbose bool) {
req := &topo.GetSrvKeyspaceArgs{
Cell: cell,
Keyspace: keyspace,
}
reply := &topo.SrvKeyspace{}
if err := rpcClient.Call(context.TODO(), "TopoReader.GetSrvKeyspace", req, reply); err != nil {
if err := rpcClient.Call(ctx, "TopoReader.GetSrvKeyspace", req, reply); err != nil {
log.Fatalf("TopoReader.GetSrvKeyspace error: %v", err)
}
if verbose {
@ -89,7 +89,7 @@ func getSrvKeyspace(rpcClient *rpcplus.Client, cell, keyspace string, verbose bo
}
}
func getEndPoints(rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType string, verbose bool) {
func getEndPoints(ctx context.Context, rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType string, verbose bool) {
req := &topo.GetEndPointsArgs{
Cell: cell,
Keyspace: keyspace,
@ -97,7 +97,7 @@ func getEndPoints(rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType s
TabletType: topo.TabletType(tabletType),
}
reply := &topo.EndPoints{}
if err := rpcClient.Call(context.TODO(), "TopoReader.GetEndPoints", req, reply); err != nil {
if err := rpcClient.Call(ctx, "TopoReader.GetEndPoints", req, reply); err != nil {
log.Fatalf("TopoReader.GetEndPoints error: %v", err)
}
if verbose {
@ -109,14 +109,14 @@ func getEndPoints(rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType s
// qps is a function used by tests to run a vtgate load check.
// It will get the same srvKeyspaces as fast as possible and display the QPS.
func qps(cell string, keyspaces []string) {
func qps(ctx context.Context, cell string, keyspaces []string) {
var count sync2.AtomicInt32
for _, keyspace := range keyspaces {
for i := 0; i < 10; i++ {
go func() {
rpcClient := connect()
for true {
getSrvKeyspace(rpcClient, cell, keyspace, false)
getSrvKeyspace(ctx, rpcClient, cell, keyspace, false)
count.Add(1)
}
}()
@ -157,10 +157,11 @@ func main() {
defer pprof.StopCPUProfile()
}
ctx := context.Background()
if *mode == "getSrvKeyspaceNames" {
rpcClient := connect()
if len(args) == 1 {
getSrvKeyspaceNames(rpcClient, args[0], true)
getSrvKeyspaceNames(ctx, rpcClient, args[0], true)
} else {
log.Errorf("getSrvKeyspaceNames only takes one argument")
exit.Return(1)
@ -169,7 +170,7 @@ func main() {
} else if *mode == "getSrvKeyspace" {
rpcClient := connect()
if len(args) == 2 {
getSrvKeyspace(rpcClient, args[0], args[1], true)
getSrvKeyspace(ctx, rpcClient, args[0], args[1], true)
} else {
log.Errorf("getSrvKeyspace only takes two arguments")
exit.Return(1)
@ -178,14 +179,14 @@ func main() {
} else if *mode == "getEndPoints" {
rpcClient := connect()
if len(args) == 4 {
getEndPoints(rpcClient, args[0], args[1], args[2], args[3], true)
getEndPoints(ctx, rpcClient, args[0], args[1], args[2], args[3], true)
} else {
log.Errorf("getEndPoints only takes four arguments")
exit.Return(1)
}
} else if *mode == "qps" {
qps(args[0], args[1:])
qps(ctx, args[0], args[1:])
} else {
flag.Usage()

Просмотреть файл

@ -99,27 +99,28 @@ type Charset struct {
// Convert takes a type and a value, and returns the type:
// - nil for NULL value
// - int64 if possible, otherwise, uint64
// - uint64 for unsigned BIGINT values
// - int64 for all other integer values (signed and unsigned)
// - float64 for floating point values that fit in a float
// - []byte for everything else
// TODO(mberlin): Make this a method of "Field" and consider VT_UNSIGNED_FLAG in "Flags" as well.
func Convert(mysqlType int64, val sqltypes.Value) (interface{}, error) {
func Convert(field Field, val sqltypes.Value) (interface{}, error) {
if val.IsNull() {
return nil, nil
}
switch mysqlType {
case VT_TINY, VT_SHORT, VT_LONG, VT_LONGLONG, VT_INT24:
val := val.String()
signed, err := strconv.ParseInt(val, 0, 64)
if err == nil {
return signed, nil
switch field.Type {
case VT_LONGLONG:
if field.Flags&VT_UNSIGNED_FLAG == VT_UNSIGNED_FLAG {
return strconv.ParseUint(val.String(), 0, 64)
}
unsigned, err := strconv.ParseUint(val, 0, 64)
if err == nil {
return unsigned, nil
}
return nil, err
return strconv.ParseInt(val.String(), 0, 64)
case VT_TINY, VT_SHORT, VT_LONG, VT_INT24:
// Regardless of whether UNSIGNED_FLAG is set in field.Flags, we map all
// signed and unsigned values to a signed Go type because
// - Go doesn't officially support uint64 in their SQL interface
// - there is no loss of the value
// The only exception we make are for unsigned BIGINTs, see VT_LONGLONG above.
return strconv.ParseInt(val.String(), 0, 64)
case VT_FLOAT, VT_DOUBLE:
return strconv.ParseFloat(val.String(), 64)
}

Просмотреть файл

@ -12,81 +12,79 @@ import (
func TestConvert(t *testing.T) {
cases := []struct {
Desc string
Typ int64
Field Field
Val sqltypes.Value
Want interface{}
}{{
Desc: "null",
Typ: VT_LONG,
Field: Field{"null", VT_LONG, VT_ZEROVALUE_FLAG},
Val: sqltypes.Value{},
Want: nil,
}, {
Desc: "decimal",
Typ: VT_DECIMAL,
Field: Field{"decimal", VT_DECIMAL, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("aa")),
Want: "aa",
}, {
Desc: "tiny",
Typ: VT_TINY,
Field: Field{"tiny", VT_TINY, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1")),
Want: int64(1),
}, {
Desc: "short",
Typ: VT_SHORT,
Field: Field{"short", VT_SHORT, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1")),
Want: int64(1),
}, {
Desc: "long",
Typ: VT_LONG,
Field: Field{"long", VT_LONG, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1")),
Want: int64(1),
}, {
Desc: "longlong",
Typ: VT_LONGLONG,
Field: Field{"unsigned long", VT_LONG, VT_UNSIGNED_FLAG},
Val: sqltypes.MakeString([]byte("1")),
// Unsigned types which aren't VT_LONGLONG are mapped to int64.
Want: int64(1),
}, {
Field: Field{"longlong", VT_LONGLONG, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1")),
Want: int64(1),
}, {
Desc: "int24",
Typ: VT_INT24,
Field: Field{"int24", VT_INT24, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1")),
Want: int64(1),
}, {
Desc: "float",
Typ: VT_FLOAT,
Field: Field{"float", VT_FLOAT, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1")),
Want: float64(1),
}, {
Desc: "double",
Typ: VT_DOUBLE,
Field: Field{"double", VT_DOUBLE, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1")),
Want: float64(1),
}, {
Desc: "large int",
Typ: VT_LONGLONG,
Field: Field{"large int out of range for int64", VT_LONGLONG, VT_ZEROVALUE_FLAG},
// 2^63, out of range for int64
Val: sqltypes.MakeString([]byte("9223372036854775808")),
Want: `strconv.ParseInt: parsing "9223372036854775808": value out of range`,
}, {
Field: Field{"large int", VT_LONGLONG, VT_UNSIGNED_FLAG},
// 2^63, not out of range for uint64
Val: sqltypes.MakeString([]byte("9223372036854775808")),
Want: uint64(9223372036854775808),
}, {
Desc: "float for int",
Typ: VT_LONGLONG,
Field: Field{"float for int", VT_LONGLONG, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("1.1")),
Want: `strconv.ParseUint: parsing "1.1": invalid syntax`,
Want: `strconv.ParseInt: parsing "1.1": invalid syntax`,
}, {
Desc: "string for float",
Typ: VT_FLOAT,
Field: Field{"string for float", VT_FLOAT, VT_ZEROVALUE_FLAG},
Val: sqltypes.MakeString([]byte("aa")),
Want: `strconv.ParseFloat: parsing "aa": invalid syntax`,
}}
for _, c := range cases {
r, err := Convert(c.Typ, c.Val)
r, err := Convert(c.Field, c.Val)
if err != nil {
r = err.Error()
} else if _, ok := r.([]byte); ok {
r = string(r.([]byte))
}
if r != c.Want {
t.Errorf("%s: %+v, want %+v", c.Desc, r, c.Want)
t.Errorf("%s: %+v, want %+v", c.Field.Name, r, c.Want)
}
}
}

Просмотреть файл

@ -77,7 +77,7 @@ func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Dur
// Close empties the pool calling Close on all its resources.
// You can call Close while there are outstanding resources.
// It waits for all resources to be returned (Put).
// After a Close, Get and TryGet are not allowed.
// After a Close, Get is not allowed.
func (rp *ResourcePool) Close() {
_ = rp.SetCapacity(0)
}
@ -95,13 +95,6 @@ func (rp *ResourcePool) Get(ctx context.Context) (resource Resource, err error)
return rp.get(ctx, true)
}
// TryGet will return the next available resource. If none is available, and capacity
// has not been reached, it will create a new one using the factory. Otherwise,
// it will return nil with no error.
func (rp *ResourcePool) TryGet() (resource Resource, err error) {
return rp.get(context.TODO(), false)
}
func (rp *ResourcePool) get(ctx context.Context, wait bool) (resource Resource, err error) {
// If ctx has already expired, avoid racing with rp's resource channel.
select {

Просмотреть файл

@ -74,38 +74,6 @@ func TestOpen(t *testing.T) {
}
}
// Test TryGet
r, err := p.TryGet()
if err != nil {
t.Errorf("Unexpected error %v", err)
}
if r != nil {
t.Errorf("Expecting nil")
}
for i := 0; i < 5; i++ {
p.Put(resources[i])
_, available, _, _, _, _ := p.Stats()
if available != int64(i+1) {
t.Errorf("expecting %d, received %d", 5-i-1, available)
}
}
for i := 0; i < 5; i++ {
r, err := p.TryGet()
resources[i] = r
if err != nil {
t.Errorf("Unexpected error %v", err)
}
if r == nil {
t.Errorf("Expecting non-nil")
}
if lastID.Get() != 5 {
t.Errorf("Expecting 5, received %d", lastID.Get())
}
if count.Get() != 5 {
t.Errorf("Expecting 5, received %d", count.Get())
}
}
// Test that Get waits
ch := make(chan bool)
go func() {
@ -139,7 +107,7 @@ func TestOpen(t *testing.T) {
}
// Test Close resource
r, err = p.Get(ctx)
r, err := p.Get(ctx)
if err != nil {
t.Errorf("Unexpected error %v", err)
}
@ -241,15 +209,6 @@ func TestShrinking(t *testing.T) {
t.Errorf(`expecting '%s', received '%s'`, expected, stats)
}
// TryGet is allowed when shrinking
r, err := p.TryGet()
if err != nil {
t.Errorf("Unexpected error %v", err)
}
if r != nil {
t.Errorf("Expecting nil")
}
// Get is allowed when shrinking, but it will wait
getdone := make(chan bool)
go func() {
@ -278,6 +237,7 @@ func TestShrinking(t *testing.T) {
// Ensure no deadlock if SetCapacity is called after we start
// waiting for a resource
var err error
for i := 0; i < 3; i++ {
resources[i], err = p.Get(ctx)
if err != nil {

Просмотреть файл

@ -0,0 +1,52 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
pb "github.com/youtube/vitess/go/vt/proto/automation"
)
// ClusterOperationInstance is a runtime type which enhances the protobuf message "ClusterOperation" with runtime specific data.
// Unlike the protobuf message, the additional runtime data will not be part of a checkpoint.
type ClusterOperationInstance struct {
pb.ClusterOperation
taskIDGenerator *IDGenerator
}
// NewClusterOperationInstance creates a new cluster operation instance with one initial task.
func NewClusterOperationInstance(clusterOpID string, initialTask *pb.TaskContainer, taskIDGenerator *IDGenerator) *ClusterOperationInstance {
c := &ClusterOperationInstance{
pb.ClusterOperation{
Id: clusterOpID,
SerialTasks: []*pb.TaskContainer{},
State: pb.ClusterOperationState_CLUSTER_OPERATION_NOT_STARTED,
},
taskIDGenerator,
}
c.InsertTaskContainers([]*pb.TaskContainer{initialTask}, 0)
return c
}
// addMissingTaskID assigns a task id to each task in "tc".
func (c *ClusterOperationInstance) addMissingTaskID(tc []*pb.TaskContainer) {
for _, taskContainer := range tc {
for _, task := range taskContainer.ParallelTasks {
if task.Id == "" {
task.Id = c.taskIDGenerator.GetNextID()
}
}
}
}
// InsertTaskContainers inserts "newTaskContainers" at pos in the current list of task containers. Existing task containers will be moved after the new task containers.
func (c *ClusterOperationInstance) InsertTaskContainers(newTaskContainers []*pb.TaskContainer, pos int) {
c.addMissingTaskID(newTaskContainers)
newSerialTasks := make([]*pb.TaskContainer, len(c.SerialTasks)+len(newTaskContainers))
copy(newSerialTasks, c.SerialTasks[:pos])
copy(newSerialTasks[pos:], newTaskContainers)
copy(newSerialTasks[pos+len(newTaskContainers):], c.SerialTasks[pos:])
c.SerialTasks = newSerialTasks
}

Просмотреть файл

@ -0,0 +1,106 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
"fmt"
"strings"
pb "github.com/youtube/vitess/go/vt/proto/automation"
)
// HorizontalReshardingTask is a cluster operation which allows to increase the number of shards.
type HorizontalReshardingTask struct {
}
// TODO(mberlin): Uncomment/remove when "ForceReparent" and "CopySchemaShard" will be implemented.
//func selectAnyTabletFromShardByType(shard string, tabletType string) string {
// return ""
//}
func (t *HorizontalReshardingTask) run(parameters map[string]string) ([]*pb.TaskContainer, string, error) {
// Example: test_keyspace
keyspace := parameters["keyspace"]
// Example: 10-20
sourceShards := strings.Split(parameters["source_shard_list"], ",")
// Example: 10-18,18-20
destShards := strings.Split(parameters["dest_shard_list"], ",")
// Example: cell1-0000062352
sourceRdonlyTablets := strings.Split(parameters["source_shard_rdonly_list"], ",")
var newTasks []*pb.TaskContainer
// TODO(mberlin): Implement "ForceParent" task and uncomment this.
// reparentTasks := NewTaskContainer()
// for _, destShard := range destShards {
// newMaster := selectAnyTabletFromShardByType(destShard, "master")
// AddTask(reparentTasks, "ForceReparent", map[string]string{
// "shard": destShard,
// "master": newMaster,
// })
// }
// newTasks = append(newTasks, reparentTasks)
// TODO(mberlin): Implement "CopySchemaShard" task and uncomment this.
// copySchemaTasks := NewTaskContainer()
// sourceRdonlyTablet := selectAnyTabletFromShardByType(sourceShards[0], "rdonly")
// for _, destShard := range destShards {
// AddTask(copySchemaTasks, "CopySchemaShard", map[string]string{
// "shard": destShard,
// "source_rdonly_tablet": sourceRdonlyTablet,
// })
// }
// newTasks = append(newTasks, copySchemaTasks)
splitCloneTasks := NewTaskContainer()
for _, sourceShard := range sourceShards {
// TODO(mberlin): Add a semaphore as argument to limit the parallism.
AddTask(splitCloneTasks, "vtworker", map[string]string{
"command": "SplitClone",
"keyspace": keyspace,
"shard": sourceShard,
"vtworker_endpoint": parameters["vtworker_endpoint"],
})
}
newTasks = append(newTasks, splitCloneTasks)
// TODO(mberlin): Remove this once SplitClone does this on its own.
restoreTypeTasks := NewTaskContainer()
for _, sourceRdonlyTablet := range sourceRdonlyTablets {
AddTask(restoreTypeTasks, "vtctl", map[string]string{
"command": fmt.Sprintf("ChangeSlaveType %v rdonly", sourceRdonlyTablet),
})
}
newTasks = append(newTasks, restoreTypeTasks)
splitDiffTasks := NewTaskContainer()
for _, destShard := range destShards {
AddTask(splitDiffTasks, "vtworker", map[string]string{
"command": "SplitDiff",
"keyspace": keyspace,
"shard": destShard,
"vtworker_endpoint": parameters["vtworker_endpoint"],
})
}
newTasks = append(newTasks, splitDiffTasks)
// TODO(mberlin): Implement "CopySchemaShard" task and uncomment this.
// for _, servedType := range []string{"rdonly", "replica", "master"} {
// migrateServedTypesTasks := NewTaskContainer()
// for _, sourceShard := range sourceShards {
// AddTask(migrateServedTypesTasks, "MigrateServedTypes", map[string]string{
// "keyspace": keyspace,
// "shard": sourceShard,
// "served_type": servedType,
// })
// }
// newTasks = append(newTasks, migrateServedTypesTasks)
// }
return newTasks, "", nil
}
func (t *HorizontalReshardingTask) requiredParameters() []string {
return []string{"keyspace", "source_shard_list", "source_shard_rdonly_list", "dest_shard_list"}
}

Просмотреть файл

@ -0,0 +1,35 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
"testing"
"github.com/golang/protobuf/proto"
)
func TestHorizontalReshardingTaskEmittedTasks(t *testing.T) {
reshardingTask := &HorizontalReshardingTask{}
parameters := map[string]string{
"source_shard_rdonly_list": "cell1-0000062352",
"keyspace": "test_keyspace",
"source_shard_list": "10-20",
"dest_shard_list": "10-18,18-20",
"vtworker_endpoint": "localhost:12345",
}
err := checkRequiredParameters(reshardingTask, parameters)
if err != nil {
t.Fatalf("Not all required parameters were specified: %v", err)
}
newTaskContainers, _, _ := reshardingTask.run(parameters)
// TODO(mberlin): Check emitted tasks against expected output.
for _, tc := range newTaskContainers {
t.Logf("new tasks: %v", proto.MarshalTextString(tc))
}
}

Просмотреть файл

@ -0,0 +1,18 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import "strconv"
import "sync/atomic"
// IDGenerator generates unique task and cluster operation IDs.
type IDGenerator struct {
counter int64
}
// GetNextID returns an ID which wasn't returned before.
func (ig *IDGenerator) GetNextID() string {
return strconv.FormatInt(atomic.AddInt64(&ig.counter, 1), 10)
}

Просмотреть файл

@ -0,0 +1,317 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package automation contains code to execute high-level cluster operations
(e.g. resharding) as a series of low-level operations
(e.g. vtctl, shell commands, ...).
*/
package automation
import (
"fmt"
"sync"
log "github.com/golang/glog"
pb "github.com/youtube/vitess/go/vt/proto/automation"
"golang.org/x/net/context"
)
type schedulerState int32
const (
stateNotRunning schedulerState = iota
stateRunning
stateShuttingDown
stateShutdown
)
type taskCreator func(string) Task
// Scheduler executes automation tasks and maintains the execution state.
type Scheduler struct {
idGenerator IDGenerator
mu sync.Mutex
// Guarded by "mu".
registeredClusterOperations map[string]bool
// Guarded by "mu".
toBeScheduledClusterOperations chan *ClusterOperationInstance
// Guarded by "mu".
state schedulerState
// Guarded by "taskCreatorMu". May be overriden by testing code.
taskCreator taskCreator
taskCreatorMu sync.Mutex
pendingOpsWg *sync.WaitGroup
muOpList sync.Mutex
// Guarded by "muOpList".
activeClusterOperations map[string]*ClusterOperationInstance
// Guarded by "muOpList".
finishedClusterOperations map[string]*ClusterOperationInstance
}
// NewScheduler creates a new instance.
func NewScheduler() (*Scheduler, error) {
defaultClusterOperations := map[string]bool{
"HorizontalReshardingTask": true,
}
s := &Scheduler{
registeredClusterOperations: defaultClusterOperations,
idGenerator: IDGenerator{},
toBeScheduledClusterOperations: make(chan *ClusterOperationInstance, 10),
state: stateNotRunning,
taskCreator: defaultTaskCreator,
pendingOpsWg: &sync.WaitGroup{},
activeClusterOperations: make(map[string]*ClusterOperationInstance),
finishedClusterOperations: make(map[string]*ClusterOperationInstance),
}
return s, nil
}
func (s *Scheduler) registerClusterOperation(clusterOperationName string) {
s.mu.Lock()
defer s.mu.Unlock()
s.registeredClusterOperations[clusterOperationName] = true
}
// Run processes queued cluster operations.
func (s *Scheduler) Run() {
s.mu.Lock()
s.state = stateRunning
s.mu.Unlock()
s.startProcessRequestsLoop()
}
func (s *Scheduler) startProcessRequestsLoop() {
// Use a WaitGroup instead of just a done channel, because we want
// to be able to shut down the scheduler even if Run() was never executed.
s.pendingOpsWg.Add(1)
go s.processRequestsLoop()
}
func (s *Scheduler) processRequestsLoop() {
defer s.pendingOpsWg.Done()
for op := range s.toBeScheduledClusterOperations {
s.processClusterOperation(op)
}
log.Infof("Stopped processing loop for ClusterOperations.")
}
func (s *Scheduler) processClusterOperation(clusterOp *ClusterOperationInstance) {
if clusterOp.State == pb.ClusterOperationState_CLUSTER_OPERATION_DONE {
log.Infof("ClusterOperation: %v skipping because it is already done. Details: %v", clusterOp.Id, clusterOp)
return
}
log.Infof("ClusterOperation: %v running. Details: %v", clusterOp.Id, clusterOp)
var lastTaskError string
for i := 0; i < len(clusterOp.SerialTasks); i++ {
taskContainer := clusterOp.SerialTasks[i]
for _, taskProto := range taskContainer.ParallelTasks {
if taskProto.State == pb.TaskState_DONE {
if taskProto.Error != "" {
log.Errorf("Task: %v (%v/%v) failed before. Aborting the ClusterOperation. Error: %v Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, taskProto.Error, taskProto)
lastTaskError = taskProto.Error
break
} else {
log.Infof("Task: %v (%v/%v) skipped because it is already done. Full Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, taskProto)
}
}
task, err := s.createTaskInstance(taskProto.Name)
if err != nil {
log.Errorf("Task: %v (%v/%v) could not be instantiated. Error: %v Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, err, taskProto)
MarkTaskFailed(taskProto, "", err)
lastTaskError = err.Error()
break
}
taskProto.State = pb.TaskState_RUNNING
log.Infof("Task: %v (%v/%v) running. Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, taskProto)
newTaskContainers, output, errRun := task.run(taskProto.Parameters)
log.Infof("Task: %v (%v/%v) finished. newTaskContainers: %v, output: %v, error: %v", taskProto.Name, clusterOp.Id, taskProto.Id, newTaskContainers, output, errRun)
if errRun != nil {
MarkTaskFailed(taskProto, output, errRun)
lastTaskError = errRun.Error()
break
}
MarkTaskSucceeded(taskProto, output)
if newTaskContainers != nil {
// Make sure all new tasks do not miss any required parameters.
for _, newTaskContainer := range newTaskContainers {
for _, newTaskProto := range newTaskContainer.ParallelTasks {
err := s.validateTaskSpecification(newTaskProto.Name, newTaskProto.Parameters)
if err != nil {
log.Errorf("Task: %v (%v/%v) emitted a new task which is not valid. Error: %v Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, err, newTaskProto)
MarkTaskFailed(taskProto, output, err)
lastTaskError = err.Error()
break
}
}
}
if lastTaskError == "" {
clusterOp.InsertTaskContainers(newTaskContainers, i+1)
log.Infof("ClusterOperation: %v %d new task containers added by %v (%v/%v). Updated ClusterOperation: %v",
clusterOp.Id, len(newTaskContainers), taskProto.Name, clusterOp.Id, taskProto.Id, clusterOp)
}
}
}
}
clusterOp.State = pb.ClusterOperationState_CLUSTER_OPERATION_DONE
if lastTaskError != "" {
clusterOp.Error = lastTaskError
}
log.Infof("ClusterOperation: %v finished. Details: %v", clusterOp.Id, clusterOp)
// Move operation from active to finished.
s.muOpList.Lock()
if s.activeClusterOperations[clusterOp.Id] != clusterOp {
panic("Pending ClusterOperation was not recorded as active, but should have.")
}
delete(s.activeClusterOperations, clusterOp.Id)
s.finishedClusterOperations[clusterOp.Id] = clusterOp
s.muOpList.Unlock()
}
func defaultTaskCreator(taskName string) Task {
switch taskName {
case "HorizontalReshardingTask":
return &HorizontalReshardingTask{}
default:
return nil
}
}
func (s *Scheduler) setTaskCreator(creator taskCreator) {
s.taskCreatorMu.Lock()
defer s.taskCreatorMu.Unlock()
s.taskCreator = creator
}
func (s *Scheduler) validateTaskSpecification(taskName string, parameters map[string]string) error {
taskInstanceForParametersCheck, err := s.createTaskInstance(taskName)
if err != nil {
return err
}
errParameters := checkRequiredParameters(taskInstanceForParametersCheck, parameters)
if errParameters != nil {
return errParameters
}
return nil
}
func (s *Scheduler) createTaskInstance(taskName string) (Task, error) {
s.taskCreatorMu.Lock()
taskCreator := s.taskCreator
s.taskCreatorMu.Unlock()
task := taskCreator(taskName)
if task == nil {
return nil, fmt.Errorf("No implementation found for: %v", taskName)
}
return task, nil
}
// checkRequiredParameters returns an error if not all required parameters are provided in "parameters".
func checkRequiredParameters(task Task, parameters map[string]string) error {
for _, requiredParameter := range task.requiredParameters() {
if _, ok := parameters[requiredParameter]; !ok {
return fmt.Errorf("Parameter %v is required, but not provided", requiredParameter)
}
}
return nil
}
// EnqueueClusterOperation can be used to start a new cluster operation.
func (s *Scheduler) EnqueueClusterOperation(ctx context.Context, req *pb.EnqueueClusterOperationRequest) (*pb.EnqueueClusterOperationResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state != stateRunning {
return nil, fmt.Errorf("Scheduler is not running. State: %v", s.state)
}
if s.registeredClusterOperations[req.Name] != true {
return nil, fmt.Errorf("No ClusterOperation with name: %v is registered", req.Name)
}
err := s.validateTaskSpecification(req.Name, req.Parameters)
if err != nil {
return nil, err
}
clusterOpID := s.idGenerator.GetNextID()
taskIDGenerator := IDGenerator{}
initialTask := NewTaskContainerWithSingleTask(req.Name, req.Parameters)
clusterOp := NewClusterOperationInstance(clusterOpID, initialTask, &taskIDGenerator)
s.muOpList.Lock()
s.toBeScheduledClusterOperations <- clusterOp
s.activeClusterOperations[clusterOpID] = clusterOp
s.muOpList.Unlock()
return &pb.EnqueueClusterOperationResponse{
Id: clusterOp.Id,
}, nil
}
// findClusterOp checks for a given ClusterOperation ID if it's in the list of active or finished operations.
func (s *Scheduler) findClusterOp(id string) (*ClusterOperationInstance, error) {
var ok bool
var clusterOp *ClusterOperationInstance
s.muOpList.Lock()
defer s.muOpList.Unlock()
clusterOp, ok = s.activeClusterOperations[id]
if !ok {
clusterOp, ok = s.finishedClusterOperations[id]
}
if !ok {
return nil, fmt.Errorf("ClusterOperation with id: %v not found", id)
}
return clusterOp, nil
}
// GetClusterOperationDetails can be used to query the full details of active or finished operations.
func (s *Scheduler) GetClusterOperationDetails(ctx context.Context, req *pb.GetClusterOperationDetailsRequest) (*pb.GetClusterOperationDetailsResponse, error) {
clusterOp, err := s.findClusterOp(req.Id)
if err != nil {
return nil, err
}
return &pb.GetClusterOperationDetailsResponse{
ClusterOp: &clusterOp.ClusterOperation,
}, nil
}
// ShutdownAndWait shuts down the scheduler and waits infinitely until all pending cluster operations have finished.
func (s *Scheduler) ShutdownAndWait() {
s.mu.Lock()
if s.state != stateShuttingDown {
s.state = stateShuttingDown
close(s.toBeScheduledClusterOperations)
}
s.mu.Unlock()
log.Infof("Scheduler was shut down. Waiting for pending ClusterOperations to finish.")
s.pendingOpsWg.Wait()
s.mu.Lock()
s.state = stateShutdown
s.mu.Unlock()
log.Infof("All pending ClusterOperations finished.")
}

Просмотреть файл

@ -0,0 +1,242 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
"testing"
"time"
"github.com/golang/protobuf/proto"
context "golang.org/x/net/context"
pb "github.com/youtube/vitess/go/vt/proto/automation"
)
// newTestScheduler constructs a scheduler with test tasks.
// If tasks should be available as cluster operation, they still have to be registered manually with scheduler.registerClusterOperation.
func newTestScheduler(t *testing.T) *Scheduler {
scheduler, err := NewScheduler()
if err != nil {
t.Fatalf("Failed to create scheduler: %v", err)
}
scheduler.setTaskCreator(testingTaskCreator)
return scheduler
}
// waitForClusterOperation is a helper function which blocks until the Cluster Operation has finished.
func waitForClusterOperation(t *testing.T, scheduler *Scheduler, id string, expectedOutputLastTask string, expectedErrorLastTask string) *pb.ClusterOperation {
if expectedOutputLastTask == "" && expectedErrorLastTask == "" {
t.Fatal("Error in test: Cannot wait for an operation where both output and error are expected to be empty.")
}
getDetailsRequest := &pb.GetClusterOperationDetailsRequest{
Id: id,
}
for {
getDetailsResponse, err := scheduler.GetClusterOperationDetails(context.TODO(), getDetailsRequest)
if err != nil {
t.Fatalf("Failed to get details for cluster operation. Request: %v Error: %v", getDetailsRequest, err)
}
if getDetailsResponse.ClusterOp.State == pb.ClusterOperationState_CLUSTER_OPERATION_DONE {
tc := getDetailsResponse.ClusterOp.SerialTasks
lastTc := tc[len(tc)-1]
if expectedOutputLastTask != "" {
if lastTc.ParallelTasks[len(lastTc.ParallelTasks)-1].Output != expectedOutputLastTask {
t.Fatalf("ClusterOperation finished but did not return expected output. want: %v Full ClusterOperation details: %v", expectedOutputLastTask, proto.MarshalTextString(getDetailsResponse.ClusterOp))
}
}
if expectedErrorLastTask != "" {
if lastTc.ParallelTasks[len(lastTc.ParallelTasks)-1].Error != expectedErrorLastTask {
t.Fatalf("ClusterOperation finished but did not return expected error. Full ClusterOperation details: %v", getDetailsResponse.ClusterOp)
}
}
return getDetailsResponse.ClusterOp
}
t.Logf("Waiting for clusterOp: %v", getDetailsResponse.ClusterOp)
time.Sleep(5 * time.Millisecond)
}
}
func TestSchedulerImmediateShutdown(t *testing.T) {
// Make sure that the scheduler shuts down cleanly when it was instantiated, but not started with Run().
scheduler, err := NewScheduler()
if err != nil {
t.Fatalf("Failed to create scheduler: %v", err)
}
scheduler.ShutdownAndWait()
}
func enqueueClusterOperationAndCheckOutput(t *testing.T, taskName string, expectedOutput string) {
scheduler := newTestScheduler(t)
defer scheduler.ShutdownAndWait()
scheduler.registerClusterOperation("TestingEchoTask")
scheduler.registerClusterOperation("TestingEmitEchoTask")
scheduler.Run()
enqueueRequest := &pb.EnqueueClusterOperationRequest{
Name: taskName,
Parameters: map[string]string{
"echo_text": expectedOutput,
},
}
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
if err != nil {
t.Fatalf("Failed to start cluster operation. Request: %v Error: %v", enqueueRequest, err)
}
waitForClusterOperation(t, scheduler, enqueueResponse.Id, expectedOutput, "")
}
func TestEnqueueSingleTask(t *testing.T) {
enqueueClusterOperationAndCheckOutput(t, "TestingEchoTask", "echoed text")
}
func TestEnqueueEmittingTask(t *testing.T) {
enqueueClusterOperationAndCheckOutput(t, "TestingEmitEchoTask", "echoed text from emitted task")
}
func TestEnqueueFailsDueToMissingParameter(t *testing.T) {
scheduler := newTestScheduler(t)
defer scheduler.ShutdownAndWait()
scheduler.registerClusterOperation("TestingEchoTask")
scheduler.Run()
enqueueRequest := &pb.EnqueueClusterOperationRequest{
Name: "TestingEchoTask",
Parameters: map[string]string{
"unrelevant-parameter": "value",
},
}
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
if err == nil {
t.Fatalf("Scheduler should have failed to start cluster operation because not all required parameters were provided. Request: %v Error: %v Response: %v", enqueueRequest, err, enqueueResponse)
}
want := "Parameter echo_text is required, but not provided"
if err.Error() != want {
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
}
}
func TestFailedTaskFailsClusterOperation(t *testing.T) {
scheduler := newTestScheduler(t)
defer scheduler.ShutdownAndWait()
scheduler.registerClusterOperation("TestingFailTask")
scheduler.Run()
enqueueRequest := &pb.EnqueueClusterOperationRequest{
Name: "TestingFailTask",
}
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
if err != nil {
t.Fatalf("Failed to start cluster operation. Request: %v Error: %v", enqueueRequest, err)
}
waitForClusterOperation(t, scheduler, enqueueResponse.Id, "something went wrong", "full error message")
}
func TestEnqueueFailsDueToUnregisteredClusterOperation(t *testing.T) {
scheduler := newTestScheduler(t)
defer scheduler.ShutdownAndWait()
scheduler.Run()
enqueueRequest := &pb.EnqueueClusterOperationRequest{
Name: "TestingEchoTask",
Parameters: map[string]string{
"unrelevant-parameter": "value",
},
}
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
if err == nil {
t.Fatalf("Scheduler should have failed to start cluster operation because it should not have been registered. Request: %v Error: %v Response: %v", enqueueRequest, err, enqueueResponse)
}
want := "No ClusterOperation with name: TestingEchoTask is registered"
if err.Error() != want {
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
}
}
func TestGetDetailsFailsUnknownId(t *testing.T) {
scheduler := newTestScheduler(t)
defer scheduler.ShutdownAndWait()
scheduler.Run()
getDetailsRequest := &pb.GetClusterOperationDetailsRequest{
Id: "-1", // There will never be a ClusterOperation with this id.
}
getDetailsResponse, err := scheduler.GetClusterOperationDetails(context.TODO(), getDetailsRequest)
if err == nil {
t.Fatalf("Did not fail to get details for invalid ClusterOperation id. Request: %v Response: %v Error: %v", getDetailsRequest, getDetailsResponse, err)
}
want := "ClusterOperation with id: -1 not found"
if err.Error() != want {
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
}
}
func TestEnqueueFailsBecauseTaskInstanceCannotBeCreated(t *testing.T) {
scheduler := newTestScheduler(t)
defer scheduler.ShutdownAndWait()
scheduler.setTaskCreator(defaultTaskCreator)
// TestingEchoTask is registered as cluster operation, but its task cannot be instantied because "testingTaskCreator" was not set.
scheduler.registerClusterOperation("TestingEchoTask")
scheduler.Run()
enqueueRequest := &pb.EnqueueClusterOperationRequest{
Name: "TestingEchoTask",
Parameters: map[string]string{
"unrelevant-parameter": "value",
},
}
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
if err == nil {
t.Fatalf("Scheduler should have failed to start cluster operation because the task could not be instantiated. Request: %v Error: %v Response: %v", enqueueRequest, err, enqueueResponse)
}
want := "No implementation found for: TestingEchoTask"
if err.Error() != want {
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
}
}
func TestTaskEmitsTaskWhichCannotBeInstantiated(t *testing.T) {
scheduler := newTestScheduler(t)
defer scheduler.ShutdownAndWait()
scheduler.setTaskCreator(func(taskName string) Task {
// TaskCreator which doesn't know TestingEchoTask (but emitted by TestingEmitEchoTask).
switch taskName {
case "TestingEmitEchoTask":
return &TestingEmitEchoTask{}
default:
return nil
}
})
scheduler.registerClusterOperation("TestingEmitEchoTask")
scheduler.Run()
enqueueRequest := &pb.EnqueueClusterOperationRequest{
Name: "TestingEmitEchoTask",
}
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
if err != nil {
t.Fatalf("Failed to start cluster operation. Request: %v Error: %v", enqueueRequest, err)
}
details := waitForClusterOperation(t, scheduler, enqueueResponse.Id, "emitted TestingEchoTask", "No implementation found for: TestingEchoTask")
if len(details.SerialTasks) != 1 {
t.Errorf("A task has been emitted, but it shouldn't. Details:\n%v", proto.MarshalTextString(details))
}
}

20
go/vt/automation/task.go Normal file
Просмотреть файл

@ -0,0 +1,20 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
pb "github.com/youtube/vitess/go/vt/proto/automation"
)
// Task implementations can be executed by the scheduler.
type Task interface {
// run executes the task using the key/values from parameters.
// "newTaskContainers" contains new tasks which the task can emit. They'll be inserted in the cluster operation directly after this task. It may be "nil".
// "output" may be empty. It contains any text which maybe must e.g. to debug the task or show it in the UI.
run(parameters map[string]string) (newTaskContainers []*pb.TaskContainer, output string, err error)
// requiredParameters() returns a list of parameter keys which must be provided as input for run().
requiredParameters() []string
}

Просмотреть файл

@ -0,0 +1,32 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
pb "github.com/youtube/vitess/go/vt/proto/automation"
)
// Helper functions for "TaskContainer" protobuf message.
// NewTaskContainerWithSingleTask creates a new task container with exactly one task.
func NewTaskContainerWithSingleTask(taskName string, parameters map[string]string) *pb.TaskContainer {
return &pb.TaskContainer{
ParallelTasks: []*pb.Task{
NewTask(taskName, parameters),
},
}
}
// NewTaskContainer creates an empty task container. Use AddTask() to add tasks to it.
func NewTaskContainer() *pb.TaskContainer {
return &pb.TaskContainer{
ParallelTasks: []*pb.Task{},
}
}
// AddTask adds a new task to an existing task container.
func AddTask(t *pb.TaskContainer, taskName string, parameters map[string]string) {
t.ParallelTasks = append(t.ParallelTasks, NewTask(taskName, parameters))
}

33
go/vt/automation/tasks.go Normal file
Просмотреть файл

@ -0,0 +1,33 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
pb "github.com/youtube/vitess/go/vt/proto/automation"
)
// Helper functions for "Task" protobuf message.
// MarkTaskSucceeded marks the task as done.
func MarkTaskSucceeded(t *pb.Task, output string) {
t.State = pb.TaskState_DONE
t.Output = output
}
// MarkTaskFailed marks the task as failed.
func MarkTaskFailed(t *pb.Task, output string, err error) {
t.State = pb.TaskState_DONE
t.Output = output
t.Error = err.Error()
}
// NewTask creates a new task protobuf message for "taskName" with "parameters".
func NewTask(taskName string, parameters map[string]string) *pb.Task {
return &pb.Task{
State: pb.TaskState_NOT_STARTED,
Name: taskName,
Parameters: parameters,
}
}

Просмотреть файл

@ -0,0 +1,66 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package automation
import (
"errors"
pb "github.com/youtube/vitess/go/vt/proto/automation"
)
func testingTaskCreator(taskName string) Task {
switch taskName {
// Tasks for testing only.
case "TestingEchoTask":
return &TestingEchoTask{}
case "TestingEmitEchoTask":
return &TestingEmitEchoTask{}
case "TestingFailTask":
return &TestingFailTask{}
default:
return nil
}
}
// TestingEchoTask is used only for testing. It returns the join of all parameter values.
type TestingEchoTask struct {
}
func (t *TestingEchoTask) run(parameters map[string]string) (newTasks []*pb.TaskContainer, output string, err error) {
for _, v := range parameters {
output += v
}
return
}
func (t *TestingEchoTask) requiredParameters() []string {
return []string{"echo_text"}
}
// TestingEmitEchoTask is used only for testing. It emits a TestingEchoTask.
type TestingEmitEchoTask struct {
}
func (t *TestingEmitEchoTask) run(parameters map[string]string) (newTasks []*pb.TaskContainer, output string, err error) {
return []*pb.TaskContainer{
NewTaskContainerWithSingleTask("TestingEchoTask", parameters),
}, "emitted TestingEchoTask", nil
}
func (t *TestingEmitEchoTask) requiredParameters() []string {
return []string{}
}
// TestingFailTask is used only for testing. It always fails.
type TestingFailTask struct {
}
func (t *TestingFailTask) run(parameters map[string]string) (newTasks []*pb.TaskContainer, output string, err error) {
return nil, "something went wrong", errors.New("full error message")
}
func (t *TestingFailTask) requiredParameters() []string {
return []string{}
}

Просмотреть файл

@ -64,7 +64,7 @@ func getStatementCategory(sql []byte) int {
type BinlogStreamer struct {
// dbname and mysqld are set at creation.
dbname string
mysqld *mysqlctl.Mysqld
mysqld mysqlctl.MysqlDaemon
clientCharset *mproto.Charset
startPos myproto.ReplicationPosition
sendTransaction sendTransactionFunc
@ -79,7 +79,7 @@ type BinlogStreamer struct {
// charset is the default character set on the BinlogPlayer side.
// startPos is the position to start streaming at.
// sendTransaction is called each time a transaction is committed or rolled back.
func NewBinlogStreamer(dbname string, mysqld *mysqlctl.Mysqld, clientCharset *mproto.Charset, startPos myproto.ReplicationPosition, sendTransaction sendTransactionFunc) *BinlogStreamer {
func NewBinlogStreamer(dbname string, mysqld mysqlctl.MysqlDaemon, clientCharset *mproto.Charset, startPos myproto.ReplicationPosition, sendTransaction sendTransactionFunc) *BinlogStreamer {
return &BinlogStreamer{
dbname: dbname,
mysqld: mysqld,
@ -99,7 +99,7 @@ func (bls *BinlogStreamer) Stream(ctx *sync2.ServiceContext) (err error) {
log.Infof("stream ended @ %v, err = %v", stopPos, err)
}()
if bls.conn, err = mysqlctl.NewSlaveConnection(bls.mysqld); err != nil {
if bls.conn, err = bls.mysqld.NewSlaveConnection(); err != nil {
return err
}
defer bls.conn.Close()

Просмотреть файл

@ -39,7 +39,7 @@ type EventStreamer struct {
sendEvent sendEventFunc
}
func NewEventStreamer(dbname string, mysqld *mysqlctl.Mysqld, startPos myproto.ReplicationPosition, sendEvent sendEventFunc) *EventStreamer {
func NewEventStreamer(dbname string, mysqld mysqlctl.MysqlDaemon, startPos myproto.ReplicationPosition, sendEvent sendEventFunc) *EventStreamer {
evs := &EventStreamer{
sendEvent: sendEvent,
}

Просмотреть файл

@ -45,7 +45,7 @@ type UpdateStream struct {
actionLock sync.Mutex
state sync2.AtomicInt64
mysqld *mysqlctl.Mysqld
mysqld mysqlctl.MysqlDaemon
stateWaitGroup sync.WaitGroup
dbname string
streams streamList
@ -121,7 +121,7 @@ func logError() {
}
// EnableUpdateStreamService enables the RPC service for UpdateStream
func EnableUpdateStreamService(dbname string, mysqld *mysqlctl.Mysqld) {
func EnableUpdateStreamService(dbname string, mysqld mysqlctl.MysqlDaemon) {
defer logError()
UpdateStreamRpcService.enable(dbname, mysqld)
}
@ -148,7 +148,7 @@ func GetReplicationPosition() (myproto.ReplicationPosition, error) {
return UpdateStreamRpcService.getReplicationPosition()
}
func (updateStream *UpdateStream) enable(dbname string, mysqld *mysqlctl.Mysqld) {
func (updateStream *UpdateStream) enable(dbname string, mysqld mysqlctl.MysqlDaemon) {
updateStream.actionLock.Lock()
defer updateStream.actionLock.Unlock()
if updateStream.isEnabled() {

Просмотреть файл

@ -52,17 +52,17 @@ type conn struct {
TabletType topo.TabletType `json:"tablet_type"`
Streaming bool
Timeout time.Duration
vtgateConn vtgateconn.VTGateConn
tx vtgateconn.VTGateTx
vtgateConn *vtgateconn.VTGateConn
tx *vtgateconn.VTGateTx
}
func (c *conn) dial() error {
dialer := vtgateconn.GetDialerWithProtocol(c.Protocol)
if dialer == nil {
return fmt.Errorf("could not find dialer for protocol %s", c.Protocol)
}
var err error
c.vtgateConn, err = dialer(context.Background(), c.Address, c.Timeout)
if c.Protocol == "" {
c.vtgateConn, err = vtgateconn.Dial(context.Background(), c.Address, c.Timeout)
} else {
c.vtgateConn, err = vtgateconn.DialProtocol(context.Background(), c.Protocol, c.Address, c.Timeout)
}
return err
}
@ -150,7 +150,7 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
defer cancel()
if s.c.Streaming {
qrc, errFunc := s.c.vtgateConn.StreamExecute(ctx, s.query, makeBindVars(args), s.c.TabletType)
return vtgateconn.NewStreamingRows(qrc, errFunc), nil
return newStreamingRows(qrc, errFunc), nil
}
var qr *mproto.QueryResult
var err error
@ -162,7 +162,7 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
if err != nil {
return nil, err
}
return vtgateconn.NewRows(qr), nil
return newRows(qr), nil
}
func makeBindVars(args []driver.Value) map[string]interface{} {

Просмотреть файл

@ -94,7 +94,7 @@ func TestDial(t *testing.T) {
_ = c.Close()
_, err = drv{}.Open(`{"protocol": "none"}`)
want := "could not find dialer for protocol none"
want := "no dialer registered for VTGate protocol none"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("err: %v, want %s", err, want)
}

Просмотреть файл

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vtgateconn
package client
import (
"database/sql/driver"
@ -20,8 +20,8 @@ type rows struct {
index int
}
// NewRows creates a new rows from qr.
func NewRows(qr *mproto.QueryResult) driver.Rows {
// newRows creates a new rows from qr.
func newRows(qr *mproto.QueryResult) driver.Rows {
return &rows{qr: qr}
}
@ -46,6 +46,10 @@ func (ri *rows) Next(dest []driver.Value) error {
return err
}
// populateRow populates a row of data using the table's field descriptions.
// The returned types for "dest" include the list from the interface
// specification at https://golang.org/pkg/database/sql/driver/#Value
// and in addition the type "uint64" for unsigned BIGINT MySQL records.
func populateRow(dest []driver.Value, fields []mproto.Field, row []sqltypes.Value) error {
if len(dest) != len(fields) {
return fmt.Errorf("length mismatch: dest is %d, fields are %d", len(dest), len(fields))
@ -55,7 +59,7 @@ func populateRow(dest []driver.Value, fields []mproto.Field, row []sqltypes.Valu
}
var err error
for i := range dest {
dest[i], err = mproto.Convert(fields[i].Type, row[i])
dest[i], err = mproto.Convert(fields[i], row[i])
if err != nil {
return fmt.Errorf("conversion error: field: %v, val: %v: %v", fields[i], row[i], err)
}

Просмотреть файл

@ -1,4 +1,4 @@
package vtgateconn
package client
import (
"database/sql/driver"
@ -10,7 +10,7 @@ import (
"github.com/youtube/vitess/go/sqltypes"
)
var result1 = mproto.QueryResult{
var rowsResult1 = mproto.QueryResult{
Fields: []mproto.Field{
mproto.Field{
Name: "field1",
@ -24,29 +24,59 @@ var result1 = mproto.QueryResult{
Name: "field3",
Type: mproto.VT_VAR_STRING,
},
// Signed types which are smaller than uint64, will become an int64.
mproto.Field{
Name: "field4",
Type: mproto.VT_LONG,
Flags: mproto.VT_UNSIGNED_FLAG,
},
RowsAffected: 3,
// Signed uint64 values must be mapped to uint64.
mproto.Field{
Name: "field5",
Type: mproto.VT_LONGLONG,
Flags: mproto.VT_UNSIGNED_FLAG,
},
},
RowsAffected: 2,
InsertId: 0,
Rows: [][]sqltypes.Value{
[]sqltypes.Value{
sqltypes.MakeString([]byte("1")),
sqltypes.MakeString([]byte("1.1")),
sqltypes.MakeString([]byte("value1")),
sqltypes.MakeString([]byte("2147483647")), // 2^31-1, NOT out of range for int32 => should become int64
sqltypes.MakeString([]byte("9223372036854775807")), // 2^63-1, NOT out of range for int64
},
[]sqltypes.Value{
sqltypes.MakeString([]byte("2")),
sqltypes.MakeString([]byte("2.2")),
sqltypes.MakeString([]byte("value2")),
sqltypes.MakeString([]byte("4294967295")), // 2^32, out of range for int32 => should become int64
sqltypes.MakeString([]byte("18446744073709551615")), // 2^64, out of range for int64
},
},
}
func logMismatchedTypes(t *testing.T, gotRow, wantRow []driver.Value) {
for i := 1; i < len(wantRow); i++ {
got := gotRow[i]
want := wantRow[i]
v1 := reflect.ValueOf(got)
v2 := reflect.ValueOf(want)
if v1.Type() != v2.Type() {
t.Errorf("Wrong type: field: %d got: %T want: %T", i+1, got, want)
}
}
}
func TestRows(t *testing.T) {
ri := NewRows(&result1)
ri := newRows(&rowsResult1)
wantCols := []string{
"field1",
"field2",
"field3",
"field4",
"field5",
}
gotCols := ri.Columns()
if !reflect.DeepEqual(gotCols, wantCols) {
@ -57,20 +87,25 @@ func TestRows(t *testing.T) {
int64(1),
float64(1.1),
[]byte("value1"),
int64(2147483647),
uint64(9223372036854775807),
}
gotRow := make([]driver.Value, 3)
gotRow := make([]driver.Value, len(wantRow))
err := ri.Next(gotRow)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(gotRow, wantRow) {
t.Errorf("row1: %v, want %v", gotRow, wantRow)
t.Errorf("row1: %v, want %v type: %T", gotRow, wantRow, wantRow[3])
logMismatchedTypes(t, gotRow, wantRow)
}
wantRow = []driver.Value{
int64(2),
float64(2.2),
[]byte("value2"),
int64(4294967295),
uint64(18446744073709551615),
}
err = ri.Next(gotRow)
if err != nil {
@ -78,6 +113,7 @@ func TestRows(t *testing.T) {
}
if !reflect.DeepEqual(gotRow, wantRow) {
t.Errorf("row1: %v, want %v", gotRow, wantRow)
logMismatchedTypes(t, gotRow, wantRow)
}
err = ri.Next(gotRow)
@ -112,7 +148,7 @@ var badResult2 = mproto.QueryResult{
}
func TestRowsFail(t *testing.T) {
ri := NewRows(&badResult1)
ri := newRows(&badResult1)
var dest []driver.Value
err := ri.Next(dest)
want := "length mismatch: dest is 0, fields are 1"
@ -120,7 +156,7 @@ func TestRowsFail(t *testing.T) {
t.Errorf("Next: %v, want %s", err, want)
}
ri = NewRows(&badResult1)
ri = newRows(&badResult1)
dest = make([]driver.Value, 1)
err = ri.Next(dest)
want = "internal error: length mismatch: dest is 1, fields are 0"
@ -128,10 +164,10 @@ func TestRowsFail(t *testing.T) {
t.Errorf("Next: %v, want %s", err, want)
}
ri = NewRows(&badResult2)
ri = newRows(&badResult2)
dest = make([]driver.Value, 1)
err = ri.Next(dest)
want = `conversion error: field: {field1 3 0}, val: value: strconv.ParseUint: parsing "value": invalid syntax`
want = `conversion error: field: {field1 3 0}, val: value: strconv.ParseInt: parsing "value": invalid syntax`
if err == nil || err.Error() != want {
t.Errorf("Next: %v, want %s", err, want)
}

Просмотреть файл

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vtgateconn
package client
import (
"database/sql/driver"
@ -10,21 +10,22 @@ import (
"io"
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
)
// streamingRows creates a database/sql/driver compliant Row iterator
// for a streaming query.
type streamingRows struct {
qrc <-chan *mproto.QueryResult
errFunc ErrFunc
errFunc vtgateconn.ErrFunc
failed error
fields []mproto.Field
qr *mproto.QueryResult
index int
}
// NewStreamingRows creates a new streamingRows from qrc and errFunc.
func NewStreamingRows(qrc <-chan *mproto.QueryResult, errFunc ErrFunc) driver.Rows {
// newStreamingRows creates a new streamingRows from qrc and errFunc.
func newStreamingRows(qrc <-chan *mproto.QueryResult, errFunc vtgateconn.ErrFunc) driver.Rows {
return &streamingRows{qrc: qrc, errFunc: errFunc}
}

Просмотреть файл

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vtgateconn
package client
import (
"database/sql/driver"
@ -14,6 +14,7 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
)
var packet1 = mproto.QueryResult{
@ -54,7 +55,7 @@ var packet3 = mproto.QueryResult{
}
func TestStreamingRows(t *testing.T) {
qrc, errFunc := func() (<-chan *mproto.QueryResult, ErrFunc) {
qrc, errFunc := func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
ch := make(chan *mproto.QueryResult)
go func() {
ch <- &packet1
@ -64,7 +65,7 @@ func TestStreamingRows(t *testing.T) {
}()
return ch, func() error { return nil }
}()
ri := NewStreamingRows(qrc, errFunc)
ri := newStreamingRows(qrc, errFunc)
wantCols := []string{
"field1",
"field2",
@ -111,7 +112,7 @@ func TestStreamingRows(t *testing.T) {
}
func TestStreamingRowsReversed(t *testing.T) {
qrc, errFunc := func() (<-chan *mproto.QueryResult, ErrFunc) {
qrc, errFunc := func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
ch := make(chan *mproto.QueryResult)
go func() {
ch <- &packet1
@ -121,7 +122,7 @@ func TestStreamingRowsReversed(t *testing.T) {
}()
return ch, func() error { return nil }
}()
ri := NewStreamingRows(qrc, errFunc)
ri := newStreamingRows(qrc, errFunc)
wantRow := []driver.Value{
int64(1),
@ -151,14 +152,14 @@ func TestStreamingRowsReversed(t *testing.T) {
}
func TestStreamingRowsError(t *testing.T) {
qrc, errFunc := func() (<-chan *mproto.QueryResult, ErrFunc) {
qrc, errFunc := func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
ch := make(chan *mproto.QueryResult)
go func() {
close(ch)
}()
return ch, func() error { return errors.New("error before fields") }
}()
ri := NewStreamingRows(qrc, errFunc)
ri := newStreamingRows(qrc, errFunc)
gotCols := ri.Columns()
if gotCols != nil {
t.Errorf("cols: %v, want nil", gotCols)
@ -171,7 +172,7 @@ func TestStreamingRowsError(t *testing.T) {
}
_ = ri.Close()
qrc, errFunc = func() (<-chan *mproto.QueryResult, ErrFunc) {
qrc, errFunc = func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
ch := make(chan *mproto.QueryResult)
go func() {
ch <- &packet1
@ -179,7 +180,7 @@ func TestStreamingRowsError(t *testing.T) {
}()
return ch, func() error { return errors.New("error after fields") }
}()
ri = NewStreamingRows(qrc, errFunc)
ri = newStreamingRows(qrc, errFunc)
wantCols := []string{
"field1",
"field2",
@ -202,7 +203,7 @@ func TestStreamingRowsError(t *testing.T) {
}
_ = ri.Close()
qrc, errFunc = func() (<-chan *mproto.QueryResult, ErrFunc) {
qrc, errFunc = func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
ch := make(chan *mproto.QueryResult)
go func() {
ch <- &packet1
@ -211,7 +212,7 @@ func TestStreamingRowsError(t *testing.T) {
}()
return ch, func() error { return errors.New("error after rows") }
}()
ri = NewStreamingRows(qrc, errFunc)
ri = newStreamingRows(qrc, errFunc)
gotRow = make([]driver.Value, 3)
err = ri.Next(gotRow)
if err != nil {
@ -224,7 +225,7 @@ func TestStreamingRowsError(t *testing.T) {
}
_ = ri.Close()
qrc, errFunc = func() (<-chan *mproto.QueryResult, ErrFunc) {
qrc, errFunc = func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
ch := make(chan *mproto.QueryResult)
go func() {
ch <- &packet2
@ -232,7 +233,7 @@ func TestStreamingRowsError(t *testing.T) {
}()
return ch, func() error { return nil }
}()
ri = NewStreamingRows(qrc, errFunc)
ri = newStreamingRows(qrc, errFunc)
gotRow = make([]driver.Value, 3)
err = ri.Next(gotRow)
wantErr = "first packet did not return fields"

Просмотреть файл

@ -1,290 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package client2
import (
"fmt"
"strconv"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/sqlparser"
)
const (
EID_NODE = iota
VALUE_NODE
LIST_NODE
OTHER_NODE
)
type RoutingPlan struct {
criteria sqlparser.SQLNode
}
func GetShardList(sql string, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardlist []int, err error) {
plan, err := buildPlan(sql)
if err != nil {
return nil, err
}
return shardListFromPlan(plan, bindVariables, tabletKeys)
}
func buildPlan(sql string) (plan *RoutingPlan, err error) {
statement, err := sqlparser.Parse(sql)
if err != nil {
return nil, err
}
return getRoutingPlan(statement)
}
func shardListFromPlan(plan *RoutingPlan, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardList []int, err error) {
if plan.criteria == nil {
return makeList(0, len(tabletKeys)), nil
}
switch criteria := plan.criteria.(type) {
case sqlparser.Values:
index, err := findInsertShard(criteria, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return []int{index}, nil
case *sqlparser.ComparisonExpr:
switch criteria.Operator {
case "=", "<=>":
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return []int{index}, nil
case "<", "<=":
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return makeList(0, index+1), nil
case ">", ">=":
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return makeList(index, len(tabletKeys)), nil
case "in":
return findShardList(criteria.Right, bindVariables, tabletKeys)
}
case *sqlparser.RangeCond:
if criteria.Operator == "between" {
start, err := findShard(criteria.From, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
last, err := findShard(criteria.To, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
if last < start {
start, last = last, start
}
return makeList(start, last+1), nil
}
}
return makeList(0, len(tabletKeys)), nil
}
func getRoutingPlan(statement sqlparser.Statement) (plan *RoutingPlan, err error) {
plan = &RoutingPlan{}
if ins, ok := statement.(*sqlparser.Insert); ok {
if sel, ok := ins.Rows.(sqlparser.SelectStatement); ok {
return getRoutingPlan(sel)
}
plan.criteria, err = routingAnalyzeValues(ins.Rows.(sqlparser.Values))
if err != nil {
return nil, err
}
return plan, nil
}
var where *sqlparser.Where
switch stmt := statement.(type) {
case *sqlparser.Select:
where = stmt.Where
case *sqlparser.Update:
where = stmt.Where
case *sqlparser.Delete:
where = stmt.Where
}
if where != nil {
plan.criteria = routingAnalyzeBoolean(where.Expr)
}
return plan, nil
}
func routingAnalyzeValues(vals sqlparser.Values) (sqlparser.Values, error) {
// Analyze first value of every item in the list
for i := 0; i < len(vals); i++ {
switch tuple := vals[i].(type) {
case sqlparser.ValTuple:
result := routingAnalyzeValue(tuple[0])
if result != VALUE_NODE {
return nil, fmt.Errorf("insert is too complex")
}
default:
return nil, fmt.Errorf("insert is too complex")
}
}
return vals, nil
}
func routingAnalyzeBoolean(node sqlparser.BoolExpr) sqlparser.BoolExpr {
switch node := node.(type) {
case *sqlparser.AndExpr:
left := routingAnalyzeBoolean(node.Left)
right := routingAnalyzeBoolean(node.Right)
if left != nil && right != nil {
return nil
} else if left != nil {
return left
} else {
return right
}
case *sqlparser.ParenBoolExpr:
return routingAnalyzeBoolean(node.Expr)
case *sqlparser.ComparisonExpr:
switch {
case sqlparser.StringIn(node.Operator, "=", "<", ">", "<=", ">=", "<=>"):
left := routingAnalyzeValue(node.Left)
right := routingAnalyzeValue(node.Right)
if (left == EID_NODE && right == VALUE_NODE) || (left == VALUE_NODE && right == EID_NODE) {
return node
}
case node.Operator == "in":
left := routingAnalyzeValue(node.Left)
right := routingAnalyzeValue(node.Right)
if left == EID_NODE && right == LIST_NODE {
return node
}
}
case *sqlparser.RangeCond:
if node.Operator != "between" {
return nil
}
left := routingAnalyzeValue(node.Left)
from := routingAnalyzeValue(node.From)
to := routingAnalyzeValue(node.To)
if left == EID_NODE && from == VALUE_NODE && to == VALUE_NODE {
return node
}
}
return nil
}
func routingAnalyzeValue(valExpr sqlparser.ValExpr) int {
switch node := valExpr.(type) {
case *sqlparser.ColName:
if string(node.Name) == "entity_id" {
return EID_NODE
}
case sqlparser.ValTuple:
for _, n := range node {
if routingAnalyzeValue(n) != VALUE_NODE {
return OTHER_NODE
}
}
return LIST_NODE
case sqlparser.StrVal, sqlparser.NumVal, sqlparser.ValArg:
return VALUE_NODE
}
return OTHER_NODE
}
func findShardList(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) ([]int, error) {
shardset := make(map[int]bool)
switch node := valExpr.(type) {
case sqlparser.ValTuple:
for _, n := range node {
index, err := findShard(n, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
shardset[index] = true
}
}
shardlist := make([]int, len(shardset))
index := 0
for k := range shardset {
shardlist[index] = k
index++
}
return shardlist, nil
}
func findInsertShard(vals sqlparser.Values, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
index := -1
for i := 0; i < len(vals); i++ {
first_value_expression := vals[i].(sqlparser.ValTuple)[0]
newIndex, err := findShard(first_value_expression, bindVariables, tabletKeys)
if err != nil {
return -1, err
}
if index == -1 {
index = newIndex
} else if index != newIndex {
return -1, fmt.Errorf("insert has multiple shard targets")
}
}
return index, nil
}
func findShard(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
value, err := getBoundValue(valExpr, bindVariables)
if err != nil {
return -1, err
}
return key.FindShardForValue(value, tabletKeys), nil
}
func getBoundValue(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}) (string, error) {
switch node := valExpr.(type) {
case sqlparser.ValTuple:
if len(node) != 1 {
return "", fmt.Errorf("tuples not allowed as insert values")
}
// TODO: Change parser to create single value tuples into non-tuples.
return getBoundValue(node[0], bindVariables)
case sqlparser.StrVal:
return string(node), nil
case sqlparser.NumVal:
val, err := strconv.ParseInt(string(node), 10, 64)
if err != nil {
return "", err
}
return key.Uint64Key(val).String(), nil
case sqlparser.ValArg:
value, err := findBindValue(node, bindVariables)
if err != nil {
return "", err
}
return key.EncodeValue(value), nil
}
panic("Unexpected token")
}
func findBindValue(valArg sqlparser.ValArg, bindVariables map[string]interface{}) (interface{}, error) {
if bindVariables == nil {
return nil, fmt.Errorf("No bind variable for " + string(valArg))
}
value, ok := bindVariables[string(valArg[1:])]
if !ok {
return nil, fmt.Errorf("No bind variable for " + string(valArg))
}
return value, nil
}
func makeList(start, end int) []int {
list := make([]int, end-start)
for i := start; i < end; i++ {
list[i-start] = i
}
return list
}

Просмотреть файл

@ -1,105 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package client2
import (
"bufio"
"fmt"
"io"
"os"
"sort"
"strings"
"testing"
"github.com/youtube/vitess/go/testfiles"
"github.com/youtube/vitess/go/vt/key"
)
func TestRouting(t *testing.T) {
tabletkeys := []key.KeyspaceId{
"\x00\x00\x00\x00\x00\x00\x00\x02",
"\x00\x00\x00\x00\x00\x00\x00\x04",
"\x00\x00\x00\x00\x00\x00\x00\x06",
"a",
"b",
"d",
}
bindVariables := make(map[string]interface{})
bindVariables["id0"] = 0
bindVariables["id2"] = 2
bindVariables["id3"] = 3
bindVariables["id4"] = 4
bindVariables["id6"] = 6
bindVariables["id8"] = 8
bindVariables["ids"] = []interface{}{1, 4}
bindVariables["a"] = "a"
bindVariables["b"] = "b"
bindVariables["c"] = "c"
bindVariables["d"] = "d"
bindVariables["e"] = "e"
for tcase := range iterateFiles("sqlparser_test/routing_cases.txt") {
if tcase.output == "" {
tcase.output = tcase.input
}
out, err := GetShardList(tcase.input, bindVariables, tabletkeys)
if err != nil {
if err.Error() != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.input, err))
}
continue
}
sort.Ints(out)
outstr := fmt.Sprintf("%v", out)
if outstr != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, outstr))
}
}
}
// TODO(sougou): This is now duplicated in three plcaes. Refactor.
type testCase struct {
file string
lineno int
input string
output string
}
func iterateFiles(pattern string) (testCaseIterator chan testCase) {
names := testfiles.Glob(pattern)
testCaseIterator = make(chan testCase)
go func() {
defer close(testCaseIterator)
for _, name := range names {
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
panic(fmt.Sprintf("Could not open file %s", name))
}
r := bufio.NewReader(fd)
lineno := 0
for {
line, err := r.ReadString('\n')
lines := strings.Split(strings.TrimRight(line, "\n"), "#")
lineno++
if err != nil {
if err != io.EOF {
panic(fmt.Sprintf("Error reading file %s: %s", name, err.Error()))
}
break
}
input := lines[0]
output := ""
if len(lines) > 1 {
output = lines[1]
}
if input == "" {
continue
}
testCaseIterator <- testCase{name, lineno, input, output}
}
}
}()
return testCaseIterator
}

Просмотреть файл

@ -1,589 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package client2
import (
"fmt"
"net/url"
"path"
"strings"
"sync"
"time"
"github.com/youtube/vitess/go/db"
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/vt/client2/tablet"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/zktopo"
"github.com/youtube/vitess/go/zk"
)
// The sharded client handles writing to multiple shards across the
// database.
//
// The ShardedConn can handles several separate aspects:
// * loading/reloading tablet addresses on demand from zk
// * maintaining at most one connection to each tablet as required
// * transaction tracking across shards
// * preflight checking all transactions before attempting to commit
// (reduce partial commit probability)
//
// NOTE: Queries with aggregate results will not produce expected
// results right now. For instance, running a count(*) on a table
// across all tablets will return one row per tablet. In the future,
// the SQL parser and query engine can handle these more
// automatically. For now, clients will have to do the rollup at a
// higher level.
var (
ErrNotConnected = fmt.Errorf("vt: not connected")
)
type VtClientError struct {
msg string
partial bool
}
func (err VtClientError) Error() string {
return err.msg
}
func (err VtClientError) Partial() bool {
return err.partial
}
// Not thread safe, as per sql package.
type ShardedConn struct {
ts topo.Server
cell string
keyspace string
tabletType topo.TabletType
stream bool // Use streaming RPC
srvKeyspace *topo.SrvKeyspace
// Keep a map per shard mapping tabletType to a real connection.
// connByType []map[string]*Conn
// Sorted list of the max keys for each shard.
shardMaxKeys []key.KeyspaceId
conns []*tablet.VtConn
timeout time.Duration // How long should we wait for a given operation?
// Currently running transaction (or nil if not inside a transaction)
currentTransaction *MetaTx
}
// FIXME(msolomon) Normally a connect method would actually connect up
// to the appropriate endpoints. In the distributed case, it's unclear
// that this is necessary. You have to deal with transient failures
// anyway, so the whole system degenerates to managing connections on
// demand.
func Dial(ts topo.Server, cell, keyspace string, tabletType topo.TabletType, stream bool, timeout time.Duration) (*ShardedConn, error) {
sc := &ShardedConn{
ts: ts,
cell: cell,
keyspace: keyspace,
tabletType: tabletType,
stream: stream,
}
err := sc.readKeyspace()
if err != nil {
return nil, err
}
return sc, nil
}
func (sc *ShardedConn) Close() error {
if sc.conns == nil {
return nil
}
if sc.currentTransaction != nil {
sc.rollback()
}
for _, conn := range sc.conns {
if conn != nil {
conn.Close()
}
}
sc.conns = nil
sc.srvKeyspace = nil
sc.shardMaxKeys = nil
return nil
}
func (sc *ShardedConn) readKeyspace() error {
sc.Close()
var err error
sc.srvKeyspace, err = sc.ts.GetSrvKeyspace(sc.cell, sc.keyspace)
if err != nil {
return fmt.Errorf("vt: GetSrvKeyspace failed %v", err)
}
sc.conns = make([]*tablet.VtConn, len(sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences))
sc.shardMaxKeys = make([]key.KeyspaceId, len(sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences))
for i, shardReference := range sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences {
sc.shardMaxKeys[i] = shardReference.KeyRange.End
}
// Disabled for now.
// sc.connByType = make([]map[string]*Conn, len(sc.srvKeyspace.ShardReferences))
// for i := 0; i < len(sc.connByType); i++ {
// sc.connByType[i] = make(map[string]*Conn, 8)
// }
return nil
}
// A "transaction" that may be across and thus, not transactional at
// this point.
type MetaTx struct {
// The connections involved in this transaction, in the order they
// were added to the transaction.
shardedConn *ShardedConn
conns []*tablet.VtConn
}
// makes sure the given transaction was issued a Begin() call
func (tx *MetaTx) begin(conn *tablet.VtConn) (err error) {
for _, v := range tx.conns {
if v == conn {
return
}
}
_, err = conn.Begin()
if err != nil {
// the caller will need to take care of the rollback,
// and therefore issue a rollback on all pre-existing
// transactions
return err
}
tx.conns = append(tx.conns, conn)
return nil
}
func (tx *MetaTx) Commit() (err error) {
if tx.shardedConn.currentTransaction == nil {
return tablet.ErrBadRollback
}
commit := true
for _, conn := range tx.conns {
if commit {
if err = conn.Commit(); err != nil {
commit = false
}
}
if !commit {
conn.Rollback()
}
}
tx.shardedConn.currentTransaction = nil
return err
}
func (tx *MetaTx) Rollback() error {
if tx.shardedConn.currentTransaction == nil {
return tablet.ErrBadRollback
}
var someErr error
for _, conn := range tx.conns {
if err := conn.Rollback(); err != nil {
someErr = err
}
}
tx.shardedConn.currentTransaction = nil
return someErr
}
func (sc *ShardedConn) Begin() (db.Tx, error) {
if sc.srvKeyspace == nil {
return nil, ErrNotConnected
}
if sc.currentTransaction != nil {
return nil, tablet.ErrNoNestedTxn
}
tx := &MetaTx{sc, make([]*tablet.VtConn, 0, 32)}
sc.currentTransaction = tx
return tx, nil
}
func (sc *ShardedConn) rollback() error {
if sc.currentTransaction == nil {
return tablet.ErrBadRollback
}
var someErr error
for _, conn := range sc.conns {
if conn.TransactionId != 0 {
if err := conn.Rollback(); err != nil {
someErr = err
}
}
}
sc.currentTransaction = nil
return someErr
}
func (sc *ShardedConn) Exec(query string, bindVars map[string]interface{}) (db.Result, error) {
if sc.srvKeyspace == nil {
return nil, ErrNotConnected
}
shards, err := GetShardList(query, bindVars, sc.shardMaxKeys)
if err != nil {
return nil, err
}
if sc.stream {
return sc.execOnShardsStream(query, bindVars, shards)
}
return sc.execOnShards(query, bindVars, shards)
}
// FIXME(msolomon) define key interface "Keyer" or force a concrete type?
func (sc *ShardedConn) ExecWithKey(query string, bindVars map[string]interface{}, keyVal interface{}) (db.Result, error) {
shardIdx, err := key.FindShardForKey(keyVal, sc.shardMaxKeys)
if err != nil {
return nil, err
}
if sc.stream {
return sc.execOnShardsStream(query, bindVars, []int{shardIdx})
}
return sc.execOnShards(query, bindVars, []int{shardIdx})
}
type tabletResult struct {
error
*tablet.Result
}
func (sc *ShardedConn) execOnShards(query string, bindVars map[string]interface{}, shards []int) (metaResult *tablet.Result, err error) {
rchan := make(chan tabletResult, len(shards))
for _, shardIdx := range shards {
go func(shardIdx int) {
qr, err := sc.execOnShard(query, bindVars, shardIdx)
if err != nil {
rchan <- tabletResult{error: err}
} else {
rchan <- tabletResult{Result: qr.(*tablet.Result)}
}
}(shardIdx)
}
results := make([]tabletResult, len(shards))
rowCount := int64(0)
rowsAffected := int64(0)
lastInsertId := int64(0)
var hasError error
for i := range results {
results[i] = <-rchan
if results[i].error != nil {
hasError = results[i].error
continue
}
affected, _ := results[i].RowsAffected()
insertId, _ := results[i].LastInsertId()
rowsAffected += affected
if insertId > 0 {
if lastInsertId == 0 {
lastInsertId = insertId
}
// FIXME(msolomon) issue an error when you have multiple last inserts?
}
rowCount += results[i].RowsRetrieved()
}
// FIXME(msolomon) allow partial result set?
if hasError != nil {
return nil, fmt.Errorf("vt: partial result set (%v)", hasError)
}
for _, tr := range results {
if tr.error != nil {
return nil, tr.error
}
// FIXME(msolomon) This error message should be a const. Should this
// be deferred until we get a next query?
if tr.error != nil && tr.error.Error() == "retry: unavailable" {
sc.readKeyspace()
}
}
var fields []mproto.Field
if len(results) > 0 {
fields = results[0].Fields()
}
// check the schemas all match (both names and types)
if len(results) > 1 {
firstFields := results[0].Fields()
for _, r := range results[1:] {
fields := r.Fields()
if len(fields) != len(firstFields) {
return nil, fmt.Errorf("vt: column count mismatch: %v != %v", len(firstFields), len(fields))
}
for i, name := range fields {
if name.Name != firstFields[i].Name {
return nil, fmt.Errorf("vt: column[%v] name mismatch: %v != %v", i, name.Name, firstFields[i].Name)
}
}
}
}
// Combine results.
metaResult = tablet.NewResult(rowCount, rowsAffected, lastInsertId, fields)
curIndex := 0
rows := metaResult.Rows()
for _, tr := range results {
for _, row := range tr.Rows() {
rows[curIndex] = row
curIndex++
}
}
return metaResult, nil
}
func (sc *ShardedConn) execOnShard(query string, bindVars map[string]interface{}, shardIdx int) (db.Result, error) {
if sc.conns[shardIdx] == nil {
conn, err := sc.dial(shardIdx)
if err != nil {
return nil, err
}
sc.conns[shardIdx] = conn
}
conn := sc.conns[shardIdx]
// if we haven't started the transaction on that shard and need to, now is the time
if sc.currentTransaction != nil {
err := sc.currentTransaction.begin(conn)
if err != nil {
return nil, err
}
}
// Retries should have already taken place inside the tablet connection.
// At this point, all that's left are more sinister failures.
// FIXME(msolomon) reload just this shard unless the failure pertains to
// needing to reload the entire keyspace.
return conn.Exec(query, bindVars)
}
// when doing a streaming query, we send this structure back
type streamTabletResult struct {
error
row []interface{}
}
// our streaming result, just aggregates from all streaming results
// it implements both driver.Result and driver.Rows
type multiStreamResult struct {
cols []string
// results flow through this, maybe with errors
rows chan streamTabletResult
err error
}
// driver.Result interface
func (*multiStreamResult) LastInsertId() (int64, error) {
return 0, tablet.ErrNoLastInsertId
}
func (*multiStreamResult) RowsAffected() (int64, error) {
return 0, tablet.ErrNoRowsAffected
}
// driver.Rows interface
func (sr *multiStreamResult) Columns() []string {
return sr.cols
}
func (sr *multiStreamResult) Close() error {
close(sr.rows)
return nil
}
// read from the stream and gets the next value
// if one of the go routines returns an error, we want to save it and return it
// eventually. (except if it's EOF, then we just know that routine is done)
func (sr *multiStreamResult) Next() (row []interface{}) {
for {
str, ok := <-sr.rows
if !ok {
return nil
}
if str.error != nil {
sr.err = str.error
continue
}
return str.row
}
}
func (sr *multiStreamResult) Err() error {
return sr.err
}
func (sc *ShardedConn) execOnShardsStream(query string, bindVars map[string]interface{}, shards []int) (msr *multiStreamResult, err error) {
// we synchronously do the exec on each shard
// so we can get the Columns from the first one
// and check the others match them
var cols []string
qrs := make([]db.Result, len(shards))
for i, shardIdx := range shards {
qr, err := sc.execOnShard(query, bindVars, shardIdx)
if err != nil {
// FIXME(alainjobart) if the first queries went through
// we need to cancel them
return nil, err
}
// we know the result is a tablet.StreamResult,
// and we use it as a driver.Rows
qrs[i] = qr.(db.Result)
// save the columns or check they match
if i == 0 {
cols = qrs[i].Columns()
} else {
ncols := qrs[i].Columns()
if len(ncols) != len(cols) {
return nil, fmt.Errorf("vt: column count mismatch: %v != %v", len(ncols), len(cols))
}
for i, name := range cols {
if name != ncols[i] {
return nil, fmt.Errorf("vt: column[%v] name mismatch: %v != %v", i, name, ncols[i])
}
}
}
}
// now we create the result, its channel, and run background
// routines to stream results
msr = &multiStreamResult{cols: cols, rows: make(chan streamTabletResult, 10*len(shards))}
var wg sync.WaitGroup
for i, shardIdx := range shards {
wg.Add(1)
go func(i, shardIdx int) {
defer wg.Done()
for row := qrs[i].Next(); row != nil; row = qrs[i].Next() {
msr.rows <- streamTabletResult{row: row}
}
if err := qrs[i].Err(); err != nil {
msr.rows <- streamTabletResult{error: err}
}
}(i, shardIdx)
}
// Close channel once all data has been sent
go func() {
wg.Wait()
close(msr.rows)
}()
return msr, nil
}
/*
type ClientQuery struct {
Sql string
BindVariables map[string]interface{}
}
// FIXME(msolomon) There are multiple options for an efficient ExecMulti.
// * Use a special stmt object, buffer all statements, connections, etc and send when it's ready.
// * Take a list of (sql, bind) pairs and just send that - have to parse and route that anyway.
// * Probably need separate support for the a MultiTx too.
func (sc *ShardedConn) ExecuteBatch(queryList []ClientQuery, keyVal interface{}) (*tabletserver.QueryResult, error) {
shardIdx, err := key.FindShardForKey(keyVal, sc.shardMaxKeys)
shards := []int{shardIdx}
if err = sc.tabletPrepare(shardIdx); err != nil {
return nil, err
}
reqs := make([]tabletserver.Query, len(queryList))
for i, cq := range queryList {
reqs[i] = tabletserver.Query{
Sql: cq.Sql,
BindVariables: cq.BindVariables,
TransactionId: sc.conns[shardIdx].TransactionId,
SessionId: sc.conns[shardIdx].SessionId,
}
}
res := new(tabletserver.QueryResult)
err = sc.conns[shardIdx].Call("SqlQuery.ExecuteBatch", reqs, res)
if err != nil {
return nil, err
}
return res, nil
}
*/
func (sc *ShardedConn) dial(shardIdx int) (conn *tablet.VtConn, err error) {
shardReference := &(sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences[shardIdx])
addrs, err := sc.ts.GetEndPoints(sc.cell, sc.keyspace, shardReference.Name, sc.tabletType)
if err != nil {
return nil, fmt.Errorf("vt: GetEndPoints failed %v", err)
}
srvs, err := topo.SrvEntries(addrs, "")
if err != nil {
return nil, err
}
// Try to connect to any address.
for _, srv := range srvs {
name := topo.SrvAddr(srv) + "/" + sc.keyspace + "/" + shardReference.Name
conn, err = tablet.DialVtdb(name, sc.stream, tablet.DefaultTimeout)
if err == nil {
return conn, nil
}
}
return nil, err
}
type sDriver struct {
ts topo.Server
stream bool
}
// for direct zk connection: vtzk://host:port/cell/keyspace/tabletType
// we always use a MetaConn, host and port are ignored.
// the driver name dictates if we streaming or not
func (driver *sDriver) Open(name string) (sc db.Conn, err error) {
if !strings.HasPrefix(name, "vtzk://") {
// add a default protocol talking to zk
name = "vtzk://" + name
}
u, err := url.Parse(name)
if err != nil {
return nil, err
}
dbi, tabletType := path.Split(u.Path)
dbi = strings.Trim(dbi, "/")
tabletType = strings.Trim(tabletType, "/")
cell, keyspace := path.Split(dbi)
cell = strings.Trim(cell, "/")
keyspace = strings.Trim(keyspace, "/")
return Dial(driver.ts, cell, keyspace, topo.TabletType(tabletType), driver.stream, tablet.DefaultTimeout)
}
func RegisterShardedDrivers() {
// default topo server
ts := topo.GetServer()
db.Register("vtdb", &sDriver{ts, false})
db.Register("vtdb-streaming", &sDriver{ts, true})
// forced zk topo server
zconn := zk.NewMetaConn()
zkts := zktopo.NewServer(zconn)
db.Register("vtdb-zk", &sDriver{zkts, false})
db.Register("vtdb-zk-streaming", &sDriver{zkts, true})
}

Просмотреть файл

@ -1,332 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tablet is an API compliant to the requirements of database/sql
// Open expects name to be "hostname:port/keyspace/shard"
// For query arguments, we assume place-holders in the query string
// in the form of :v0, :v1, etc.
package tablet
import (
"errors"
"fmt"
"net/url"
"strings"
"time"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/db"
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/netutil"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
"github.com/youtube/vitess/go/vt/topo"
"golang.org/x/net/context"
)
var (
ErrNoNestedTxn = errors.New("vt: no nested transactions")
ErrBadCommit = errors.New("vt: commit without corresponding begin")
ErrBadRollback = errors.New("vt: rollback without corresponding begin")
ErrNoLastInsertId = errors.New("vt: no LastInsertId available after streaming statement")
ErrNoRowsAffected = errors.New("vt: no RowsAffected available after streaming statement")
ErrFieldLengthMismatch = errors.New("vt: no RowsAffected available after streaming statement")
)
type TabletError struct {
err error
addr string
}
func (te TabletError) Error() string {
return fmt.Sprintf("vt: client error on %v %v", te.addr, te.err)
}
// Not thread safe, as per sql package.
type Conn struct {
dbi *url.URL
stream bool
tabletConn tabletconn.TabletConn
TransactionId int64
timeout time.Duration
}
type Tx struct {
conn *Conn
}
type StreamResult struct {
errFunc tabletconn.ErrFunc
sr <-chan *mproto.QueryResult
columns *mproto.QueryResult
// current result and index on it
qr *mproto.QueryResult
index int
err error
}
func (conn *Conn) keyspace() string {
return strings.Split(conn.dbi.Path, "/")[1]
}
func (conn *Conn) shard() string {
return strings.Split(conn.dbi.Path, "/")[2]
}
// parseDbi parses the dbi and a URL. The dbi may or may not contain
// the scheme part.
func parseDbi(dbi string) (*url.URL, error) {
if !strings.HasPrefix(dbi, "vttp://") {
dbi = "vttp://" + dbi
}
return url.Parse(dbi)
}
func DialTablet(dbi string, stream bool, timeout time.Duration) (conn *Conn, err error) {
conn = new(Conn)
if conn.dbi, err = parseDbi(dbi); err != nil {
return
}
conn.stream = stream
conn.timeout = timeout
if err = conn.dial(); err != nil {
return nil, conn.fmtErr(err)
}
return
}
// Format error for exported methods to give callers more information.
func (conn *Conn) fmtErr(err error) error {
if err == nil {
return nil
}
return TabletError{err, conn.dbi.Host}
}
func (conn *Conn) dial() (err error) {
// build the endpoint in the right format
host, port, err := netutil.SplitHostPort(conn.dbi.Host)
if err != nil {
return err
}
endPoint := topo.EndPoint{
Host: host,
NamedPortMap: map[string]int{
"vt": port,
},
}
// and dial
tabletConn, err := tabletconn.GetDialer()(context.TODO(), endPoint, conn.keyspace(), conn.shard(), conn.timeout)
if err != nil {
return err
}
conn.tabletConn = tabletConn
return
}
func (conn *Conn) Close() error {
conn.tabletConn.Close()
return nil
}
func (conn *Conn) Exec(query string, bindVars map[string]interface{}) (db.Result, error) {
if conn.stream {
sr, errFunc, err := conn.tabletConn.StreamExecute(context.TODO(), query, bindVars, conn.TransactionId)
if err != nil {
return nil, conn.fmtErr(err)
}
// read the columns, or grab the error
cols, ok := <-sr
if !ok {
return nil, conn.fmtErr(errFunc())
}
return &StreamResult{errFunc, sr, cols, nil, 0, nil}, nil
}
qr, err := conn.tabletConn.Execute(context.TODO(), query, bindVars, conn.TransactionId)
if err != nil {
return nil, conn.fmtErr(err)
}
return &Result{qr, 0, nil}, nil
}
func (conn *Conn) Begin() (db.Tx, error) {
if conn.TransactionId != 0 {
return &Tx{}, ErrNoNestedTxn
}
if transactionId, err := conn.tabletConn.Begin(context.TODO()); err != nil {
return &Tx{}, conn.fmtErr(err)
} else {
conn.TransactionId = transactionId
}
return &Tx{conn}, nil
}
func (conn *Conn) Commit() error {
if conn.TransactionId == 0 {
return ErrBadCommit
}
// NOTE(msolomon) Unset the transaction_id irrespective of the RPC's
// response. The intent of commit is that no more statements can be
// made on this transaction, so we guarantee that. Transient errors
// between the db and the client shouldn't affect this part of the
// bookkeeping. According to the Go Driver API, this will not be
// called concurrently. Defer this because we this affects the
// session referenced in the request.
defer func() { conn.TransactionId = 0 }()
return conn.fmtErr(conn.tabletConn.Commit(context.TODO(), conn.TransactionId))
}
func (conn *Conn) Rollback() error {
if conn.TransactionId == 0 {
return ErrBadRollback
}
// See note in Commit about the behavior of TransactionId.
defer func() { conn.TransactionId = 0 }()
return conn.fmtErr(conn.tabletConn.Rollback(context.TODO(), conn.TransactionId))
}
// driver.Tx interface (forwarded to Conn)
func (tx *Tx) Commit() error {
return tx.conn.Commit()
}
func (tx *Tx) Rollback() error {
return tx.conn.Rollback()
}
type Result struct {
qr *mproto.QueryResult
index int
err error
}
// TODO(mberlin): Populate flags here as well (e.g. to correctly identify unsigned integer type)?
func NewResult(rowCount, rowsAffected, insertId int64, fields []mproto.Field) *Result {
return &Result{
qr: &mproto.QueryResult{
Rows: make([][]sqltypes.Value, int(rowCount)),
Fields: fields,
RowsAffected: uint64(rowsAffected),
InsertId: uint64(insertId),
},
}
}
func (result *Result) RowsRetrieved() int64 {
return int64(len(result.qr.Rows))
}
func (result *Result) LastInsertId() (int64, error) {
return int64(result.qr.InsertId), nil
}
func (result *Result) RowsAffected() (int64, error) {
return int64(result.qr.RowsAffected), nil
}
// driver.Rows interface
func (result *Result) Columns() []string {
cols := make([]string, len(result.qr.Fields))
for i, f := range result.qr.Fields {
cols[i] = f.Name
}
return cols
}
func (result *Result) Rows() [][]sqltypes.Value {
return result.qr.Rows
}
// FIXME(msolomon) This should be intependent of the mysql module.
func (result *Result) Fields() []mproto.Field {
return result.qr.Fields
}
func (result *Result) Close() error {
result.index = 0
return nil
}
func (result *Result) Next() (row []interface{}) {
if result.index >= len(result.qr.Rows) {
return nil
}
row = make([]interface{}, len(result.qr.Rows[result.index]))
for i, v := range result.qr.Rows[result.index] {
var err error
row[i], err = mproto.Convert(result.qr.Fields[i].Type, v)
if err != nil {
panic(err) // unexpected
}
}
result.index++
return row
}
func (result *Result) Err() error {
return result.err
}
// driver.Result interface
func (*StreamResult) LastInsertId() (int64, error) {
return 0, ErrNoLastInsertId
}
func (*StreamResult) RowsAffected() (int64, error) {
return 0, ErrNoRowsAffected
}
// driver.Rows interface
func (sr *StreamResult) Columns() (cols []string) {
cols = make([]string, len(sr.columns.Fields))
for i, f := range sr.columns.Fields {
cols[i] = f.Name
}
return cols
}
func (*StreamResult) Close() error {
return nil
}
func (sr *StreamResult) Next() (row []interface{}) {
if sr.qr == nil {
// we need to read the next record that may contain
// multiple rows
qr, ok := <-sr.sr
if !ok {
if sr.errFunc() != nil {
log.Warningf("vt: error reading the next value %v", sr.errFunc())
sr.err = sr.errFunc()
}
return nil
}
sr.qr = qr
sr.index = 0
}
row = make([]interface{}, len(sr.qr.Rows[sr.index]))
for i, v := range sr.qr.Rows[sr.index] {
var err error
row[i], err = mproto.Convert(sr.columns.Fields[i].Type, v)
if err != nil {
panic(err) // unexpected
}
}
sr.index++
if sr.index == len(sr.qr.Rows) {
// we reached the end of our rows, nil it so next run
// will fetch the next one
sr.qr = nil
}
return row
}
func (sr *StreamResult) Err() error {
return sr.err
}

Просмотреть файл

@ -1,178 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tablet implements some additional error handling logic to
// make the client more robust in the face of transient problems with
// easy solutions.
package tablet
import (
"fmt"
"net"
"strings"
"time"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/db"
)
const (
ErrTypeFatal = 1 //errors.New("vt: fatal: reresolve endpoint")
ErrTypeRetry = 2 //errors.New("vt: retry: reconnect endpoint")
ErrTypeApp = 3 //errors.New("vt: app level error")
)
const (
DefaultReconnectDelay = 2 * time.Millisecond
DefaultMaxAttempts = 2
DefaultTimeout = 30 * time.Second
)
var zeroTime time.Time
// Layer some logic on top of the basic tablet protocol to support
// fast retry when we can.
type VtConn struct {
Conn
maxAttempts int // How many times should try each retriable operation?
timeFailed time.Time // This is the time a client transitioned from presumable health to failure.
reconnectDelay time.Duration
}
// How long should we wait to try to recover?
// FIXME(msolomon) not sure if maxAttempts is still useful
func (vtc *VtConn) recoveryTimeout() time.Duration {
return vtc.timeout * 2
}
func (vtc *VtConn) handleErr(err error) (int, error) {
now := time.Now()
if vtc.timeFailed.IsZero() {
vtc.timeFailed = now
} else if now.Sub(vtc.timeFailed) > vtc.recoveryTimeout() {
vtc.Close()
return ErrTypeFatal, fmt.Errorf("vt: max recovery time exceeded: %v", err)
}
errType := ErrTypeApp
if tabletErr, ok := err.(TabletError); ok {
msg := strings.ToLower(tabletErr.err.Error())
if strings.HasPrefix(msg, "fatal") {
errType = ErrTypeFatal
} else if strings.HasPrefix(msg, "retry") {
errType = ErrTypeRetry
}
} else if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
errType = ErrTypeRetry
}
if errType == ErrTypeRetry && vtc.TransactionId != 0 {
errType = ErrTypeApp
err = fmt.Errorf("vt: cannot retry within a transaction: %v", err)
time.Sleep(vtc.reconnectDelay)
vtc.Close()
dialErr := vtc.dial()
log.Warningf("vt: redial error %v", dialErr)
}
return errType, err
}
func (vtc *VtConn) Exec(query string, bindVars map[string]interface{}) (db.Result, error) {
attempt := 0
for {
result, err := vtc.Conn.Exec(query, bindVars)
if err == nil {
vtc.timeFailed = zeroTime
return result, nil
}
errType, err := vtc.handleErr(err)
if errType != ErrTypeRetry {
return nil, err
}
for {
attempt++
if attempt > vtc.maxAttempts {
return nil, fmt.Errorf("vt: max recovery attempts exceeded: %v", err)
}
vtc.Close()
time.Sleep(vtc.reconnectDelay)
if err := vtc.dial(); err == nil {
break
}
log.Warningf("vt: error dialing on exec %v", vtc.Conn.dbi.Host)
}
}
}
func (vtc *VtConn) Begin() (db.Tx, error) {
attempt := 0
for {
tx, err := vtc.Conn.Begin()
if err == nil {
vtc.timeFailed = zeroTime
return tx, nil
}
errType, err := vtc.handleErr(err)
if errType != ErrTypeRetry {
return nil, err
}
for {
attempt++
if attempt > vtc.maxAttempts {
return nil, fmt.Errorf("vt: max recovery attempts exceeded: %v", err)
}
vtc.Close()
time.Sleep(vtc.reconnectDelay)
if err := vtc.dial(); err == nil {
break
}
log.Warningf("vt: error dialing on begin %v", vtc.Conn.dbi.Host)
}
}
}
func (vtc *VtConn) Commit() (err error) {
if err = vtc.Conn.Commit(); err == nil {
vtc.timeFailed = zeroTime
return nil
}
// Not much we can do at this point, just annotate the error and return.
_, err = vtc.handleErr(err)
return err
}
func DialVtdb(dbi string, stream bool, timeout time.Duration) (*VtConn, error) {
url, err := parseDbi(dbi)
if err != nil {
return nil, err
}
conn := &VtConn{
Conn: Conn{dbi: url, stream: stream, timeout: timeout},
maxAttempts: DefaultMaxAttempts,
reconnectDelay: DefaultReconnectDelay,
}
if err := conn.dial(); err != nil {
return nil, err
}
return conn, nil
}
type vDriver struct {
stream bool
}
func (driver *vDriver) Open(name string) (db.Conn, error) {
return DialVtdb(name, driver.stream, DefaultTimeout)
}
func init() {
db.Register("vttablet", &vDriver{})
db.Register("vttablet-streaming", &vDriver{true})
}

Просмотреть файл

@ -1,106 +0,0 @@
// Copyright 2013, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package concurrency
import (
"fmt"
"sync"
"github.com/youtube/vitess/go/sync2"
)
// ResourceConstraint combines 3 different features:
// - a WaitGroup to wait for all tasks to be done
// - a Semaphore to control concurrency
// - an ErrorRecorder
type ResourceConstraint struct {
semaphore *sync2.Semaphore
wg sync.WaitGroup
FirstErrorRecorder
}
// NewResourceConstraint creates a ResourceConstraint with
// max concurrency.
func NewResourceConstraint(max int) *ResourceConstraint {
return &ResourceConstraint{semaphore: sync2.NewSemaphore(max, 0)}
}
func (rc *ResourceConstraint) Add(n int) {
rc.wg.Add(n)
}
func (rc *ResourceConstraint) Done() {
rc.wg.Done()
}
// Wait waits for the WG and returns the firstError we encountered, or nil
func (rc *ResourceConstraint) Wait() error {
rc.wg.Wait()
return rc.Error()
}
// Acquire will wait until we have a resource to use
func (rc *ResourceConstraint) Acquire() {
rc.semaphore.Acquire()
}
func (rc *ResourceConstraint) Release() {
rc.semaphore.Release()
}
func (rc *ResourceConstraint) ReleaseAndDone() {
rc.Release()
rc.Done()
}
// MultiResourceConstraint combines 3 different features:
// - a WaitGroup to wait for all tasks to be done
// - a Semaphore map to control multiple concurrencies
// - an ErrorRecorder
type MultiResourceConstraint struct {
semaphoreMap map[string]*sync2.Semaphore
wg sync.WaitGroup
FirstErrorRecorder
}
func NewMultiResourceConstraint(semaphoreMap map[string]*sync2.Semaphore) *MultiResourceConstraint {
return &MultiResourceConstraint{semaphoreMap: semaphoreMap}
}
func (mrc *MultiResourceConstraint) Add(n int) {
mrc.wg.Add(n)
}
func (mrc *MultiResourceConstraint) Done() {
mrc.wg.Done()
}
// Returns the firstError we encountered, or nil
func (mrc *MultiResourceConstraint) Wait() error {
mrc.wg.Wait()
return mrc.Error()
}
// Acquire will wait until we have a resource to use
func (mrc *MultiResourceConstraint) Acquire(name string) {
s, ok := mrc.semaphoreMap[name]
if !ok {
panic(fmt.Errorf("MultiResourceConstraint: No resource named %v in semaphore map", name))
}
s.Acquire()
}
func (mrc *MultiResourceConstraint) Release(name string) {
s, ok := mrc.semaphoreMap[name]
if !ok {
panic(fmt.Errorf("MultiResourceConstraint: No resource named %v in semaphore map", name))
}
s.Release()
}
func (mrc *MultiResourceConstraint) ReleaseAndDone(name string) {
mrc.Release(name)
mrc.Done()
}

Просмотреть файл

@ -107,7 +107,7 @@ func (cp *ConnectionPool) Get(timeout time.Duration) (PoolConnection, error) {
ctx := context.Background()
if timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(context.Background(), timeout)
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
r, err := p.Get(ctx)
@ -117,20 +117,6 @@ func (cp *ConnectionPool) Get(timeout time.Duration) (PoolConnection, error) {
return r.(PoolConnection), nil
}
// TryGet returns a connection, or nil.
// You must call Recycle on the PoolConnection once done.
func (cp *ConnectionPool) TryGet() (PoolConnection, error) {
p := cp.pool()
if p == nil {
return nil, ErrConnPoolClosed
}
r, err := p.TryGet()
if err != nil || r == nil {
return nil, err
}
return r.(PoolConnection), nil
}
// Put puts a connection into the pool.
func (cp *ConnectionPool) Put(conn PoolConnection) {
p := cp.pool()

Просмотреть файл

@ -14,6 +14,7 @@ import (
"github.com/youtube/vitess/go/jscfg"
"github.com/youtube/vitess/go/vt/topo"
"golang.org/x/net/context"
)
func TestSplitCellPath(t *testing.T) {
@ -78,14 +79,15 @@ func TestHandlePathKeyspace(t *testing.T) {
shard := &topo.Shard{}
want := jscfg.ToJSON(keyspace)
ctx := context.Background()
ts := newTestServer(t, cells)
if err := ts.CreateKeyspace("test_keyspace", keyspace); err != nil {
if err := ts.CreateKeyspace(ctx, "test_keyspace", keyspace); err != nil {
t.Fatalf("CreateKeyspace error: %v", err)
}
if err := ts.CreateShard("test_keyspace", "10-20", shard); err != nil {
if err := ts.CreateShard(ctx, "test_keyspace", "10-20", shard); err != nil {
t.Fatalf("CreateShard error: %v", err)
}
if err := ts.CreateShard("test_keyspace", "20-30", shard); err != nil {
if err := ts.CreateShard(ctx, "test_keyspace", "20-30", shard); err != nil {
t.Fatalf("CreateShard error: %v", err)
}
@ -114,11 +116,12 @@ func TestHandlePathShard(t *testing.T) {
shard := &topo.Shard{}
want := jscfg.ToJSON(shard)
ctx := context.Background()
ts := newTestServer(t, cells)
if err := ts.CreateKeyspace("test_keyspace", keyspace); err != nil {
if err := ts.CreateKeyspace(ctx, "test_keyspace", keyspace); err != nil {
t.Fatalf("CreateKeyspace error: %v", err)
}
if err := ts.CreateShard("test_keyspace", "-80", shard); err != nil {
if err := ts.CreateShard(ctx, "test_keyspace", "-80", shard); err != nil {
t.Fatalf("CreateShard error: %v", err)
}
@ -150,8 +153,9 @@ func TestHandlePathTablet(t *testing.T) {
}
want := jscfg.ToJSON(tablet)
ctx := context.Background()
ts := newTestServer(t, cells)
if err := ts.CreateTablet(tablet); err != nil {
if err := ts.CreateTablet(ctx, tablet); err != nil {
t.Fatalf("CreateTablet error: %v", err)
}

Просмотреть файл

@ -14,10 +14,11 @@ import (
"github.com/youtube/vitess/go/vt/concurrency"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/topo/events"
"golang.org/x/net/context"
)
// CreateKeyspace implements topo.Server.
func (s *Server) CreateKeyspace(keyspace string, value *topo.Keyspace) error {
func (s *Server) CreateKeyspace(ctx context.Context, keyspace string, value *topo.Keyspace) error {
data := jscfg.ToJSON(value)
global := s.getGlobal()
@ -44,7 +45,7 @@ func (s *Server) CreateKeyspace(keyspace string, value *topo.Keyspace) error {
}
// UpdateKeyspace implements topo.Server.
func (s *Server) UpdateKeyspace(ki *topo.KeyspaceInfo, existingVersion int64) (int64, error) {
func (s *Server) UpdateKeyspace(ctx context.Context, ki *topo.KeyspaceInfo, existingVersion int64) (int64, error) {
data := jscfg.ToJSON(ki.Keyspace)
resp, err := s.getGlobal().CompareAndSwap(keyspaceFilePath(ki.KeyspaceName()),
@ -64,7 +65,7 @@ func (s *Server) UpdateKeyspace(ki *topo.KeyspaceInfo, existingVersion int64) (i
}
// GetKeyspace implements topo.Server.
func (s *Server) GetKeyspace(keyspace string) (*topo.KeyspaceInfo, error) {
func (s *Server) GetKeyspace(ctx context.Context, keyspace string) (*topo.KeyspaceInfo, error) {
resp, err := s.getGlobal().Get(keyspaceFilePath(keyspace), false /* sort */, false /* recursive */)
if err != nil {
return nil, convertError(err)
@ -82,7 +83,7 @@ func (s *Server) GetKeyspace(keyspace string) (*topo.KeyspaceInfo, error) {
}
// GetKeyspaces implements topo.Server.
func (s *Server) GetKeyspaces() ([]string, error) {
func (s *Server) GetKeyspaces(ctx context.Context) ([]string, error) {
resp, err := s.getGlobal().Get(keyspacesDirPath, true /* sort */, false /* recursive */)
if err != nil {
err = convertError(err)
@ -95,8 +96,8 @@ func (s *Server) GetKeyspaces() ([]string, error) {
}
// DeleteKeyspaceShards implements topo.Server.
func (s *Server) DeleteKeyspaceShards(keyspace string) error {
shards, err := s.GetShardNames(keyspace)
func (s *Server) DeleteKeyspaceShards(ctx context.Context, keyspace string) error {
shards, err := s.GetShardNames(ctx, keyspace)
if err != nil {
return err
}

Просмотреть файл

@ -185,7 +185,7 @@ func (s *Server) LockSrvShardForAction(ctx context.Context, cellName, keyspace,
}
// UnlockSrvShardForAction implements topo.Server.
func (s *Server) UnlockSrvShardForAction(cellName, keyspace, shard, actionPath, results string) error {
func (s *Server) UnlockSrvShardForAction(ctx context.Context, cellName, keyspace, shard, actionPath, results string) error {
log.Infof("results of %v: %v", actionPath, results)
cell, err := s.getCell(cellName)
@ -204,7 +204,7 @@ func (s *Server) LockKeyspaceForAction(ctx context.Context, keyspace, contents s
}
// UnlockKeyspaceForAction implements topo.Server.
func (s *Server) UnlockKeyspaceForAction(keyspace, actionPath, results string) error {
func (s *Server) UnlockKeyspaceForAction(ctx context.Context, keyspace, actionPath, results string) error {
log.Infof("results of %v: %v", actionPath, results)
return unlock(s.getGlobal(), keyspaceDirPath(keyspace), actionPath,
@ -218,7 +218,7 @@ func (s *Server) LockShardForAction(ctx context.Context, keyspace, shard, conten
}
// UnlockShardForAction implements topo.Server.
func (s *Server) UnlockShardForAction(keyspace, shard, actionPath, results string) error {
func (s *Server) UnlockShardForAction(ctx context.Context, keyspace, shard, actionPath, results string) error {
log.Infof("results of %v: %v", actionPath, results)
return unlock(s.getGlobal(), shardDirPath(keyspace, shard), actionPath,

Просмотреть файл

@ -10,10 +10,11 @@ import (
"github.com/youtube/vitess/go/jscfg"
"github.com/youtube/vitess/go/vt/topo"
"golang.org/x/net/context"
)
// UpdateShardReplicationFields implements topo.Server.
func (s *Server) UpdateShardReplicationFields(cell, keyspace, shard string, updateFunc func(*topo.ShardReplication) error) error {
func (s *Server) UpdateShardReplicationFields(ctx context.Context, cell, keyspace, shard string, updateFunc func(*topo.ShardReplication) error) error {
var sri *topo.ShardReplicationInfo
var version int64
var err error
@ -82,7 +83,7 @@ func (s *Server) createShardReplication(sri *topo.ShardReplicationInfo) (int64,
}
// GetShardReplication implements topo.Server.
func (s *Server) GetShardReplication(cell, keyspace, shard string) (*topo.ShardReplicationInfo, error) {
func (s *Server) GetShardReplication(ctx context.Context, cell, keyspace, shard string) (*topo.ShardReplicationInfo, error) {
sri, _, err := s.getShardReplication(cell, keyspace, shard)
return sri, err
}
@ -110,7 +111,7 @@ func (s *Server) getShardReplication(cellName, keyspace, shard string) (*topo.Sh
}
// DeleteShardReplication implements topo.Server.
func (s *Server) DeleteShardReplication(cellName, keyspace, shard string) error {
func (s *Server) DeleteShardReplication(ctx context.Context, cellName, keyspace, shard string) error {
cell, err := s.getCell(cellName)
if err != nil {
return err

Просмотреть файл

@ -25,6 +25,7 @@ import (
"sync"
"github.com/youtube/vitess/go/vt/topo"
"golang.org/x/net/context"
)
// Server is the implementation of topo.Server for etcd.
@ -52,7 +53,7 @@ func (s *Server) Close() {
}
// GetKnownCells implements topo.Server.
func (s *Server) GetKnownCells() ([]string, error) {
func (s *Server) GetKnownCells(ctx context.Context) ([]string, error) {
resp, err := s.getGlobal().Get(cellsDirPath, true /* sort */, false /* recursive */)
if err != nil {
return nil, convertError(err)

Просмотреть файл

@ -31,73 +31,83 @@ func newTestServer(t *testing.T, cells []string) *Server {
}
func TestKeyspace(t *testing.T) {
ctx := context.Background()
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckKeyspace(t, ts)
test.CheckKeyspace(ctx, t, ts)
}
func TestShard(t *testing.T) {
ctx := context.Background()
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckShard(context.Background(), t, ts)
test.CheckShard(ctx, t, ts)
}
func TestTablet(t *testing.T) {
ctx := context.Background()
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckTablet(context.Background(), t, ts)
test.CheckTablet(ctx, t, ts)
}
func TestShardReplication(t *testing.T) {
ctx := context.Background()
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckShardReplication(t, ts)
test.CheckShardReplication(ctx, t, ts)
}
func TestServingGraph(t *testing.T) {
ctx := context.Background()
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckServingGraph(context.Background(), t, ts)
test.CheckServingGraph(ctx, t, ts)
}
func TestWatchEndPoints(t *testing.T) {
ctx := context.Background()
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckWatchEndPoints(context.Background(), t, ts)
test.CheckWatchEndPoints(ctx, t, ts)
}
func TestKeyspaceLock(t *testing.T) {
ctx := context.Background()
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckKeyspaceLock(t, ts)
test.CheckKeyspaceLock(ctx, t, ts)
}
func TestShardLock(t *testing.T) {
ctx := context.Background()
if testing.Short() {
t.Skip("skipping wait-based test in short mode.")
}
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckShardLock(t, ts)
test.CheckShardLock(ctx, t, ts)
}
func TestSrvShardLock(t *testing.T) {
ctx := context.Background()
if testing.Short() {
t.Skip("skipping wait-based test in short mode.")
}
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckSrvShardLock(t, ts)
test.CheckSrvShardLock(ctx, t, ts)
}
func TestVSchema(t *testing.T) {
ctx := context.Background()
if testing.Short() {
t.Skip("skipping wait-based test in short mode.")
}
ts := newTestServer(t, []string{"test"})
defer ts.Close()
test.CheckVSchema(t, ts)
test.CheckVSchema(ctx, t, ts)
}

Просмотреть файл

@ -14,6 +14,7 @@ import (
log "github.com/golang/glog"
"github.com/youtube/vitess/go/jscfg"
"github.com/youtube/vitess/go/vt/topo"
"golang.org/x/net/context"
)
// WatchSleepDuration is how many seconds interval to poll for in case
@ -22,7 +23,7 @@ import (
var WatchSleepDuration = 30 * time.Second
// GetSrvTabletTypesPerShard implements topo.Server.
func (s *Server) GetSrvTabletTypesPerShard(cellName, keyspace, shard string) ([]topo.TabletType, error) {
func (s *Server) GetSrvTabletTypesPerShard(ctx context.Context, cellName, keyspace, shard string) ([]topo.TabletType, error) {
cell, err := s.getCell(cellName)
if err != nil {
return nil, err
@ -44,7 +45,7 @@ func (s *Server) GetSrvTabletTypesPerShard(cellName, keyspace, shard string) ([]
}
// UpdateEndPoints implements topo.Server.
func (s *Server) UpdateEndPoints(cellName, keyspace, shard string, tabletType topo.TabletType, addrs *topo.EndPoints) error {
func (s *Server) UpdateEndPoints(ctx context.Context, cellName, keyspace, shard string, tabletType topo.TabletType, addrs *topo.EndPoints) error {
cell, err := s.getCell(cellName)
if err != nil {
return err
@ -71,7 +72,7 @@ func (s *Server) updateEndPoints(cellName, keyspace, shard string, tabletType to
}
// GetEndPoints implements topo.Server.
func (s *Server) GetEndPoints(cell, keyspace, shard string, tabletType topo.TabletType) (*topo.EndPoints, error) {
func (s *Server) GetEndPoints(ctx context.Context, cell, keyspace, shard string, tabletType topo.TabletType) (*topo.EndPoints, error) {
value, _, err := s.getEndPoints(cell, keyspace, shard, tabletType)
return value, err
}
@ -100,7 +101,7 @@ func (s *Server) getEndPoints(cellName, keyspace, shard string, tabletType topo.
}
// DeleteEndPoints implements topo.Server.
func (s *Server) DeleteEndPoints(cellName, keyspace, shard string, tabletType topo.TabletType) error {
func (s *Server) DeleteEndPoints(ctx context.Context, cellName, keyspace, shard string, tabletType topo.TabletType) error {
cell, err := s.getCell(cellName)
if err != nil {
return err
@ -111,7 +112,7 @@ func (s *Server) DeleteEndPoints(cellName, keyspace, shard string, tabletType to
}
// UpdateSrvShard implements topo.Server.
func (s *Server) UpdateSrvShard(cellName, keyspace, shard string, srvShard *topo.SrvShard) error {
func (s *Server) UpdateSrvShard(ctx context.Context, cellName, keyspace, shard string, srvShard *topo.SrvShard) error {
cell, err := s.getCell(cellName)
if err != nil {
return err
@ -124,7 +125,7 @@ func (s *Server) UpdateSrvShard(cellName, keyspace, shard string, srvShard *topo
}
// GetSrvShard implements topo.Server.
func (s *Server) GetSrvShard(cellName, keyspace, shard string) (*topo.SrvShard, error) {
func (s *Server) GetSrvShard(ctx context.Context, cellName, keyspace, shard string) (*topo.SrvShard, error) {
cell, err := s.getCell(cellName)
if err != nil {
return nil, err
@ -146,7 +147,7 @@ func (s *Server) GetSrvShard(cellName, keyspace, shard string) (*topo.SrvShard,
}
// DeleteSrvShard implements topo.Server.
func (s *Server) DeleteSrvShard(cellName, keyspace, shard string) error {
func (s *Server) DeleteSrvShard(ctx context.Context, cellName, keyspace, shard string) error {
cell, err := s.getCell(cellName)
if err != nil {
return err
@ -157,7 +158,7 @@ func (s *Server) DeleteSrvShard(cellName, keyspace, shard string) error {
}
// UpdateSrvKeyspace implements topo.Server.
func (s *Server) UpdateSrvKeyspace(cellName, keyspace string, srvKeyspace *topo.SrvKeyspace) error {
func (s *Server) UpdateSrvKeyspace(ctx context.Context, cellName, keyspace string, srvKeyspace *topo.SrvKeyspace) error {
cell, err := s.getCell(cellName)
if err != nil {
return err
@ -170,7 +171,7 @@ func (s *Server) UpdateSrvKeyspace(cellName, keyspace string, srvKeyspace *topo.
}
// GetSrvKeyspace implements topo.Server.
func (s *Server) GetSrvKeyspace(cellName, keyspace string) (*topo.SrvKeyspace, error) {
func (s *Server) GetSrvKeyspace(ctx context.Context, cellName, keyspace string) (*topo.SrvKeyspace, error) {
cell, err := s.getCell(cellName)
if err != nil {
return nil, err
@ -192,7 +193,7 @@ func (s *Server) GetSrvKeyspace(cellName, keyspace string) (*topo.SrvKeyspace, e
}
// GetSrvKeyspaceNames implements topo.Server.
func (s *Server) GetSrvKeyspaceNames(cellName string) ([]string, error) {
func (s *Server) GetSrvKeyspaceNames(ctx context.Context, cellName string) ([]string, error) {
cell, err := s.getCell(cellName)
if err != nil {
return nil, err
@ -206,7 +207,7 @@ func (s *Server) GetSrvKeyspaceNames(cellName string) ([]string, error) {
}
// UpdateTabletEndpoint implements topo.Server.
func (s *Server) UpdateTabletEndpoint(cell, keyspace, shard string, tabletType topo.TabletType, addr *topo.EndPoint) error {
func (s *Server) UpdateTabletEndpoint(ctx context.Context, cell, keyspace, shard string, tabletType topo.TabletType, addr *topo.EndPoint) error {
for {
addrs, version, err := s.getEndPoints(cell, keyspace, shard, tabletType)
if err == topo.ErrNoNode {
@ -239,7 +240,7 @@ func (s *Server) UpdateTabletEndpoint(cell, keyspace, shard string, tabletType t
}
// WatchEndPoints is part of the topo.Server interface
func (s *Server) WatchEndPoints(cellName, keyspace, shard string, tabletType topo.TabletType) (<-chan *topo.EndPoints, chan<- struct{}, error) {
func (s *Server) WatchEndPoints(ctx context.Context, cellName, keyspace, shard string, tabletType topo.TabletType) (<-chan *topo.EndPoints, chan<- struct{}, error) {
cell, err := s.getCell(cellName)
if err != nil {
return nil, nil, fmt.Errorf("WatchEndPoints cannot get cell: %v", err)

Просмотреть файл

@ -12,10 +12,11 @@ import (
"github.com/youtube/vitess/go/jscfg"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/topo/events"
"golang.org/x/net/context"
)
// CreateShard implements topo.Server.
func (s *Server) CreateShard(keyspace, shard string, value *topo.Shard) error {
func (s *Server) CreateShard(ctx context.Context, keyspace, shard string, value *topo.Shard) error {
data := jscfg.ToJSON(value)
global := s.getGlobal()
@ -42,7 +43,7 @@ func (s *Server) CreateShard(keyspace, shard string, value *topo.Shard) error {
}
// UpdateShard implements topo.Server.
func (s *Server) UpdateShard(si *topo.ShardInfo, existingVersion int64) (int64, error) {
func (s *Server) UpdateShard(ctx context.Context, si *topo.ShardInfo, existingVersion int64) (int64, error) {
data := jscfg.ToJSON(si.Shard)
resp, err := s.getGlobal().CompareAndSwap(shardFilePath(si.Keyspace(), si.ShardName()),
@ -62,13 +63,13 @@ func (s *Server) UpdateShard(si *topo.ShardInfo, existingVersion int64) (int64,
}
// ValidateShard implements topo.Server.
func (s *Server) ValidateShard(keyspace, shard string) error {
_, err := s.GetShard(keyspace, shard)
func (s *Server) ValidateShard(ctx context.Context, keyspace, shard string) error {
_, err := s.GetShard(ctx, keyspace, shard)
return err
}
// GetShard implements topo.Server.
func (s *Server) GetShard(keyspace, shard string) (*topo.ShardInfo, error) {
func (s *Server) GetShard(ctx context.Context, keyspace, shard string) (*topo.ShardInfo, error) {
resp, err := s.getGlobal().Get(shardFilePath(keyspace, shard), false /* sort */, false /* recursive */)
if err != nil {
return nil, convertError(err)
@ -86,7 +87,7 @@ func (s *Server) GetShard(keyspace, shard string) (*topo.ShardInfo, error) {
}
// GetShardNames implements topo.Server.
func (s *Server) GetShardNames(keyspace string) ([]string, error) {
func (s *Server) GetShardNames(ctx context.Context, keyspace string) ([]string, error) {
resp, err := s.getGlobal().Get(shardsDirPath(keyspace), true /* sort */, false /* recursive */)
if err != nil {
return nil, convertError(err)
@ -95,7 +96,7 @@ func (s *Server) GetShardNames(keyspace string) ([]string, error) {
}
// DeleteShard implements topo.Server.
func (s *Server) DeleteShard(keyspace, shard string) error {
func (s *Server) DeleteShard(ctx context.Context, keyspace, shard string) error {
_, err := s.getGlobal().Delete(shardDirPath(keyspace, shard), true /* recursive */)
if err != nil {
return convertError(err)

Просмотреть файл

@ -12,10 +12,11 @@ import (
"github.com/youtube/vitess/go/jscfg"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/topo/events"
"golang.org/x/net/context"
)
// CreateTablet implements topo.Server.
func (s *Server) CreateTablet(tablet *topo.Tablet) error {
func (s *Server) CreateTablet(ctx context.Context, tablet *topo.Tablet) error {
cell, err := s.getCell(tablet.Alias.Cell)
if err != nil {
return err
@ -35,7 +36,7 @@ func (s *Server) CreateTablet(tablet *topo.Tablet) error {
}
// UpdateTablet implements topo.Server.
func (s *Server) UpdateTablet(ti *topo.TabletInfo, existingVersion int64) (int64, error) {
func (s *Server) UpdateTablet(ctx context.Context, ti *topo.TabletInfo, existingVersion int64) (int64, error) {
cell, err := s.getCell(ti.Alias.Cell)
if err != nil {
return -1, err
@ -59,18 +60,18 @@ func (s *Server) UpdateTablet(ti *topo.TabletInfo, existingVersion int64) (int64
}
// UpdateTabletFields implements topo.Server.
func (s *Server) UpdateTabletFields(tabletAlias topo.TabletAlias, updateFunc func(*topo.Tablet) error) error {
func (s *Server) UpdateTabletFields(ctx context.Context, tabletAlias topo.TabletAlias, updateFunc func(*topo.Tablet) error) error {
var ti *topo.TabletInfo
var err error
for {
if ti, err = s.GetTablet(tabletAlias); err != nil {
if ti, err = s.GetTablet(ctx, tabletAlias); err != nil {
return err
}
if err = updateFunc(ti.Tablet); err != nil {
return err
}
if _, err = s.UpdateTablet(ti, ti.Version()); err != topo.ErrBadVersion {
if _, err = s.UpdateTablet(ctx, ti, ti.Version()); err != topo.ErrBadVersion {
break
}
}
@ -86,14 +87,14 @@ func (s *Server) UpdateTabletFields(tabletAlias topo.TabletAlias, updateFunc fun
}
// DeleteTablet implements topo.Server.
func (s *Server) DeleteTablet(tabletAlias topo.TabletAlias) error {
func (s *Server) DeleteTablet(ctx context.Context, tabletAlias topo.TabletAlias) error {
cell, err := s.getCell(tabletAlias.Cell)
if err != nil {
return err
}
// Get the keyspace and shard names for the TabletChange event.
ti, tiErr := s.GetTablet(tabletAlias)
ti, tiErr := s.GetTablet(ctx, tabletAlias)
_, err = cell.Delete(tabletDirPath(tabletAlias.String()), true /* recursive */)
if err != nil {
@ -116,13 +117,13 @@ func (s *Server) DeleteTablet(tabletAlias topo.TabletAlias) error {
}
// ValidateTablet implements topo.Server.
func (s *Server) ValidateTablet(tabletAlias topo.TabletAlias) error {
_, err := s.GetTablet(tabletAlias)
func (s *Server) ValidateTablet(ctx context.Context, tabletAlias topo.TabletAlias) error {
_, err := s.GetTablet(ctx, tabletAlias)
return err
}
// GetTablet implements topo.Server.
func (s *Server) GetTablet(tabletAlias topo.TabletAlias) (*topo.TabletInfo, error) {
func (s *Server) GetTablet(ctx context.Context, tabletAlias topo.TabletAlias) (*topo.TabletInfo, error) {
cell, err := s.getCell(tabletAlias.Cell)
if err != nil {
return nil, err
@ -145,7 +146,7 @@ func (s *Server) GetTablet(tabletAlias topo.TabletAlias) (*topo.TabletInfo, erro
}
// GetTabletsByCell implements topo.Server.
func (s *Server) GetTabletsByCell(cellName string) ([]topo.TabletAlias, error) {
func (s *Server) GetTabletsByCell(ctx context.Context, cellName string) ([]topo.TabletAlias, error) {
cell, err := s.getCell(cellName)
if err != nil {
return nil, err

Просмотреть файл

@ -3,6 +3,7 @@ package etcdtopo
import (
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtgate/planbuilder"
"golang.org/x/net/context"
// vindexes needs to be imported so that they register
// themselves against vtgate/planbuilder. This will allow
// us to sanity check the schema being uploaded.
@ -14,7 +15,7 @@ This file contains the vschema management code for etcdtopo.Server
*/
// SaveVSchema saves the JSON vschema into the topo.
func (s *Server) SaveVSchema(vschema string) error {
func (s *Server) SaveVSchema(ctx context.Context, vschema string) error {
_, err := planbuilder.NewSchema([]byte(vschema))
if err != nil {
return err
@ -28,7 +29,7 @@ func (s *Server) SaveVSchema(vschema string) error {
}
// GetVSchema fetches the JSON vschema from the topo.
func (s *Server) GetVSchema() (string, error) {
func (s *Server) GetVSchema(ctx context.Context) (string, error) {
resp, err := s.getGlobal().Get(vschemaPath, false /* sort */, false /* recursive */)
if err != nil {
err = convertError(err)

Просмотреть файл

@ -9,7 +9,6 @@ import (
"time"
"github.com/youtube/vitess/go/vt/concurrency"
"github.com/youtube/vitess/go/vt/topo"
)
var (
@ -26,11 +25,11 @@ func init() {
type Reporter interface {
// Report returns the replication delay gathered by this
// module (or 0 if it thinks it's not behind), assuming that
// its tablet type is TabletType, and that its query service
// it is a slave type or not, and that its query service
// should be running or not. If Report returns an error it
// implies that the tablet is in a bad shape and not able to
// handle queries.
Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (replicationDelay time.Duration, err error)
Report(isSlaveType, shouldQueryServiceBeRunning bool) (replicationDelay time.Duration, err error)
// HTMLName returns a displayable name for the module.
// Can be used to be displayed in the status page.
@ -38,11 +37,11 @@ type Reporter interface {
}
// FunctionReporter is a function that may act as a Reporter.
type FunctionReporter func(topo.TabletType, bool) (time.Duration, error)
type FunctionReporter func(bool, bool) (time.Duration, error)
// Report implements Reporter.Report
func (fc FunctionReporter) Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
return fc(tabletType, shouldQueryServiceBeRunning)
func (fc FunctionReporter) Report(isSlaveType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
return fc(isSlaveType, shouldQueryServiceBeRunning)
}
// HTMLName implements Reporter.HTMLName
@ -71,7 +70,7 @@ func NewAggregator() *Aggregator {
// The returned replication delay will be the highest of all the replication
// delays returned by the Reporter implementations (although typically
// only one implementation will actually return a meaningful one).
func (ag *Aggregator) Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
func (ag *Aggregator) Report(isSlaveType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
var (
wg sync.WaitGroup
rec concurrency.AllErrorRecorder
@ -83,7 +82,7 @@ func (ag *Aggregator) Report(tabletType topo.TabletType, shouldQueryServiceBeRun
wg.Add(1)
go func(name string, rep Reporter) {
defer wg.Done()
replicationDelay, err := rep.Report(tabletType, shouldQueryServiceBeRunning)
replicationDelay, err := rep.Report(isSlaveType, shouldQueryServiceBeRunning)
if err != nil {
rec.RecordError(fmt.Errorf("%v: %v", name, err))
return

Просмотреть файл

@ -4,23 +4,21 @@ import (
"errors"
"testing"
"time"
"github.com/youtube/vitess/go/vt/topo"
)
func TestReporters(t *testing.T) {
ag := NewAggregator()
ag.Register("a", FunctionReporter(func(topo.TabletType, bool) (time.Duration, error) {
ag.Register("a", FunctionReporter(func(bool, bool) (time.Duration, error) {
return 10 * time.Second, nil
}))
ag.Register("b", FunctionReporter(func(topo.TabletType, bool) (time.Duration, error) {
ag.Register("b", FunctionReporter(func(bool, bool) (time.Duration, error) {
return 5 * time.Second, nil
}))
delay, err := ag.Report(topo.TYPE_REPLICA, true)
delay, err := ag.Report(true, true)
if err != nil {
t.Error(err)
@ -29,10 +27,10 @@ func TestReporters(t *testing.T) {
t.Errorf("delay=%v, want 10s", delay)
}
ag.Register("c", FunctionReporter(func(topo.TabletType, bool) (time.Duration, error) {
ag.Register("c", FunctionReporter(func(bool, bool) (time.Duration, error) {
return 0, errors.New("e error")
}))
if _, err := ag.Report(topo.TYPE_REPLICA, false); err == nil {
if _, err := ag.Report(true, false); err == nil {
t.Errorf("ag.Run: expected error")
}

Просмотреть файл

@ -116,9 +116,14 @@ func (hook *Hook) Execute() (result *HookResult) {
// Execute an optional hook, returns a printable error
func (hook *Hook) ExecuteOptional() error {
hr := hook.Execute()
if hr.ExitStatus == HOOK_DOES_NOT_EXIST {
switch hr.ExitStatus {
case HOOK_DOES_NOT_EXIST:
log.Infof("%v hook doesn't exist", hook.Name)
} else if hr.ExitStatus != HOOK_SUCCESS {
case HOOK_VTROOT_ERROR:
log.Infof("VTROOT not set, so %v hook doesn't exist", hook.Name)
case HOOK_SUCCESS:
// nothing to do here
default:
return fmt.Errorf("%v hook failed(%v): %v", hook.Name, hr.ExitStatus, hr.Stderr)
}
return nil

557
go/vt/mysqlctl/backup.go Normal file
Просмотреть файл

@ -0,0 +1,557 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mysqlctl
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"strings"
"sync"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/cgzip"
"github.com/youtube/vitess/go/sync2"
"github.com/youtube/vitess/go/vt/concurrency"
"github.com/youtube/vitess/go/vt/logutil"
"github.com/youtube/vitess/go/vt/mysqlctl/backupstorage"
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
)
// This file handles the backup and restore related code
const (
// the three bases for files to restore
backupInnodbDataHomeDir = "InnoDBData"
backupInnodbLogGroupHomeDir = "InnoDBLog"
backupData = "Data"
// the manifest file name
backupManifest = "MANIFEST"
)
const (
// slaveStartDeadline is the deadline for starting a slave
slaveStartDeadline = 30
)
var (
// ErrNoBackup is returned when there is no backup
ErrNoBackup = errors.New("no available backup")
)
// FileEntry is one file to backup
type FileEntry struct {
// Base is one of:
// - backupInnodbDataHomeDir for files that go into Mycnf.InnodbDataHomeDir
// - backupInnodbLogGroupHomeDir for files that go into Mycnf.InnodbLogGroupHomeDir
// - backupData for files that go into Mycnf.DataDir
Base string
// Name is the file name, relative to Base
Name string
// Hash is the hash of the gzip compressed data stored in the
// BackupStorage.
Hash string
}
func (fe *FileEntry) open(cnf *Mycnf, readOnly bool) (*os.File, error) {
// find the root to use
var root string
switch fe.Base {
case backupInnodbDataHomeDir:
root = cnf.InnodbDataHomeDir
case backupInnodbLogGroupHomeDir:
root = cnf.InnodbLogGroupHomeDir
case backupData:
root = cnf.DataDir
default:
return nil, fmt.Errorf("unknown base: %v", fe.Base)
}
// and open the file
name := path.Join(root, fe.Name)
var fd *os.File
var err error
if readOnly {
fd, err = os.Open(name)
} else {
dir := path.Dir(name)
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return nil, fmt.Errorf("cannot create destination directory %v: %v", dir, err)
}
fd, err = os.Create(name)
}
if err != nil {
return nil, fmt.Errorf("cannot open source file %v: %v", name, err)
}
return fd, nil
}
// BackupManifest represents the backup. It lists all the files, and
// the ReplicationPosition that the backup was taken at.
type BackupManifest struct {
// FileEntries contains all the files in the backup
FileEntries []FileEntry
// ReplicationPosition is the position at which the backup was taken
ReplicationPosition proto.ReplicationPosition
}
// isDbDir returns true if the given directory contains a DB
func isDbDir(p string) bool {
// db.opt is there
if _, err := os.Stat(path.Join(p, "db.opt")); err == nil {
return true
}
// Look for at least one .frm file
fis, err := ioutil.ReadDir(p)
if err != nil {
return false
}
for _, fi := range fis {
if strings.HasSuffix(fi.Name(), ".frm") {
return true
}
}
return false
}
func addDirectory(fes []FileEntry, base string, baseDir string, subDir string) ([]FileEntry, error) {
p := path.Join(baseDir, subDir)
fis, err := ioutil.ReadDir(p)
if err != nil {
return nil, err
}
for _, fi := range fis {
fes = append(fes, FileEntry{
Base: base,
Name: path.Join(subDir, fi.Name()),
})
}
return fes, nil
}
func findFilesTobackup(cnf *Mycnf) ([]FileEntry, error) {
var err error
var result []FileEntry
// first add inno db files
result, err = addDirectory(result, backupInnodbDataHomeDir, cnf.InnodbDataHomeDir, "")
if err != nil {
return nil, err
}
result, err = addDirectory(result, backupInnodbLogGroupHomeDir, cnf.InnodbLogGroupHomeDir, "")
if err != nil {
return nil, err
}
// then add DB directories
fis, err := ioutil.ReadDir(cnf.DataDir)
if err != nil {
return nil, err
}
for _, fi := range fis {
p := path.Join(cnf.DataDir, fi.Name())
if isDbDir(p) {
result, err = addDirectory(result, backupData, cnf.DataDir, fi.Name())
if err != nil {
return nil, err
}
}
}
return result, nil
}
// Backup is the main entry point for a backup:
// - uses the BackupStorage service to store a new backup
// - shuts down Mysqld during the backup
// - remember if we were replicating, restore the exact same state
func Backup(mysqld MysqlDaemon, logger logutil.Logger, bucket, name string, backupConcurrency int, hookExtraEnv map[string]string) error {
// start the backup with the BackupStorage
bs, err := backupstorage.GetBackupStorage()
if err != nil {
return err
}
bh, err := bs.StartBackup(bucket, name)
if err != nil {
return fmt.Errorf("StartBackup failed: %v", err)
}
if err = backup(mysqld, logger, bh, backupConcurrency, hookExtraEnv); err != nil {
if abortErr := bh.AbortBackup(); abortErr != nil {
logger.Errorf("failed to abort backup: %v", abortErr)
}
return err
}
return bh.EndBackup()
}
func backup(mysqld MysqlDaemon, logger logutil.Logger, bh backupstorage.BackupHandle, backupConcurrency int, hookExtraEnv map[string]string) error {
// save initial state so we can restore
slaveStartRequired := false
sourceIsMaster := false
readOnly := true
var replicationPosition proto.ReplicationPosition
// see if we need to restart replication after backup
logger.Infof("getting current replication status")
slaveStatus, err := mysqld.SlaveStatus()
switch err {
case nil:
slaveStartRequired = slaveStatus.SlaveRunning()
case ErrNotSlave:
// keep going if we're the master, might be a degenerate case
sourceIsMaster = true
default:
return fmt.Errorf("cannot get slave status: %v", err)
}
// get the read-only flag
readOnly, err = mysqld.IsReadOnly()
if err != nil {
return fmt.Errorf("cannot get read only status: %v", err)
}
// get the replication position
if sourceIsMaster {
if !readOnly {
logger.Infof("turning master read-onyl before backup")
if err = mysqld.SetReadOnly(true); err != nil {
return fmt.Errorf("cannot get read only status: %v", err)
}
}
replicationPosition, err = mysqld.MasterPosition()
if err != nil {
return fmt.Errorf("cannot get master position: %v", err)
}
} else {
if err = StopSlave(mysqld, hookExtraEnv); err != nil {
return fmt.Errorf("cannot stop slave: %v", err)
}
var slaveStatus proto.ReplicationStatus
slaveStatus, err = mysqld.SlaveStatus()
if err != nil {
return fmt.Errorf("cannot get slave status: %v", err)
}
replicationPosition = slaveStatus.Position
}
logger.Infof("using replication position: %v", replicationPosition)
// shutdown mysqld
if err = mysqld.Shutdown(true, MysqlWaitTime); err != nil {
return fmt.Errorf("cannot shutdown mysqld: %v", err)
}
// get the files to backup
fes, err := findFilesTobackup(mysqld.Cnf())
if err != nil {
return fmt.Errorf("cannot find files to backup: %v", err)
}
logger.Infof("found %v files to backup", len(fes))
// backup everything
if err := backupFiles(mysqld, logger, bh, fes, replicationPosition, backupConcurrency); err != nil {
return fmt.Errorf("cannot backup files: %v", err)
}
// Try to restart mysqld
if err := mysqld.Start(MysqlWaitTime); err != nil {
return fmt.Errorf("cannot restart mysqld: %v", err)
}
// Restore original mysqld state that we saved above.
if slaveStartRequired {
logger.Infof("restarting mysql replication")
if err := StartSlave(mysqld, hookExtraEnv); err != nil {
return fmt.Errorf("cannot restart slave: %v", err)
}
// this should be quick, but we might as well just wait
if err := WaitForSlaveStart(mysqld, slaveStartDeadline); err != nil {
return fmt.Errorf("slave is not restarting: %v", err)
}
}
// And set read-only mode
logger.Infof("resetting mysqld read-only to %v", readOnly)
if err := mysqld.SetReadOnly(readOnly); err != nil {
return err
}
return nil
}
func backupFiles(mysqld MysqlDaemon, logger logutil.Logger, bh backupstorage.BackupHandle, fes []FileEntry, replicationPosition proto.ReplicationPosition, backupConcurrency int) (err error) {
sema := sync2.NewSemaphore(backupConcurrency, 0)
rec := concurrency.AllErrorRecorder{}
wg := sync.WaitGroup{}
for i, fe := range fes {
wg.Add(1)
go func(i int, fe FileEntry) {
defer wg.Done()
// wait until we are ready to go, skip if we already
// encountered an error
sema.Acquire()
defer sema.Release()
if rec.HasErrors() {
return
}
// open the source file for reading
source, err := fe.open(mysqld.Cnf(), true)
if err != nil {
rec.RecordError(err)
return
}
defer source.Close()
// open the destination file for writing, and a buffer
name := fmt.Sprintf("%v", i)
wc, err := bh.AddFile(name)
if err != nil {
rec.RecordError(fmt.Errorf("cannot add file: %v", err))
return
}
defer func() { rec.RecordError(wc.Close()) }()
dst := bufio.NewWriterSize(wc, 2*1024*1024)
// create the hasher and the tee on top
hasher := newHasher()
tee := io.MultiWriter(dst, hasher)
// create the gzip compression filter
gzip, err := cgzip.NewWriterLevel(tee, cgzip.Z_BEST_SPEED)
if err != nil {
rec.RecordError(fmt.Errorf("cannot create gziper: %v", err))
return
}
// copy from the source file to gzip to tee to output file and hasher
_, err = io.Copy(gzip, source)
if err != nil {
rec.RecordError(fmt.Errorf("cannot copy data: %v", err))
return
}
// close gzip to flush it, after that the hash is good
if err = gzip.Close(); err != nil {
rec.RecordError(fmt.Errorf("cannot close gzip: %v", err))
return
}
// flush the buffer to finish writing, save the hash
rec.RecordError(dst.Flush())
fes[i].Hash = hasher.HashString()
}(i, fe)
}
wg.Wait()
if rec.HasErrors() {
return rec.Error()
}
// open the MANIFEST
wc, err := bh.AddFile(backupManifest)
if err != nil {
return fmt.Errorf("cannot add %v to backup: %v", backupManifest, err)
}
defer func() {
if closeErr := wc.Close(); err == nil {
err = closeErr
}
}()
// JSON-encode and write the MANIFEST
bm := &BackupManifest{
FileEntries: fes,
ReplicationPosition: replicationPosition,
}
data, err := json.MarshalIndent(bm, "", " ")
if err != nil {
return fmt.Errorf("cannot JSON encode %v: %v", backupManifest, err)
}
if _, err := wc.Write([]byte(data)); err != nil {
return fmt.Errorf("cannot write %v: %v", backupManifest, err)
}
return nil
}
// checkNoDB makes sure there is no vt_ db already there. Used by Restore,
// we do not wnat to destroy an existing DB.
func checkNoDB(mysqld MysqlDaemon) error {
qr, err := mysqld.FetchSuperQuery("SHOW DATABASES")
if err != nil {
return fmt.Errorf("checkNoDB failed: %v", err)
}
for _, row := range qr.Rows {
if strings.HasPrefix(row[0].String(), "vt_") {
dbName := row[0].String()
tableQr, err := mysqld.FetchSuperQuery("SHOW TABLES FROM " + dbName)
if err != nil {
return fmt.Errorf("checkNoDB failed: %v", err)
} else if len(tableQr.Rows) == 0 {
// no tables == empty db, all is well
continue
}
return fmt.Errorf("checkNoDB failed, found active db %v", dbName)
}
}
return nil
}
// restoreFiles will copy all the files from the BackupStorage to the
// right place
func restoreFiles(cnf *Mycnf, bh backupstorage.BackupHandle, fes []FileEntry, restoreConcurrency int) error {
sema := sync2.NewSemaphore(restoreConcurrency, 0)
rec := concurrency.AllErrorRecorder{}
wg := sync.WaitGroup{}
for i, fe := range fes {
wg.Add(1)
go func(i int, fe FileEntry) {
defer wg.Done()
// wait until we are ready to go, skip if we already
// encountered an error
sema.Acquire()
defer sema.Release()
if rec.HasErrors() {
return
}
// open the source file for reading
name := fmt.Sprintf("%v", i)
source, err := bh.ReadFile(name)
if err != nil {
rec.RecordError(err)
return
}
defer source.Close()
// open the destination file for writing
dstFile, err := fe.open(cnf, false)
if err != nil {
rec.RecordError(err)
return
}
defer func() { rec.RecordError(dstFile.Close()) }()
// create a buffering output
dst := bufio.NewWriterSize(dstFile, 2*1024*1024)
// create hash to write the compressed data to
hasher := newHasher()
// create a Tee: we split the input into the hasher
// and into the gunziper
tee := io.TeeReader(source, hasher)
// create the uncompresser
gz, err := cgzip.NewReader(tee)
if err != nil {
rec.RecordError(err)
return
}
defer func() { rec.RecordError(gz.Close()) }()
// copy the data. Will also write to the hasher
if _, err = io.Copy(dst, gz); err != nil {
rec.RecordError(err)
return
}
// check the hash
hash := hasher.HashString()
if hash != fe.Hash {
rec.RecordError(fmt.Errorf("hash mismatch for %v, got %v expected %v", fe.Name, hash, fe.Hash))
return
}
// flush the buffer
rec.RecordError(dst.Flush())
}(i, fe)
}
wg.Wait()
return rec.Error()
}
// Restore is the main entry point for backup restore. If there is no
// appropriate backup on the BackupStorage, Restore logs an error
// and returns ErrNoBackup. Any other error is returned.
func Restore(mysqld MysqlDaemon, bucket string, restoreConcurrency int, hookExtraEnv map[string]string) (proto.ReplicationPosition, error) {
// find the right backup handle: most recent one, with a MANIFEST
log.Infof("Restore: looking for a suitable backup to restore")
bs, err := backupstorage.GetBackupStorage()
if err != nil {
return proto.ReplicationPosition{}, err
}
bhs, err := bs.ListBackups(bucket)
if err != nil {
return proto.ReplicationPosition{}, fmt.Errorf("ListBackups failed: %v", err)
}
toRestore := len(bhs) - 1
var bh backupstorage.BackupHandle
var bm BackupManifest
for toRestore >= 0 {
bh = bhs[toRestore]
if rc, err := bh.ReadFile(backupManifest); err == nil {
dec := json.NewDecoder(rc)
err := dec.Decode(&bm)
rc.Close()
if err != nil {
log.Warningf("Possibly incomplete backup %v in bucket %v on BackupStorage (cannot JSON decode MANIFEST: %v)", bh.Name(), bucket, err)
} else {
log.Infof("Restore: found backup %v %v to restore with %v files", bh.Bucket(), bh.Name(), len(bm.FileEntries))
break
}
} else {
log.Warningf("Possibly incomplete backup %v in bucket %v on BackupStorage (cannot read MANIFEST)", bh.Name(), bucket)
}
toRestore--
}
if toRestore < 0 {
log.Errorf("No backup to restore on BackupStorage for bucket %v", bucket)
return proto.ReplicationPosition{}, ErrNoBackup
}
log.Infof("Restore: checking no existing data is present")
if err := checkNoDB(mysqld); err != nil {
return proto.ReplicationPosition{}, err
}
log.Infof("Restore: shutdown mysqld")
if err := mysqld.Shutdown(true, MysqlWaitTime); err != nil {
return proto.ReplicationPosition{}, err
}
log.Infof("Restore: copying all files")
if err := restoreFiles(mysqld.Cnf(), bh, bm.FileEntries, restoreConcurrency); err != nil {
return proto.ReplicationPosition{}, err
}
log.Infof("Restore: restart mysqld")
if err := mysqld.Start(MysqlWaitTime); err != nil {
return proto.ReplicationPosition{}, err
}
return bm.ReplicationPosition, nil
}

Просмотреть файл

@ -0,0 +1,93 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mysqlctl
import (
"io/ioutil"
"os"
"path"
"reflect"
"sort"
"testing"
)
func TestFindFilesToBackup(t *testing.T) {
root, err := ioutil.TempDir("", "backuptest")
if err != nil {
t.Fatalf("os.TempDir failed: %v", err)
}
defer os.RemoveAll(root)
// Initialize the fake mysql root directories
innodbDataDir := path.Join(root, "innodb_data")
innodbLogDir := path.Join(root, "innodb_log")
dataDir := path.Join(root, "data")
dataDbDir := path.Join(dataDir, "vt_db")
extraDir := path.Join(dataDir, "extra_dir")
outsideDbDir := path.Join(root, "outside_db")
for _, s := range []string{innodbDataDir, innodbLogDir, dataDbDir, extraDir, outsideDbDir} {
if err := os.MkdirAll(s, os.ModePerm); err != nil {
t.Fatalf("failed to create directory %v: %v", s, err)
}
}
if err := ioutil.WriteFile(path.Join(innodbDataDir, "innodb_data_1"), []byte("innodb data 1 contents"), os.ModePerm); err != nil {
t.Fatalf("failed to write file innodb_data_1: %v", err)
}
if err := ioutil.WriteFile(path.Join(innodbLogDir, "innodb_log_1"), []byte("innodb log 1 contents"), os.ModePerm); err != nil {
t.Fatalf("failed to write file innodb_log_1: %v", err)
}
if err := ioutil.WriteFile(path.Join(dataDbDir, "db.opt"), []byte("db opt file"), os.ModePerm); err != nil {
t.Fatalf("failed to write file db.opt: %v", err)
}
if err := ioutil.WriteFile(path.Join(extraDir, "extra.stuff"), []byte("extra file"), os.ModePerm); err != nil {
t.Fatalf("failed to write file extra.stuff: %v", err)
}
if err := ioutil.WriteFile(path.Join(outsideDbDir, "table1.frm"), []byte("frm file"), os.ModePerm); err != nil {
t.Fatalf("failed to write file table1.opt: %v", err)
}
if err := os.Symlink(outsideDbDir, path.Join(dataDir, "vt_symlink")); err != nil {
t.Fatalf("failed to symlink vt_symlink: %v", err)
}
cnf := &Mycnf{
InnodbDataHomeDir: innodbDataDir,
InnodbLogGroupHomeDir: innodbLogDir,
DataDir: dataDir,
}
result, err := findFilesTobackup(cnf)
if err != nil {
t.Fatalf("findFilesTobackup failed: %v", err)
}
sort.Sort(forTest(result))
t.Logf("findFilesTobackup returned: %v", result)
expected := []FileEntry{
FileEntry{
Base: "Data",
Name: "vt_db/db.opt",
},
FileEntry{
Base: "Data",
Name: "vt_symlink/table1.frm",
},
FileEntry{
Base: "InnoDBData",
Name: "innodb_data_1",
},
FileEntry{
Base: "InnoDBLog",
Name: "innodb_log_1",
},
}
if !reflect.DeepEqual(result, expected) {
t.Fatalf("got wrong list of FileEntry %v, expected %v", result, expected)
}
}
type forTest []FileEntry
func (f forTest) Len() int { return len(f) }
func (f forTest) Swap(i, j int) { f[i], f[j] = f[j], f[i] }
func (f forTest) Less(i, j int) bool { return f[i].Base+f[i].Name < f[j].Base+f[j].Name }

Просмотреть файл

@ -0,0 +1,85 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package backupstorage contains the interface and file system implementation
// of the backup system.
package backupstorage
import (
"flag"
"fmt"
"io"
)
var (
// BackupStorageImplementation is the implementation to use
// for BackupStorage. Exported for test purposes.
BackupStorageImplementation = flag.String("backup_storage_implementation", "", "which implementation to use for the backup storage feature")
)
// BackupHandle describes an individual backup.
type BackupHandle interface {
// Bucket is the location of the backup. Will contain keyspace/shard.
Bucket() string
// Name is the individual name of the backup. Will contain
// tabletAlias-timestamp.
Name() string
// AddFile opens a new file to be added to the backup.
// Only works for read-write backups (created by StartBackup).
// filename is guaranteed to only contain alphanumerical
// characters and hyphens.
// It should be thread safe, it is possible to call AddFile in
// multiple go routines once a backup has been started.
AddFile(filename string) (io.WriteCloser, error)
// EndBackup stops and closes a backup. The contents should be kept.
// Only works for read-write backups (created by StartBackup).
EndBackup() error
// AbortBackup stops a backup, and removes the contents that
// have been copied already. It is called if an error occurs
// while the backup is being taken, and the backup cannot be finished.
// Only works for read-write backups (created by StartBackup).
AbortBackup() error
// ReadFile starts reading a file from a backup.
// Only works for read-only backups (created by ListBackups).
ReadFile(filename string) (io.ReadCloser, error)
}
// BackupStorage is the interface to the storage system
type BackupStorage interface {
// ListBackups returns all the backups in a bucket. The
// returned backups are read-only (ReadFile can be called, but
// AddFile/EndBackup/AbortBackup cannot).
// The backups are string-sorted by Name(), ascending (ends up
// being the oldest backup first).
ListBackups(bucket string) ([]BackupHandle, error)
// StartBackup creates a new backup with the given name. If a
// backup with the same name already exists, it's an error.
// The returned backup is read-write
// (AddFile/EndBackup/AbortBackup cann all be called, not
// ReadFile)
StartBackup(bucket, name string) (BackupHandle, error)
// RemoveBackup removes all the data associated with a backup.
// It will not appear in ListBackups after RemoveBackup succeeds.
RemoveBackup(bucket, name string) error
}
// BackupStorageMap contains the registered implementations for BackupStorage
var BackupStorageMap = make(map[string]BackupStorage)
// GetBackupStorage returns the current BackupStorage implementation.
// Should be called after flags have been initialized.
func GetBackupStorage() (BackupStorage, error) {
bs, ok := BackupStorageMap[*BackupStorageImplementation]
if !ok {
return nil, fmt.Errorf("no registered implementation of BackupStorage")
}
return bs, nil
}

Просмотреть файл

@ -1,474 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mysqlctl
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/ioutil2"
"github.com/youtube/vitess/go/vt/hook"
"github.com/youtube/vitess/go/vt/logutil"
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
)
// These methods deal with cloning a running instance of mysql.
const (
maxLagSeconds = 5
)
const (
// slaveStartDeadline is the deadline for starting a slave
slaveStartDeadline = 30
)
const (
// SnapshotManifestFile is the file name for the snapshot manifest.
SnapshotManifestFile = "snapshot_manifest.json"
// SnapshotURLPath is the URL where to find the snapshot manifest.
SnapshotURLPath = "/snapshot"
)
// Validate that this instance is a reasonable source of data.
func (mysqld *Mysqld) validateCloneSource(serverMode bool, hookExtraEnv map[string]string) error {
// NOTE(msolomon) Removing this check for now - I don't see the value of validating this.
// // needs to be master, or slave that's not too far behind
// slaveStatus, err := mysqld.slaveStatus()
// if err != nil {
// if err != ErrNotSlave {
// return fmt.Errorf("mysqlctl: validateCloneSource failed, %v", err)
// }
// } else {
// lagSeconds, _ := strconv.Atoi(slaveStatus["seconds_behind_master"])
// if lagSeconds > maxLagSeconds {
// return fmt.Errorf("mysqlctl: validateCloneSource failed, lag_seconds exceed maximum tolerance (%v)", lagSeconds)
// }
// }
// make sure we can write locally
if err := mysqld.ValidateSnapshotPath(); err != nil {
return err
}
// run a hook to check local things
// FIXME(alainjobart) What other parameters do we have to
// provide? dbname, host, socket?
params := make([]string, 0, 1)
if serverMode {
params = append(params, "--server-mode")
}
h := hook.NewHook("preflight_snapshot", params)
h.ExtraEnv = hookExtraEnv
if err := h.ExecuteOptional(); err != nil {
return err
}
// FIXME(msolomon) check free space based on an estimate of the current
// size of the db files.
// Also, check that we aren't already cloning/compressing or acting as a
// source. Mysqld being down isn't enough, presumably that will be
// restarted as soon as the snapshot is taken.
return nil
}
// ValidateCloneTarget makes sure this mysql daemon is a valid target
// for a clone.
func (mysqld *Mysqld) ValidateCloneTarget(hookExtraEnv map[string]string) error {
// run a hook to check local things
h := hook.NewSimpleHook("preflight_restore")
h.ExtraEnv = hookExtraEnv
if err := h.ExecuteOptional(); err != nil {
return err
}
qr, err := mysqld.fetchSuperQuery("SHOW DATABASES")
if err != nil {
return fmt.Errorf("mysqlctl: ValidateCloneTarget failed, %v", err)
}
for _, row := range qr.Rows {
if strings.HasPrefix(row[0].String(), "vt_") {
dbName := row[0].String()
tableQr, err := mysqld.fetchSuperQuery("SHOW TABLES FROM " + dbName)
if err != nil {
return fmt.Errorf("mysqlctl: ValidateCloneTarget failed, %v", err)
} else if len(tableQr.Rows) == 0 {
// no tables == empty db, all is well
continue
}
return fmt.Errorf("mysqlctl: ValidateCloneTarget failed, found active db %v", dbName)
}
}
return nil
}
func findFilesToServe(srcDir, dstDir string, compress bool) ([]string, []string, error) {
fiList, err := ioutil.ReadDir(srcDir)
if err != nil {
return nil, nil, err
}
sources := make([]string, 0, len(fiList))
destinations := make([]string, 0, len(fiList))
for _, fi := range fiList {
if !fi.IsDir() {
srcPath := path.Join(srcDir, fi.Name())
var dstPath string
if compress {
dstPath = path.Join(dstDir, fi.Name()+".gz")
} else {
dstPath = path.Join(dstDir, fi.Name())
}
sources = append(sources, srcPath)
destinations = append(destinations, dstPath)
}
}
return sources, destinations, nil
}
func (mysqld *Mysqld) createSnapshot(logger logutil.Logger, concurrency int, serverMode bool) ([]SnapshotFile, error) {
sources := make([]string, 0, 128)
destinations := make([]string, 0, 128)
// clean out and start fresh
logger.Infof("removing previous snapshots: %v", mysqld.SnapshotDir)
if err := os.RemoveAll(mysqld.SnapshotDir); err != nil {
return nil, err
}
// FIXME(msolomon) innodb paths must match patterns in mycnf -
// probably belongs as a derived path.
type snapPair struct{ srcDir, dstDir string }
dps := []snapPair{
{mysqld.config.InnodbDataHomeDir, path.Join(mysqld.SnapshotDir, innodbDataSubdir)},
{mysqld.config.InnodbLogGroupHomeDir, path.Join(mysqld.SnapshotDir, innodbLogSubdir)},
}
dataDirEntries, err := ioutil.ReadDir(mysqld.config.DataDir)
if err != nil {
return nil, err
}
for _, de := range dataDirEntries {
dbDirPath := path.Join(mysqld.config.DataDir, de.Name())
// If this is not a directory, try to eval it as a syslink.
if !de.IsDir() {
dbDirPath, err = filepath.EvalSymlinks(dbDirPath)
if err != nil {
return nil, err
}
de, err = os.Stat(dbDirPath)
if err != nil {
return nil, err
}
}
if de.IsDir() {
// Copy anything that defines a db.opt file - that includes empty databases.
_, err := os.Stat(path.Join(dbDirPath, "db.opt"))
if err == nil {
dps = append(dps, snapPair{dbDirPath, path.Join(mysqld.SnapshotDir, dataDir, de.Name())})
} else {
// Look for at least one .frm file
dbDirEntries, err := ioutil.ReadDir(dbDirPath)
if err == nil {
for _, dbEntry := range dbDirEntries {
if strings.HasSuffix(dbEntry.Name(), ".frm") {
dps = append(dps, snapPair{dbDirPath, path.Join(mysqld.SnapshotDir, dataDir, de.Name())})
break
}
}
} else {
return nil, err
}
}
}
}
for _, dp := range dps {
if err := os.MkdirAll(dp.dstDir, 0775); err != nil {
return nil, err
}
if s, d, err := findFilesToServe(dp.srcDir, dp.dstDir, !serverMode); err != nil {
return nil, err
} else {
sources = append(sources, s...)
destinations = append(destinations, d...)
}
}
return newSnapshotFiles(sources, destinations, mysqld.SnapshotDir, concurrency, !serverMode)
}
// CreateSnapshot runs on the machine acting as the source for the clone.
//
// Check master/slave status and determine restore needs.
// If this instance is a slave, stop replication, otherwise place in read-only mode.
// Record replication position.
// Shutdown mysql
// Check paths for storing data
//
// Depending on the serverMode flag, we do the following:
// serverMode = false:
// Compress /vt/vt_[0-9a-f]+/data/vt_.+
// Compute hash (of compressed files, as we serve .gz files here)
// Place in /vt/clone_src where they will be served by http server (not rpc)
// Restart mysql
// serverMode = true:
// Make symlinks for /vt/vt_[0-9a-f]+/data/vt_.+ to innodb files
// Compute hash (of uncompressed files, as we serve uncompressed files)
// Place symlinks in /vt/clone_src where they will be served by http server
// Leave mysql stopped, return slaveStartRequired, readOnly
func (mysqld *Mysqld) CreateSnapshot(logger logutil.Logger, dbName, sourceAddr string, allowHierarchicalReplication bool, concurrency int, serverMode bool, hookExtraEnv map[string]string) (snapshotManifestURLPath string, slaveStartRequired, readOnly bool, err error) {
if dbName == "" {
return "", false, false, errors.New("CreateSnapshot failed: no database name provided")
}
if err = mysqld.validateCloneSource(serverMode, hookExtraEnv); err != nil {
return
}
// save initial state so we can restore on Start()
slaveStartRequired = false
sourceIsMaster := false
readOnly = true
slaveStatus, err := mysqld.SlaveStatus()
if err == nil {
slaveStartRequired = slaveStatus.SlaveRunning()
} else if err == ErrNotSlave {
sourceIsMaster = true
} else {
// If we can't get any data, just fail.
return
}
readOnly, err = mysqld.IsReadOnly()
if err != nil {
return
}
// Stop sources of writes so we can get a consistent replication position.
// If the source is a slave use the master replication position
// unless we are allowing hierarchical replicas.
masterAddr := ""
var replicationPosition proto.ReplicationPosition
if sourceIsMaster {
if err = mysqld.SetReadOnly(true); err != nil {
return
}
replicationPosition, err = mysqld.MasterPosition()
if err != nil {
return
}
masterAddr = mysqld.IPAddr()
} else {
if err = mysqld.StopSlave(hookExtraEnv); err != nil {
return
}
var slaveStatus *proto.ReplicationStatus
slaveStatus, err = mysqld.SlaveStatus()
if err != nil {
return
}
replicationPosition = slaveStatus.Position
// We are a slave, check our replication strategy before
// choosing the master address.
if allowHierarchicalReplication {
masterAddr = mysqld.IPAddr()
} else {
masterAddr, err = mysqld.GetMasterAddr()
if err != nil {
return
}
}
}
if err = mysqld.Shutdown(true, MysqlWaitTime); err != nil {
return
}
var smFile string
dataFiles, snapshotErr := mysqld.createSnapshot(logger, concurrency, serverMode)
if snapshotErr != nil {
logger.Errorf("CreateSnapshot failed: %v", snapshotErr)
} else {
var sm *SnapshotManifest
sm, snapshotErr = newSnapshotManifest(sourceAddr, mysqld.IPAddr(),
masterAddr, dbName, dataFiles, replicationPosition, proto.ReplicationPosition{})
if snapshotErr != nil {
logger.Errorf("CreateSnapshot failed: %v", snapshotErr)
} else {
smFile = path.Join(mysqld.SnapshotDir, SnapshotManifestFile)
if snapshotErr = writeJSON(smFile, sm); snapshotErr != nil {
logger.Errorf("CreateSnapshot failed: %v", snapshotErr)
}
}
}
// restore our state if required
if serverMode && snapshotErr == nil {
logger.Infof("server mode snapshot worked, not restarting mysql")
} else {
if err = mysqld.SnapshotSourceEnd(slaveStartRequired, readOnly, false /*deleteSnapshot*/, hookExtraEnv); err != nil {
return
}
}
if snapshotErr != nil {
return "", slaveStartRequired, readOnly, snapshotErr
}
relative, err := filepath.Rel(mysqld.SnapshotDir, smFile)
if err != nil {
return "", slaveStartRequired, readOnly, nil
}
return path.Join(SnapshotURLPath, relative), slaveStartRequired, readOnly, nil
}
// SnapshotSourceEnd removes the current snapshot, and restarts mysqld.
func (mysqld *Mysqld) SnapshotSourceEnd(slaveStartRequired, readOnly, deleteSnapshot bool, hookExtraEnv map[string]string) error {
if deleteSnapshot {
// clean out our files
log.Infof("removing snapshot links: %v", mysqld.SnapshotDir)
if err := os.RemoveAll(mysqld.SnapshotDir); err != nil {
log.Warningf("failed to remove old snapshot: %v", err)
return err
}
}
// Try to restart mysqld
if err := mysqld.Start(MysqlWaitTime); err != nil {
return err
}
// Restore original mysqld state that we saved above.
if slaveStartRequired {
if err := mysqld.StartSlave(hookExtraEnv); err != nil {
return err
}
// this should be quick, but we might as well just wait
if err := mysqld.WaitForSlaveStart(slaveStartDeadline); err != nil {
return err
}
}
// And set read-only mode
if err := mysqld.SetReadOnly(readOnly); err != nil {
return err
}
return nil
}
func writeJSON(filename string, x interface{}) error {
data, err := json.MarshalIndent(x, " ", " ")
if err != nil {
return err
}
return ioutil2.WriteFileAtomic(filename, data, 0660)
}
// ReadSnapshotManifest reads and unpacks a SnapshotManifest
func ReadSnapshotManifest(filename string) (*SnapshotManifest, error) {
data, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
sm := new(SnapshotManifest)
if err = json.Unmarshal(data, sm); err != nil {
return nil, fmt.Errorf("ReadSnapshotManifest failed: %v %v", filename, err)
}
return sm, nil
}
// RestoreFromSnapshot runs on the presumably empty machine acting as
// the target in the create replica action.
//
// validate target (self)
// shutdown_mysql()
// create temp data directory /vt/target/vt_<keyspace>
// copy compressed data files via HTTP
// verify hash of compressed files
// uncompress into /vt/vt_<target-uid>/data/vt_<keyspace>
// start_mysql()
// clean up compressed files
func (mysqld *Mysqld) RestoreFromSnapshot(logger logutil.Logger, snapshotManifest *SnapshotManifest, fetchConcurrency, fetchRetryCount int, dontWaitForSlaveStart bool, hookExtraEnv map[string]string) error {
if snapshotManifest == nil {
return errors.New("RestoreFromSnapshot: nil snapshotManifest")
}
logger.Infof("ValidateCloneTarget")
if err := mysqld.ValidateCloneTarget(hookExtraEnv); err != nil {
return err
}
logger.Infof("Shutdown mysqld")
if err := mysqld.Shutdown(true, MysqlWaitTime); err != nil {
return err
}
logger.Infof("Fetch snapshot")
if err := mysqld.fetchSnapshot(snapshotManifest, fetchConcurrency, fetchRetryCount); err != nil {
return err
}
logger.Infof("Restart mysqld")
if err := mysqld.Start(MysqlWaitTime); err != nil {
return err
}
cmdList, err := mysqld.StartReplicationCommands(snapshotManifest.ReplicationStatus)
if err != nil {
return err
}
if err := mysqld.ExecuteSuperQueryList(cmdList); err != nil {
return err
}
if !dontWaitForSlaveStart {
if err := mysqld.WaitForSlaveStart(slaveStartDeadline); err != nil {
return err
}
}
h := hook.NewSimpleHook("postflight_restore")
h.ExtraEnv = hookExtraEnv
if err := h.ExecuteOptional(); err != nil {
return err
}
return nil
}
func (mysqld *Mysqld) fetchSnapshot(snapshotManifest *SnapshotManifest, fetchConcurrency, fetchRetryCount int) error {
replicaDbPath := path.Join(mysqld.config.DataDir, snapshotManifest.DbName)
cleanDirs := []string{mysqld.SnapshotDir, replicaDbPath,
mysqld.config.InnodbDataHomeDir, mysqld.config.InnodbLogGroupHomeDir}
// clean out and start fresh
// FIXME(msolomon) this might be changed to allow partial recovery, but at that point
// we are starting to reimplement rsync.
for _, dir := range cleanDirs {
if err := os.RemoveAll(dir); err != nil {
return err
}
if err := os.MkdirAll(dir, 0775); err != nil {
return err
}
}
return fetchFiles(snapshotManifest, mysqld.TabletDir, fetchConcurrency, fetchRetryCount)
}

Просмотреть файл

@ -0,0 +1,141 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package filebackupstorage implements the BacksupStorage interface
// for a local filesystem (which can be an NFS mount).
package filebackupstorage
import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"github.com/youtube/vitess/go/vt/mysqlctl/backupstorage"
)
var (
// FileBackupStorageRoot is where the backups will go.
// Exported for test purposes.
FileBackupStorageRoot = flag.String("file_backup_storage_root", "", "root directory for the file backup storage")
)
// FileBackupHandle implements BackupHandle for local file system.
type FileBackupHandle struct {
fbs *FileBackupStorage
bucket string
name string
readOnly bool
}
// Bucket is part of the BackupHandle interface
func (fbh *FileBackupHandle) Bucket() string {
return fbh.bucket
}
// Name is part of the BackupHandle interface
func (fbh *FileBackupHandle) Name() string {
return fbh.name
}
// AddFile is part of the BackupHandle interface
func (fbh *FileBackupHandle) AddFile(filename string) (io.WriteCloser, error) {
if fbh.readOnly {
return nil, fmt.Errorf("AddFile cannot be called on read-only backup")
}
p := path.Join(*FileBackupStorageRoot, fbh.bucket, fbh.name, filename)
return os.Create(p)
}
// EndBackup is part of the BackupHandle interface
func (fbh *FileBackupHandle) EndBackup() error {
if fbh.readOnly {
return fmt.Errorf("EndBackup cannot be called on read-only backup")
}
return nil
}
// AbortBackup is part of the BackupHandle interface
func (fbh *FileBackupHandle) AbortBackup() error {
if fbh.readOnly {
return fmt.Errorf("AbortBackup cannot be called on read-only backup")
}
return fbh.fbs.RemoveBackup(fbh.bucket, fbh.name)
}
// ReadFile is part of the BackupHandle interface
func (fbh *FileBackupHandle) ReadFile(filename string) (io.ReadCloser, error) {
if !fbh.readOnly {
return nil, fmt.Errorf("ReadFile cannot be called on read-write backup")
}
p := path.Join(*FileBackupStorageRoot, fbh.bucket, fbh.name, filename)
return os.Open(p)
}
// FileBackupStorage implements BackupStorage for local file system.
type FileBackupStorage struct{}
// ListBackups is part of the BackupStorage interface
func (fbs *FileBackupStorage) ListBackups(bucket string) ([]backupstorage.BackupHandle, error) {
// ReadDir already sorts the results
p := path.Join(*FileBackupStorageRoot, bucket)
fi, err := ioutil.ReadDir(p)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
result := make([]backupstorage.BackupHandle, 0, len(fi))
for _, info := range fi {
if !info.IsDir() {
continue
}
if info.Name() == "." || info.Name() == ".." {
continue
}
result = append(result, &FileBackupHandle{
fbs: fbs,
bucket: bucket,
name: info.Name(),
readOnly: true,
})
}
return result, nil
}
// StartBackup is part of the BackupStorage interface
func (fbs *FileBackupStorage) StartBackup(bucket, name string) (backupstorage.BackupHandle, error) {
// make sure the bucket directory exists
p := path.Join(*FileBackupStorageRoot, bucket)
if err := os.MkdirAll(p, os.ModePerm); err != nil {
return nil, err
}
// creates the backup directory
p = path.Join(p, name)
if err := os.Mkdir(p, os.ModePerm); err != nil {
return nil, err
}
return &FileBackupHandle{
fbs: fbs,
bucket: bucket,
name: name,
readOnly: false,
}, nil
}
// RemoveBackup is part of the BackupStorage interface
func (fbs *FileBackupStorage) RemoveBackup(bucket, name string) error {
p := path.Join(*FileBackupStorageRoot, bucket, name)
return os.RemoveAll(p)
}
func init() {
backupstorage.BackupStorageMap["file"] = &FileBackupStorage{}
}

Просмотреть файл

@ -0,0 +1,190 @@
// Copyright 2015, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filebackupstorage
import (
"io"
"io/ioutil"
"os"
"testing"
)
// This file tests the file BackupStorage engine.
// Note this is a very generic test for BackupStorage implementations,
// we test the interface only. But making it a generic test library is
// more cumbersome, we'll do that when we have an actual need for
// another BackupStorage implementation.
// setupFileBackupStorage creates a temporary directory, and
// returns a FileBackupStorage based on it
func setupFileBackupStorage(t *testing.T) *FileBackupStorage {
root, err := ioutil.TempDir("", "fbstest")
if err != nil {
t.Fatalf("os.TempDir failed: %v", err)
}
*FileBackupStorageRoot = root
return &FileBackupStorage{}
}
// cleanupFileBackupStorage removes the entire directory
func cleanupFileBackupStorage(fbs *FileBackupStorage) {
os.RemoveAll(*FileBackupStorageRoot)
}
func TestListBackups(t *testing.T) {
fbs := setupFileBackupStorage(t)
defer cleanupFileBackupStorage(fbs)
// verify we have no entry now
bucket := "keyspace/shard"
bhs, err := fbs.ListBackups(bucket)
if err != nil {
t.Fatalf("ListBackups on empty fbs failed: %v", err)
}
if len(bhs) != 0 {
t.Fatalf("ListBackups on empty fbs returned results: %#v", bhs)
}
// add one empty backup
firstBackup := "cell-0001-2015-01-14-10-00-00"
bh, err := fbs.StartBackup(bucket, firstBackup)
if err != nil {
t.Fatalf("fbs.StartBackup failed: %v", err)
}
if err := bh.EndBackup(); err != nil {
t.Fatalf("bh.EndBackup failed: %v", err)
}
// verify we have one entry now
bhs, err = fbs.ListBackups(bucket)
if err != nil {
t.Fatalf("ListBackups on empty fbs failed: %v", err)
}
if len(bhs) != 1 ||
bhs[0].Bucket() != bucket ||
bhs[0].Name() != firstBackup {
t.Fatalf("ListBackups with one backup returned wrong results: %#v", bhs)
}
// add another one, with earlier date
secondBackup := "cell-0001-2015-01-12-10-00-00"
bh, err = fbs.StartBackup(bucket, secondBackup)
if err != nil {
t.Fatalf("fbs.StartBackup failed: %v", err)
}
if err := bh.EndBackup(); err != nil {
t.Fatalf("bh.EndBackup failed: %v", err)
}
// verify we have two sorted entries now
bhs, err = fbs.ListBackups(bucket)
if err != nil {
t.Fatalf("ListBackups on empty fbs failed: %v", err)
}
if len(bhs) != 2 ||
bhs[0].Bucket() != bucket ||
bhs[0].Name() != secondBackup ||
bhs[1].Bucket() != bucket ||
bhs[1].Name() != firstBackup {
t.Fatalf("ListBackups with two backups returned wrong results: %#v", bhs)
}
// remove a backup, back to one
if err := fbs.RemoveBackup(bucket, secondBackup); err != nil {
t.Fatalf("RemoveBackup failed: %v", err)
}
bhs, err = fbs.ListBackups(bucket)
if err != nil {
t.Fatalf("ListBackups after deletion failed: %v", err)
}
if len(bhs) != 1 ||
bhs[0].Bucket() != bucket ||
bhs[0].Name() != firstBackup {
t.Fatalf("ListBackups after deletion returned wrong results: %#v", bhs)
}
// add a backup but abort it, should stay at one
bh, err = fbs.StartBackup(bucket, secondBackup)
if err != nil {
t.Fatalf("fbs.StartBackup failed: %v", err)
}
if err := bh.AbortBackup(); err != nil {
t.Fatalf("bh.AbortBackup failed: %v", err)
}
bhs, err = fbs.ListBackups(bucket)
if err != nil {
t.Fatalf("ListBackups after abort failed: %v", err)
}
if len(bhs) != 1 ||
bhs[0].Bucket() != bucket ||
bhs[0].Name() != firstBackup {
t.Fatalf("ListBackups after abort returned wrong results: %#v", bhs)
}
// check we cannot chaneg a backup we listed
if _, err := bhs[0].AddFile("test"); err == nil {
t.Fatalf("was able to AddFile to read-only backup")
}
if err := bhs[0].EndBackup(); err == nil {
t.Fatalf("was able to EndBackup a read-only backup")
}
if err := bhs[0].AbortBackup(); err == nil {
t.Fatalf("was able to AbortBackup a read-only backup")
}
}
func TestFileContents(t *testing.T) {
fbs := setupFileBackupStorage(t)
defer cleanupFileBackupStorage(fbs)
bucket := "keyspace/shard"
name := "cell-0001-2015-01-14-10-00-00"
filename1 := "file1"
contents1 := "contents of the first file"
// start a backup, add a file
bh, err := fbs.StartBackup(bucket, name)
if err != nil {
t.Fatalf("fbs.StartBackup failed: %v", err)
}
wc, err := bh.AddFile(filename1)
if err != nil {
t.Fatalf("bh.AddFile failed: %v", err)
}
if _, err := wc.Write([]byte(contents1)); err != nil {
t.Fatalf("wc.Write failed: %v", err)
}
if err := wc.Close(); err != nil {
t.Fatalf("wc.Close failed: %v", err)
}
// test we can't read back on read-write backup
if _, err := bh.ReadFile(filename1); err == nil {
t.Fatalf("was able to ReadFile to read-write backup")
}
// and close
if err := bh.EndBackup(); err != nil {
t.Fatalf("bh.EndBackup failed: %v", err)
}
// re-read the file
bhs, err := fbs.ListBackups(bucket)
if err != nil || len(bhs) != 1 {
t.Fatalf("ListBackups after abort returned wrong return: %v %v", err, bhs)
}
rc, err := bhs[0].ReadFile(filename1)
if err != nil {
t.Fatalf("bhs[0].ReadFile failed: %v", err)
}
buf := make([]byte, len(contents1)+10)
if n, err := rc.Read(buf); (err != nil && err != io.EOF) || n != len(contents1) {
t.Fatalf("rc.Read returned wrong result: %v %#v", err, n)
}
if err := rc.Close(); err != nil {
t.Fatalf("rc.Close failed: %v", err)
}
}

Просмотреть файл

@ -5,25 +5,13 @@
package mysqlctl
import (
"bufio"
// "crypto/md5"
"encoding/hex"
"fmt"
"hash"
// "hash/crc64"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"sort"
"strings"
"sync"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/cgzip"
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
)
// Use this to simulate failures in tests
@ -75,452 +63,3 @@ func newHasher() *hasher {
func (h *hasher) HashString() string {
return hex.EncodeToString(h.Sum(nil))
}
// SnapshotFile describes a file to serve.
// 'Path' is the path component of the URL. SnapshotManifest.Addr is
// the host+port component of the URL.
// If path ends in '.gz', it is compressed.
// Size and Hash are computed on the Path itself
// if TableName is set, this file belongs to that table
type SnapshotFile struct {
Path string
Size int64
Hash string
TableName string
}
type SnapshotFiles []SnapshotFile
// sort.Interface
// we sort by descending file size
func (s SnapshotFiles) Len() int { return len(s) }
func (s SnapshotFiles) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s SnapshotFiles) Less(i, j int) bool { return s[i].Size > s[j].Size }
// This function returns the local file used to store the SnapshotFile,
// relative to the basePath.
// for instance, if the source path is something like:
// /vt/snapshot/vt_0000062344/data/vt_snapshot_test-MA,Mw/vt_insert_test.csv.gz
// we will get everything starting with 'data/...', append it to basepath,
// and remove the .gz extension. So with basePath=myPath, it will return:
// myPath/data/vt_snapshot_test-MA,Mw/vt_insert_test.csv
func (dataFile *SnapshotFile) getLocalFilename(basePath string) string {
filename := path.Join(basePath, dataFile.Path)
// trim compression extension if present
if strings.HasSuffix(filename, ".gz") {
filename = filename[:len(filename)-3]
}
return filename
}
// newSnapshotFile behavior depends on the compress flag:
// - if compress is true , it compresses a single file with gzip, and
// computes the hash on the compressed version.
// - if compress is false, just symlinks and computes the hash on the file
// The source file is always left intact.
// The path of the returned SnapshotFile will be relative
// to root.
func newSnapshotFile(srcPath, dstPath, root string, compress bool) (*SnapshotFile, error) {
// open the source file
srcFile, err := os.OpenFile(srcPath, os.O_RDONLY, 0)
if err != nil {
return nil, err
}
defer srcFile.Close()
src := bufio.NewReaderSize(srcFile, 2*1024*1024)
var hash string
var size int64
if compress {
log.Infof("newSnapshotFile: starting to compress %v into %v", srcPath, dstPath)
// open the temporary destination file
dir, filePrefix := path.Split(dstPath)
dstFile, err := ioutil.TempFile(dir, filePrefix)
if err != nil {
return nil, err
}
defer func() {
// try to close and delete the file. in the
// success case, the file will already be
// closed and renamed, so all of this would
// fail anyway, no biggie
dstFile.Close()
os.Remove(dstFile.Name())
}()
dst := bufio.NewWriterSize(dstFile, 2*1024*1024)
// create the hasher and the tee on top
hasher := newHasher()
tee := io.MultiWriter(dst, hasher)
// create the gzip compression filter
gzip, err := cgzip.NewWriterLevel(tee, cgzip.Z_BEST_SPEED)
if err != nil {
return nil, err
}
// copy from the file to gzip to tee to output file and hasher
_, err = io.Copy(gzip, src)
if err != nil {
return nil, err
}
// close gzip to flush it
if err = gzip.Close(); err != nil {
return nil, err
}
// close dst manually to flush all buffers to disk
dst.Flush()
dstFile.Close()
hash = hasher.HashString()
// atomically move completed compressed file
err = os.Rename(dstFile.Name(), dstPath)
if err != nil {
return nil, err
}
// and get its size
fi, err := os.Stat(dstPath)
if err != nil {
return nil, err
}
size = fi.Size()
} else {
log.Infof("newSnapshotFile: starting to hash and symlinking %v to %v", srcPath, dstPath)
// get the hash
hasher := newHasher()
_, err = io.Copy(hasher, src)
if err != nil {
return nil, err
}
hash = hasher.HashString()
// do the symlink
err = os.Symlink(srcPath, dstPath)
if err != nil {
return nil, err
}
// and get the size
fi, err := os.Stat(srcPath)
if err != nil {
return nil, err
}
size = fi.Size()
}
log.Infof("clone data ready %v:%v", dstPath, hash)
relativeDst, err := filepath.Rel(root, dstPath)
if err != nil {
return nil, err
}
return &SnapshotFile{relativeDst, size, hash, ""}, nil
}
// newSnapshotFiles processes multiple files in parallel. The Paths of
// the returned SnapshotFiles will be relative to root.
// - if compress is true, we compress the files and compute the hash on
// the compressed version.
// - if compress is false, we symlink the files, and compute the hash on
// the original version.
func newSnapshotFiles(sources, destinations []string, root string, concurrency int, compress bool) ([]SnapshotFile, error) {
if len(sources) != len(destinations) || len(sources) == 0 {
return nil, fmt.Errorf("programming error: bad array lengths: %v %v", len(sources), len(destinations))
}
workQueue := make(chan int, len(sources))
for i := 0; i < len(sources); i++ {
workQueue <- i
}
close(workQueue)
snapshotFiles := make([]SnapshotFile, len(sources))
resultQueue := make(chan error, len(sources))
for i := 0; i < concurrency; i++ {
go func() {
for i := range workQueue {
sf, err := newSnapshotFile(sources[i], destinations[i], root, compress)
if err == nil {
snapshotFiles[i] = *sf
}
resultQueue <- err
}
}()
}
var err error
for i := 0; i < len(sources); i++ {
if compressErr := <-resultQueue; compressErr != nil {
err = compressErr
}
}
// clean up files if we had an error
// FIXME(alainjobart) it seems extreme to delete all files if
// the last one failed. Since we only move the file into
// its destination when it worked, we could assume if the file
// already exists it's good, and re-compute its hash.
if err != nil {
log.Infof("Error happened, deleting all the files we already compressed")
for _, dest := range destinations {
os.Remove(dest)
}
return nil, err
}
return snapshotFiles, nil
}
// a SnapshotManifest describes multiple SnapshotFiles and where
// to get them from.
type SnapshotManifest struct {
Addr string // this is the address of the tabletserver, not mysql
DbName string
Files SnapshotFiles
ReplicationStatus *proto.ReplicationStatus
MasterPosition proto.ReplicationPosition
}
func newSnapshotManifest(addr, mysqlAddr, masterAddr, dbName string, files []SnapshotFile, pos, masterPos proto.ReplicationPosition) (*SnapshotManifest, error) {
nrs, err := proto.NewReplicationStatus(masterAddr)
if err != nil {
return nil, err
}
rs := &SnapshotManifest{
Addr: addr,
DbName: dbName,
Files: files,
ReplicationStatus: nrs,
MasterPosition: masterPos,
}
sort.Sort(rs.Files)
rs.ReplicationStatus.Position = pos
return rs, nil
}
// fetchFile fetches data from the web server. It then sends it to a
// tee, which on one side has an hash checksum reader, and on the other
// a gunzip reader writing to a file. It will compare the hash
// checksum after the copy is done.
func fetchFile(srcUrl, srcHash, dstFilename string) error {
log.Infof("fetchFile: starting to fetch %v from %v", dstFilename, srcUrl)
// open the URL
req, err := http.NewRequest("GET", srcUrl, nil)
if err != nil {
return fmt.Errorf("NewRequest failed for %v: %v", srcUrl, err)
}
// we set the 'gzip' encoding ourselves so the library doesn't
// do it for us and ends up using go gzip (we want to use our own
// cgzip which is much faster)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
if resp.StatusCode != 200 {
return fmt.Errorf("failed fetching %v: %v", srcUrl, resp.Status)
}
defer resp.Body.Close()
// see if we need some uncompression
var reader io.Reader = resp.Body
ce := resp.Header.Get("Content-Encoding")
if ce != "" {
if ce == "gzip" {
gz, err := cgzip.NewReader(reader)
if err != nil {
return err
}
defer gz.Close()
reader = gz
} else {
return fmt.Errorf("unsupported Content-Encoding: %v", ce)
}
}
return uncompressAndCheck(reader, srcHash, dstFilename, strings.HasSuffix(srcUrl, ".gz"))
}
// uncompressAndCheck uses the provided reader to read data, and then
// sends it to a tee, which on one side has an hash checksum reader,
// and on the other a gunzip reader writing to a file. It will
// compare the hash checksum after the copy is done.
func uncompressAndCheck(reader io.Reader, srcHash, dstFilename string, needsUncompress bool) error {
// create destination directory
dir, filePrefix := path.Split(dstFilename)
if dirErr := os.MkdirAll(dir, 0775); dirErr != nil {
return dirErr
}
// create a temporary file to uncompress to
dstFile, err := ioutil.TempFile(dir, filePrefix)
if err != nil {
return err
}
defer func() {
// try to close and delete the file.
// in the success case, the file will already be closed
// and renamed, so all of this would fail anyway, no biggie
dstFile.Close()
os.Remove(dstFile.Name())
}()
// create a buffering output
dst := bufio.NewWriterSize(dstFile, 2*1024*1024)
// create hash to write the compressed data to
hasher := newHasher()
// create a Tee: we split the HTTP input into the hasher
// and into the gunziper
tee := io.TeeReader(reader, hasher)
// create the uncompresser
var decompressor io.Reader
if needsUncompress {
gz, err := cgzip.NewReader(tee)
if err != nil {
return err
}
defer gz.Close()
decompressor = gz
} else {
decompressor = tee
}
// see if we need to introduce failures
if simulateFailures {
failureCounter++
if failureCounter%10 == 0 {
return fmt.Errorf("Simulated error")
}
}
// copy the data. Will also write to the hasher
if _, err = io.Copy(dst, decompressor); err != nil {
return err
}
// check the hash
hash := hasher.HashString()
if srcHash != hash {
return fmt.Errorf("hash mismatch for %v, %v != %v", dstFilename, srcHash, hash)
}
// we're good
log.Infof("processed snapshot file: %v", dstFilename)
dst.Flush()
dstFile.Close()
// atomically move uncompressed file
if err := os.Chmod(dstFile.Name(), 0664); err != nil {
return err
}
return os.Rename(dstFile.Name(), dstFilename)
}
// fetchFileWithRetry fetches data from the web server, retrying a few
// times.
func fetchFileWithRetry(srcUrl, srcHash, dstFilename string, fetchRetryCount int) (err error) {
for i := 0; i < fetchRetryCount; i++ {
err = fetchFile(srcUrl, srcHash, dstFilename)
if err == nil {
return nil
}
log.Warningf("fetching snapshot file %v failed (try=%v): %v", dstFilename, i, err)
}
log.Errorf("fetching snapshot file %v failed too many times", dstFilename)
return err
}
// FIXME(msolomon) Should we add deadlines? What really matters more
// than a deadline is probably a sense of progress, more like a
// "progress timeout" - how long will we wait if there is no change in
// received bytes.
// FIXME(alainjobart) support fetching files in chunks: create a new
// struct fileChunk {
// snapshotFile *SnapshotFile
// relatedChunks []*fileChunk
// start,end uint64
// observedCrc32 uint32
// }
// Create a slice of fileChunk objects, populate it:
// For files smaller than <threshold>, create one fileChunk
// For files bigger than <threshold>, create N fileChunks
// (the first one has the list of all the others)
// Fetch them all:
// - change the workqueue to have indexes on the fileChunk slice
// - compute the crc32 while fetching, but don't compare right away
// Collect results the same way, write observedCrc32 in the fileChunk
// For each fileChunk, compare checksum:
// - if single file, compare snapshotFile.hash with observedCrc32
// - if multiple chunks and first chunk, merge observedCrc32, and compare
func fetchFiles(snapshotManifest *SnapshotManifest, destinationPath string, fetchConcurrency, fetchRetryCount int) (err error) {
// create a workQueue, a resultQueue, and the go routines
// to process entries out of workQueue into resultQueue
// the mutex protects the error response
workQueue := make(chan SnapshotFile, len(snapshotManifest.Files))
resultQueue := make(chan error, len(snapshotManifest.Files))
mutex := sync.Mutex{}
for i := 0; i < fetchConcurrency; i++ {
go func() {
for sf := range workQueue {
// if someone else errored out, we skip our job
mutex.Lock()
previousError := err
mutex.Unlock()
if previousError != nil {
resultQueue <- previousError
continue
}
// do our fetch, save the error
filename := sf.getLocalFilename(destinationPath)
furl := "http://" + snapshotManifest.Addr + path.Join(SnapshotURLPath, sf.Path)
fetchErr := fetchFileWithRetry(furl, sf.Hash, filename, fetchRetryCount)
if fetchErr != nil {
mutex.Lock()
err = fetchErr
mutex.Unlock()
}
resultQueue <- fetchErr
}
}()
}
// add the jobs (writing on the channel will block if the queue
// is full, no big deal)
jobCount := 0
for _, fi := range snapshotManifest.Files {
workQueue <- fi
jobCount++
}
close(workQueue)
// read the responses (we guarantee one response per job)
for i := 0; i < jobCount; i++ {
<-resultQueue
}
// clean up files if we had an error
// FIXME(alainjobart) it seems extreme to delete all files if
// the last one failed. Maybe we shouldn't, and if a file already
// exists, we hash it before retransmitting.
if err != nil {
log.Infof("Error happened, deleting all the files we already got")
for _, fi := range snapshotManifest.Files {
filename := fi.getLocalFilename(destinationPath)
os.Remove(filename)
}
}
return err
}

Просмотреть файл

@ -6,7 +6,6 @@ import (
"time"
"github.com/youtube/vitess/go/vt/health"
"github.com/youtube/vitess/go/vt/topo"
)
// mysqlReplicationLag implements health.Reporter
@ -15,8 +14,8 @@ type mysqlReplicationLag struct {
}
// Report is part of the health.Reporter interface
func (mrl *mysqlReplicationLag) Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
if !topo.IsSlaveType(tabletType) {
func (mrl *mysqlReplicationLag) Report(isSlaveType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
if !isSlaveType {
return 0, nil
}

Просмотреть файл

@ -10,6 +10,7 @@ import (
"strings"
"time"
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/sqldb"
"github.com/youtube/vitess/go/stats"
"github.com/youtube/vitess/go/vt/dbconnpool"
@ -19,21 +20,23 @@ import (
// MysqlDaemon is the interface we use for abstracting Mysqld.
type MysqlDaemon interface {
// GetMasterAddr returns the mysql master address, as shown by
// 'show slave status'.
GetMasterAddr() (string, error)
// Cnf returns the underlying mycnf
Cnf() *Mycnf
// methods related to mysql running or not
Start(mysqlWaitTime time.Duration) error
Shutdown(waitForMysqld bool, mysqlWaitTime time.Duration) error
// GetMysqlPort returns the current port mysql is listening on.
GetMysqlPort() (int, error)
// replication related methods
StartSlave(hookExtraEnv map[string]string) error
StopSlave(hookExtraEnv map[string]string) error
SlaveStatus() (*proto.ReplicationStatus, error)
SlaveStatus() (proto.ReplicationStatus, error)
// reparenting related methods
ResetReplicationCommands() ([]string, error)
MasterPosition() (proto.ReplicationPosition, error)
IsReadOnly() (bool, error)
SetReadOnly(on bool) error
StartReplicationCommands(status *proto.ReplicationStatus) ([]string, error)
SetMasterCommands(masterHost string, masterPort int) ([]string, error)
@ -52,21 +55,42 @@ type MysqlDaemon interface {
// Schema related methods
GetSchema(dbName string, tables, excludeTables []string, includeViews bool) (*proto.SchemaDefinition, error)
PreflightSchemaChange(dbName string, change string) (*proto.SchemaChangeResult, error)
ApplySchemaChange(dbName string, change *proto.SchemaChange) (*proto.SchemaChangeResult, error)
// GetAppConnection returns a app connection to be able to talk to the database.
GetAppConnection() (dbconnpool.PoolConnection, error)
// GetDbaConnection returns a dba connection.
GetDbaConnection() (*dbconnpool.DBConnection, error)
// query execution methods
// ExecuteSuperQueryList executes a list of queries, no result
ExecuteSuperQueryList(queryList []string) error
// FetchSuperQuery executes one query, returns the result
FetchSuperQuery(query string) (*mproto.QueryResult, error)
// NewSlaveConnection returns a SlaveConnection to the database.
NewSlaveConnection() (*SlaveConnection, error)
// EnableBinlogPlayback enables playback of binlog events
EnableBinlogPlayback() error
// DisableBinlogPlayback disable playback of binlog events
DisableBinlogPlayback() error
// Close will close this instance of Mysqld. It will wait for all dba
// queries to be finished.
Close()
}
// FakeMysqlDaemon implements MysqlDaemon and allows the user to fake
// everything.
type FakeMysqlDaemon struct {
// MasterAddr will be returned by GetMasterAddr(). Set to "" to return
// ErrNotSlave, or to "ERROR" to return an error.
MasterAddr string
// Mycnf will be returned by Cnf()
Mycnf *Mycnf
// Running is used by Start / Shutdown
Running bool
// MysqlPort will be returned by GetMysqlPort(). Set to -1 to
// return an error.
@ -77,9 +101,6 @@ type FakeMysqlDaemon struct {
// test owner responsability to have these two match)
Replicating bool
// CurrentSlaveStatus is returned by SlaveStatus
CurrentSlaveStatus *proto.ReplicationStatus
// ResetReplicationResult is returned by ResetReplication
ResetReplicationResult []string
@ -87,8 +108,15 @@ type FakeMysqlDaemon struct {
ResetReplicationError error
// CurrentMasterPosition is returned by MasterPosition
// and SlaveStatus
CurrentMasterPosition proto.ReplicationPosition
// CurrentMasterHost is returned by SlaveStatus
CurrentMasterHost string
// CurrentMasterport is returned by SlaveStatus
CurrentMasterPort int
// ReadOnly is the current value of the flag
ReadOnly bool
@ -120,12 +148,17 @@ type FakeMysqlDaemon struct {
// PromoteSlaveResult is returned by PromoteSlave
PromoteSlaveResult proto.ReplicationPosition
// Schema that will be returned by GetSchema. If nil we'll
// Schema will be returned by GetSchema. If nil we'll
// return an error.
Schema *proto.SchemaDefinition
// DbaConnectionFactory is the factory for making fake dba connections
DbaConnectionFactory func() (dbconnpool.PoolConnection, error)
// PreflightSchemaChangeResult will be returned by PreflightSchemaChange.
// If nil we'll return an error.
PreflightSchemaChangeResult *proto.SchemaChangeResult
// ApplySchemaChangeResult will be returned by ApplySchemaChange.
// If nil we'll return an error.
ApplySchemaChangeResult *proto.SchemaChangeResult
// DbAppConnectionFactory is the factory for making fake db app connections
DbAppConnectionFactory func() (dbconnpool.PoolConnection, error)
@ -141,17 +174,43 @@ type FakeMysqlDaemon struct {
// ExpectedExecuteSuperQueryCurrent is the current index of the queries
// we expect
ExpectedExecuteSuperQueryCurrent int
// FetchSuperQueryResults is used by FetchSuperQuery
FetchSuperQueryMap map[string]*mproto.QueryResult
// BinlogPlayerEnabled is used by {Enable,Disable}BinlogPlayer
BinlogPlayerEnabled bool
}
// GetMasterAddr is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) GetMasterAddr() (string, error) {
if fmd.MasterAddr == "" {
return "", ErrNotSlave
// NewFakeMysqlDaemon returns a FakeMysqlDaemon where mysqld appears
// to be running
func NewFakeMysqlDaemon() *FakeMysqlDaemon {
return &FakeMysqlDaemon{
Running: true,
}
if fmd.MasterAddr == "ERROR" {
return "", fmt.Errorf("FakeMysqlDaemon.GetMasterAddr returns an error")
}
// Cnf is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) Cnf() *Mycnf {
return fmd.Mycnf
}
// Start is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) Start(mysqlWaitTime time.Duration) error {
if fmd.Running {
return fmt.Errorf("fake mysql daemon already running")
}
return fmd.MasterAddr, nil
fmd.Running = true
return nil
}
// Shutdown is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) Shutdown(waitForMysqld bool, mysqlWaitTime time.Duration) error {
if !fmd.Running {
return fmt.Errorf("fake mysql daemon not running")
}
fmd.Running = false
return nil
}
// GetMysqlPort is part of the MysqlDaemon interface
@ -162,24 +221,15 @@ func (fmd *FakeMysqlDaemon) GetMysqlPort() (int, error) {
return fmd.MysqlPort, nil
}
// StartSlave is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) StartSlave(hookExtraEnv map[string]string) error {
fmd.Replicating = true
return nil
}
// StopSlave is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) StopSlave(hookExtraEnv map[string]string) error {
fmd.Replicating = false
return nil
}
// SlaveStatus is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) SlaveStatus() (*proto.ReplicationStatus, error) {
if fmd.CurrentSlaveStatus == nil {
return nil, fmt.Errorf("no slave status defined")
}
return fmd.CurrentSlaveStatus, nil
func (fmd *FakeMysqlDaemon) SlaveStatus() (proto.ReplicationStatus, error) {
return proto.ReplicationStatus{
Position: fmd.CurrentMasterPosition,
SlaveIORunning: fmd.Replicating,
SlaveSQLRunning: fmd.Replicating,
MasterHost: fmd.CurrentMasterHost,
MasterPort: fmd.CurrentMasterPort,
}, nil
}
// ResetReplicationCommands is part of the MysqlDaemon interface
@ -192,6 +242,11 @@ func (fmd *FakeMysqlDaemon) MasterPosition() (proto.ReplicationPosition, error)
return fmd.CurrentMasterPosition, nil
}
// IsReadOnly is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) IsReadOnly() (bool, error) {
return fmd.ReadOnly, nil
}
// SetReadOnly is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) SetReadOnly(on bool) error {
fmd.ReadOnly = on
@ -261,10 +316,58 @@ func (fmd *FakeMysqlDaemon) ExecuteSuperQueryList(queryList []string) error {
if expected != query {
return fmt.Errorf("wrong query for ExecuteSuperQueryList: expected %v got %v", expected, query)
}
// intercept some queries to update our status
switch query {
case SqlStartSlave:
fmd.Replicating = true
case SqlStopSlave:
fmd.Replicating = false
}
}
return nil
}
// FetchSuperQuery returns the results from the map, if any
func (fmd *FakeMysqlDaemon) FetchSuperQuery(query string) (*mproto.QueryResult, error) {
if fmd.FetchSuperQueryMap == nil {
return nil, fmt.Errorf("unexpected query: %v", query)
}
qr, ok := fmd.FetchSuperQueryMap[query]
if !ok {
return nil, fmt.Errorf("unexpected query: %v", query)
}
return qr, nil
}
// NewSlaveConnection is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) NewSlaveConnection() (*SlaveConnection, error) {
panic(fmt.Errorf("not implemented on FakeMysqlDaemon"))
}
// EnableBinlogPlayback is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) EnableBinlogPlayback() error {
if fmd.BinlogPlayerEnabled {
return fmt.Errorf("binlog player already enabled")
}
fmd.BinlogPlayerEnabled = true
return nil
}
// DisableBinlogPlayback disable playback of binlog events
func (fmd *FakeMysqlDaemon) DisableBinlogPlayback() error {
if fmd.BinlogPlayerEnabled {
return fmt.Errorf("binlog player already disabled")
}
fmd.BinlogPlayerEnabled = false
return nil
}
// Close is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) Close() {
}
// CheckSuperQueryList returns an error if all the queries we expected
// haven't been seen.
func (fmd *FakeMysqlDaemon) CheckSuperQueryList() error {
@ -282,6 +385,22 @@ func (fmd *FakeMysqlDaemon) GetSchema(dbName string, tables, excludeTables []str
return fmd.Schema.FilterTables(tables, excludeTables, includeViews)
}
// PreflightSchemaChange is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) PreflightSchemaChange(dbName string, change string) (*proto.SchemaChangeResult, error) {
if fmd.PreflightSchemaChangeResult == nil {
return nil, fmt.Errorf("no preflight result defined")
}
return fmd.PreflightSchemaChangeResult, nil
}
// ApplySchemaChange is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) ApplySchemaChange(dbName string, change *proto.SchemaChange) (*proto.SchemaChangeResult, error) {
if fmd.ApplySchemaChangeResult == nil {
return nil, fmt.Errorf("no apply schema defined")
}
return fmd.ApplySchemaChangeResult, nil
}
// GetAppConnection is part of the MysqlDaemon interface
func (fmd *FakeMysqlDaemon) GetAppConnection() (dbconnpool.PoolConnection, error) {
if fmd.DbAppConnectionFactory == nil {

Просмотреть файл

@ -30,7 +30,7 @@ type MysqlFlavor interface {
MasterPosition(mysqld *Mysqld) (proto.ReplicationPosition, error)
// SlaveStatus returns the ReplicationStatus of a slave.
SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error)
SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error)
// ResetReplicationCommands returns the commands to completely reset
// replication on the host.
@ -63,7 +63,7 @@ type MysqlFlavor interface {
// SendBinlogDumpCommand sends the flavor-specific version of
// the COM_BINLOG_DUMP command to start dumping raw binlog
// events over a slave connection, starting at a given GTID.
SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error
SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error
// MakeBinlogEvent takes a raw packet from the MySQL binlog
// stream connection and returns a BinlogEvent through which
@ -120,7 +120,7 @@ func (mysqld *Mysqld) detectFlavor() (MysqlFlavor, error) {
// If no environment variable set, fall back to auto-detect.
log.Infof("MYSQL_FLAVOR empty or unset, attempting to auto-detect...")
qr, err := mysqld.fetchSuperQuery("SELECT VERSION()")
qr, err := mysqld.FetchSuperQuery("SELECT VERSION()")
if err != nil {
return nil, fmt.Errorf("couldn't SELECT VERSION(): %v", err)
}

Просмотреть файл

@ -29,7 +29,7 @@ func (*mariaDB10) VersionMatch(version string) bool {
// MasterPosition implements MysqlFlavor.MasterPosition().
func (flavor *mariaDB10) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPosition, err error) {
qr, err := mysqld.fetchSuperQuery("SELECT @@GLOBAL.gtid_binlog_pos")
qr, err := mysqld.FetchSuperQuery("SELECT @@GLOBAL.gtid_binlog_pos")
if err != nil {
return rp, err
}
@ -40,16 +40,16 @@ func (flavor *mariaDB10) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPos
}
// SlaveStatus implements MysqlFlavor.SlaveStatus().
func (flavor *mariaDB10) SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error) {
func (flavor *mariaDB10) SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error) {
fields, err := mysqld.fetchSuperQueryMap("SHOW ALL SLAVES STATUS")
if err != nil {
return nil, ErrNotSlave
return proto.ReplicationStatus{}, ErrNotSlave
}
status := parseSlaveStatus(fields)
status.Position, err = flavor.ParseReplicationPosition(fields["Gtid_Slave_Pos"])
if err != nil {
return nil, fmt.Errorf("SlaveStatus can't parse MariaDB GTID (Gtid_Slave_Pos: %#v): %v", fields["Gtid_Slave_Pos"], err)
return proto.ReplicationStatus{}, fmt.Errorf("SlaveStatus can't parse MariaDB GTID (Gtid_Slave_Pos: %#v): %v", fields["Gtid_Slave_Pos"], err)
}
return status, nil
}
@ -69,7 +69,7 @@ func (*mariaDB10) WaitMasterPos(mysqld *Mysqld, targetPos proto.ReplicationPosit
}
log.Infof("Waiting for minimum replication position with query: %v", query)
qr, err := mysqld.fetchSuperQuery(query)
qr, err := mysqld.FetchSuperQuery(query)
if err != nil {
return fmt.Errorf("MASTER_GTID_WAIT() failed: %v", err)
}
@ -138,7 +138,7 @@ func (*mariaDB10) ParseReplicationPosition(s string) (proto.ReplicationPosition,
}
// SendBinlogDumpCommand implements MysqlFlavor.SendBinlogDumpCommand().
func (*mariaDB10) SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error {
func (*mariaDB10) SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error {
const ComBinlogDump = 0x12
// Tell the server that we understand GTIDs by setting our slave capability

Просмотреть файл

@ -30,7 +30,7 @@ func (*mysql56) VersionMatch(version string) bool {
// MasterPosition implements MysqlFlavor.MasterPosition().
func (flavor *mysql56) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPosition, err error) {
qr, err := mysqld.fetchSuperQuery("SELECT @@GLOBAL.gtid_executed")
qr, err := mysqld.FetchSuperQuery("SELECT @@GLOBAL.gtid_executed")
if err != nil {
return rp, err
}
@ -41,16 +41,16 @@ func (flavor *mysql56) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPosit
}
// SlaveStatus implements MysqlFlavor.SlaveStatus().
func (flavor *mysql56) SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error) {
func (flavor *mysql56) SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error) {
fields, err := mysqld.fetchSuperQueryMap("SHOW SLAVE STATUS")
if err != nil {
return nil, ErrNotSlave
return proto.ReplicationStatus{}, ErrNotSlave
}
status := parseSlaveStatus(fields)
status.Position, err = flavor.ParseReplicationPosition(fields["Executed_Gtid_Set"])
if err != nil {
return nil, fmt.Errorf("SlaveStatus can't parse MySQL 5.6 GTID (Executed_Gtid_Set: %#v): %v", fields["Executed_Gtid_Set"], err)
return proto.ReplicationStatus{}, fmt.Errorf("SlaveStatus can't parse MySQL 5.6 GTID (Executed_Gtid_Set: %#v): %v", fields["Executed_Gtid_Set"], err)
}
return status, nil
}
@ -62,7 +62,7 @@ func (*mysql56) WaitMasterPos(mysqld *Mysqld, targetPos proto.ReplicationPositio
query = fmt.Sprintf("SELECT WAIT_UNTIL_SQL_THREAD_AFTER_GTIDS('%s', %v)", targetPos, int(waitTimeout.Seconds()))
log.Infof("Waiting for minimum replication position with query: %v", query)
qr, err := mysqld.fetchSuperQuery(query)
qr, err := mysqld.FetchSuperQuery(query)
if err != nil {
return fmt.Errorf("WAIT_UNTIL_SQL_THREAD_AFTER_GTIDS() failed: %v", err)
}
@ -134,7 +134,7 @@ func (*mysql56) ParseReplicationPosition(s string) (proto.ReplicationPosition, e
}
// SendBinlogDumpCommand implements MysqlFlavor.SendBinlogDumpCommand().
func (flavor *mysql56) SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error {
func (flavor *mysql56) SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error {
const ComBinlogDumpGTID = 0x1E // COM_BINLOG_DUMP_GTID
gtidSet, ok := startPos.GTIDSet.(proto.Mysql56GTIDSet)

Просмотреть файл

@ -24,7 +24,7 @@ func (fakeMysqlFlavor) MakeBinlogEvent(buf []byte) blproto.BinlogEvent { return
func (fakeMysqlFlavor) ParseReplicationPosition(string) (proto.ReplicationPosition, error) {
return proto.ReplicationPosition{}, nil
}
func (fakeMysqlFlavor) SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error {
func (fakeMysqlFlavor) SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error {
return nil
}
func (fakeMysqlFlavor) WaitMasterPos(mysqld *Mysqld, targetPos proto.ReplicationPosition, waitTimeout time.Duration) error {
@ -33,7 +33,9 @@ func (fakeMysqlFlavor) WaitMasterPos(mysqld *Mysqld, targetPos proto.Replication
func (fakeMysqlFlavor) MasterPosition(mysqld *Mysqld) (proto.ReplicationPosition, error) {
return proto.ReplicationPosition{}, nil
}
func (fakeMysqlFlavor) SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error) { return nil, nil }
func (fakeMysqlFlavor) SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error) {
return proto.ReplicationStatus{}, nil
}
func (fakeMysqlFlavor) StartReplicationCommands(params *sqldb.ConnParams, status *proto.ReplicationStatus) ([]string, error) {
return nil, nil
}

Просмотреть файл

@ -42,8 +42,6 @@ const (
)
var (
// TODO(aaijazi): for reasons I don't understand, the dba pool size needs to be fairly large (15+)
// for test/clone.py to pass.
dbaPoolSize = flag.Int("dba_pool_size", 20, "Size of the connection pool for dba connections")
dbaIdleTimeout = flag.Duration("dba_idle_timeout", time.Minute, "Idle timeout for dba connections")
appPoolSize = flag.Int("app_pool_size", 40, "Size of the connection pool for app connections")
@ -402,7 +400,7 @@ func (mysqld *Mysqld) initConfig(root string) error {
func (mysqld *Mysqld) createDirs() error {
log.Infof("creating directory %s", mysqld.TabletDir)
if err := os.MkdirAll(mysqld.TabletDir, 0775); err != nil {
if err := os.MkdirAll(mysqld.TabletDir, os.ModePerm); err != nil {
return err
}
for _, dir := range TopLevelDirs() {
@ -412,7 +410,7 @@ func (mysqld *Mysqld) createDirs() error {
}
for _, dir := range mysqld.config.directoryList() {
log.Infof("creating directory %s", dir)
if err := os.MkdirAll(dir, 0775); err != nil {
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}
// FIXME(msolomon) validate permissions?
@ -435,14 +433,14 @@ func (mysqld *Mysqld) createTopDir(dir string) error {
if os.IsNotExist(err) {
topdir := path.Join(mysqld.TabletDir, dir)
log.Infof("creating directory %s", topdir)
return os.MkdirAll(topdir, 0775)
return os.MkdirAll(topdir, os.ModePerm)
}
return err
}
linkto := path.Join(target, vtname)
source := path.Join(mysqld.TabletDir, dir)
log.Infof("creating directory %s", linkto)
err = os.MkdirAll(linkto, 0775)
err = os.MkdirAll(linkto, os.ModePerm)
if err != nil {
return err
}

Просмотреть файл

@ -8,11 +8,12 @@ import (
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
)
func (mysqld *Mysqld) GetPermissions() (*proto.Permissions, error) {
// GetPermissions lists the permissions on the mysqld
func GetPermissions(mysqld MysqlDaemon) (*proto.Permissions, error) {
permissions := &proto.Permissions{}
// get Users
qr, err := mysqld.fetchSuperQuery("SELECT * FROM mysql.user")
qr, err := mysqld.FetchSuperQuery("SELECT * FROM mysql.user")
if err != nil {
return nil, err
}
@ -21,7 +22,7 @@ func (mysqld *Mysqld) GetPermissions() (*proto.Permissions, error) {
}
// get Dbs
qr, err = mysqld.fetchSuperQuery("SELECT * FROM mysql.db")
qr, err = mysqld.FetchSuperQuery("SELECT * FROM mysql.db")
if err != nil {
return nil, err
}
@ -30,7 +31,7 @@ func (mysqld *Mysqld) GetPermissions() (*proto.Permissions, error) {
}
// get Hosts
qr, err = mysqld.fetchSuperQuery("SELECT * FROM mysql.host")
qr, err = mysqld.FetchSuperQuery("SELECT * FROM mysql.host")
if err != nil {
return nil, err
}

Просмотреть файл

@ -9,7 +9,6 @@ import (
"encoding/hex"
"fmt"
"regexp"
"sort"
"strings"
"github.com/youtube/vitess/go/jscfg"
@ -17,41 +16,38 @@ import (
)
const (
TABLE_BASE_TABLE = "BASE TABLE"
TABLE_VIEW = "VIEW"
// TableBaseTable indicates the table type is a base table.
TableBaseTable = "BASE TABLE"
// TableView indicates the table type is a view.
TableView = "VIEW"
)
// TableDefinition contains all schema information about a table.
type TableDefinition struct {
Name string // the table name
Schema string // the SQL to run to create the table
Columns []string // the columns in the order that will be used to dump and load the data
PrimaryKeyColumns []string // the columns used by the primary key, in order
Type string // TABLE_BASE_TABLE or TABLE_VIEW
Type string // TableBaseTable or TableView
DataLength uint64 // how much space the data file takes.
RowCount uint64 // how many rows in the table (may
// be approximate count)
}
// helper methods for sorting
// TableDefinitions is a list of TableDefinition.
type TableDefinitions []*TableDefinition
// Len returns TableDefinitions length.
func (tds TableDefinitions) Len() int {
return len(tds)
}
// Swap used for sorting TableDefinitions.
func (tds TableDefinitions) Swap(i, j int) {
tds[i], tds[j] = tds[j], tds[i]
}
// sort by reverse DataLength
type ByReverseDataLength struct {
TableDefinitions
}
func (bdl ByReverseDataLength) Less(i, j int) bool {
return bdl.TableDefinitions[j].DataLength < bdl.TableDefinitions[i].DataLength
}
// SchemaDefinition defines schema for a certain database.
type SchemaDefinition struct {
// the 'CREATE DATABASE...' statement, with db name as {{.DatabaseName}}
DatabaseSchema string
@ -67,10 +63,6 @@ func (sd *SchemaDefinition) String() string {
return jscfg.ToJSON(sd)
}
func (sd *SchemaDefinition) SortByReverseDataLength() {
sort.Sort(ByReverseDataLength{sd.TableDefinitions})
}
// FilterTables returns a copy which includes only
// whitelisted tables (tables), no blacklisted tables (excludeTables) and optionally views (includeViews).
func (sd *SchemaDefinition) FilterTables(tables, excludeTables []string, includeViews bool) (*SchemaDefinition, error) {
@ -126,7 +118,7 @@ func (sd *SchemaDefinition) FilterTables(tables, excludeTables []string, include
continue
}
if !includeViews && table.Type == TABLE_VIEW {
if !includeViews && table.Type == TableView {
continue
}
@ -141,6 +133,8 @@ func (sd *SchemaDefinition) FilterTables(tables, excludeTables []string, include
return &copy, nil
}
// GenerateSchemaVersion return a unique schema version string based on
// its TableDefinitions.
func (sd *SchemaDefinition) GenerateSchemaVersion() {
hasher := md5.New()
for _, td := range sd.TableDefinitions {
@ -151,6 +145,7 @@ func (sd *SchemaDefinition) GenerateSchemaVersion() {
sd.Version = hex.EncodeToString(hasher.Sum(nil))
}
// GetTable returns TableDefinition for a given table name.
func (sd *SchemaDefinition) GetTable(table string) (td *TableDefinition, ok bool) {
for _, td := range sd.TableDefinitions {
if td.Name == table {
@ -170,7 +165,7 @@ func (sd *SchemaDefinition) ToSQLStrings() []string {
sqlStrings = append(sqlStrings, sd.DatabaseSchema)
for _, td := range sd.TableDefinitions {
if td.Type == TABLE_VIEW {
if td.Type == TableView {
createViewSql = append(createViewSql, td.Schema)
} else {
lines := strings.Split(td.Schema, "\n")
@ -186,9 +181,16 @@ func (sd *SchemaDefinition) ToSQLStrings() []string {
return append(sqlStrings, createViewSql...)
}
// generates a report on what's different between two SchemaDefinition
// for now, we skip the VIEW entirely.
// DiffSchema generates a report on what's different between two SchemaDefinitions
// including views.
func DiffSchema(leftName string, left *SchemaDefinition, rightName string, right *SchemaDefinition, er concurrency.ErrorRecorder) {
if left == nil && right == nil {
return
}
if left == nil || right == nil {
er.RecordError(fmt.Errorf("%v and %v are different, %s: %v, %s: %v", leftName, rightName, leftName, left, rightName, right))
return
}
if left.DatabaseSchema != right.DatabaseSchema {
er.RecordError(fmt.Errorf("%v and %v don't agree on database creation command:\n%v\n differs from:\n%v", leftName, rightName, left.DatabaseSchema, right.DatabaseSchema))
}
@ -196,16 +198,6 @@ func DiffSchema(leftName string, left *SchemaDefinition, rightName string, right
leftIndex := 0
rightIndex := 0
for leftIndex < len(left.TableDefinitions) && rightIndex < len(right.TableDefinitions) {
// skip views
if left.TableDefinitions[leftIndex].Type == TABLE_VIEW {
leftIndex++
continue
}
if right.TableDefinitions[rightIndex].Type == TABLE_VIEW {
rightIndex++
continue
}
// extra table on the left side
if left.TableDefinitions[leftIndex].Name < right.TableDefinitions[rightIndex].Name {
er.RecordError(fmt.Errorf("%v has an extra table named %v", leftName, left.TableDefinitions[leftIndex].Name))
@ -224,34 +216,46 @@ func DiffSchema(leftName string, left *SchemaDefinition, rightName string, right
if left.TableDefinitions[leftIndex].Schema != right.TableDefinitions[rightIndex].Schema {
er.RecordError(fmt.Errorf("%v and %v disagree on schema for table %v:\n%v\n differs from:\n%v", leftName, rightName, left.TableDefinitions[leftIndex].Name, left.TableDefinitions[leftIndex].Schema, right.TableDefinitions[rightIndex].Schema))
}
if left.TableDefinitions[leftIndex].Type != right.TableDefinitions[rightIndex].Type {
er.RecordError(fmt.Errorf("%v and %v disagree on table type for table %v:\n%v\n differs from:\n%v", leftName, rightName, left.TableDefinitions[leftIndex].Name, left.TableDefinitions[leftIndex].Type, right.TableDefinitions[rightIndex].Type))
}
leftIndex++
rightIndex++
}
for leftIndex < len(left.TableDefinitions) {
if left.TableDefinitions[leftIndex].Type == TABLE_BASE_TABLE {
if left.TableDefinitions[leftIndex].Type == TableBaseTable {
er.RecordError(fmt.Errorf("%v has an extra table named %v", leftName, left.TableDefinitions[leftIndex].Name))
}
if left.TableDefinitions[leftIndex].Type == TableView {
er.RecordError(fmt.Errorf("%v has an extra view named %v", leftName, left.TableDefinitions[leftIndex].Name))
}
leftIndex++
}
for rightIndex < len(right.TableDefinitions) {
if right.TableDefinitions[rightIndex].Type == TABLE_BASE_TABLE {
if right.TableDefinitions[rightIndex].Type == TableBaseTable {
er.RecordError(fmt.Errorf("%v has an extra table named %v", rightName, right.TableDefinitions[rightIndex].Name))
}
if right.TableDefinitions[rightIndex].Type == TableView {
er.RecordError(fmt.Errorf("%v has an extra view named %v", rightName, right.TableDefinitions[rightIndex].Name))
}
rightIndex++
}
}
// DiffSchemaToArray diffs two schemas and return the schema diffs if there is any.
func DiffSchemaToArray(leftName string, left *SchemaDefinition, rightName string, right *SchemaDefinition) (result []string) {
er := concurrency.AllErrorRecorder{}
DiffSchema(leftName, left, rightName, right, &er)
if er.HasErrors() {
return er.ErrorStrings()
} else {
return nil
}
return nil
}
// SchemaChange contains all necessary information to apply a schema change.
type SchemaChange struct {
Sql string
Force bool
@ -260,6 +264,8 @@ type SchemaChange struct {
AfterSchema *SchemaDefinition
}
// SchemaChangeResult contains before and after table schemas for
// a schema change sql.
type SchemaChangeResult struct {
BeforeSchema *SchemaDefinition
AfterSchema *SchemaDefinition

Просмотреть файл

@ -6,6 +6,7 @@ package proto
import (
"errors"
"fmt"
"reflect"
"testing"
)
@ -13,12 +14,12 @@ import (
var basicTable1 = &TableDefinition{
Name: "table1",
Schema: "table schema 1",
Type: TABLE_BASE_TABLE,
Type: TableBaseTable,
}
var basicTable2 = &TableDefinition{
Name: "table2",
Schema: "table schema 2",
Type: TABLE_BASE_TABLE,
Type: TableBaseTable,
}
var table3 = &TableDefinition{
@ -26,19 +27,19 @@ var table3 = &TableDefinition{
Schema: "CREATE TABLE `table3` (\n" +
"id bigint not null,\n" +
") Engine=InnoDB",
Type: TABLE_BASE_TABLE,
Type: TableBaseTable,
}
var view1 = &TableDefinition{
Name: "view1",
Schema: "view schema 1",
Type: TABLE_VIEW,
Type: TableView,
}
var view2 = &TableDefinition{
Name: "view2",
Schema: "view schema 2",
Type: TABLE_VIEW,
Type: TableView,
}
func TestToSQLStrings(t *testing.T) {
@ -152,30 +153,89 @@ func TestSchemaDiff(t *testing.T) {
&TableDefinition{
Name: "table1",
Schema: "schema1",
Type: TABLE_BASE_TABLE,
Type: TableBaseTable,
},
&TableDefinition{
Name: "table2",
Schema: "schema2",
Type: TABLE_BASE_TABLE,
Type: TableBaseTable,
},
},
}
testDiff(t, sd1, sd1, "sd1", "sd2", []string{})
sd2 := &SchemaDefinition{TableDefinitions: make([]*TableDefinition, 0, 2)}
sd3 := &SchemaDefinition{
TableDefinitions: []*TableDefinition{
&TableDefinition{
Name: "table2",
Schema: "schema2",
Type: TableBaseTable,
},
},
}
sd4 := &SchemaDefinition{
TableDefinitions: []*TableDefinition{
&TableDefinition{
Name: "table2",
Schema: "table2",
Type: TableView,
},
},
}
sd5 := &SchemaDefinition{
TableDefinitions: []*TableDefinition{
&TableDefinition{
Name: "table2",
Schema: "table2",
Type: TableBaseTable,
},
},
}
testDiff(t, sd1, sd1, "sd1", "sd2", []string{})
testDiff(t, sd2, sd2, "sd2", "sd2", []string{})
// two schemas are considered the same if both nil
testDiff(t, nil, nil, "sd1", "sd2", nil)
testDiff(t, sd1, nil, "sd1", "sd2", []string{
fmt.Sprintf("sd1 and sd2 are different, sd1: %v, sd2: null", sd1),
})
testDiff(t, sd1, sd3, "sd1", "sd3", []string{
"sd1 has an extra table named table1",
})
testDiff(t, sd3, sd1, "sd3", "sd1", []string{
"sd1 has an extra table named table1",
})
testDiff(t, sd2, sd4, "sd2", "sd4", []string{
"sd4 has an extra view named table2",
})
testDiff(t, sd4, sd2, "sd4", "sd2", []string{
"sd4 has an extra view named table2",
})
testDiff(t, sd4, sd5, "sd4", "sd5", []string{
fmt.Sprintf("sd4 and sd5 disagree on table type for table table2:\nVIEW\n differs from:\nBASE TABLE"),
})
sd1.DatabaseSchema = "CREATE DATABASE {{.DatabaseName}}"
sd2.DatabaseSchema = "DONT CREATE DATABASE {{.DatabaseName}}"
testDiff(t, sd1, sd2, "sd1", "sd2", []string{"sd1 and sd2 don't agree on database creation command:\nCREATE DATABASE {{.DatabaseName}}\n differs from:\nDONT CREATE DATABASE {{.DatabaseName}}", "sd1 has an extra table named table1", "sd1 has an extra table named table2"})
sd2.DatabaseSchema = "CREATE DATABASE {{.DatabaseName}}"
testDiff(t, sd2, sd1, "sd2", "sd1", []string{"sd1 has an extra table named table1", "sd1 has an extra table named table2"})
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table1", Schema: "schema1", Type: TABLE_BASE_TABLE})
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table1", Schema: "schema1", Type: TableBaseTable})
testDiff(t, sd1, sd2, "sd1", "sd2", []string{"sd1 has an extra table named table2"})
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table2", Schema: "schema3", Type: TABLE_BASE_TABLE})
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table2", Schema: "schema3", Type: TableBaseTable})
testDiff(t, sd1, sd2, "sd1", "sd2", []string{"sd1 and sd2 disagree on schema for table table2:\nschema2\n differs from:\nschema3"})
}

Просмотреть файл

@ -33,8 +33,8 @@ func (mysqld *Mysqld) ExecuteSuperQueryList(queryList []string) error {
return nil
}
// fetchSuperQuery returns the results of executing a query as a super user.
func (mysqld *Mysqld) fetchSuperQuery(query string) (*mproto.QueryResult, error) {
// FetchSuperQuery returns the results of executing a query as a super user.
func (mysqld *Mysqld) FetchSuperQuery(query string) (*mproto.QueryResult, error) {
conn, connErr := mysqld.dbaPool.Get(0)
if connErr != nil {
return nil, connErr
@ -51,7 +51,7 @@ func (mysqld *Mysqld) fetchSuperQuery(query string) (*mproto.QueryResult, error)
// fetchSuperQueryMap returns a map from column names to cell data for a query
// that should return exactly 1 row.
func (mysqld *Mysqld) fetchSuperQueryMap(query string) (map[string]string, error) {
qr, err := mysqld.fetchSuperQuery(query)
qr, err := mysqld.FetchSuperQuery(query)
if err != nil {
return nil, err
}
@ -69,6 +69,9 @@ func (mysqld *Mysqld) fetchSuperQueryMap(query string) (map[string]string, error
return rowMap, nil
}
const masterPasswordStart = " MASTER_PASSWORD = '"
const masterPasswordEnd = "',\n"
func redactMasterPassword(input string) string {
i := strings.Index(input, masterPasswordStart)
if i == -1 {

Просмотреть файл

@ -52,7 +52,7 @@ func queryReparentJournal(timeCreatedNS int64) string {
// the row in the reparent_journal table.
func (mysqld *Mysqld) WaitForReparentJournal(ctx context.Context, timeCreatedNS int64) error {
for {
qr, err := mysqld.fetchSuperQuery(queryReparentJournal(timeCreatedNS))
qr, err := mysqld.FetchSuperQuery(queryReparentJournal(timeCreatedNS))
if err == nil && len(qr.Rows) == 1 {
// we have the row, we're done
return nil
@ -82,12 +82,10 @@ func (mysqld *Mysqld) DemoteMaster() (rp proto.ReplicationPosition, err error) {
return mysqld.MasterPosition()
}
// PromoteSlave will promote a slave to be the new master
// PromoteSlave will promote a slave to be the new master.
func (mysqld *Mysqld) PromoteSlave(hookExtraEnv map[string]string) (proto.ReplicationPosition, error) {
// stop replication for good
if err := mysqld.StopSlave(hookExtraEnv); err != nil {
return proto.ReplicationPosition{}, err
}
// we handle replication, just stop it
cmds := []string{SqlStopSlave}
// Promote to master.
flavor, err := mysqld.flavor()
@ -95,7 +93,7 @@ func (mysqld *Mysqld) PromoteSlave(hookExtraEnv map[string]string) (proto.Replic
err = fmt.Errorf("PromoteSlave needs flavor: %v", err)
return proto.ReplicationPosition{}, err
}
cmds := flavor.PromoteSlaveCommands()
cmds = append(cmds, flavor.PromoteSlaveCommands()...)
if err := mysqld.ExecuteSuperQueryList(cmds); err != nil {
return proto.ReplicationPosition{}, err
}

Просмотреть файл

@ -12,8 +12,6 @@ import (
"bytes"
"errors"
"fmt"
"os"
"path"
"strconv"
"strings"
"text/template"
@ -30,8 +28,13 @@ import (
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
)
var masterPasswordStart = " MASTER_PASSWORD = '"
var masterPasswordEnd = "',\n"
const (
// SqlStartSlave is the SQl command issued to start MySQL replication
SqlStartSlave = "START SLAVE"
// SqlStopSlave is the SQl command issued to stop MySQL replication
SqlStopSlave = "STOP SLAVE"
)
func fillStringTemplate(tmpl string, vars interface{}) (string, error) {
myTemplate := template.Must(template.New("").Parse(tmpl))
@ -69,8 +72,8 @@ func changeMasterArgs(params *sqldb.ConnParams, masterHost string, masterPort in
}
// parseSlaveStatus parses the common fields of SHOW SLAVE STATUS.
func parseSlaveStatus(fields map[string]string) *proto.ReplicationStatus {
status := &proto.ReplicationStatus{
func parseSlaveStatus(fields map[string]string) proto.ReplicationStatus {
status := proto.ReplicationStatus{
MasterHost: fields["Master_Host"],
SlaveIORunning: fields["Slave_IO_Running"] == "Yes",
SlaveSQLRunning: fields["Slave_SQL_Running"] == "Yes",
@ -84,8 +87,9 @@ func parseSlaveStatus(fields map[string]string) *proto.ReplicationStatus {
return status
}
// WaitForSlaveStart waits a slave until given deadline passed
func (mysqld *Mysqld) WaitForSlaveStart(slaveStartDeadline int) error {
// WaitForSlaveStart waits until the deadline for replication to start.
// This validates the current master is correct and can be connected to.
func WaitForSlaveStart(mysqld MysqlDaemon, slaveStartDeadline int) error {
var rowMap map[string]string
for slaveWait := 0; slaveWait < slaveStartDeadline; slaveWait++ {
status, err := mysqld.SlaveStatus()
@ -112,9 +116,9 @@ func (mysqld *Mysqld) WaitForSlaveStart(slaveStartDeadline int) error {
return nil
}
// StartSlave starts a slave
func (mysqld *Mysqld) StartSlave(hookExtraEnv map[string]string) error {
if err := mysqld.ExecuteSuperQuery("START SLAVE"); err != nil {
// StartSlave starts a slave on the provided MysqldDaemon
func StartSlave(md MysqlDaemon, hookExtraEnv map[string]string) error {
if err := md.ExecuteSuperQueryList([]string{SqlStartSlave}); err != nil {
return err
}
@ -123,29 +127,20 @@ func (mysqld *Mysqld) StartSlave(hookExtraEnv map[string]string) error {
return h.ExecuteOptional()
}
// StopSlave stops a slave
func (mysqld *Mysqld) StopSlave(hookExtraEnv map[string]string) error {
// StopSlave stops a slave on the provided MysqldDaemon
func StopSlave(md MysqlDaemon, hookExtraEnv map[string]string) error {
h := hook.NewSimpleHook("preflight_stop_slave")
h.ExtraEnv = hookExtraEnv
if err := h.ExecuteOptional(); err != nil {
return err
}
return mysqld.ExecuteSuperQuery("STOP SLAVE")
}
// GetMasterAddr returns master address
func (mysqld *Mysqld) GetMasterAddr() (string, error) {
slaveStatus, err := mysqld.SlaveStatus()
if err != nil {
return "", err
}
return slaveStatus.MasterAddr(), nil
return md.ExecuteSuperQueryList([]string{SqlStopSlave})
}
// GetMysqlPort returns mysql port
func (mysqld *Mysqld) GetMysqlPort() (int, error) {
qr, err := mysqld.fetchSuperQuery("SHOW VARIABLES LIKE 'port'")
qr, err := mysqld.FetchSuperQuery("SHOW VARIABLES LIKE 'port'")
if err != nil {
return 0, err
}
@ -161,7 +156,7 @@ func (mysqld *Mysqld) GetMysqlPort() (int, error) {
// IsReadOnly return true if the instance is read only
func (mysqld *Mysqld) IsReadOnly() (bool, error) {
qr, err := mysqld.fetchSuperQuery("SHOW VARIABLES LIKE 'read_only'")
qr, err := mysqld.FetchSuperQuery("SHOW VARIABLES LIKE 'read_only'")
if err != nil {
return true, err
}
@ -202,10 +197,10 @@ func (mysqld *Mysqld) WaitMasterPos(targetPos proto.ReplicationPosition, waitTim
}
// SlaveStatus returns the slave replication statuses
func (mysqld *Mysqld) SlaveStatus() (*proto.ReplicationStatus, error) {
func (mysqld *Mysqld) SlaveStatus() (proto.ReplicationStatus, error) {
flavor, err := mysqld.flavor()
if err != nil {
return nil, fmt.Errorf("SlaveStatus needs flavor: %v", err)
return proto.ReplicationStatus{}, fmt.Errorf("SlaveStatus needs flavor: %v", err)
}
return flavor.SlaveStatus(mysqld)
}
@ -250,43 +245,6 @@ func (mysqld *Mysqld) SetMasterCommands(masterHost string, masterPort int) ([]st
return flavor.SetMasterCommands(&params, masterHost, masterPort, int(masterConnectRetry.Seconds()))
}
// WaitForSlave waits for a slave if its lag is larger than given maxLag
func (mysqld *Mysqld) WaitForSlave(maxLag int) (err error) {
// FIXME(msolomon) verify that slave started based on show slave status;
var rowMap map[string]string
for {
rowMap, err = mysqld.fetchSuperQueryMap("SHOW SLAVE STATUS")
if err != nil {
return
}
if rowMap["Seconds_Behind_Master"] == "NULL" {
break
} else {
lag, err := strconv.Atoi(rowMap["Seconds_Behind_Master"])
if err != nil {
break
}
if lag < maxLag {
return nil
}
}
time.Sleep(time.Second)
}
errorKeys := []string{"Last_Error", "Last_IO_Error", "Last_SQL_Error"}
errs := make([]string, 0, len(errorKeys))
for _, key := range errorKeys {
if rowMap[key] != "" {
errs = append(errs, key+": "+rowMap[key])
}
}
if len(errs) != 0 {
return errors.New(strings.Join(errs, ", "))
}
return errors.New("replication stopped, it will never catch up")
}
// ResetReplicationCommands returns the commands to run to reset all
// replication for this host.
func (mysqld *Mysqld) ResetReplicationCommands() ([]string, error) {
@ -319,8 +277,8 @@ const (
)
// FindSlaves gets IP addresses for all currently connected slaves.
func (mysqld *Mysqld) FindSlaves() ([]string, error) {
qr, err := mysqld.fetchSuperQuery("SHOW PROCESSLIST")
func FindSlaves(mysqld MysqlDaemon) ([]string, error) {
qr, err := mysqld.FetchSuperQuery("SHOW PROCESSLIST")
if err != nil {
return nil, err
}
@ -339,25 +297,9 @@ func (mysqld *Mysqld) FindSlaves() ([]string, error) {
return addrs, nil
}
// ValidateSnapshotPath is a helper function to make sure we can write to the local snapshot area, before we actually do any action
// (can be used for both partial and full snapshots)
func (mysqld *Mysqld) ValidateSnapshotPath() error {
_path := path.Join(mysqld.SnapshotDir, "validate_test")
if err := os.RemoveAll(_path); err != nil {
return fmt.Errorf("ValidateSnapshotPath: Cannot validate snapshot directory: %v", err)
}
if err := os.MkdirAll(_path, 0775); err != nil {
return fmt.Errorf("ValidateSnapshotPath: Cannot validate snapshot directory: %v", err)
}
if err := os.RemoveAll(_path); err != nil {
return fmt.Errorf("ValidateSnapshotPath: Cannot validate snapshot directory: %v", err)
}
return nil
}
// WaitBlpPosition will wait for the filtered replication to reach at least
// the provided position.
func (mysqld *Mysqld) WaitBlpPosition(bp *blproto.BlpPosition, waitTimeout time.Duration) error {
func WaitBlpPosition(mysqld MysqlDaemon, bp *blproto.BlpPosition, waitTimeout time.Duration) error {
timeOut := time.Now().Add(waitTimeout)
for {
if time.Now().After(timeOut) {
@ -365,7 +307,7 @@ func (mysqld *Mysqld) WaitBlpPosition(bp *blproto.BlpPosition, waitTimeout time.
}
cmd := binlogplayer.QueryBlpCheckpoint(bp.Uid)
qr, err := mysqld.fetchSuperQuery(cmd)
qr, err := mysqld.FetchSuperQuery(cmd)
if err != nil {
return err
}

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше