зеркало из https://github.com/github/vitess-gh.git
Merge branch 'master' into replication
This commit is contained in:
Коммит
5dc0717d4c
9
Makefile
9
Makefile
|
@ -92,7 +92,7 @@ small_integration_test_files = \
|
|||
sharded.py \
|
||||
secure.py \
|
||||
binlog.py \
|
||||
clone.py \
|
||||
backup.py \
|
||||
update_stream.py
|
||||
|
||||
medium_integration_test_files = \
|
||||
|
@ -100,7 +100,8 @@ medium_integration_test_files = \
|
|||
reparent.py \
|
||||
vtdb_test.py \
|
||||
vtgate_utils_test.py \
|
||||
rowcache_invalidator.py
|
||||
rowcache_invalidator.py \
|
||||
worker.py
|
||||
|
||||
large_integration_test_files = \
|
||||
vtgatev2_test.py \
|
||||
|
@ -122,7 +123,8 @@ worker_integration_test_files = \
|
|||
vertical_split.py \
|
||||
vertical_split_vtgate.py \
|
||||
initial_sharding.py \
|
||||
initial_sharding_bytes.py
|
||||
initial_sharding_bytes.py \
|
||||
worker.py
|
||||
|
||||
.ONESHELL:
|
||||
SHELL = /bin/bash
|
||||
|
@ -183,6 +185,7 @@ proto:
|
|||
cd go/vt/proto/queryservice && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/queryservice.proto --go_out=plugins=grpc:.
|
||||
cd go/vt/proto/vtctl && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/vtctl.proto --go_out=plugins=grpc:.
|
||||
cd go/vt/proto/tabletmanager && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/tabletmanager.proto --go_out=plugins=grpc:.
|
||||
cd go/vt/proto/automation && $$VTROOT/dist/protobuf/bin/protoc -I../../../../proto ../../../../proto/automation.proto --go_out=plugins=grpc:.
|
||||
find go/vt/proto -name "*.pb.go" | xargs sed --in-place -r -e 's,"([a-z0-9_]+).pb","github.com/youtube/vitess/go/vt/proto/\1",g'
|
||||
cd py/vtctl && $$VTROOT/dist/protobuf/bin/protoc -I../../proto ../../proto/vtctl.proto --python_out=. --grpc_out=. --plugin=protoc-gen-grpc=$$VTROOT/dist/grpc/bin/grpc_python_plugin
|
||||
|
||||
|
|
|
@ -24,10 +24,9 @@ and a more [detailed presentation from @Scale '14](http://youtu.be/5yDO-tmIoXY).
|
|||
|
||||
### Using Vitess
|
||||
|
||||
* [Getting Started](http://vitess.io/getting-started/):
|
||||
running Vitess on Kubernetes.
|
||||
* [Building](http://vitess.io/doc/GettingStarted):
|
||||
how to manually build Vitess.
|
||||
* Getting Started
|
||||
* [On Kubernetes](http://vitess.io/getting-started/).
|
||||
* [From the ground up](http://vitess.io/doc/GettingStarted).
|
||||
* [Tools](http://vitess.io/doc/Tools):
|
||||
all Vitess tools and servers.
|
||||
* [vttablet/vtocc](http://vitess.io/doc/vtocc):
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Backup and Restore
|
||||
|
||||
This document describes Vitess Backup and Restore strategy.
|
||||
|
||||
### Overview
|
||||
|
||||
Backups are used in Vitess for two purposes: to provide a point-in-time backup for the data, and to bootstrap new instances.
|
||||
|
||||
### Backup Storage
|
||||
|
||||
Backups are stored on a Backup Storage service. Vitess core software contains an implementation that uses a local filesystem to store the files. Any network-mounted drive can then be used as the repository for backups.
|
||||
|
||||
We have plans to implement a version of the Backup Storage service for Google Cloud Storage (contact us if you are interested).
|
||||
|
||||
(The interface definition for the Backup Storage service is in [interface.go](https://github.com/youtube/vitess/blob/master/go/vt/mysqlctl/backupstorage/interface.go), see comments there for more details).
|
||||
|
||||
Concretely, the following command line flags are used for Backup Storage:
|
||||
* -backup\_storage\_implementation: which implementation of the Backup Storage interface to use.
|
||||
* -file\_backup\_storage\_root: the root of the backups if 'file' is used as a Backup Storage.
|
||||
|
||||
### Taking a Backup
|
||||
|
||||
To take a backup is very straightforward: just run the 'vtctl Backup <tablet-alias>' command. The designated tablet will take itself out of the healthy serving tablets, shutdown its mysqld process, copy the necessary files to the Backup Storage, restart mysql, restart replication, and join the cluster back.
|
||||
|
||||
With health-check enabled (the recommended default), the tablet goes back to spare state. Once it catches up on replication, it will go back to a serving state.
|
||||
|
||||
Note for this to work correctly, the tablet must be started with the right parameters to point it to the Backup Storage system (see previous section).
|
||||
|
||||
### Life of a Shard
|
||||
|
||||
To illustrate how backups are used in Vitess to bootstrap new instances, let's go through the creation and life of a Shard:
|
||||
* A shard is initially brought up with no existing backup. All instances are started as replicas. With health-check enabled (the recommended default), each instance will realize replication is not running, and just stay unhealthy as spare.
|
||||
* Once a few replicas are up, InitShardMaster is run, one host becomes the master, the others replicas. Master becomes healthy, replicas are not as no database exists.
|
||||
* Initial schema can then be applied to the Master. Either use the usual Schema change tools, or use CopySchemaShard for shards created as targets for resharding.
|
||||
* After replicating the schema creation, all replicas become healthy. At this point, we have a working and functionnal shard.
|
||||
* The initial backup is taken (that stores the data and the current replication position), and backup data is copied to a network storage.
|
||||
* When a replica comes up (either a new replica, or one whose instance was just restarted), it restores the latest backup, resets its master to the current shard master, and starts replicating.
|
||||
* A Cronjob to backup the data on a regular basis should then be run. The frequency of the backups should be high enough (compared to MySQL binlog retention), so we can always have a backup to fall back upon.
|
||||
|
||||
Restoring a backup is enabled by the --restore\_from\_backup command line option in vttablet. It can be enabled all the time for all the tablets in a shard, as it doesn't prevent vttablet from starting if no backup can be found.
|
||||
|
||||
### Backup Management
|
||||
|
||||
Two vtctl commands exist to manage the backups:
|
||||
* 'vtctl ListBackups <keyspace/shard>' will display the existing backups for a keyspace/shard in the order they were taken (oldest first).
|
||||
* 'vtctl RemoveBackup <keyspace/shard> <backup name>' will remove a backup from Backup Storage.
|
||||
|
||||
### Details
|
||||
|
||||
Both Backup and Restore copy and compress / decompress multiple files simultaneously to increase throughput. The concurrency can be controlled by command-line flags (-concurrency for 'vtctl Backup', and -restore\_concurrency for vttablet). If the network link is fast enough, the concurrency will match the CPU usage of the process during backup / restore.
|
||||
|
||||
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ Let’s assume that you’ve already got a keyspace up and running, with a singl
|
|||
|
||||
The first thing that we need to do is add a column to the soon-to-be-sharded keyspace which will be used as the "sharding key". This column will tell Vitess which shard a particular row of data should go to. You can add the column by running an alter on the unsharded keyspace - probably by running something like:
|
||||
|
||||
`vtctl ApplySchemaKeyspace -simple -sql="alter table <table name> add keyspace_id" test_keyspace`
|
||||
`vtctl ApplySchema -sql="alter table <table name> add keyspace_id" test_keyspace`
|
||||
|
||||
for each table in the keyspace. Once the column is added everywhere, each row needs to be backfilled with the appropriate keyspace ID.
|
||||
|
||||
|
|
|
@ -61,89 +61,14 @@ type SchemaChangeResult struct {
|
|||
}
|
||||
```
|
||||
|
||||
The ApplySchema action applies a schema change. It is described by the following structure (also returns a SchemaChangeResult):
|
||||
The ApplySchema action applies a schema change to a specified keyspace, the performed steps are:
|
||||
|
||||
```go
|
||||
type SchemaChange struct {
|
||||
Sql string
|
||||
Force bool
|
||||
AllowReplication bool
|
||||
BeforeSchema *SchemaDefinition
|
||||
AfterSchema *SchemaDefinition
|
||||
}
|
||||
```
|
||||
|
||||
And the associated ApplySchema remote action for a tablet. Then the performed steps are:
|
||||
|
||||
* The database to use is either derived from the tablet dbName if UseVt is false, or is the _vt database. A ‘use dbname’ is prepended to the Sql.
|
||||
* (if BeforeSchema is not nil) read the schema, make sure it is equal to BeforeSchema. If not equal: if Force is not set, we will abort, if Force is set, we’ll issue a warning and keep going.
|
||||
* if AllowReplication is false, we’ll disable replication (adding SET sql_log_bin=0 before the Sql).
|
||||
* We will then apply the Sql command.
|
||||
* (if AfterSchema is not nil) read the schema again, make sure it is equal to AfterSchema. If not equal: if Force is not set, we will issue an error, if Force is set, we’ll issue a warning.
|
||||
|
||||
We will return the following information:
|
||||
|
||||
* whether it worked or not (doh!)
|
||||
* BeforeSchema
|
||||
* AfterSchema
|
||||
|
||||
### Use case 1: Single tablet update:
|
||||
|
||||
* we first do a Preflight (to know what BeforeSchema and AfterSchema will be). This can be disabled, but is not recommended.
|
||||
* we then do the schema upgrade. We will check BeforeSchema before the upgrade, and AfterSchema after the upgrade.
|
||||
|
||||
### Use case 2: Single Shard update:
|
||||
|
||||
* need to figure out (or be told) if it’s a simple or complex schema update (does it require the shell game?). For now we'll use a command line flag.
|
||||
* in any case, do a Preflight on the master, to get the BeforeSchema and AfterSchema values.
|
||||
* in any case, gather the schema on all databases, to see which ones have been upgraded already or not. This guarantees we can interrupt and restart a schema change. Also, this makes sure no action is currently running on the databases we're about to change.
|
||||
* if simple:
|
||||
* nobody has it: apply to master, very similar to a single tablet update.
|
||||
* some tablets have it but not others: error out
|
||||
* if complex: do the shell game while disabling replication. Skip the tablets that already have it. Have an option to re-parent at the end.
|
||||
* Note the Backup, and Lag servers won't apply a complex schema change. Only the servers actively in the replication graph will.
|
||||
* the process can be interrupted at any time, restarting it as a complex schema upgrade should just work.
|
||||
|
||||
### Use case 3: Keyspace update:
|
||||
|
||||
* Similar to Single Shard, but the BeforeSchema and AfterSchema values are taken from the first shard, and used in all shards after that.
|
||||
* We don't know the new masters to use on each shard, so just skip re-parenting all together.
|
||||
|
||||
This translates into the following vtctl commands:
|
||||
* It first finds shards belong to this keyspace, including newly added shards in the presence of [resharding event](Resharding.md).
|
||||
* Validate the sql syntax and reject the schema change if the sql 1) Alter more then 100,000 rows, or 2) The targed table has more then 2,000,000 rows. The rational behind this is that ApplySchema simply applies schema changes to the masters; therefore, a big schema change that takes too much time slows down the replication and may reduce the availability of the overall system.
|
||||
* Create a temporary database that has the same schema as the targeted table. Apply the sql to it and makes sure it changes table structure.
|
||||
* Apply the Sql command to the database.
|
||||
* Read the schema again, make sure it is equal to AfterSchema.
|
||||
|
||||
```
|
||||
PreflightSchema {-sql=<sql> || -sql_file=<filename>} <tablet alias>
|
||||
ApplySchema {-sql=<sql> || -sql_file=<filename>} <keyspace>
|
||||
```
|
||||
|
||||
apply the schema change to a temporary database to gather before and after schema and validate the change. The sql can be inlined or read from a file.
|
||||
This will create a temporary database, copy the existing keyspace schema into it, apply the schema change, and re-read the resulting schema.
|
||||
|
||||
```
|
||||
$ echo "create table test_table(id int);" > change.sql
|
||||
$ vtctl PreflightSchema -sql_file=change.sql nyc-0002009001
|
||||
```
|
||||
|
||||
```
|
||||
ApplySchema {-sql=<sql> || -sql_file=<filename>} [-skip_preflight] [-stop_replication] <tablet alias>
|
||||
```
|
||||
|
||||
apply the schema change to the specific tablet (allowing replication by default). The sql can be inlined or read from a file.
|
||||
a PreflightSchema operation will first be used to make sure the schema is OK (unless skip_preflight is specified).
|
||||
|
||||
```
|
||||
ApplySchemaShard {-sql=<sql> || -sql_file=<filename>} [-simple] [-new_parent=<tablet alias>] <keyspace/shard>
|
||||
```
|
||||
|
||||
apply the schema change to the specific shard. If simple is specified, we just apply on the live master. Otherwise, we do the shell game and will optionally re-parent.
|
||||
if new_parent is set, we will also reparent (otherwise the master won't be touched at all). Using the force flag will cause a bunch of checks to be ignored, use with care.
|
||||
|
||||
```
|
||||
$ vtctl ApplySchemaShard --sql-file=change.sql -simple vtx/0
|
||||
$ vtctl ApplySchemaShard --sql-file=change.sql -new_parent=nyc-0002009002 vtx/0
|
||||
```
|
||||
|
||||
```
|
||||
ApplySchemaKeyspace {-sql=<sql> || -sql_file=<filename>} [-simple] <keyspace>
|
||||
```
|
||||
|
||||
apply the schema change to the specified shard. If simple is specified, we just apply on the live master. Otherwise we will need to do the shell game. So we will apply the schema change to every single slave.
|
||||
|
|
Двоичный файл не отображается.
|
@ -18,8 +18,9 @@ if [[ ! -f bootstrap.sh ]]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
# To avoid AUFS permission issues, files must allow access by "other"
|
||||
chmod -R o=g *
|
||||
# To avoid AUFS permission issues, files must allow access by "other" (permissions rX required).
|
||||
# Mirror permissions to "other" from the owning group (for which we assume it has at least rX permissions).
|
||||
chmod -R o=g .
|
||||
|
||||
args="$args --rm -e USER=vitess -v /dev/log:/dev/log"
|
||||
args="$args -v $PWD:/tmp/src"
|
||||
|
|
|
@ -27,10 +27,11 @@ $ go get github.com/youtube/vitess/go/cmd/vtctlclient
|
|||
### Set the path to kubectl
|
||||
|
||||
If you're running in Container Engine, set the `KUBECTL` environment variable
|
||||
to point to the `gcloud` command:
|
||||
to point to the `kubectl` command provided by the Google Cloud SDK (if you've
|
||||
already added gcloud to your PATH, you likely have kubectl):
|
||||
|
||||
```
|
||||
$ export KUBECTL='gcloud alpha container kubectl'
|
||||
$ export KUBECTL='kubectl'
|
||||
```
|
||||
|
||||
If you're running Kubernetes manually, set the `KUBECTL` environment variable
|
||||
|
|
|
@ -89,7 +89,7 @@ if [ -z "$GOPATH" ]; then
|
|||
exit -1
|
||||
fi
|
||||
|
||||
export KUBECTL='gcloud alpha container kubectl'
|
||||
export KUBECTL='kubectl'
|
||||
go get github.com/youtube/vitess/go/cmd/vtctlclient
|
||||
gcloud config set compute/zone $GKE_ZONE
|
||||
project_id=`gcloud config list project | sed -n 2p | cut -d " " -f 3`
|
||||
|
|
|
@ -42,26 +42,6 @@ func initCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) err
|
|||
return nil
|
||||
}
|
||||
|
||||
func restoreCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
|
||||
dontWaitForSlaveStart := subFlags.Bool("dont_wait_for_slave_start", false, "won't wait for replication to start (useful when restoring from master server)")
|
||||
fetchConcurrency := subFlags.Int("fetch_concurrency", 3, "how many files to fetch simultaneously")
|
||||
fetchRetryCount := subFlags.Int("fetch_retry_count", 3, "how many times to retry a failed transfer")
|
||||
subFlags.Parse(args)
|
||||
if subFlags.NArg() != 1 {
|
||||
return fmt.Errorf("Command restore requires <snapshot manifest file>")
|
||||
}
|
||||
|
||||
rs, err := mysqlctl.ReadSnapshotManifest(subFlags.Arg(0))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restore failed: ReadSnapshotManifest: %v", err)
|
||||
}
|
||||
err = mysqld.RestoreFromSnapshot(logutil.NewConsoleLogger(), rs, *fetchConcurrency, *fetchRetryCount, *dontWaitForSlaveStart, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("restore failed: RestoreFromSnapshot: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shutdownCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
|
||||
waitTime := subFlags.Duration("wait_time", mysqlctl.MysqlWaitTime, "how long to wait for shutdown")
|
||||
subFlags.Parse(args)
|
||||
|
@ -72,50 +52,6 @@ func shutdownCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string)
|
|||
return nil
|
||||
}
|
||||
|
||||
func snapshotCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
|
||||
concurrency := subFlags.Int("concurrency", 4, "how many compression jobs to run simultaneously")
|
||||
subFlags.Parse(args)
|
||||
if subFlags.NArg() != 1 {
|
||||
return fmt.Errorf("Command snapshot requires <db name>")
|
||||
}
|
||||
|
||||
filename, _, _, err := mysqld.CreateSnapshot(logutil.NewConsoleLogger(), subFlags.Arg(0), tabletAddr, false, *concurrency, false, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot failed: %v", err)
|
||||
}
|
||||
log.Infof("manifest location: %v", filename)
|
||||
return nil
|
||||
}
|
||||
|
||||
func snapshotSourceStartCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
|
||||
concurrency := subFlags.Int("concurrency", 4, "how many checksum jobs to run simultaneously")
|
||||
subFlags.Parse(args)
|
||||
if subFlags.NArg() != 1 {
|
||||
return fmt.Errorf("Command snapshotsourcestart requires <db name>")
|
||||
}
|
||||
|
||||
filename, slaveStartRequired, readOnly, err := mysqld.CreateSnapshot(logutil.NewConsoleLogger(), subFlags.Arg(0), tabletAddr, false, *concurrency, true, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshot failed: %v", err)
|
||||
}
|
||||
log.Infof("manifest location: %v", filename)
|
||||
log.Infof("slave start required: %v", slaveStartRequired)
|
||||
log.Infof("read only: %v", readOnly)
|
||||
return nil
|
||||
}
|
||||
|
||||
func snapshotSourceEndCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
|
||||
slaveStartRequired := subFlags.Bool("slave_start", false, "will restart replication")
|
||||
readWrite := subFlags.Bool("read_write", false, "will make the server read-write")
|
||||
subFlags.Parse(args)
|
||||
|
||||
err := mysqld.SnapshotSourceEnd(*slaveStartRequired, !(*readWrite), true, map[string]string{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("snapshotsourceend failed: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func startCmd(mysqld *mysqlctl.Mysqld, subFlags *flag.FlagSet, args []string) error {
|
||||
waitTime := subFlags.Duration("wait_time", mysqlctl.MysqlWaitTime, "how long to wait for startup")
|
||||
subFlags.Parse(args)
|
||||
|
@ -188,19 +124,6 @@ var commands = []command{
|
|||
command{"shutdown", shutdownCmd, "[-wait_time=20s]",
|
||||
"Shuts down mysqld, does not remove any file"},
|
||||
|
||||
command{"snapshot", snapshotCmd,
|
||||
"[-concurrency=4] <db name>",
|
||||
"Takes a full snapshot, copying the innodb data files"},
|
||||
command{"snapshotsourcestart", snapshotSourceStartCmd,
|
||||
"[-concurrency=4] <db name>",
|
||||
"Enters snapshot server mode (mysqld stopped, serving innodb data files)"},
|
||||
command{"snapshotsourceend", snapshotSourceEndCmd,
|
||||
"[-slave_start] [-read_write]",
|
||||
"Gets out of snapshot server mode"},
|
||||
command{"restore", restoreCmd,
|
||||
"[-fetch_concurrency=3] [-fetch_retry_count=3] [-dont_wait_for_slave_start] <snapshot manifest file>",
|
||||
"Restores a full snapshot"},
|
||||
|
||||
command{"position", positionCmd,
|
||||
"<operation> <pos1> <pos2 | gtid>",
|
||||
"Compute operations on replication positions"},
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/topo/helpers"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var fromTopo = flag.String("from", "", "topology to copy data from")
|
||||
|
@ -41,19 +42,20 @@ func main() {
|
|||
exit.Return(1)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
fromTS := topo.GetServerByName(*fromTopo)
|
||||
toTS := topo.GetServerByName(*toTopo)
|
||||
|
||||
if *doKeyspaces {
|
||||
helpers.CopyKeyspaces(fromTS, toTS)
|
||||
helpers.CopyKeyspaces(ctx, fromTS, toTS)
|
||||
}
|
||||
if *doShards {
|
||||
helpers.CopyShards(fromTS, toTS, *deleteKeyspaceShards)
|
||||
helpers.CopyShards(ctx, fromTS, toTS, *deleteKeyspaceShards)
|
||||
}
|
||||
if *doShardReplications {
|
||||
helpers.CopyShardReplications(fromTS, toTS)
|
||||
helpers.CopyShardReplications(ctx, fromTS, toTS)
|
||||
}
|
||||
if *doTablets {
|
||||
helpers.CopyTablets(fromTS, toTS)
|
||||
helpers.CopyTablets(ctx, fromTS, toTS)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the gorpc vtgateconn client
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
|
||||
)
|
|
@ -0,0 +1,189 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/exit"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
|
||||
// import the 'vitess' sql driver
|
||||
_ "github.com/youtube/vitess/go/vt/client"
|
||||
)
|
||||
|
||||
var (
|
||||
usage = `
|
||||
vtclient connects to a vtgate server using the standard go driver API.
|
||||
Version 3 of the API is used, we do not send any hint to the server.
|
||||
|
||||
For query bound variables, we assume place-holders in the query string
|
||||
in the form of :v1, :v2, etc.
|
||||
`
|
||||
server = flag.String("server", "", "vtgate server to connect to")
|
||||
tabletType = flag.String("tablet_type", "rdonly", "tablet type to direct queries to")
|
||||
timeout = flag.Duration("timeout", 30*time.Second, "timeout for queries")
|
||||
streaming = flag.Bool("streaming", false, "use a streaming query")
|
||||
bindVariables = newBindvars("bind_variables", "bind variables as a json list")
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, usage)
|
||||
}
|
||||
}
|
||||
|
||||
type bindvars []interface{}
|
||||
|
||||
func (bv *bindvars) String() string {
|
||||
b, err := json.Marshal(bv)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (bv *bindvars) Set(s string) (err error) {
|
||||
err = json.Unmarshal([]byte(s), &bv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// json reads all numbers as float64
|
||||
// So, we just ditch floats for bindvars
|
||||
for i, v := range *bv {
|
||||
if f, ok := v.(float64); ok {
|
||||
if f > 0 {
|
||||
(*bv)[i] = uint64(f)
|
||||
} else {
|
||||
(*bv)[i] = int64(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// For internal flag compatibility
|
||||
func (bv *bindvars) Get() interface{} {
|
||||
return bv
|
||||
}
|
||||
|
||||
func newBindvars(name, usage string) *bindvars {
|
||||
var bv bindvars
|
||||
flag.Var(&bv, name, usage)
|
||||
return &bv
|
||||
}
|
||||
|
||||
// FIXME(alainjobart) this is a cheap trick. Should probably use the
|
||||
// query parser if we needed this to be 100% reliable.
|
||||
func isDml(sql string) bool {
|
||||
lower := strings.TrimSpace(strings.ToLower(sql))
|
||||
return strings.HasPrefix(lower, "insert") || strings.HasPrefix(lower, "update") || strings.HasPrefix(lower, "delete")
|
||||
}
|
||||
|
||||
func main() {
|
||||
defer exit.Recover()
|
||||
defer logutil.Flush()
|
||||
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
|
||||
if len(args) == 0 {
|
||||
flag.Usage()
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
connStr := fmt.Sprintf(`{"address": "%s", "tablet_type": "%s", "streaming": %v, "timeout": %d}`, *server, *tabletType, *streaming, int64(30*(*timeout)))
|
||||
db, err := sql.Open("vitess", connStr)
|
||||
if err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
log.Infof("Sending the query...")
|
||||
now := time.Now()
|
||||
|
||||
// handle dml
|
||||
if isDml(args[0]) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
log.Errorf("begin failed: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
result, err := db.Exec(args[0], []interface{}(*bindVariables)...)
|
||||
if err != nil {
|
||||
log.Errorf("exec failed: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Errorf("commit failed: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
lastInsertId, err := result.LastInsertId()
|
||||
log.Infof("Total time: %v / Row affected: %v / Last Insert Id: %v", time.Now().Sub(now), rowsAffected, lastInsertId)
|
||||
} else {
|
||||
|
||||
// launch the query
|
||||
rows, err := db.Query(args[0], []interface{}(*bindVariables)...)
|
||||
if err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// print the headers
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
line := "Index"
|
||||
for _, field := range cols {
|
||||
line += "\t" + field
|
||||
}
|
||||
fmt.Printf("%s\n", line)
|
||||
|
||||
// get the rows
|
||||
rowIndex := 0
|
||||
for rows.Next() {
|
||||
row := make([]interface{}, len(cols))
|
||||
for i := range row {
|
||||
var col string
|
||||
row[i] = &col
|
||||
}
|
||||
if err := rows.Scan(row...); err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
// print the line
|
||||
line := fmt.Sprintf("%d", rowIndex)
|
||||
for _, value := range row {
|
||||
line += fmt.Sprintf("\t%v", *(value.(*string)))
|
||||
}
|
||||
fmt.Printf("%s\n", line)
|
||||
rowIndex++
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Errorf("Error %v\n", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
log.Infof("Total time: %v / Row count: %v", time.Now().Sub(now), rowIndex)
|
||||
}
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// This plugin imports etcdtopo to register the etcd implementation of TopoServer.
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/etcdtopo"
|
||||
)
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2013, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the gorpc tabletconn client
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/tabletserver/gorpctabletconn"
|
||||
)
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2013, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the Zookeeper TopologyServer
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/zktopo"
|
||||
)
|
|
@ -1,193 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/db"
|
||||
"github.com/youtube/vitess/go/exit"
|
||||
"github.com/youtube/vitess/go/vt/client2"
|
||||
_ "github.com/youtube/vitess/go/vt/client2/tablet"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
)
|
||||
|
||||
var usage = `
|
||||
The parameters are first the SQL command, then the bound variables.
|
||||
For query arguments, we assume place-holders in the query string
|
||||
in the form of :v0, :v1, etc.
|
||||
`
|
||||
|
||||
var count = flag.Int("count", 1, "how many times to run the query")
|
||||
var bindvars = FlagMap("bindvars", "bind vars as a json dictionary")
|
||||
var server = flag.String("server", "localhost:6603/test_keyspace/0", "vtocc server as [user:password@]hostname:port/keyspace/shard[#keyrangestart-keyrangeend]")
|
||||
var driver = flag.String("driver", "vttablet", "which driver to use (one of vttablet, vttablet-streaming, vtdb, vtdb-streaming)")
|
||||
var verbose = flag.Bool("verbose", false, "show results")
|
||||
|
||||
func init() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, usage)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
//----------------------------------
|
||||
|
||||
type Map map[string]interface{}
|
||||
|
||||
func (m *Map) String() string {
|
||||
b, err := json.Marshal(*m)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (m *Map) Set(s string) (err error) {
|
||||
err = json.Unmarshal([]byte(s), m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// json reads all numbers as float64
|
||||
// So, we just ditch floats for bindvars
|
||||
for k, v := range *m {
|
||||
f, ok := v.(float64)
|
||||
if ok {
|
||||
if f > 0 {
|
||||
(*m)[k] = uint64(f)
|
||||
} else {
|
||||
(*m)[k] = int64(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// For internal flag compatibility
|
||||
func (m *Map) Get() interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
func FlagMap(name, usage string) (m map[string]interface{}) {
|
||||
m = make(map[string]interface{})
|
||||
mm := Map(m)
|
||||
flag.Var(&mm, name, usage)
|
||||
return m
|
||||
}
|
||||
|
||||
// FIXME(alainjobart) this is a cheap trick. Should probably use the
|
||||
// query parser if we needed this to be 100% reliable.
|
||||
func isDml(sql string) bool {
|
||||
lower := strings.TrimSpace(strings.ToLower(sql))
|
||||
return strings.HasPrefix(lower, "insert") || strings.HasPrefix(lower, "update")
|
||||
}
|
||||
|
||||
func main() {
|
||||
defer exit.Recover()
|
||||
defer logutil.Flush()
|
||||
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
|
||||
if len(args) == 0 {
|
||||
flag.Usage()
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
client2.RegisterShardedDrivers()
|
||||
conn, err := db.Open(*driver, *server)
|
||||
if err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
log.Infof("Sending the query...")
|
||||
now := time.Now()
|
||||
|
||||
// handle dml
|
||||
if isDml(args[0]) {
|
||||
t, err := conn.Begin()
|
||||
if err != nil {
|
||||
log.Errorf("begin failed: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
r, err := conn.Exec(args[0], bindvars)
|
||||
if err != nil {
|
||||
log.Errorf("exec failed: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
err = t.Commit()
|
||||
if err != nil {
|
||||
log.Errorf("commit failed: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
n, err := r.RowsAffected()
|
||||
log.Infof("Total time: %v / Row affected: %v", time.Now().Sub(now), n)
|
||||
} else {
|
||||
|
||||
// launch the query
|
||||
r, err := conn.Exec(args[0], bindvars)
|
||||
if err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
// get the headers
|
||||
cols := r.Columns()
|
||||
if err != nil {
|
||||
log.Errorf("client error: %v", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
// print the header
|
||||
if *verbose {
|
||||
line := "Index"
|
||||
for _, field := range cols {
|
||||
line += "\t" + field
|
||||
}
|
||||
log.Infof(line)
|
||||
}
|
||||
|
||||
// get the rows
|
||||
rowIndex := 0
|
||||
for row := r.Next(); row != nil; row = r.Next() {
|
||||
// print the line if needed
|
||||
if *verbose {
|
||||
line := fmt.Sprintf("%d", rowIndex)
|
||||
for _, value := range row {
|
||||
if value != nil {
|
||||
switch value.(type) {
|
||||
case []byte:
|
||||
line += fmt.Sprintf("\t%s", value)
|
||||
default:
|
||||
line += fmt.Sprintf("\t%v", value)
|
||||
}
|
||||
} else {
|
||||
line += "\t"
|
||||
}
|
||||
}
|
||||
log.Infof(line)
|
||||
}
|
||||
rowIndex++
|
||||
}
|
||||
if err := r.Err(); err != nil {
|
||||
log.Errorf("Error %v\n", err)
|
||||
exit.Return(1)
|
||||
}
|
||||
log.Infof("Total time: %v / Row count: %v", time.Now().Sub(now), rowIndex)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/mysqlctl/filebackupstorage"
|
||||
)
|
|
@ -10,7 +10,6 @@ import (
|
|||
"log/syslog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
@ -18,7 +17,6 @@ import (
|
|||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/exit"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
myproto "github.com/youtube/vitess/go/vt/mysqlctl/proto"
|
||||
"github.com/youtube/vitess/go/vt/tabletmanager/tmclient"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/vtctl"
|
||||
|
@ -54,6 +52,12 @@ func installSignalHandlers(cancel func()) {
|
|||
}()
|
||||
}
|
||||
|
||||
// hooks to register plug-ins after flag init
|
||||
|
||||
type initFunc func()
|
||||
|
||||
var initFuncs []initFunc
|
||||
|
||||
func main() {
|
||||
defer exit.RecoverAll()
|
||||
defer logutil.Flush()
|
||||
|
@ -81,6 +85,10 @@ func main() {
|
|||
wr := wrangler.New(logutil.NewConsoleLogger(), topoServer, tmclient.NewTabletManagerClient(), *lockWaitTimeout)
|
||||
installSignalHandlers(cancel)
|
||||
|
||||
for _, f := range initFuncs {
|
||||
f()
|
||||
}
|
||||
|
||||
err := vtctl.RunCommand(ctx, wr, args)
|
||||
cancel()
|
||||
switch err {
|
||||
|
@ -94,52 +102,3 @@ func main() {
|
|||
exit.Return(255)
|
||||
}
|
||||
}
|
||||
|
||||
type rTablet struct {
|
||||
*topo.TabletInfo
|
||||
*myproto.ReplicationStatus
|
||||
}
|
||||
|
||||
type rTablets []*rTablet
|
||||
|
||||
func (rts rTablets) Len() int { return len(rts) }
|
||||
|
||||
func (rts rTablets) Swap(i, j int) { rts[i], rts[j] = rts[j], rts[i] }
|
||||
|
||||
// Sort for tablet replication.
|
||||
// master first, then i/o position, then sql position
|
||||
func (rts rTablets) Less(i, j int) bool {
|
||||
// NOTE: Swap order of unpack to reverse sort
|
||||
l, r := rts[j], rts[i]
|
||||
// l or r ReplicationPosition would be nil if we failed to get
|
||||
// the position (put them at the beginning of the list)
|
||||
if l.ReplicationStatus == nil {
|
||||
return r.ReplicationStatus != nil
|
||||
}
|
||||
if r.ReplicationStatus == nil {
|
||||
return false
|
||||
}
|
||||
var lTypeMaster, rTypeMaster int
|
||||
if l.Type == topo.TYPE_MASTER {
|
||||
lTypeMaster = 1
|
||||
}
|
||||
if r.Type == topo.TYPE_MASTER {
|
||||
rTypeMaster = 1
|
||||
}
|
||||
if lTypeMaster < rTypeMaster {
|
||||
return true
|
||||
}
|
||||
if lTypeMaster == rTypeMaster {
|
||||
return !l.Position.AtLeast(r.Position)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sortReplicatingTablets(tablets []*topo.TabletInfo, stats []*myproto.ReplicationStatus) []*rTablet {
|
||||
rtablets := make([]*rTablet, len(tablets))
|
||||
for i, status := range stats {
|
||||
rtablets[i] = &rTablet{tablets[i], status}
|
||||
}
|
||||
sort.Sort(rTablets(rtablets))
|
||||
return rtablets
|
||||
}
|
||||
|
|
|
@ -48,12 +48,12 @@ func newTabletHealth(thc *tabletHealthCache, tabletAlias topo.TabletAlias) (*Tab
|
|||
func (th *TabletHealth) update(thc *tabletHealthCache, tabletAlias topo.TabletAlias) {
|
||||
defer thc.delete(tabletAlias)
|
||||
|
||||
ti, err := thc.ts.GetTablet(tabletAlias)
|
||||
ctx := context.Background()
|
||||
ti, err := thc.ts.GetTablet(ctx, tabletAlias)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
c, errFunc, err := thc.tmc.HealthStream(ctx, ti)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -235,7 +235,7 @@ func (loader *TemplateLoader) ServeTemplate(templateName string, data interface{
|
|||
|
||||
var (
|
||||
modifyDbTopology func(context.Context, topo.Server, *topotools.Topology) error
|
||||
modifyDbServingGraph func(topo.Server, *topotools.ServingGraph)
|
||||
modifyDbServingGraph func(context.Context, topo.Server, *topotools.ServingGraph)
|
||||
)
|
||||
|
||||
// SetDbTopologyPostprocessor installs a hook that can modify
|
||||
|
@ -249,7 +249,7 @@ func SetDbTopologyPostprocessor(f func(context.Context, topo.Server, *topotools.
|
|||
|
||||
// SetDbServingGraphPostprocessor installs a hook that can modify
|
||||
// topotools.ServingGraph struct before it's displayed.
|
||||
func SetDbServingGraphPostprocessor(f func(topo.Server, *topotools.ServingGraph)) {
|
||||
func SetDbServingGraphPostprocessor(f func(context.Context, topo.Server, *topotools.ServingGraph)) {
|
||||
if modifyDbServingGraph != nil {
|
||||
panic("Cannot set multiple DbServingGraph postprocessors")
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// This file includes the support for serving topo data to an ajax-based
|
||||
|
@ -47,7 +48,7 @@ func (bvo *BaseVersionedObject) SetVersion(version int) {
|
|||
|
||||
// GetVersionedObjectFunc is the function the cache will call to get
|
||||
// the object itself.
|
||||
type GetVersionedObjectFunc func() (VersionedObject, error)
|
||||
type GetVersionedObjectFunc func(ctx context.Context) (VersionedObject, error)
|
||||
|
||||
// VersionedObjectCache is the main cache object. Just needs a method to get
|
||||
// the content.
|
||||
|
@ -68,7 +69,7 @@ func NewVersionedObjectCache(getObject GetVersionedObjectFunc) *VersionedObjectC
|
|||
}
|
||||
|
||||
// Get returns the versioned value from the cache.
|
||||
func (voc *VersionedObjectCache) Get() ([]byte, error) {
|
||||
func (voc *VersionedObjectCache) Get(ctx context.Context) ([]byte, error) {
|
||||
voc.mu.Lock()
|
||||
defer voc.mu.Unlock()
|
||||
|
||||
|
@ -77,7 +78,7 @@ func (voc *VersionedObjectCache) Get() ([]byte, error) {
|
|||
return voc.result, nil
|
||||
}
|
||||
|
||||
newObject, err := voc.getObject()
|
||||
newObject, err := voc.getObject(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -142,7 +143,7 @@ func NewVersionedObjectCacheMap(factory VersionedObjectCacheFactory) *VersionedO
|
|||
}
|
||||
|
||||
// Get finds the right VersionedObjectCache and returns its value
|
||||
func (vocm *VersionedObjectCacheMap) Get(key string) ([]byte, error) {
|
||||
func (vocm *VersionedObjectCacheMap) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
vocm.mu.Lock()
|
||||
voc, ok := vocm.cacheMap[key]
|
||||
if !ok {
|
||||
|
@ -151,7 +152,7 @@ func (vocm *VersionedObjectCacheMap) Get(key string) ([]byte, error) {
|
|||
}
|
||||
vocm.mu.Unlock()
|
||||
|
||||
return voc.Get()
|
||||
return voc.Get(ctx)
|
||||
}
|
||||
|
||||
// Flush will flush the entire cache
|
||||
|
@ -177,8 +178,8 @@ func (kc *KnownCells) Reset() {
|
|||
}
|
||||
|
||||
func newKnownCellsCache(ts topo.Server) *VersionedObjectCache {
|
||||
return NewVersionedObjectCache(func() (VersionedObject, error) {
|
||||
cells, err := ts.GetKnownCells()
|
||||
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
|
||||
cells, err := ts.GetKnownCells(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -202,8 +203,8 @@ func (k *Keyspaces) Reset() {
|
|||
}
|
||||
|
||||
func newKeyspacesCache(ts topo.Server) *VersionedObjectCache {
|
||||
return NewVersionedObjectCache(func() (VersionedObject, error) {
|
||||
keyspaces, err := ts.GetKeyspaces()
|
||||
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
|
||||
keyspaces, err := ts.GetKeyspaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -232,8 +233,8 @@ func (k *Keyspace) Reset() {
|
|||
|
||||
func newKeyspaceCache(ts topo.Server) *VersionedObjectCacheMap {
|
||||
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
|
||||
return NewVersionedObjectCache(func() (VersionedObject, error) {
|
||||
k, err := ts.GetKeyspace(key)
|
||||
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
|
||||
k, err := ts.GetKeyspace(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -264,8 +265,8 @@ func (s *ShardNames) Reset() {
|
|||
|
||||
func newShardNamesCache(ts topo.Server) *VersionedObjectCacheMap {
|
||||
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
|
||||
return NewVersionedObjectCache(func() (VersionedObject, error) {
|
||||
sn, err := ts.GetShardNames(key)
|
||||
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
|
||||
sn, err := ts.GetShardNames(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -301,13 +302,13 @@ func (s *Shard) Reset() {
|
|||
|
||||
func newShardCache(ts topo.Server) *VersionedObjectCacheMap {
|
||||
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
|
||||
return NewVersionedObjectCache(func() (VersionedObject, error) {
|
||||
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
|
||||
|
||||
keyspace, shard, err := topo.ParseKeyspaceShardString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := ts.GetShard(keyspace, shard)
|
||||
s, err := ts.GetShard(ctx, keyspace, shard)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -349,12 +350,12 @@ func (cst *CellShardTablets) Reset() {
|
|||
|
||||
func newCellShardTabletsCache(ts topo.Server) *VersionedObjectCacheMap {
|
||||
return NewVersionedObjectCacheMap(func(key string) *VersionedObjectCache {
|
||||
return NewVersionedObjectCache(func() (VersionedObject, error) {
|
||||
return NewVersionedObjectCache(func(ctx context.Context) (VersionedObject, error) {
|
||||
parts := strings.Split(key, "/")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("Invalid shard tablets path: %v", key)
|
||||
}
|
||||
sr, err := ts.GetShardReplication(parts[0], parts[1], parts[2])
|
||||
sr, err := ts.GetShardReplication(ctx, parts[0], parts[1], parts[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -8,10 +8,12 @@ import (
|
|||
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/zktopo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo VersionedObject, expectedVO VersionedObject) {
|
||||
result, err := voc.Get()
|
||||
ctx := context.Background()
|
||||
result, err := voc.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -26,7 +28,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
|
|||
t.Fatalf("Got bad result: %#v expected: %#v", vo, expectedVO)
|
||||
}
|
||||
|
||||
result2, err := voc.Get()
|
||||
result2, err := voc.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -36,7 +38,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
|
|||
|
||||
// force a re-get with same content, version shouldn't change
|
||||
voc.timestamp = time.Time{}
|
||||
result2, err = voc.Get()
|
||||
result2, err = voc.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -47,7 +49,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
|
|||
// force a reget with different content, version should change
|
||||
voc.timestamp = time.Time{}
|
||||
voc.versionedObject.Reset() // poking inside the object here
|
||||
result, err = voc.Get()
|
||||
result, err = voc.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -64,7 +66,7 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
|
|||
|
||||
// force a flush and see the version increase again
|
||||
voc.Flush()
|
||||
result, err = voc.Get()
|
||||
result, err = voc.Get(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -81,7 +83,8 @@ func testVersionedObjectCache(t *testing.T, voc *VersionedObjectCache, vo Versio
|
|||
}
|
||||
|
||||
func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, key string, vo VersionedObject, expectedVO VersionedObject) {
|
||||
result, err := vocm.Get(key)
|
||||
ctx := context.Background()
|
||||
result, err := vocm.Get(ctx, key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -96,7 +99,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
|
|||
t.Fatalf("Got bad result: %#v expected: %#v", vo, expectedVO)
|
||||
}
|
||||
|
||||
result2, err := vocm.Get(key)
|
||||
result2, err := vocm.Get(ctx, key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -106,7 +109,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
|
|||
|
||||
// force a re-get with same content, version shouldn't change
|
||||
vocm.cacheMap[key].timestamp = time.Time{}
|
||||
result2, err = vocm.Get(key)
|
||||
result2, err = vocm.Get(ctx, key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -117,7 +120,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
|
|||
// force a reget with different content, version should change
|
||||
vocm.cacheMap[key].timestamp = time.Time{}
|
||||
vocm.cacheMap[key].versionedObject.Reset() // poking inside the object here
|
||||
result, err = vocm.Get(key)
|
||||
result, err = vocm.Get(ctx, key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -134,7 +137,7 @@ func testVersionedObjectCacheMap(t *testing.T, vocm *VersionedObjectCacheMap, ke
|
|||
|
||||
// force a flush and see the version increase again
|
||||
vocm.Flush()
|
||||
result, err = vocm.Get(key)
|
||||
result, err = vocm.Get(ctx, key)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -162,11 +165,12 @@ func TestKnownCellsCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestKeyspacesCache(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
|
||||
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{}); err != nil {
|
||||
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{}); err != nil {
|
||||
t.Fatalf("CreateKeyspace failed: %v", err)
|
||||
}
|
||||
if err := ts.CreateKeyspace("ks2", &topo.Keyspace{}); err != nil {
|
||||
if err := ts.CreateKeyspace(ctx, "ks2", &topo.Keyspace{}); err != nil {
|
||||
t.Fatalf("CreateKeyspace failed: %v", err)
|
||||
}
|
||||
kc := newKeyspacesCache(ts)
|
||||
|
@ -179,13 +183,14 @@ func TestKeyspacesCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestKeyspaceCache(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
|
||||
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{
|
||||
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{
|
||||
ShardingColumnName: "sharding_key",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateKeyspace failed: %v", err)
|
||||
}
|
||||
if err := ts.CreateKeyspace("ks2", &topo.Keyspace{
|
||||
if err := ts.CreateKeyspace(ctx, "ks2", &topo.Keyspace{
|
||||
SplitShardCount: 10,
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateKeyspace failed: %v", err)
|
||||
|
@ -211,18 +216,19 @@ func TestKeyspaceCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShardNamesCache(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
|
||||
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{
|
||||
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{
|
||||
ShardingColumnName: "sharding_key",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateKeyspace failed: %v", err)
|
||||
}
|
||||
if err := ts.CreateShard("ks1", "s1", &topo.Shard{
|
||||
if err := ts.CreateShard(ctx, "ks1", "s1", &topo.Shard{
|
||||
Cells: []string{"cell1", "cell2"},
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateShard failed: %v", err)
|
||||
}
|
||||
if err := ts.CreateShard("ks1", "s2", &topo.Shard{
|
||||
if err := ts.CreateShard(ctx, "ks1", "s2", &topo.Shard{
|
||||
MasterAlias: topo.TabletAlias{
|
||||
Cell: "cell1",
|
||||
Uid: 12,
|
||||
|
@ -241,18 +247,19 @@ func TestShardNamesCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShardCache(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
|
||||
if err := ts.CreateKeyspace("ks1", &topo.Keyspace{
|
||||
if err := ts.CreateKeyspace(ctx, "ks1", &topo.Keyspace{
|
||||
ShardingColumnName: "sharding_key",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateKeyspace failed: %v", err)
|
||||
}
|
||||
if err := ts.CreateShard("ks1", "s1", &topo.Shard{
|
||||
if err := ts.CreateShard(ctx, "ks1", "s1", &topo.Shard{
|
||||
Cells: []string{"cell1", "cell2"},
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateShard failed: %v", err)
|
||||
}
|
||||
if err := ts.CreateShard("ks1", "s2", &topo.Shard{
|
||||
if err := ts.CreateShard(ctx, "ks1", "s2", &topo.Shard{
|
||||
MasterAlias: topo.TabletAlias{
|
||||
Cell: "cell1",
|
||||
Uid: 12,
|
||||
|
@ -286,8 +293,9 @@ func TestShardCache(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCellShardTabletsCache(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := zktopo.NewTestServer(t, []string{"cell1", "cell2"})
|
||||
if err := ts.UpdateShardReplicationFields("cell1", "ks1", "s1", func(sr *topo.ShardReplication) error {
|
||||
if err := ts.UpdateShardReplicationFields(ctx, "cell1", "ks1", "s1", func(sr *topo.ShardReplication) error {
|
||||
sr.ReplicationLinks = []topo.ReplicationLink{
|
||||
topo.ReplicationLink{
|
||||
TabletAlias: topo.TabletAlias{
|
||||
|
|
|
@ -11,20 +11,22 @@ import (
|
|||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/acl"
|
||||
schmgr "github.com/youtube/vitess/go/vt/schemamanager"
|
||||
"github.com/youtube/vitess/go/vt/schemamanager/uihandler"
|
||||
"github.com/youtube/vitess/go/timer"
|
||||
"github.com/youtube/vitess/go/vt/schemamanager"
|
||||
"github.com/youtube/vitess/go/vt/servenv"
|
||||
"github.com/youtube/vitess/go/vt/tabletmanager/tmclient"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/topotools"
|
||||
"github.com/youtube/vitess/go/vt/wrangler"
|
||||
// register gorpc vtgate client
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
|
||||
)
|
||||
|
||||
var (
|
||||
templateDir = flag.String("templates", "", "directory containing templates")
|
||||
debug = flag.Bool("debug", false, "recompile templates for every request")
|
||||
templateDir = flag.String("templates", "", "directory containing templates")
|
||||
debug = flag.Bool("debug", false, "recompile templates for every request")
|
||||
schemaChangeDir = flag.String("schema-change-dir", "", "directory contains schema changes for all keyspaces. Each keyspace has its own directory and schema changes are expected to live in '$KEYSPACE/input' dir. e.g. test_keyspace/input/*sql, each sql file represents a schema change")
|
||||
schemaChangeController = flag.String("schema-change-controller", "", "schema change controller is responsible for finding schema changes and responsing schema change events")
|
||||
schemaChangeCheckInterval = flag.Int("schema-change-check-interval", 60, "this value decides how often we check schema change dir, in seconds")
|
||||
schemaChangeUser = flag.String("schema-change-user", "", "The user who submits this schema change.")
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -114,7 +116,7 @@ func main() {
|
|||
// tablet actions
|
||||
actionRepo.RegisterTabletAction("Ping", "",
|
||||
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
|
||||
ti, err := wr.TopoServer().GetTablet(tabletAlias)
|
||||
ti, err := wr.TopoServer().GetTablet(ctx, tabletAlias)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -124,7 +126,7 @@ func main() {
|
|||
actionRepo.RegisterTabletAction("ScrapTablet", acl.ADMIN,
|
||||
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
|
||||
// refuse to scrap tablets that are not spare
|
||||
ti, err := wr.TopoServer().GetTablet(tabletAlias)
|
||||
ti, err := wr.TopoServer().GetTablet(ctx, tabletAlias)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -137,7 +139,7 @@ func main() {
|
|||
actionRepo.RegisterTabletAction("ScrapTabletForce", acl.ADMIN,
|
||||
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
|
||||
// refuse to scrap tablets that are not spare
|
||||
ti, err := wr.TopoServer().GetTablet(tabletAlias)
|
||||
ti, err := wr.TopoServer().GetTablet(ctx, tabletAlias)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -149,7 +151,7 @@ func main() {
|
|||
|
||||
actionRepo.RegisterTabletAction("DeleteTablet", acl.ADMIN,
|
||||
func(ctx context.Context, wr *wrangler.Wrangler, tabletAlias topo.TabletAlias, r *http.Request) (string, error) {
|
||||
return "", wr.DeleteTablet(tabletAlias)
|
||||
return "", wr.DeleteTablet(ctx, tabletAlias)
|
||||
})
|
||||
|
||||
// keyspace actions
|
||||
|
@ -254,7 +256,8 @@ func main() {
|
|||
|
||||
cell := parts[len(parts)-1]
|
||||
if cell == "" {
|
||||
cells, err := ts.GetKnownCells()
|
||||
ctx := context.Background()
|
||||
cells, err := ts.GetKnownCells(ctx)
|
||||
if err != nil {
|
||||
httpError(w, "cannot get known cells: %v", err)
|
||||
return
|
||||
|
@ -263,9 +266,10 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
servingGraph := topotools.DbServingGraph(ts, cell)
|
||||
ctx := context.Background()
|
||||
servingGraph := topotools.DbServingGraph(ctx, ts, cell)
|
||||
if modifyDbServingGraph != nil {
|
||||
modifyDbServingGraph(ts, servingGraph)
|
||||
modifyDbServingGraph(ctx, ts, servingGraph)
|
||||
}
|
||||
templateLoader.ServeTemplate("serving_graph.html", servingGraph, w, r)
|
||||
})
|
||||
|
@ -292,12 +296,13 @@ func main() {
|
|||
Error error
|
||||
Input, Output string
|
||||
}
|
||||
ctx := context.Background()
|
||||
switch r.Method {
|
||||
case "POST":
|
||||
data.Input = r.FormValue("vschema")
|
||||
data.Error = schemafier.SaveVSchema(data.Input)
|
||||
data.Error = schemafier.SaveVSchema(ctx, data.Input)
|
||||
}
|
||||
vschema, err := schemafier.GetVSchema()
|
||||
vschema, err := schemafier.GetVSchema(ctx)
|
||||
if err != nil {
|
||||
if data.Error == nil {
|
||||
data.Error = fmt.Errorf("Error fetching schema: %s", err)
|
||||
|
@ -330,7 +335,8 @@ func main() {
|
|||
// serve some data
|
||||
knownCellsCache := newKnownCellsCache(ts)
|
||||
http.HandleFunc("/json/KnownCells", func(w http.ResponseWriter, r *http.Request) {
|
||||
result, err := knownCellsCache.Get()
|
||||
ctx := context.Background()
|
||||
result, err := knownCellsCache.Get(ctx)
|
||||
if err != nil {
|
||||
httpError(w, "error getting known cells: %v", err)
|
||||
return
|
||||
|
@ -340,7 +346,8 @@ func main() {
|
|||
|
||||
keyspacesCache := newKeyspacesCache(ts)
|
||||
http.HandleFunc("/json/Keyspaces", func(w http.ResponseWriter, r *http.Request) {
|
||||
result, err := keyspacesCache.Get()
|
||||
ctx := context.Background()
|
||||
result, err := keyspacesCache.Get(ctx)
|
||||
if err != nil {
|
||||
httpError(w, "error getting keyspaces: %v", err)
|
||||
return
|
||||
|
@ -359,7 +366,8 @@ func main() {
|
|||
http.Error(w, "no keyspace provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
result, err := keyspaceCache.Get(keyspace)
|
||||
ctx := context.Background()
|
||||
result, err := keyspaceCache.Get(ctx, keyspace)
|
||||
if err != nil {
|
||||
httpError(w, "error getting keyspace: %v", err)
|
||||
return
|
||||
|
@ -378,7 +386,8 @@ func main() {
|
|||
http.Error(w, "no keyspace provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
result, err := shardNamesCache.Get(keyspace)
|
||||
ctx := context.Background()
|
||||
result, err := shardNamesCache.Get(ctx, keyspace)
|
||||
if err != nil {
|
||||
httpError(w, "error getting shardNames: %v", err)
|
||||
return
|
||||
|
@ -402,7 +411,8 @@ func main() {
|
|||
http.Error(w, "no shard provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
result, err := shardCache.Get(keyspace + "/" + shard)
|
||||
ctx := context.Background()
|
||||
result, err := shardCache.Get(ctx, keyspace+"/"+shard)
|
||||
if err != nil {
|
||||
httpError(w, "error getting shard: %v", err)
|
||||
return
|
||||
|
@ -431,7 +441,8 @@ func main() {
|
|||
http.Error(w, "no shard provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
result, err := cellShardTabletsCache.Get(cell + "/" + keyspace + "/" + shard)
|
||||
ctx := context.Background()
|
||||
result, err := cellShardTabletsCache.Get(ctx, cell+"/"+keyspace+"/"+shard)
|
||||
if err != nil {
|
||||
httpError(w, "error getting shard: %v", err)
|
||||
return
|
||||
|
@ -485,16 +496,47 @@ func main() {
|
|||
}
|
||||
sqlStr := r.FormValue("data")
|
||||
keyspace := r.FormValue("keyspace")
|
||||
shards, err := ts.GetShardNames(keyspace)
|
||||
if err != nil {
|
||||
httpError(w, "error getting shards for keyspace: <"+keyspace+">, error: %v", err)
|
||||
}
|
||||
schmgr.Run(
|
||||
schmgr.NewSimepleDataSourcer(sqlStr),
|
||||
schmgr.NewVtGateExecutor(
|
||||
keyspace, nil, 1*time.Second),
|
||||
uihandler.NewUIEventHandler(w),
|
||||
shards)
|
||||
executor := schemamanager.NewTabletExecutor(
|
||||
tmclient.NewTabletManagerClient(),
|
||||
ts)
|
||||
|
||||
schemamanager.Run(
|
||||
context.Background(),
|
||||
schemamanager.NewUIController(sqlStr, keyspace, w),
|
||||
executor,
|
||||
)
|
||||
})
|
||||
if *schemaChangeDir != "" {
|
||||
interval := 60
|
||||
if *schemaChangeCheckInterval > 0 {
|
||||
interval = *schemaChangeCheckInterval
|
||||
}
|
||||
timer := timer.NewTimer(time.Duration(interval) * time.Second)
|
||||
controllerFactory, err :=
|
||||
schemamanager.GetControllerFactory(*schemaChangeController)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to get a controller factory, error: %v", err)
|
||||
}
|
||||
|
||||
timer.Start(func() {
|
||||
controller, err := controllerFactory(map[string]string{
|
||||
schemamanager.SchemaChangeDirName: *schemaChangeDir,
|
||||
schemamanager.SchemaChangeUser: *schemaChangeUser,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to get controller, error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = schemamanager.Run(
|
||||
context.Background(),
|
||||
controller,
|
||||
schemamanager.NewTabletExecutor(
|
||||
tmclient.NewTabletManagerClient(), ts),
|
||||
)
|
||||
log.Errorf("Schema change failed, error: %v", err)
|
||||
})
|
||||
servenv.OnClose(func() { timer.Stop() })
|
||||
}
|
||||
servenv.RunDefault()
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/vtgate"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/planbuilder"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -48,7 +49,6 @@ func main() {
|
|||
defer topo.CloseServers()
|
||||
|
||||
var schema *planbuilder.Schema
|
||||
log.Info(*cell, *schemaFile)
|
||||
if *schemaFile != "" {
|
||||
var err error
|
||||
if schema, err = planbuilder.LoadFile(*schemaFile); err != nil {
|
||||
|
@ -62,7 +62,8 @@ func main() {
|
|||
log.Infof("Skipping v3 initialization: topo does not suppurt schemafier interface")
|
||||
goto startServer
|
||||
}
|
||||
schemaJSON, err := schemafier.GetVSchema()
|
||||
ctx := context.Background()
|
||||
schemaJSON, err := schemafier.GetVSchema(ctx)
|
||||
if err != nil {
|
||||
log.Warningf("Skipping v3 initialization: GetVSchema failed: %v", err)
|
||||
goto startServer
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/mysqlctl/filebackupstorage"
|
||||
)
|
|
@ -108,7 +108,6 @@ func main() {
|
|||
exit.Return(1)
|
||||
}
|
||||
|
||||
tabletmanager.HttpHandleSnapshots(mycnf, tabletAlias.Uid)
|
||||
servenv.OnRun(func() {
|
||||
addStatusParts(qsc)
|
||||
})
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/vt/worker"
|
||||
"github.com/youtube/vitess/go/vt/wrangler"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -23,7 +24,7 @@ var (
|
|||
type command struct {
|
||||
Name string
|
||||
method func(wr *wrangler.Wrangler, subFlags *flag.FlagSet, args []string) (worker.Worker, error)
|
||||
interactive func(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request)
|
||||
interactive func(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request)
|
||||
params string
|
||||
Help string // if help is empty, won't list the command
|
||||
}
|
||||
|
@ -112,7 +113,11 @@ func runCommand(args []string) error {
|
|||
case <-done:
|
||||
log.Infof("Command is done:")
|
||||
log.Info(wrk.StatusAsText())
|
||||
if wrk.Error() != nil {
|
||||
currentWorkerMutex.Lock()
|
||||
err := lastRunError
|
||||
currentWorkerMutex.Unlock()
|
||||
if err != nil {
|
||||
log.Errorf("Ended with an error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const indexHTML = `
|
||||
|
@ -83,7 +84,8 @@ func initInteractiveMode() {
|
|||
// closure.
|
||||
pc := c
|
||||
http.HandleFunc("/"+cg.Name+"/"+c.Name, func(w http.ResponseWriter, r *http.Request) {
|
||||
pc.interactive(wr, w, r)
|
||||
ctx := context.Background()
|
||||
pc.interactive(ctx, wr, w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/topotools"
|
||||
"github.com/youtube/vitess/go/vt/worker"
|
||||
"github.com/youtube/vitess/go/vt/wrangler"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const splitCloneHTML = `
|
||||
|
@ -107,8 +108,8 @@ func commandSplitClone(wr *wrangler.Wrangler, subFlags *flag.FlagSet, args []str
|
|||
return worker, nil
|
||||
}
|
||||
|
||||
func keyspacesWithOverlappingShards(wr *wrangler.Wrangler) ([]map[string]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces()
|
||||
func keyspacesWithOverlappingShards(ctx context.Context, wr *wrangler.Wrangler) ([]map[string]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -121,7 +122,7 @@ func keyspacesWithOverlappingShards(wr *wrangler.Wrangler) ([]map[string]string,
|
|||
wg.Add(1)
|
||||
go func(keyspace string) {
|
||||
defer wg.Done()
|
||||
osList, err := topotools.FindOverlappingShards(wr.TopoServer(), keyspace)
|
||||
osList, err := topotools.FindOverlappingShards(ctx, wr.TopoServer(), keyspace)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
|
@ -147,7 +148,7 @@ func keyspacesWithOverlappingShards(wr *wrangler.Wrangler) ([]map[string]string,
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func interactiveSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
func interactiveSplitClone(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httpError(w, "cannot parse form: %s", err)
|
||||
return
|
||||
|
@ -159,7 +160,7 @@ func interactiveSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter, r *http
|
|||
// display the list of possible splits to choose from
|
||||
// (just find all the overlapping guys)
|
||||
result := make(map[string]interface{})
|
||||
choices, err := keyspacesWithOverlappingShards(wr)
|
||||
choices, err := keyspacesWithOverlappingShards(ctx, wr)
|
||||
if err != nil {
|
||||
result["Error"] = err.Error()
|
||||
} else {
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/worker"
|
||||
"github.com/youtube/vitess/go/vt/wrangler"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const splitDiffHTML = `
|
||||
|
@ -76,8 +77,8 @@ func commandSplitDiff(wr *wrangler.Wrangler, subFlags *flag.FlagSet, args []stri
|
|||
|
||||
// shardsWithSources returns all the shards that have SourceShards set
|
||||
// with no Tables list.
|
||||
func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces()
|
||||
func shardsWithSources(ctx context.Context, wr *wrangler.Wrangler) ([]map[string]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -90,7 +91,7 @@ func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
|
|||
wg.Add(1)
|
||||
go func(keyspace string) {
|
||||
defer wg.Done()
|
||||
shards, err := wr.TopoServer().GetShardNames(keyspace)
|
||||
shards, err := wr.TopoServer().GetShardNames(ctx, keyspace)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
|
@ -99,7 +100,7 @@ func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
|
|||
wg.Add(1)
|
||||
go func(keyspace, shard string) {
|
||||
defer wg.Done()
|
||||
si, err := wr.TopoServer().GetShard(keyspace, shard)
|
||||
si, err := wr.TopoServer().GetShard(ctx, keyspace, shard)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
|
@ -128,7 +129,7 @@ func shardsWithSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func interactiveSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
func interactiveSplitDiff(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httpError(w, "cannot parse form: %s", err)
|
||||
return
|
||||
|
@ -139,7 +140,7 @@ func interactiveSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.
|
|||
if keyspace == "" || shard == "" {
|
||||
// display the list of possible shards to chose from
|
||||
result := make(map[string]interface{})
|
||||
shards, err := shardsWithSources(wr)
|
||||
shards, err := shardsWithSources(ctx, wr)
|
||||
if err != nil {
|
||||
result["Error"] = err.Error()
|
||||
} else {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -68,17 +69,20 @@ func initStatusHandling() {
|
|||
currentWorkerMutex.Lock()
|
||||
wrk := currentWorker
|
||||
logger := currentMemoryLogger
|
||||
done := currentDone
|
||||
ctx := currentContext
|
||||
err := lastRunError
|
||||
currentWorkerMutex.Unlock()
|
||||
|
||||
data := make(map[string]interface{})
|
||||
if wrk != nil {
|
||||
data["Status"] = wrk.StatusAsHTML()
|
||||
select {
|
||||
case <-done:
|
||||
status := template.HTML("Current worker:<br>\n") + wrk.StatusAsHTML()
|
||||
if ctx == nil {
|
||||
data["Done"] = true
|
||||
default:
|
||||
if err != nil {
|
||||
status += template.HTML(fmt.Sprintf("<br>\nEnded with an error: %v<br>\n", err))
|
||||
}
|
||||
}
|
||||
data["Status"] = status
|
||||
if logger != nil {
|
||||
data["Logs"] = template.HTML(strings.Replace(logger.String(), "\n", "</br>\n", -1))
|
||||
} else {
|
||||
|
@ -99,29 +103,27 @@ func initStatusHandling() {
|
|||
acl.SendError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
currentWorkerMutex.Lock()
|
||||
wrk := currentWorker
|
||||
done := currentDone
|
||||
currentWorkerMutex.Unlock()
|
||||
|
||||
// no worker, we go to the menu
|
||||
if wrk == nil {
|
||||
if currentWorker == nil {
|
||||
currentWorkerMutex.Unlock()
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
// check the worker is really done
|
||||
select {
|
||||
case <-done:
|
||||
currentWorkerMutex.Lock()
|
||||
if currentContext == nil {
|
||||
currentWorker = nil
|
||||
currentMemoryLogger = nil
|
||||
currentDone = nil
|
||||
currentWorkerMutex.Unlock()
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
default:
|
||||
httpError(w, "worker still executing", nil)
|
||||
return
|
||||
}
|
||||
|
||||
currentWorkerMutex.Unlock()
|
||||
httpError(w, "worker still executing", nil)
|
||||
})
|
||||
|
||||
// cancel handler
|
||||
|
@ -130,18 +132,20 @@ func initStatusHandling() {
|
|||
acl.SendError(w, err)
|
||||
return
|
||||
}
|
||||
currentWorkerMutex.Lock()
|
||||
wrk := currentWorker
|
||||
currentWorkerMutex.Unlock()
|
||||
|
||||
// no worker, we go to the menu
|
||||
if wrk == nil {
|
||||
currentWorkerMutex.Lock()
|
||||
|
||||
// no worker, or not running, we go to the menu
|
||||
if currentWorker == nil || currentCancelFunc == nil {
|
||||
currentWorkerMutex.Unlock()
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
// otherwise, cancel the running worker and go back to the status page
|
||||
wrk.Cancel()
|
||||
cancel := currentCancelFunc
|
||||
currentWorkerMutex.Unlock()
|
||||
cancel()
|
||||
http.Redirect(w, r, servenv.StatusURLPath(), http.StatusTemporaryRedirect)
|
||||
|
||||
})
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/worker"
|
||||
"github.com/youtube/vitess/go/vt/wrangler"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -114,8 +115,8 @@ func commandVerticalSplitClone(wr *wrangler.Wrangler, subFlags *flag.FlagSet, ar
|
|||
|
||||
// keyspacesWithServedFrom returns all the keyspaces that have ServedFrom set
|
||||
// to one value.
|
||||
func keyspacesWithServedFrom(wr *wrangler.Wrangler) ([]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces()
|
||||
func keyspacesWithServedFrom(ctx context.Context, wr *wrangler.Wrangler) ([]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -128,7 +129,7 @@ func keyspacesWithServedFrom(wr *wrangler.Wrangler) ([]string, error) {
|
|||
wg.Add(1)
|
||||
go func(keyspace string) {
|
||||
defer wg.Done()
|
||||
ki, err := wr.TopoServer().GetKeyspace(keyspace)
|
||||
ki, err := wr.TopoServer().GetKeyspace(ctx, keyspace)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
|
@ -151,7 +152,7 @@ func keyspacesWithServedFrom(wr *wrangler.Wrangler) ([]string, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func interactiveVerticalSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
func interactiveVerticalSplitClone(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httpError(w, "cannot parse form: %s", err)
|
||||
return
|
||||
|
@ -161,7 +162,7 @@ func interactiveVerticalSplitClone(wr *wrangler.Wrangler, w http.ResponseWriter,
|
|||
if keyspace == "" {
|
||||
// display the list of possible keyspaces to choose from
|
||||
result := make(map[string]interface{})
|
||||
keyspaces, err := keyspacesWithServedFrom(wr)
|
||||
keyspaces, err := keyspacesWithServedFrom(ctx, wr)
|
||||
if err != nil {
|
||||
result["Error"] = err.Error()
|
||||
} else {
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/worker"
|
||||
"github.com/youtube/vitess/go/vt/wrangler"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const verticalSplitDiffHTML = `
|
||||
|
@ -75,8 +76,8 @@ func commandVerticalSplitDiff(wr *wrangler.Wrangler, subFlags *flag.FlagSet, arg
|
|||
|
||||
// shardsWithTablesSources returns all the shards that have SourceShards set
|
||||
// to one value, with an array of Tables.
|
||||
func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces()
|
||||
func shardsWithTablesSources(ctx context.Context, wr *wrangler.Wrangler) ([]map[string]string, error) {
|
||||
keyspaces, err := wr.TopoServer().GetKeyspaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -89,7 +90,7 @@ func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error)
|
|||
wg.Add(1)
|
||||
go func(keyspace string) {
|
||||
defer wg.Done()
|
||||
shards, err := wr.TopoServer().GetShardNames(keyspace)
|
||||
shards, err := wr.TopoServer().GetShardNames(ctx, keyspace)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
|
@ -98,7 +99,7 @@ func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error)
|
|||
wg.Add(1)
|
||||
go func(keyspace, shard string) {
|
||||
defer wg.Done()
|
||||
si, err := wr.TopoServer().GetShard(keyspace, shard)
|
||||
si, err := wr.TopoServer().GetShard(ctx, keyspace, shard)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
|
@ -127,7 +128,7 @@ func shardsWithTablesSources(wr *wrangler.Wrangler) ([]map[string]string, error)
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func interactiveVerticalSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
func interactiveVerticalSplitDiff(ctx context.Context, wr *wrangler.Wrangler, w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httpError(w, "cannot parse form: %s", err)
|
||||
return
|
||||
|
@ -138,7 +139,7 @@ func interactiveVerticalSplitDiff(wr *wrangler.Wrangler, w http.ResponseWriter,
|
|||
if keyspace == "" || shard == "" {
|
||||
// display the list of possible shards to chose from
|
||||
result := make(map[string]interface{})
|
||||
shards, err := shardsWithTablesSources(wr)
|
||||
shards, err := shardsWithTablesSources(ctx, wr)
|
||||
if err != nil {
|
||||
result["Error"] = err.Error()
|
||||
} else {
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/worker"
|
||||
"github.com/youtube/vitess/go/vt/wrangler"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -44,10 +45,20 @@ var (
|
|||
wr *wrangler.Wrangler
|
||||
|
||||
// mutex is protecting all the following variables
|
||||
// 3 states here:
|
||||
// - no job ever ran (or reset was run): currentWorker is nil,
|
||||
// currentContext/currentCancelFunc is nil, lastRunError is nil
|
||||
// - one worker running: currentWorker is set,
|
||||
// currentContext/currentCancelFunc is set, lastRunError is nil
|
||||
// - (at least) one worker already ran, none is running atm:
|
||||
// currentWorker is set, currentContext is nil, lastRunError
|
||||
// has the error returned by the worker.
|
||||
currentWorkerMutex sync.Mutex
|
||||
currentWorker worker.Worker
|
||||
currentMemoryLogger *logutil.MemoryLogger
|
||||
currentDone chan struct{}
|
||||
currentContext context.Context
|
||||
currentCancelFunc context.CancelFunc
|
||||
lastRunError error
|
||||
)
|
||||
|
||||
// signal handling, centralized here
|
||||
|
@ -59,7 +70,9 @@ func installSignalHandlers() {
|
|||
// we got a signal, notify our modules
|
||||
currentWorkerMutex.Lock()
|
||||
defer currentWorkerMutex.Unlock()
|
||||
currentWorker.Cancel()
|
||||
if currentCancelFunc != nil {
|
||||
currentCancelFunc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
@ -75,17 +88,27 @@ func setAndStartWorker(wrk worker.Worker) (chan struct{}, error) {
|
|||
|
||||
currentWorker = wrk
|
||||
currentMemoryLogger = logutil.NewMemoryLogger()
|
||||
currentDone = make(chan struct{})
|
||||
currentContext, currentCancelFunc = context.WithCancel(context.Background())
|
||||
lastRunError = nil
|
||||
done := make(chan struct{})
|
||||
wr.SetLogger(logutil.NewTeeLogger(currentMemoryLogger, logutil.NewConsoleLogger()))
|
||||
|
||||
// one go function runs the worker, closes 'done' when done
|
||||
// one go function runs the worker, changes state when done
|
||||
go func() {
|
||||
// run will take a long time
|
||||
log.Infof("Starting worker...")
|
||||
wrk.Run()
|
||||
close(currentDone)
|
||||
err := wrk.Run(currentContext)
|
||||
|
||||
// it's done, let's save our state
|
||||
currentWorkerMutex.Lock()
|
||||
currentContext = nil
|
||||
currentCancelFunc = nil
|
||||
lastRunError = err
|
||||
currentWorkerMutex.Unlock()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
return currentDone, nil
|
||||
return done, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -50,12 +50,12 @@ func connect() *rpcplus.Client {
|
|||
return rpcClient
|
||||
}
|
||||
|
||||
func getSrvKeyspaceNames(rpcClient *rpcplus.Client, cell string, verbose bool) {
|
||||
func getSrvKeyspaceNames(ctx context.Context, rpcClient *rpcplus.Client, cell string, verbose bool) {
|
||||
req := &topo.GetSrvKeyspaceNamesArgs{
|
||||
Cell: cell,
|
||||
}
|
||||
reply := &topo.SrvKeyspaceNames{}
|
||||
if err := rpcClient.Call(context.TODO(), "TopoReader.GetSrvKeyspaceNames", req, reply); err != nil {
|
||||
if err := rpcClient.Call(ctx, "TopoReader.GetSrvKeyspaceNames", req, reply); err != nil {
|
||||
log.Fatalf("TopoReader.GetSrvKeyspaceNames error: %v", err)
|
||||
}
|
||||
if verbose {
|
||||
|
@ -65,13 +65,13 @@ func getSrvKeyspaceNames(rpcClient *rpcplus.Client, cell string, verbose bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func getSrvKeyspace(rpcClient *rpcplus.Client, cell, keyspace string, verbose bool) {
|
||||
func getSrvKeyspace(ctx context.Context, rpcClient *rpcplus.Client, cell, keyspace string, verbose bool) {
|
||||
req := &topo.GetSrvKeyspaceArgs{
|
||||
Cell: cell,
|
||||
Keyspace: keyspace,
|
||||
}
|
||||
reply := &topo.SrvKeyspace{}
|
||||
if err := rpcClient.Call(context.TODO(), "TopoReader.GetSrvKeyspace", req, reply); err != nil {
|
||||
if err := rpcClient.Call(ctx, "TopoReader.GetSrvKeyspace", req, reply); err != nil {
|
||||
log.Fatalf("TopoReader.GetSrvKeyspace error: %v", err)
|
||||
}
|
||||
if verbose {
|
||||
|
@ -89,7 +89,7 @@ func getSrvKeyspace(rpcClient *rpcplus.Client, cell, keyspace string, verbose bo
|
|||
}
|
||||
}
|
||||
|
||||
func getEndPoints(rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType string, verbose bool) {
|
||||
func getEndPoints(ctx context.Context, rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType string, verbose bool) {
|
||||
req := &topo.GetEndPointsArgs{
|
||||
Cell: cell,
|
||||
Keyspace: keyspace,
|
||||
|
@ -97,7 +97,7 @@ func getEndPoints(rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType s
|
|||
TabletType: topo.TabletType(tabletType),
|
||||
}
|
||||
reply := &topo.EndPoints{}
|
||||
if err := rpcClient.Call(context.TODO(), "TopoReader.GetEndPoints", req, reply); err != nil {
|
||||
if err := rpcClient.Call(ctx, "TopoReader.GetEndPoints", req, reply); err != nil {
|
||||
log.Fatalf("TopoReader.GetEndPoints error: %v", err)
|
||||
}
|
||||
if verbose {
|
||||
|
@ -109,14 +109,14 @@ func getEndPoints(rpcClient *rpcplus.Client, cell, keyspace, shard, tabletType s
|
|||
|
||||
// qps is a function used by tests to run a vtgate load check.
|
||||
// It will get the same srvKeyspaces as fast as possible and display the QPS.
|
||||
func qps(cell string, keyspaces []string) {
|
||||
func qps(ctx context.Context, cell string, keyspaces []string) {
|
||||
var count sync2.AtomicInt32
|
||||
for _, keyspace := range keyspaces {
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
rpcClient := connect()
|
||||
for true {
|
||||
getSrvKeyspace(rpcClient, cell, keyspace, false)
|
||||
getSrvKeyspace(ctx, rpcClient, cell, keyspace, false)
|
||||
count.Add(1)
|
||||
}
|
||||
}()
|
||||
|
@ -157,10 +157,11 @@ func main() {
|
|||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if *mode == "getSrvKeyspaceNames" {
|
||||
rpcClient := connect()
|
||||
if len(args) == 1 {
|
||||
getSrvKeyspaceNames(rpcClient, args[0], true)
|
||||
getSrvKeyspaceNames(ctx, rpcClient, args[0], true)
|
||||
} else {
|
||||
log.Errorf("getSrvKeyspaceNames only takes one argument")
|
||||
exit.Return(1)
|
||||
|
@ -169,7 +170,7 @@ func main() {
|
|||
} else if *mode == "getSrvKeyspace" {
|
||||
rpcClient := connect()
|
||||
if len(args) == 2 {
|
||||
getSrvKeyspace(rpcClient, args[0], args[1], true)
|
||||
getSrvKeyspace(ctx, rpcClient, args[0], args[1], true)
|
||||
} else {
|
||||
log.Errorf("getSrvKeyspace only takes two arguments")
|
||||
exit.Return(1)
|
||||
|
@ -178,14 +179,14 @@ func main() {
|
|||
} else if *mode == "getEndPoints" {
|
||||
rpcClient := connect()
|
||||
if len(args) == 4 {
|
||||
getEndPoints(rpcClient, args[0], args[1], args[2], args[3], true)
|
||||
getEndPoints(ctx, rpcClient, args[0], args[1], args[2], args[3], true)
|
||||
} else {
|
||||
log.Errorf("getEndPoints only takes four arguments")
|
||||
exit.Return(1)
|
||||
}
|
||||
|
||||
} else if *mode == "qps" {
|
||||
qps(args[0], args[1:])
|
||||
qps(ctx, args[0], args[1:])
|
||||
|
||||
} else {
|
||||
flag.Usage()
|
||||
|
|
|
@ -99,27 +99,28 @@ type Charset struct {
|
|||
|
||||
// Convert takes a type and a value, and returns the type:
|
||||
// - nil for NULL value
|
||||
// - int64 if possible, otherwise, uint64
|
||||
// - uint64 for unsigned BIGINT values
|
||||
// - int64 for all other integer values (signed and unsigned)
|
||||
// - float64 for floating point values that fit in a float
|
||||
// - []byte for everything else
|
||||
// TODO(mberlin): Make this a method of "Field" and consider VT_UNSIGNED_FLAG in "Flags" as well.
|
||||
func Convert(mysqlType int64, val sqltypes.Value) (interface{}, error) {
|
||||
func Convert(field Field, val sqltypes.Value) (interface{}, error) {
|
||||
if val.IsNull() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch mysqlType {
|
||||
case VT_TINY, VT_SHORT, VT_LONG, VT_LONGLONG, VT_INT24:
|
||||
val := val.String()
|
||||
signed, err := strconv.ParseInt(val, 0, 64)
|
||||
if err == nil {
|
||||
return signed, nil
|
||||
switch field.Type {
|
||||
case VT_LONGLONG:
|
||||
if field.Flags&VT_UNSIGNED_FLAG == VT_UNSIGNED_FLAG {
|
||||
return strconv.ParseUint(val.String(), 0, 64)
|
||||
}
|
||||
unsigned, err := strconv.ParseUint(val, 0, 64)
|
||||
if err == nil {
|
||||
return unsigned, nil
|
||||
}
|
||||
return nil, err
|
||||
return strconv.ParseInt(val.String(), 0, 64)
|
||||
case VT_TINY, VT_SHORT, VT_LONG, VT_INT24:
|
||||
// Regardless of whether UNSIGNED_FLAG is set in field.Flags, we map all
|
||||
// signed and unsigned values to a signed Go type because
|
||||
// - Go doesn't officially support uint64 in their SQL interface
|
||||
// - there is no loss of the value
|
||||
// The only exception we make are for unsigned BIGINTs, see VT_LONGLONG above.
|
||||
return strconv.ParseInt(val.String(), 0, 64)
|
||||
case VT_FLOAT, VT_DOUBLE:
|
||||
return strconv.ParseFloat(val.String(), 64)
|
||||
}
|
||||
|
|
|
@ -12,81 +12,79 @@ import (
|
|||
|
||||
func TestConvert(t *testing.T) {
|
||||
cases := []struct {
|
||||
Desc string
|
||||
Typ int64
|
||||
Val sqltypes.Value
|
||||
Want interface{}
|
||||
Field Field
|
||||
Val sqltypes.Value
|
||||
Want interface{}
|
||||
}{{
|
||||
Desc: "null",
|
||||
Typ: VT_LONG,
|
||||
Val: sqltypes.Value{},
|
||||
Want: nil,
|
||||
Field: Field{"null", VT_LONG, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.Value{},
|
||||
Want: nil,
|
||||
}, {
|
||||
Desc: "decimal",
|
||||
Typ: VT_DECIMAL,
|
||||
Val: sqltypes.MakeString([]byte("aa")),
|
||||
Want: "aa",
|
||||
Field: Field{"decimal", VT_DECIMAL, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("aa")),
|
||||
Want: "aa",
|
||||
}, {
|
||||
Desc: "tiny",
|
||||
Typ: VT_TINY,
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Field: Field{"tiny", VT_TINY, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
}, {
|
||||
Field: Field{"short", VT_SHORT, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
}, {
|
||||
Field: Field{"long", VT_LONG, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
}, {
|
||||
Field: Field{"unsigned long", VT_LONG, VT_UNSIGNED_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
// Unsigned types which aren't VT_LONGLONG are mapped to int64.
|
||||
Want: int64(1),
|
||||
}, {
|
||||
Desc: "short",
|
||||
Typ: VT_SHORT,
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
Field: Field{"longlong", VT_LONGLONG, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
}, {
|
||||
Desc: "long",
|
||||
Typ: VT_LONG,
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
Field: Field{"int24", VT_INT24, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
}, {
|
||||
Desc: "longlong",
|
||||
Typ: VT_LONGLONG,
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
Field: Field{"float", VT_FLOAT, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: float64(1),
|
||||
}, {
|
||||
Desc: "int24",
|
||||
Typ: VT_INT24,
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: int64(1),
|
||||
Field: Field{"double", VT_DOUBLE, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: float64(1),
|
||||
}, {
|
||||
Desc: "float",
|
||||
Typ: VT_FLOAT,
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: float64(1),
|
||||
Field: Field{"large int out of range for int64", VT_LONGLONG, VT_ZEROVALUE_FLAG},
|
||||
// 2^63, out of range for int64
|
||||
Val: sqltypes.MakeString([]byte("9223372036854775808")),
|
||||
Want: `strconv.ParseInt: parsing "9223372036854775808": value out of range`,
|
||||
}, {
|
||||
Desc: "double",
|
||||
Typ: VT_DOUBLE,
|
||||
Val: sqltypes.MakeString([]byte("1")),
|
||||
Want: float64(1),
|
||||
}, {
|
||||
Desc: "large int",
|
||||
Typ: VT_LONGLONG,
|
||||
Field: Field{"large int", VT_LONGLONG, VT_UNSIGNED_FLAG},
|
||||
// 2^63, not out of range for uint64
|
||||
Val: sqltypes.MakeString([]byte("9223372036854775808")),
|
||||
Want: uint64(9223372036854775808),
|
||||
}, {
|
||||
Desc: "float for int",
|
||||
Typ: VT_LONGLONG,
|
||||
Val: sqltypes.MakeString([]byte("1.1")),
|
||||
Want: `strconv.ParseUint: parsing "1.1": invalid syntax`,
|
||||
Field: Field{"float for int", VT_LONGLONG, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("1.1")),
|
||||
Want: `strconv.ParseInt: parsing "1.1": invalid syntax`,
|
||||
}, {
|
||||
Desc: "string for float",
|
||||
Typ: VT_FLOAT,
|
||||
Val: sqltypes.MakeString([]byte("aa")),
|
||||
Want: `strconv.ParseFloat: parsing "aa": invalid syntax`,
|
||||
Field: Field{"string for float", VT_FLOAT, VT_ZEROVALUE_FLAG},
|
||||
Val: sqltypes.MakeString([]byte("aa")),
|
||||
Want: `strconv.ParseFloat: parsing "aa": invalid syntax`,
|
||||
}}
|
||||
|
||||
for _, c := range cases {
|
||||
r, err := Convert(c.Typ, c.Val)
|
||||
r, err := Convert(c.Field, c.Val)
|
||||
if err != nil {
|
||||
r = err.Error()
|
||||
} else if _, ok := r.([]byte); ok {
|
||||
r = string(r.([]byte))
|
||||
}
|
||||
if r != c.Want {
|
||||
t.Errorf("%s: %+v, want %+v", c.Desc, r, c.Want)
|
||||
t.Errorf("%s: %+v, want %+v", c.Field.Name, r, c.Want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -77,7 +77,7 @@ func NewResourcePool(factory Factory, capacity, maxCap int, idleTimeout time.Dur
|
|||
// Close empties the pool calling Close on all its resources.
|
||||
// You can call Close while there are outstanding resources.
|
||||
// It waits for all resources to be returned (Put).
|
||||
// After a Close, Get and TryGet are not allowed.
|
||||
// After a Close, Get is not allowed.
|
||||
func (rp *ResourcePool) Close() {
|
||||
_ = rp.SetCapacity(0)
|
||||
}
|
||||
|
@ -95,13 +95,6 @@ func (rp *ResourcePool) Get(ctx context.Context) (resource Resource, err error)
|
|||
return rp.get(ctx, true)
|
||||
}
|
||||
|
||||
// TryGet will return the next available resource. If none is available, and capacity
|
||||
// has not been reached, it will create a new one using the factory. Otherwise,
|
||||
// it will return nil with no error.
|
||||
func (rp *ResourcePool) TryGet() (resource Resource, err error) {
|
||||
return rp.get(context.TODO(), false)
|
||||
}
|
||||
|
||||
func (rp *ResourcePool) get(ctx context.Context, wait bool) (resource Resource, err error) {
|
||||
// If ctx has already expired, avoid racing with rp's resource channel.
|
||||
select {
|
||||
|
|
|
@ -74,38 +74,6 @@ func TestOpen(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test TryGet
|
||||
r, err := p.TryGet()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error %v", err)
|
||||
}
|
||||
if r != nil {
|
||||
t.Errorf("Expecting nil")
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
p.Put(resources[i])
|
||||
_, available, _, _, _, _ := p.Stats()
|
||||
if available != int64(i+1) {
|
||||
t.Errorf("expecting %d, received %d", 5-i-1, available)
|
||||
}
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
r, err := p.TryGet()
|
||||
resources[i] = r
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error %v", err)
|
||||
}
|
||||
if r == nil {
|
||||
t.Errorf("Expecting non-nil")
|
||||
}
|
||||
if lastID.Get() != 5 {
|
||||
t.Errorf("Expecting 5, received %d", lastID.Get())
|
||||
}
|
||||
if count.Get() != 5 {
|
||||
t.Errorf("Expecting 5, received %d", count.Get())
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Get waits
|
||||
ch := make(chan bool)
|
||||
go func() {
|
||||
|
@ -139,7 +107,7 @@ func TestOpen(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test Close resource
|
||||
r, err = p.Get(ctx)
|
||||
r, err := p.Get(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error %v", err)
|
||||
}
|
||||
|
@ -241,15 +209,6 @@ func TestShrinking(t *testing.T) {
|
|||
t.Errorf(`expecting '%s', received '%s'`, expected, stats)
|
||||
}
|
||||
|
||||
// TryGet is allowed when shrinking
|
||||
r, err := p.TryGet()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error %v", err)
|
||||
}
|
||||
if r != nil {
|
||||
t.Errorf("Expecting nil")
|
||||
}
|
||||
|
||||
// Get is allowed when shrinking, but it will wait
|
||||
getdone := make(chan bool)
|
||||
go func() {
|
||||
|
@ -278,6 +237,7 @@ func TestShrinking(t *testing.T) {
|
|||
|
||||
// Ensure no deadlock if SetCapacity is called after we start
|
||||
// waiting for a resource
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
resources[i], err = p.Get(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
)
|
||||
|
||||
// ClusterOperationInstance is a runtime type which enhances the protobuf message "ClusterOperation" with runtime specific data.
|
||||
// Unlike the protobuf message, the additional runtime data will not be part of a checkpoint.
|
||||
type ClusterOperationInstance struct {
|
||||
pb.ClusterOperation
|
||||
taskIDGenerator *IDGenerator
|
||||
}
|
||||
|
||||
// NewClusterOperationInstance creates a new cluster operation instance with one initial task.
|
||||
func NewClusterOperationInstance(clusterOpID string, initialTask *pb.TaskContainer, taskIDGenerator *IDGenerator) *ClusterOperationInstance {
|
||||
c := &ClusterOperationInstance{
|
||||
pb.ClusterOperation{
|
||||
Id: clusterOpID,
|
||||
SerialTasks: []*pb.TaskContainer{},
|
||||
State: pb.ClusterOperationState_CLUSTER_OPERATION_NOT_STARTED,
|
||||
},
|
||||
taskIDGenerator,
|
||||
}
|
||||
c.InsertTaskContainers([]*pb.TaskContainer{initialTask}, 0)
|
||||
return c
|
||||
}
|
||||
|
||||
// addMissingTaskID assigns a task id to each task in "tc".
|
||||
func (c *ClusterOperationInstance) addMissingTaskID(tc []*pb.TaskContainer) {
|
||||
for _, taskContainer := range tc {
|
||||
for _, task := range taskContainer.ParallelTasks {
|
||||
if task.Id == "" {
|
||||
task.Id = c.taskIDGenerator.GetNextID()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InsertTaskContainers inserts "newTaskContainers" at pos in the current list of task containers. Existing task containers will be moved after the new task containers.
|
||||
func (c *ClusterOperationInstance) InsertTaskContainers(newTaskContainers []*pb.TaskContainer, pos int) {
|
||||
c.addMissingTaskID(newTaskContainers)
|
||||
|
||||
newSerialTasks := make([]*pb.TaskContainer, len(c.SerialTasks)+len(newTaskContainers))
|
||||
copy(newSerialTasks, c.SerialTasks[:pos])
|
||||
copy(newSerialTasks[pos:], newTaskContainers)
|
||||
copy(newSerialTasks[pos+len(newTaskContainers):], c.SerialTasks[pos:])
|
||||
c.SerialTasks = newSerialTasks
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
)
|
||||
|
||||
// HorizontalReshardingTask is a cluster operation which allows to increase the number of shards.
|
||||
type HorizontalReshardingTask struct {
|
||||
}
|
||||
|
||||
// TODO(mberlin): Uncomment/remove when "ForceReparent" and "CopySchemaShard" will be implemented.
|
||||
//func selectAnyTabletFromShardByType(shard string, tabletType string) string {
|
||||
// return ""
|
||||
//}
|
||||
|
||||
func (t *HorizontalReshardingTask) run(parameters map[string]string) ([]*pb.TaskContainer, string, error) {
|
||||
// Example: test_keyspace
|
||||
keyspace := parameters["keyspace"]
|
||||
// Example: 10-20
|
||||
sourceShards := strings.Split(parameters["source_shard_list"], ",")
|
||||
// Example: 10-18,18-20
|
||||
destShards := strings.Split(parameters["dest_shard_list"], ",")
|
||||
// Example: cell1-0000062352
|
||||
sourceRdonlyTablets := strings.Split(parameters["source_shard_rdonly_list"], ",")
|
||||
|
||||
var newTasks []*pb.TaskContainer
|
||||
// TODO(mberlin): Implement "ForceParent" task and uncomment this.
|
||||
// reparentTasks := NewTaskContainer()
|
||||
// for _, destShard := range destShards {
|
||||
// newMaster := selectAnyTabletFromShardByType(destShard, "master")
|
||||
// AddTask(reparentTasks, "ForceReparent", map[string]string{
|
||||
// "shard": destShard,
|
||||
// "master": newMaster,
|
||||
// })
|
||||
// }
|
||||
// newTasks = append(newTasks, reparentTasks)
|
||||
|
||||
// TODO(mberlin): Implement "CopySchemaShard" task and uncomment this.
|
||||
// copySchemaTasks := NewTaskContainer()
|
||||
// sourceRdonlyTablet := selectAnyTabletFromShardByType(sourceShards[0], "rdonly")
|
||||
// for _, destShard := range destShards {
|
||||
// AddTask(copySchemaTasks, "CopySchemaShard", map[string]string{
|
||||
// "shard": destShard,
|
||||
// "source_rdonly_tablet": sourceRdonlyTablet,
|
||||
// })
|
||||
// }
|
||||
// newTasks = append(newTasks, copySchemaTasks)
|
||||
|
||||
splitCloneTasks := NewTaskContainer()
|
||||
for _, sourceShard := range sourceShards {
|
||||
// TODO(mberlin): Add a semaphore as argument to limit the parallism.
|
||||
AddTask(splitCloneTasks, "vtworker", map[string]string{
|
||||
"command": "SplitClone",
|
||||
"keyspace": keyspace,
|
||||
"shard": sourceShard,
|
||||
"vtworker_endpoint": parameters["vtworker_endpoint"],
|
||||
})
|
||||
}
|
||||
newTasks = append(newTasks, splitCloneTasks)
|
||||
|
||||
// TODO(mberlin): Remove this once SplitClone does this on its own.
|
||||
restoreTypeTasks := NewTaskContainer()
|
||||
for _, sourceRdonlyTablet := range sourceRdonlyTablets {
|
||||
AddTask(restoreTypeTasks, "vtctl", map[string]string{
|
||||
"command": fmt.Sprintf("ChangeSlaveType %v rdonly", sourceRdonlyTablet),
|
||||
})
|
||||
}
|
||||
newTasks = append(newTasks, restoreTypeTasks)
|
||||
|
||||
splitDiffTasks := NewTaskContainer()
|
||||
for _, destShard := range destShards {
|
||||
AddTask(splitDiffTasks, "vtworker", map[string]string{
|
||||
"command": "SplitDiff",
|
||||
"keyspace": keyspace,
|
||||
"shard": destShard,
|
||||
"vtworker_endpoint": parameters["vtworker_endpoint"],
|
||||
})
|
||||
}
|
||||
newTasks = append(newTasks, splitDiffTasks)
|
||||
|
||||
// TODO(mberlin): Implement "CopySchemaShard" task and uncomment this.
|
||||
// for _, servedType := range []string{"rdonly", "replica", "master"} {
|
||||
// migrateServedTypesTasks := NewTaskContainer()
|
||||
// for _, sourceShard := range sourceShards {
|
||||
// AddTask(migrateServedTypesTasks, "MigrateServedTypes", map[string]string{
|
||||
// "keyspace": keyspace,
|
||||
// "shard": sourceShard,
|
||||
// "served_type": servedType,
|
||||
// })
|
||||
// }
|
||||
// newTasks = append(newTasks, migrateServedTypesTasks)
|
||||
// }
|
||||
|
||||
return newTasks, "", nil
|
||||
}
|
||||
|
||||
func (t *HorizontalReshardingTask) requiredParameters() []string {
|
||||
return []string{"keyspace", "source_shard_list", "source_shard_rdonly_list", "dest_shard_list"}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestHorizontalReshardingTaskEmittedTasks(t *testing.T) {
|
||||
reshardingTask := &HorizontalReshardingTask{}
|
||||
|
||||
parameters := map[string]string{
|
||||
"source_shard_rdonly_list": "cell1-0000062352",
|
||||
"keyspace": "test_keyspace",
|
||||
"source_shard_list": "10-20",
|
||||
"dest_shard_list": "10-18,18-20",
|
||||
"vtworker_endpoint": "localhost:12345",
|
||||
}
|
||||
|
||||
err := checkRequiredParameters(reshardingTask, parameters)
|
||||
if err != nil {
|
||||
t.Fatalf("Not all required parameters were specified: %v", err)
|
||||
}
|
||||
|
||||
newTaskContainers, _, _ := reshardingTask.run(parameters)
|
||||
|
||||
// TODO(mberlin): Check emitted tasks against expected output.
|
||||
for _, tc := range newTaskContainers {
|
||||
t.Logf("new tasks: %v", proto.MarshalTextString(tc))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import "strconv"
|
||||
import "sync/atomic"
|
||||
|
||||
// IDGenerator generates unique task and cluster operation IDs.
|
||||
type IDGenerator struct {
|
||||
counter int64
|
||||
}
|
||||
|
||||
// GetNextID returns an ID which wasn't returned before.
|
||||
func (ig *IDGenerator) GetNextID() string {
|
||||
return strconv.FormatInt(atomic.AddInt64(&ig.counter, 1), 10)
|
||||
}
|
|
@ -0,0 +1,317 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package automation contains code to execute high-level cluster operations
|
||||
(e.g. resharding) as a series of low-level operations
|
||||
(e.g. vtctl, shell commands, ...).
|
||||
*/
|
||||
package automation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type schedulerState int32
|
||||
|
||||
const (
|
||||
stateNotRunning schedulerState = iota
|
||||
stateRunning
|
||||
stateShuttingDown
|
||||
stateShutdown
|
||||
)
|
||||
|
||||
type taskCreator func(string) Task
|
||||
|
||||
// Scheduler executes automation tasks and maintains the execution state.
|
||||
type Scheduler struct {
|
||||
idGenerator IDGenerator
|
||||
|
||||
mu sync.Mutex
|
||||
// Guarded by "mu".
|
||||
registeredClusterOperations map[string]bool
|
||||
// Guarded by "mu".
|
||||
toBeScheduledClusterOperations chan *ClusterOperationInstance
|
||||
// Guarded by "mu".
|
||||
state schedulerState
|
||||
|
||||
// Guarded by "taskCreatorMu". May be overriden by testing code.
|
||||
taskCreator taskCreator
|
||||
taskCreatorMu sync.Mutex
|
||||
|
||||
pendingOpsWg *sync.WaitGroup
|
||||
|
||||
muOpList sync.Mutex
|
||||
// Guarded by "muOpList".
|
||||
activeClusterOperations map[string]*ClusterOperationInstance
|
||||
// Guarded by "muOpList".
|
||||
finishedClusterOperations map[string]*ClusterOperationInstance
|
||||
}
|
||||
|
||||
// NewScheduler creates a new instance.
|
||||
func NewScheduler() (*Scheduler, error) {
|
||||
defaultClusterOperations := map[string]bool{
|
||||
"HorizontalReshardingTask": true,
|
||||
}
|
||||
|
||||
s := &Scheduler{
|
||||
registeredClusterOperations: defaultClusterOperations,
|
||||
idGenerator: IDGenerator{},
|
||||
toBeScheduledClusterOperations: make(chan *ClusterOperationInstance, 10),
|
||||
state: stateNotRunning,
|
||||
taskCreator: defaultTaskCreator,
|
||||
pendingOpsWg: &sync.WaitGroup{},
|
||||
activeClusterOperations: make(map[string]*ClusterOperationInstance),
|
||||
finishedClusterOperations: make(map[string]*ClusterOperationInstance),
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) registerClusterOperation(clusterOperationName string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.registeredClusterOperations[clusterOperationName] = true
|
||||
}
|
||||
|
||||
// Run processes queued cluster operations.
|
||||
func (s *Scheduler) Run() {
|
||||
s.mu.Lock()
|
||||
s.state = stateRunning
|
||||
s.mu.Unlock()
|
||||
|
||||
s.startProcessRequestsLoop()
|
||||
}
|
||||
|
||||
func (s *Scheduler) startProcessRequestsLoop() {
|
||||
// Use a WaitGroup instead of just a done channel, because we want
|
||||
// to be able to shut down the scheduler even if Run() was never executed.
|
||||
s.pendingOpsWg.Add(1)
|
||||
go s.processRequestsLoop()
|
||||
}
|
||||
|
||||
func (s *Scheduler) processRequestsLoop() {
|
||||
defer s.pendingOpsWg.Done()
|
||||
|
||||
for op := range s.toBeScheduledClusterOperations {
|
||||
s.processClusterOperation(op)
|
||||
}
|
||||
log.Infof("Stopped processing loop for ClusterOperations.")
|
||||
}
|
||||
|
||||
func (s *Scheduler) processClusterOperation(clusterOp *ClusterOperationInstance) {
|
||||
if clusterOp.State == pb.ClusterOperationState_CLUSTER_OPERATION_DONE {
|
||||
log.Infof("ClusterOperation: %v skipping because it is already done. Details: %v", clusterOp.Id, clusterOp)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("ClusterOperation: %v running. Details: %v", clusterOp.Id, clusterOp)
|
||||
|
||||
var lastTaskError string
|
||||
for i := 0; i < len(clusterOp.SerialTasks); i++ {
|
||||
taskContainer := clusterOp.SerialTasks[i]
|
||||
for _, taskProto := range taskContainer.ParallelTasks {
|
||||
if taskProto.State == pb.TaskState_DONE {
|
||||
if taskProto.Error != "" {
|
||||
log.Errorf("Task: %v (%v/%v) failed before. Aborting the ClusterOperation. Error: %v Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, taskProto.Error, taskProto)
|
||||
lastTaskError = taskProto.Error
|
||||
break
|
||||
} else {
|
||||
log.Infof("Task: %v (%v/%v) skipped because it is already done. Full Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, taskProto)
|
||||
}
|
||||
}
|
||||
|
||||
task, err := s.createTaskInstance(taskProto.Name)
|
||||
if err != nil {
|
||||
log.Errorf("Task: %v (%v/%v) could not be instantiated. Error: %v Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, err, taskProto)
|
||||
MarkTaskFailed(taskProto, "", err)
|
||||
lastTaskError = err.Error()
|
||||
break
|
||||
}
|
||||
|
||||
taskProto.State = pb.TaskState_RUNNING
|
||||
log.Infof("Task: %v (%v/%v) running. Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, taskProto)
|
||||
newTaskContainers, output, errRun := task.run(taskProto.Parameters)
|
||||
log.Infof("Task: %v (%v/%v) finished. newTaskContainers: %v, output: %v, error: %v", taskProto.Name, clusterOp.Id, taskProto.Id, newTaskContainers, output, errRun)
|
||||
|
||||
if errRun != nil {
|
||||
MarkTaskFailed(taskProto, output, errRun)
|
||||
lastTaskError = errRun.Error()
|
||||
break
|
||||
}
|
||||
MarkTaskSucceeded(taskProto, output)
|
||||
|
||||
if newTaskContainers != nil {
|
||||
// Make sure all new tasks do not miss any required parameters.
|
||||
for _, newTaskContainer := range newTaskContainers {
|
||||
for _, newTaskProto := range newTaskContainer.ParallelTasks {
|
||||
err := s.validateTaskSpecification(newTaskProto.Name, newTaskProto.Parameters)
|
||||
if err != nil {
|
||||
log.Errorf("Task: %v (%v/%v) emitted a new task which is not valid. Error: %v Details: %v", taskProto.Name, clusterOp.Id, taskProto.Id, err, newTaskProto)
|
||||
MarkTaskFailed(taskProto, output, err)
|
||||
lastTaskError = err.Error()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastTaskError == "" {
|
||||
clusterOp.InsertTaskContainers(newTaskContainers, i+1)
|
||||
log.Infof("ClusterOperation: %v %d new task containers added by %v (%v/%v). Updated ClusterOperation: %v",
|
||||
clusterOp.Id, len(newTaskContainers), taskProto.Name, clusterOp.Id, taskProto.Id, clusterOp)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clusterOp.State = pb.ClusterOperationState_CLUSTER_OPERATION_DONE
|
||||
if lastTaskError != "" {
|
||||
clusterOp.Error = lastTaskError
|
||||
}
|
||||
log.Infof("ClusterOperation: %v finished. Details: %v", clusterOp.Id, clusterOp)
|
||||
|
||||
// Move operation from active to finished.
|
||||
s.muOpList.Lock()
|
||||
if s.activeClusterOperations[clusterOp.Id] != clusterOp {
|
||||
panic("Pending ClusterOperation was not recorded as active, but should have.")
|
||||
}
|
||||
delete(s.activeClusterOperations, clusterOp.Id)
|
||||
s.finishedClusterOperations[clusterOp.Id] = clusterOp
|
||||
s.muOpList.Unlock()
|
||||
}
|
||||
|
||||
func defaultTaskCreator(taskName string) Task {
|
||||
switch taskName {
|
||||
case "HorizontalReshardingTask":
|
||||
return &HorizontalReshardingTask{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) setTaskCreator(creator taskCreator) {
|
||||
s.taskCreatorMu.Lock()
|
||||
defer s.taskCreatorMu.Unlock()
|
||||
|
||||
s.taskCreator = creator
|
||||
}
|
||||
|
||||
func (s *Scheduler) validateTaskSpecification(taskName string, parameters map[string]string) error {
|
||||
taskInstanceForParametersCheck, err := s.createTaskInstance(taskName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
errParameters := checkRequiredParameters(taskInstanceForParametersCheck, parameters)
|
||||
if errParameters != nil {
|
||||
return errParameters
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) createTaskInstance(taskName string) (Task, error) {
|
||||
s.taskCreatorMu.Lock()
|
||||
taskCreator := s.taskCreator
|
||||
s.taskCreatorMu.Unlock()
|
||||
|
||||
task := taskCreator(taskName)
|
||||
if task == nil {
|
||||
return nil, fmt.Errorf("No implementation found for: %v", taskName)
|
||||
}
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// checkRequiredParameters returns an error if not all required parameters are provided in "parameters".
|
||||
func checkRequiredParameters(task Task, parameters map[string]string) error {
|
||||
for _, requiredParameter := range task.requiredParameters() {
|
||||
if _, ok := parameters[requiredParameter]; !ok {
|
||||
return fmt.Errorf("Parameter %v is required, but not provided", requiredParameter)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueClusterOperation can be used to start a new cluster operation.
|
||||
func (s *Scheduler) EnqueueClusterOperation(ctx context.Context, req *pb.EnqueueClusterOperationRequest) (*pb.EnqueueClusterOperationResponse, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.state != stateRunning {
|
||||
return nil, fmt.Errorf("Scheduler is not running. State: %v", s.state)
|
||||
}
|
||||
|
||||
if s.registeredClusterOperations[req.Name] != true {
|
||||
return nil, fmt.Errorf("No ClusterOperation with name: %v is registered", req.Name)
|
||||
}
|
||||
|
||||
err := s.validateTaskSpecification(req.Name, req.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clusterOpID := s.idGenerator.GetNextID()
|
||||
taskIDGenerator := IDGenerator{}
|
||||
initialTask := NewTaskContainerWithSingleTask(req.Name, req.Parameters)
|
||||
clusterOp := NewClusterOperationInstance(clusterOpID, initialTask, &taskIDGenerator)
|
||||
|
||||
s.muOpList.Lock()
|
||||
s.toBeScheduledClusterOperations <- clusterOp
|
||||
s.activeClusterOperations[clusterOpID] = clusterOp
|
||||
s.muOpList.Unlock()
|
||||
|
||||
return &pb.EnqueueClusterOperationResponse{
|
||||
Id: clusterOp.Id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// findClusterOp checks for a given ClusterOperation ID if it's in the list of active or finished operations.
|
||||
func (s *Scheduler) findClusterOp(id string) (*ClusterOperationInstance, error) {
|
||||
var ok bool
|
||||
var clusterOp *ClusterOperationInstance
|
||||
|
||||
s.muOpList.Lock()
|
||||
defer s.muOpList.Unlock()
|
||||
clusterOp, ok = s.activeClusterOperations[id]
|
||||
if !ok {
|
||||
clusterOp, ok = s.finishedClusterOperations[id]
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("ClusterOperation with id: %v not found", id)
|
||||
}
|
||||
return clusterOp, nil
|
||||
}
|
||||
|
||||
// GetClusterOperationDetails can be used to query the full details of active or finished operations.
|
||||
func (s *Scheduler) GetClusterOperationDetails(ctx context.Context, req *pb.GetClusterOperationDetailsRequest) (*pb.GetClusterOperationDetailsResponse, error) {
|
||||
clusterOp, err := s.findClusterOp(req.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pb.GetClusterOperationDetailsResponse{
|
||||
ClusterOp: &clusterOp.ClusterOperation,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ShutdownAndWait shuts down the scheduler and waits infinitely until all pending cluster operations have finished.
|
||||
func (s *Scheduler) ShutdownAndWait() {
|
||||
s.mu.Lock()
|
||||
if s.state != stateShuttingDown {
|
||||
s.state = stateShuttingDown
|
||||
close(s.toBeScheduledClusterOperations)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("Scheduler was shut down. Waiting for pending ClusterOperations to finish.")
|
||||
s.pendingOpsWg.Wait()
|
||||
|
||||
s.mu.Lock()
|
||||
s.state = stateShutdown
|
||||
s.mu.Unlock()
|
||||
log.Infof("All pending ClusterOperations finished.")
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
context "golang.org/x/net/context"
|
||||
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
)
|
||||
|
||||
// newTestScheduler constructs a scheduler with test tasks.
|
||||
// If tasks should be available as cluster operation, they still have to be registered manually with scheduler.registerClusterOperation.
|
||||
func newTestScheduler(t *testing.T) *Scheduler {
|
||||
scheduler, err := NewScheduler()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create scheduler: %v", err)
|
||||
}
|
||||
scheduler.setTaskCreator(testingTaskCreator)
|
||||
return scheduler
|
||||
}
|
||||
|
||||
// waitForClusterOperation is a helper function which blocks until the Cluster Operation has finished.
|
||||
func waitForClusterOperation(t *testing.T, scheduler *Scheduler, id string, expectedOutputLastTask string, expectedErrorLastTask string) *pb.ClusterOperation {
|
||||
if expectedOutputLastTask == "" && expectedErrorLastTask == "" {
|
||||
t.Fatal("Error in test: Cannot wait for an operation where both output and error are expected to be empty.")
|
||||
}
|
||||
|
||||
getDetailsRequest := &pb.GetClusterOperationDetailsRequest{
|
||||
Id: id,
|
||||
}
|
||||
|
||||
for {
|
||||
getDetailsResponse, err := scheduler.GetClusterOperationDetails(context.TODO(), getDetailsRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get details for cluster operation. Request: %v Error: %v", getDetailsRequest, err)
|
||||
}
|
||||
if getDetailsResponse.ClusterOp.State == pb.ClusterOperationState_CLUSTER_OPERATION_DONE {
|
||||
tc := getDetailsResponse.ClusterOp.SerialTasks
|
||||
lastTc := tc[len(tc)-1]
|
||||
if expectedOutputLastTask != "" {
|
||||
if lastTc.ParallelTasks[len(lastTc.ParallelTasks)-1].Output != expectedOutputLastTask {
|
||||
t.Fatalf("ClusterOperation finished but did not return expected output. want: %v Full ClusterOperation details: %v", expectedOutputLastTask, proto.MarshalTextString(getDetailsResponse.ClusterOp))
|
||||
}
|
||||
}
|
||||
if expectedErrorLastTask != "" {
|
||||
if lastTc.ParallelTasks[len(lastTc.ParallelTasks)-1].Error != expectedErrorLastTask {
|
||||
t.Fatalf("ClusterOperation finished but did not return expected error. Full ClusterOperation details: %v", getDetailsResponse.ClusterOp)
|
||||
}
|
||||
}
|
||||
return getDetailsResponse.ClusterOp
|
||||
}
|
||||
|
||||
t.Logf("Waiting for clusterOp: %v", getDetailsResponse.ClusterOp)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerImmediateShutdown(t *testing.T) {
|
||||
// Make sure that the scheduler shuts down cleanly when it was instantiated, but not started with Run().
|
||||
scheduler, err := NewScheduler()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create scheduler: %v", err)
|
||||
}
|
||||
scheduler.ShutdownAndWait()
|
||||
}
|
||||
|
||||
func enqueueClusterOperationAndCheckOutput(t *testing.T, taskName string, expectedOutput string) {
|
||||
scheduler := newTestScheduler(t)
|
||||
defer scheduler.ShutdownAndWait()
|
||||
scheduler.registerClusterOperation("TestingEchoTask")
|
||||
scheduler.registerClusterOperation("TestingEmitEchoTask")
|
||||
|
||||
scheduler.Run()
|
||||
|
||||
enqueueRequest := &pb.EnqueueClusterOperationRequest{
|
||||
Name: taskName,
|
||||
Parameters: map[string]string{
|
||||
"echo_text": expectedOutput,
|
||||
},
|
||||
}
|
||||
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start cluster operation. Request: %v Error: %v", enqueueRequest, err)
|
||||
}
|
||||
|
||||
waitForClusterOperation(t, scheduler, enqueueResponse.Id, expectedOutput, "")
|
||||
}
|
||||
|
||||
func TestEnqueueSingleTask(t *testing.T) {
|
||||
enqueueClusterOperationAndCheckOutput(t, "TestingEchoTask", "echoed text")
|
||||
}
|
||||
|
||||
func TestEnqueueEmittingTask(t *testing.T) {
|
||||
enqueueClusterOperationAndCheckOutput(t, "TestingEmitEchoTask", "echoed text from emitted task")
|
||||
}
|
||||
|
||||
func TestEnqueueFailsDueToMissingParameter(t *testing.T) {
|
||||
scheduler := newTestScheduler(t)
|
||||
defer scheduler.ShutdownAndWait()
|
||||
scheduler.registerClusterOperation("TestingEchoTask")
|
||||
|
||||
scheduler.Run()
|
||||
|
||||
enqueueRequest := &pb.EnqueueClusterOperationRequest{
|
||||
Name: "TestingEchoTask",
|
||||
Parameters: map[string]string{
|
||||
"unrelevant-parameter": "value",
|
||||
},
|
||||
}
|
||||
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Scheduler should have failed to start cluster operation because not all required parameters were provided. Request: %v Error: %v Response: %v", enqueueRequest, err, enqueueResponse)
|
||||
}
|
||||
want := "Parameter echo_text is required, but not provided"
|
||||
if err.Error() != want {
|
||||
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailedTaskFailsClusterOperation(t *testing.T) {
|
||||
scheduler := newTestScheduler(t)
|
||||
defer scheduler.ShutdownAndWait()
|
||||
scheduler.registerClusterOperation("TestingFailTask")
|
||||
|
||||
scheduler.Run()
|
||||
|
||||
enqueueRequest := &pb.EnqueueClusterOperationRequest{
|
||||
Name: "TestingFailTask",
|
||||
}
|
||||
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start cluster operation. Request: %v Error: %v", enqueueRequest, err)
|
||||
}
|
||||
|
||||
waitForClusterOperation(t, scheduler, enqueueResponse.Id, "something went wrong", "full error message")
|
||||
}
|
||||
|
||||
func TestEnqueueFailsDueToUnregisteredClusterOperation(t *testing.T) {
|
||||
scheduler := newTestScheduler(t)
|
||||
defer scheduler.ShutdownAndWait()
|
||||
|
||||
scheduler.Run()
|
||||
|
||||
enqueueRequest := &pb.EnqueueClusterOperationRequest{
|
||||
Name: "TestingEchoTask",
|
||||
Parameters: map[string]string{
|
||||
"unrelevant-parameter": "value",
|
||||
},
|
||||
}
|
||||
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Scheduler should have failed to start cluster operation because it should not have been registered. Request: %v Error: %v Response: %v", enqueueRequest, err, enqueueResponse)
|
||||
}
|
||||
want := "No ClusterOperation with name: TestingEchoTask is registered"
|
||||
if err.Error() != want {
|
||||
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDetailsFailsUnknownId(t *testing.T) {
|
||||
scheduler := newTestScheduler(t)
|
||||
defer scheduler.ShutdownAndWait()
|
||||
|
||||
scheduler.Run()
|
||||
|
||||
getDetailsRequest := &pb.GetClusterOperationDetailsRequest{
|
||||
Id: "-1", // There will never be a ClusterOperation with this id.
|
||||
}
|
||||
|
||||
getDetailsResponse, err := scheduler.GetClusterOperationDetails(context.TODO(), getDetailsRequest)
|
||||
if err == nil {
|
||||
t.Fatalf("Did not fail to get details for invalid ClusterOperation id. Request: %v Response: %v Error: %v", getDetailsRequest, getDetailsResponse, err)
|
||||
}
|
||||
want := "ClusterOperation with id: -1 not found"
|
||||
if err.Error() != want {
|
||||
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueFailsBecauseTaskInstanceCannotBeCreated(t *testing.T) {
|
||||
scheduler := newTestScheduler(t)
|
||||
defer scheduler.ShutdownAndWait()
|
||||
scheduler.setTaskCreator(defaultTaskCreator)
|
||||
// TestingEchoTask is registered as cluster operation, but its task cannot be instantied because "testingTaskCreator" was not set.
|
||||
scheduler.registerClusterOperation("TestingEchoTask")
|
||||
|
||||
scheduler.Run()
|
||||
|
||||
enqueueRequest := &pb.EnqueueClusterOperationRequest{
|
||||
Name: "TestingEchoTask",
|
||||
Parameters: map[string]string{
|
||||
"unrelevant-parameter": "value",
|
||||
},
|
||||
}
|
||||
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("Scheduler should have failed to start cluster operation because the task could not be instantiated. Request: %v Error: %v Response: %v", enqueueRequest, err, enqueueResponse)
|
||||
}
|
||||
want := "No implementation found for: TestingEchoTask"
|
||||
if err.Error() != want {
|
||||
t.Fatalf("Wrong error message. got: '%v' want: '%v'", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskEmitsTaskWhichCannotBeInstantiated(t *testing.T) {
|
||||
scheduler := newTestScheduler(t)
|
||||
defer scheduler.ShutdownAndWait()
|
||||
scheduler.setTaskCreator(func(taskName string) Task {
|
||||
// TaskCreator which doesn't know TestingEchoTask (but emitted by TestingEmitEchoTask).
|
||||
switch taskName {
|
||||
case "TestingEmitEchoTask":
|
||||
return &TestingEmitEchoTask{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
})
|
||||
scheduler.registerClusterOperation("TestingEmitEchoTask")
|
||||
|
||||
scheduler.Run()
|
||||
|
||||
enqueueRequest := &pb.EnqueueClusterOperationRequest{
|
||||
Name: "TestingEmitEchoTask",
|
||||
}
|
||||
enqueueResponse, err := scheduler.EnqueueClusterOperation(context.TODO(), enqueueRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start cluster operation. Request: %v Error: %v", enqueueRequest, err)
|
||||
}
|
||||
|
||||
details := waitForClusterOperation(t, scheduler, enqueueResponse.Id, "emitted TestingEchoTask", "No implementation found for: TestingEchoTask")
|
||||
if len(details.SerialTasks) != 1 {
|
||||
t.Errorf("A task has been emitted, but it shouldn't. Details:\n%v", proto.MarshalTextString(details))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
)
|
||||
|
||||
// Task implementations can be executed by the scheduler.
|
||||
type Task interface {
|
||||
// run executes the task using the key/values from parameters.
|
||||
// "newTaskContainers" contains new tasks which the task can emit. They'll be inserted in the cluster operation directly after this task. It may be "nil".
|
||||
// "output" may be empty. It contains any text which maybe must e.g. to debug the task or show it in the UI.
|
||||
run(parameters map[string]string) (newTaskContainers []*pb.TaskContainer, output string, err error)
|
||||
|
||||
// requiredParameters() returns a list of parameter keys which must be provided as input for run().
|
||||
requiredParameters() []string
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
)
|
||||
|
||||
// Helper functions for "TaskContainer" protobuf message.
|
||||
|
||||
// NewTaskContainerWithSingleTask creates a new task container with exactly one task.
|
||||
func NewTaskContainerWithSingleTask(taskName string, parameters map[string]string) *pb.TaskContainer {
|
||||
return &pb.TaskContainer{
|
||||
ParallelTasks: []*pb.Task{
|
||||
NewTask(taskName, parameters),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewTaskContainer creates an empty task container. Use AddTask() to add tasks to it.
|
||||
func NewTaskContainer() *pb.TaskContainer {
|
||||
return &pb.TaskContainer{
|
||||
ParallelTasks: []*pb.Task{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddTask adds a new task to an existing task container.
|
||||
func AddTask(t *pb.TaskContainer, taskName string, parameters map[string]string) {
|
||||
t.ParallelTasks = append(t.ParallelTasks, NewTask(taskName, parameters))
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
)
|
||||
|
||||
// Helper functions for "Task" protobuf message.
|
||||
|
||||
// MarkTaskSucceeded marks the task as done.
|
||||
func MarkTaskSucceeded(t *pb.Task, output string) {
|
||||
t.State = pb.TaskState_DONE
|
||||
t.Output = output
|
||||
}
|
||||
|
||||
// MarkTaskFailed marks the task as failed.
|
||||
func MarkTaskFailed(t *pb.Task, output string, err error) {
|
||||
t.State = pb.TaskState_DONE
|
||||
t.Output = output
|
||||
t.Error = err.Error()
|
||||
}
|
||||
|
||||
// NewTask creates a new task protobuf message for "taskName" with "parameters".
|
||||
func NewTask(taskName string, parameters map[string]string) *pb.Task {
|
||||
return &pb.Task{
|
||||
State: pb.TaskState_NOT_STARTED,
|
||||
Name: taskName,
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package automation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
pb "github.com/youtube/vitess/go/vt/proto/automation"
|
||||
)
|
||||
|
||||
func testingTaskCreator(taskName string) Task {
|
||||
switch taskName {
|
||||
// Tasks for testing only.
|
||||
case "TestingEchoTask":
|
||||
return &TestingEchoTask{}
|
||||
case "TestingEmitEchoTask":
|
||||
return &TestingEmitEchoTask{}
|
||||
case "TestingFailTask":
|
||||
return &TestingFailTask{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// TestingEchoTask is used only for testing. It returns the join of all parameter values.
|
||||
type TestingEchoTask struct {
|
||||
}
|
||||
|
||||
func (t *TestingEchoTask) run(parameters map[string]string) (newTasks []*pb.TaskContainer, output string, err error) {
|
||||
for _, v := range parameters {
|
||||
output += v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t *TestingEchoTask) requiredParameters() []string {
|
||||
return []string{"echo_text"}
|
||||
}
|
||||
|
||||
// TestingEmitEchoTask is used only for testing. It emits a TestingEchoTask.
|
||||
type TestingEmitEchoTask struct {
|
||||
}
|
||||
|
||||
func (t *TestingEmitEchoTask) run(parameters map[string]string) (newTasks []*pb.TaskContainer, output string, err error) {
|
||||
return []*pb.TaskContainer{
|
||||
NewTaskContainerWithSingleTask("TestingEchoTask", parameters),
|
||||
}, "emitted TestingEchoTask", nil
|
||||
}
|
||||
|
||||
func (t *TestingEmitEchoTask) requiredParameters() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// TestingFailTask is used only for testing. It always fails.
|
||||
type TestingFailTask struct {
|
||||
}
|
||||
|
||||
func (t *TestingFailTask) run(parameters map[string]string) (newTasks []*pb.TaskContainer, output string, err error) {
|
||||
return nil, "something went wrong", errors.New("full error message")
|
||||
}
|
||||
|
||||
func (t *TestingFailTask) requiredParameters() []string {
|
||||
return []string{}
|
||||
}
|
|
@ -64,7 +64,7 @@ func getStatementCategory(sql []byte) int {
|
|||
type BinlogStreamer struct {
|
||||
// dbname and mysqld are set at creation.
|
||||
dbname string
|
||||
mysqld *mysqlctl.Mysqld
|
||||
mysqld mysqlctl.MysqlDaemon
|
||||
clientCharset *mproto.Charset
|
||||
startPos myproto.ReplicationPosition
|
||||
sendTransaction sendTransactionFunc
|
||||
|
@ -79,7 +79,7 @@ type BinlogStreamer struct {
|
|||
// charset is the default character set on the BinlogPlayer side.
|
||||
// startPos is the position to start streaming at.
|
||||
// sendTransaction is called each time a transaction is committed or rolled back.
|
||||
func NewBinlogStreamer(dbname string, mysqld *mysqlctl.Mysqld, clientCharset *mproto.Charset, startPos myproto.ReplicationPosition, sendTransaction sendTransactionFunc) *BinlogStreamer {
|
||||
func NewBinlogStreamer(dbname string, mysqld mysqlctl.MysqlDaemon, clientCharset *mproto.Charset, startPos myproto.ReplicationPosition, sendTransaction sendTransactionFunc) *BinlogStreamer {
|
||||
return &BinlogStreamer{
|
||||
dbname: dbname,
|
||||
mysqld: mysqld,
|
||||
|
@ -99,7 +99,7 @@ func (bls *BinlogStreamer) Stream(ctx *sync2.ServiceContext) (err error) {
|
|||
log.Infof("stream ended @ %v, err = %v", stopPos, err)
|
||||
}()
|
||||
|
||||
if bls.conn, err = mysqlctl.NewSlaveConnection(bls.mysqld); err != nil {
|
||||
if bls.conn, err = bls.mysqld.NewSlaveConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer bls.conn.Close()
|
||||
|
|
|
@ -39,7 +39,7 @@ type EventStreamer struct {
|
|||
sendEvent sendEventFunc
|
||||
}
|
||||
|
||||
func NewEventStreamer(dbname string, mysqld *mysqlctl.Mysqld, startPos myproto.ReplicationPosition, sendEvent sendEventFunc) *EventStreamer {
|
||||
func NewEventStreamer(dbname string, mysqld mysqlctl.MysqlDaemon, startPos myproto.ReplicationPosition, sendEvent sendEventFunc) *EventStreamer {
|
||||
evs := &EventStreamer{
|
||||
sendEvent: sendEvent,
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ type UpdateStream struct {
|
|||
|
||||
actionLock sync.Mutex
|
||||
state sync2.AtomicInt64
|
||||
mysqld *mysqlctl.Mysqld
|
||||
mysqld mysqlctl.MysqlDaemon
|
||||
stateWaitGroup sync.WaitGroup
|
||||
dbname string
|
||||
streams streamList
|
||||
|
@ -121,7 +121,7 @@ func logError() {
|
|||
}
|
||||
|
||||
// EnableUpdateStreamService enables the RPC service for UpdateStream
|
||||
func EnableUpdateStreamService(dbname string, mysqld *mysqlctl.Mysqld) {
|
||||
func EnableUpdateStreamService(dbname string, mysqld mysqlctl.MysqlDaemon) {
|
||||
defer logError()
|
||||
UpdateStreamRpcService.enable(dbname, mysqld)
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ func GetReplicationPosition() (myproto.ReplicationPosition, error) {
|
|||
return UpdateStreamRpcService.getReplicationPosition()
|
||||
}
|
||||
|
||||
func (updateStream *UpdateStream) enable(dbname string, mysqld *mysqlctl.Mysqld) {
|
||||
func (updateStream *UpdateStream) enable(dbname string, mysqld mysqlctl.MysqlDaemon) {
|
||||
updateStream.actionLock.Lock()
|
||||
defer updateStream.actionLock.Unlock()
|
||||
if updateStream.isEnabled() {
|
||||
|
|
|
@ -52,17 +52,17 @@ type conn struct {
|
|||
TabletType topo.TabletType `json:"tablet_type"`
|
||||
Streaming bool
|
||||
Timeout time.Duration
|
||||
vtgateConn vtgateconn.VTGateConn
|
||||
tx vtgateconn.VTGateTx
|
||||
vtgateConn *vtgateconn.VTGateConn
|
||||
tx *vtgateconn.VTGateTx
|
||||
}
|
||||
|
||||
func (c *conn) dial() error {
|
||||
dialer := vtgateconn.GetDialerWithProtocol(c.Protocol)
|
||||
if dialer == nil {
|
||||
return fmt.Errorf("could not find dialer for protocol %s", c.Protocol)
|
||||
}
|
||||
var err error
|
||||
c.vtgateConn, err = dialer(context.Background(), c.Address, c.Timeout)
|
||||
if c.Protocol == "" {
|
||||
c.vtgateConn, err = vtgateconn.Dial(context.Background(), c.Address, c.Timeout)
|
||||
} else {
|
||||
c.vtgateConn, err = vtgateconn.DialProtocol(context.Background(), c.Protocol, c.Address, c.Timeout)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -150,7 +150,7 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|||
defer cancel()
|
||||
if s.c.Streaming {
|
||||
qrc, errFunc := s.c.vtgateConn.StreamExecute(ctx, s.query, makeBindVars(args), s.c.TabletType)
|
||||
return vtgateconn.NewStreamingRows(qrc, errFunc), nil
|
||||
return newStreamingRows(qrc, errFunc), nil
|
||||
}
|
||||
var qr *mproto.QueryResult
|
||||
var err error
|
||||
|
@ -162,7 +162,7 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return vtgateconn.NewRows(qr), nil
|
||||
return newRows(qr), nil
|
||||
}
|
||||
|
||||
func makeBindVars(args []driver.Value) map[string]interface{} {
|
||||
|
|
|
@ -94,7 +94,7 @@ func TestDial(t *testing.T) {
|
|||
_ = c.Close()
|
||||
|
||||
_, err = drv{}.Open(`{"protocol": "none"}`)
|
||||
want := "could not find dialer for protocol none"
|
||||
want := "no dialer registered for VTGate protocol none"
|
||||
if err == nil || !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("err: %v, want %s", err, want)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vtgateconn
|
||||
package client
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
@ -20,8 +20,8 @@ type rows struct {
|
|||
index int
|
||||
}
|
||||
|
||||
// NewRows creates a new rows from qr.
|
||||
func NewRows(qr *mproto.QueryResult) driver.Rows {
|
||||
// newRows creates a new rows from qr.
|
||||
func newRows(qr *mproto.QueryResult) driver.Rows {
|
||||
return &rows{qr: qr}
|
||||
}
|
||||
|
||||
|
@ -46,6 +46,10 @@ func (ri *rows) Next(dest []driver.Value) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// populateRow populates a row of data using the table's field descriptions.
|
||||
// The returned types for "dest" include the list from the interface
|
||||
// specification at https://golang.org/pkg/database/sql/driver/#Value
|
||||
// and in addition the type "uint64" for unsigned BIGINT MySQL records.
|
||||
func populateRow(dest []driver.Value, fields []mproto.Field, row []sqltypes.Value) error {
|
||||
if len(dest) != len(fields) {
|
||||
return fmt.Errorf("length mismatch: dest is %d, fields are %d", len(dest), len(fields))
|
||||
|
@ -55,7 +59,7 @@ func populateRow(dest []driver.Value, fields []mproto.Field, row []sqltypes.Valu
|
|||
}
|
||||
var err error
|
||||
for i := range dest {
|
||||
dest[i], err = mproto.Convert(fields[i].Type, row[i])
|
||||
dest[i], err = mproto.Convert(fields[i], row[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("conversion error: field: %v, val: %v: %v", fields[i], row[i], err)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package vtgateconn
|
||||
package client
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/youtube/vitess/go/sqltypes"
|
||||
)
|
||||
|
||||
var result1 = mproto.QueryResult{
|
||||
var rowsResult1 = mproto.QueryResult{
|
||||
Fields: []mproto.Field{
|
||||
mproto.Field{
|
||||
Name: "field1",
|
||||
|
@ -24,29 +24,59 @@ var result1 = mproto.QueryResult{
|
|||
Name: "field3",
|
||||
Type: mproto.VT_VAR_STRING,
|
||||
},
|
||||
// Signed types which are smaller than uint64, will become an int64.
|
||||
mproto.Field{
|
||||
Name: "field4",
|
||||
Type: mproto.VT_LONG,
|
||||
Flags: mproto.VT_UNSIGNED_FLAG,
|
||||
},
|
||||
// Signed uint64 values must be mapped to uint64.
|
||||
mproto.Field{
|
||||
Name: "field5",
|
||||
Type: mproto.VT_LONGLONG,
|
||||
Flags: mproto.VT_UNSIGNED_FLAG,
|
||||
},
|
||||
},
|
||||
RowsAffected: 3,
|
||||
RowsAffected: 2,
|
||||
InsertId: 0,
|
||||
Rows: [][]sqltypes.Value{
|
||||
[]sqltypes.Value{
|
||||
sqltypes.MakeString([]byte("1")),
|
||||
sqltypes.MakeString([]byte("1.1")),
|
||||
sqltypes.MakeString([]byte("value1")),
|
||||
sqltypes.MakeString([]byte("2147483647")), // 2^31-1, NOT out of range for int32 => should become int64
|
||||
sqltypes.MakeString([]byte("9223372036854775807")), // 2^63-1, NOT out of range for int64
|
||||
},
|
||||
[]sqltypes.Value{
|
||||
sqltypes.MakeString([]byte("2")),
|
||||
sqltypes.MakeString([]byte("2.2")),
|
||||
sqltypes.MakeString([]byte("value2")),
|
||||
sqltypes.MakeString([]byte("4294967295")), // 2^32, out of range for int32 => should become int64
|
||||
sqltypes.MakeString([]byte("18446744073709551615")), // 2^64, out of range for int64
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func logMismatchedTypes(t *testing.T, gotRow, wantRow []driver.Value) {
|
||||
for i := 1; i < len(wantRow); i++ {
|
||||
got := gotRow[i]
|
||||
want := wantRow[i]
|
||||
v1 := reflect.ValueOf(got)
|
||||
v2 := reflect.ValueOf(want)
|
||||
if v1.Type() != v2.Type() {
|
||||
t.Errorf("Wrong type: field: %d got: %T want: %T", i+1, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRows(t *testing.T) {
|
||||
ri := NewRows(&result1)
|
||||
ri := newRows(&rowsResult1)
|
||||
wantCols := []string{
|
||||
"field1",
|
||||
"field2",
|
||||
"field3",
|
||||
"field4",
|
||||
"field5",
|
||||
}
|
||||
gotCols := ri.Columns()
|
||||
if !reflect.DeepEqual(gotCols, wantCols) {
|
||||
|
@ -57,20 +87,25 @@ func TestRows(t *testing.T) {
|
|||
int64(1),
|
||||
float64(1.1),
|
||||
[]byte("value1"),
|
||||
int64(2147483647),
|
||||
uint64(9223372036854775807),
|
||||
}
|
||||
gotRow := make([]driver.Value, 3)
|
||||
gotRow := make([]driver.Value, len(wantRow))
|
||||
err := ri.Next(gotRow)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotRow, wantRow) {
|
||||
t.Errorf("row1: %v, want %v", gotRow, wantRow)
|
||||
t.Errorf("row1: %v, want %v type: %T", gotRow, wantRow, wantRow[3])
|
||||
logMismatchedTypes(t, gotRow, wantRow)
|
||||
}
|
||||
|
||||
wantRow = []driver.Value{
|
||||
int64(2),
|
||||
float64(2.2),
|
||||
[]byte("value2"),
|
||||
int64(4294967295),
|
||||
uint64(18446744073709551615),
|
||||
}
|
||||
err = ri.Next(gotRow)
|
||||
if err != nil {
|
||||
|
@ -78,6 +113,7 @@ func TestRows(t *testing.T) {
|
|||
}
|
||||
if !reflect.DeepEqual(gotRow, wantRow) {
|
||||
t.Errorf("row1: %v, want %v", gotRow, wantRow)
|
||||
logMismatchedTypes(t, gotRow, wantRow)
|
||||
}
|
||||
|
||||
err = ri.Next(gotRow)
|
||||
|
@ -112,7 +148,7 @@ var badResult2 = mproto.QueryResult{
|
|||
}
|
||||
|
||||
func TestRowsFail(t *testing.T) {
|
||||
ri := NewRows(&badResult1)
|
||||
ri := newRows(&badResult1)
|
||||
var dest []driver.Value
|
||||
err := ri.Next(dest)
|
||||
want := "length mismatch: dest is 0, fields are 1"
|
||||
|
@ -120,7 +156,7 @@ func TestRowsFail(t *testing.T) {
|
|||
t.Errorf("Next: %v, want %s", err, want)
|
||||
}
|
||||
|
||||
ri = NewRows(&badResult1)
|
||||
ri = newRows(&badResult1)
|
||||
dest = make([]driver.Value, 1)
|
||||
err = ri.Next(dest)
|
||||
want = "internal error: length mismatch: dest is 1, fields are 0"
|
||||
|
@ -128,10 +164,10 @@ func TestRowsFail(t *testing.T) {
|
|||
t.Errorf("Next: %v, want %s", err, want)
|
||||
}
|
||||
|
||||
ri = NewRows(&badResult2)
|
||||
ri = newRows(&badResult2)
|
||||
dest = make([]driver.Value, 1)
|
||||
err = ri.Next(dest)
|
||||
want = `conversion error: field: {field1 3 0}, val: value: strconv.ParseUint: parsing "value": invalid syntax`
|
||||
want = `conversion error: field: {field1 3 0}, val: value: strconv.ParseInt: parsing "value": invalid syntax`
|
||||
if err == nil || err.Error() != want {
|
||||
t.Errorf("Next: %v, want %s", err, want)
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vtgateconn
|
||||
package client
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
@ -10,21 +10,22 @@ import (
|
|||
"io"
|
||||
|
||||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
|
||||
)
|
||||
|
||||
// streamingRows creates a database/sql/driver compliant Row iterator
|
||||
// for a streaming query.
|
||||
type streamingRows struct {
|
||||
qrc <-chan *mproto.QueryResult
|
||||
errFunc ErrFunc
|
||||
errFunc vtgateconn.ErrFunc
|
||||
failed error
|
||||
fields []mproto.Field
|
||||
qr *mproto.QueryResult
|
||||
index int
|
||||
}
|
||||
|
||||
// NewStreamingRows creates a new streamingRows from qrc and errFunc.
|
||||
func NewStreamingRows(qrc <-chan *mproto.QueryResult, errFunc ErrFunc) driver.Rows {
|
||||
// newStreamingRows creates a new streamingRows from qrc and errFunc.
|
||||
func newStreamingRows(qrc <-chan *mproto.QueryResult, errFunc vtgateconn.ErrFunc) driver.Rows {
|
||||
return &streamingRows{qrc: qrc, errFunc: errFunc}
|
||||
}
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vtgateconn
|
||||
package client
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
|
||||
)
|
||||
|
||||
var packet1 = mproto.QueryResult{
|
||||
|
@ -54,7 +55,7 @@ var packet3 = mproto.QueryResult{
|
|||
}
|
||||
|
||||
func TestStreamingRows(t *testing.T) {
|
||||
qrc, errFunc := func() (<-chan *mproto.QueryResult, ErrFunc) {
|
||||
qrc, errFunc := func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
|
||||
ch := make(chan *mproto.QueryResult)
|
||||
go func() {
|
||||
ch <- &packet1
|
||||
|
@ -64,7 +65,7 @@ func TestStreamingRows(t *testing.T) {
|
|||
}()
|
||||
return ch, func() error { return nil }
|
||||
}()
|
||||
ri := NewStreamingRows(qrc, errFunc)
|
||||
ri := newStreamingRows(qrc, errFunc)
|
||||
wantCols := []string{
|
||||
"field1",
|
||||
"field2",
|
||||
|
@ -111,7 +112,7 @@ func TestStreamingRows(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestStreamingRowsReversed(t *testing.T) {
|
||||
qrc, errFunc := func() (<-chan *mproto.QueryResult, ErrFunc) {
|
||||
qrc, errFunc := func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
|
||||
ch := make(chan *mproto.QueryResult)
|
||||
go func() {
|
||||
ch <- &packet1
|
||||
|
@ -121,7 +122,7 @@ func TestStreamingRowsReversed(t *testing.T) {
|
|||
}()
|
||||
return ch, func() error { return nil }
|
||||
}()
|
||||
ri := NewStreamingRows(qrc, errFunc)
|
||||
ri := newStreamingRows(qrc, errFunc)
|
||||
|
||||
wantRow := []driver.Value{
|
||||
int64(1),
|
||||
|
@ -151,14 +152,14 @@ func TestStreamingRowsReversed(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestStreamingRowsError(t *testing.T) {
|
||||
qrc, errFunc := func() (<-chan *mproto.QueryResult, ErrFunc) {
|
||||
qrc, errFunc := func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
|
||||
ch := make(chan *mproto.QueryResult)
|
||||
go func() {
|
||||
close(ch)
|
||||
}()
|
||||
return ch, func() error { return errors.New("error before fields") }
|
||||
}()
|
||||
ri := NewStreamingRows(qrc, errFunc)
|
||||
ri := newStreamingRows(qrc, errFunc)
|
||||
gotCols := ri.Columns()
|
||||
if gotCols != nil {
|
||||
t.Errorf("cols: %v, want nil", gotCols)
|
||||
|
@ -171,7 +172,7 @@ func TestStreamingRowsError(t *testing.T) {
|
|||
}
|
||||
_ = ri.Close()
|
||||
|
||||
qrc, errFunc = func() (<-chan *mproto.QueryResult, ErrFunc) {
|
||||
qrc, errFunc = func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
|
||||
ch := make(chan *mproto.QueryResult)
|
||||
go func() {
|
||||
ch <- &packet1
|
||||
|
@ -179,7 +180,7 @@ func TestStreamingRowsError(t *testing.T) {
|
|||
}()
|
||||
return ch, func() error { return errors.New("error after fields") }
|
||||
}()
|
||||
ri = NewStreamingRows(qrc, errFunc)
|
||||
ri = newStreamingRows(qrc, errFunc)
|
||||
wantCols := []string{
|
||||
"field1",
|
||||
"field2",
|
||||
|
@ -202,7 +203,7 @@ func TestStreamingRowsError(t *testing.T) {
|
|||
}
|
||||
_ = ri.Close()
|
||||
|
||||
qrc, errFunc = func() (<-chan *mproto.QueryResult, ErrFunc) {
|
||||
qrc, errFunc = func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
|
||||
ch := make(chan *mproto.QueryResult)
|
||||
go func() {
|
||||
ch <- &packet1
|
||||
|
@ -211,7 +212,7 @@ func TestStreamingRowsError(t *testing.T) {
|
|||
}()
|
||||
return ch, func() error { return errors.New("error after rows") }
|
||||
}()
|
||||
ri = NewStreamingRows(qrc, errFunc)
|
||||
ri = newStreamingRows(qrc, errFunc)
|
||||
gotRow = make([]driver.Value, 3)
|
||||
err = ri.Next(gotRow)
|
||||
if err != nil {
|
||||
|
@ -224,7 +225,7 @@ func TestStreamingRowsError(t *testing.T) {
|
|||
}
|
||||
_ = ri.Close()
|
||||
|
||||
qrc, errFunc = func() (<-chan *mproto.QueryResult, ErrFunc) {
|
||||
qrc, errFunc = func() (<-chan *mproto.QueryResult, vtgateconn.ErrFunc) {
|
||||
ch := make(chan *mproto.QueryResult)
|
||||
go func() {
|
||||
ch <- &packet2
|
||||
|
@ -232,7 +233,7 @@ func TestStreamingRowsError(t *testing.T) {
|
|||
}()
|
||||
return ch, func() error { return nil }
|
||||
}()
|
||||
ri = NewStreamingRows(qrc, errFunc)
|
||||
ri = newStreamingRows(qrc, errFunc)
|
||||
gotRow = make([]driver.Value, 3)
|
||||
err = ri.Next(gotRow)
|
||||
wantErr = "first packet did not return fields"
|
|
@ -1,290 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package client2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
const (
|
||||
EID_NODE = iota
|
||||
VALUE_NODE
|
||||
LIST_NODE
|
||||
OTHER_NODE
|
||||
)
|
||||
|
||||
type RoutingPlan struct {
|
||||
criteria sqlparser.SQLNode
|
||||
}
|
||||
|
||||
func GetShardList(sql string, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardlist []int, err error) {
|
||||
plan, err := buildPlan(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return shardListFromPlan(plan, bindVariables, tabletKeys)
|
||||
}
|
||||
|
||||
func buildPlan(sql string) (plan *RoutingPlan, err error) {
|
||||
statement, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getRoutingPlan(statement)
|
||||
}
|
||||
|
||||
func shardListFromPlan(plan *RoutingPlan, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardList []int, err error) {
|
||||
if plan.criteria == nil {
|
||||
return makeList(0, len(tabletKeys)), nil
|
||||
}
|
||||
|
||||
switch criteria := plan.criteria.(type) {
|
||||
case sqlparser.Values:
|
||||
index, err := findInsertShard(criteria, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []int{index}, nil
|
||||
case *sqlparser.ComparisonExpr:
|
||||
switch criteria.Operator {
|
||||
case "=", "<=>":
|
||||
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []int{index}, nil
|
||||
case "<", "<=":
|
||||
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return makeList(0, index+1), nil
|
||||
case ">", ">=":
|
||||
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return makeList(index, len(tabletKeys)), nil
|
||||
case "in":
|
||||
return findShardList(criteria.Right, bindVariables, tabletKeys)
|
||||
}
|
||||
case *sqlparser.RangeCond:
|
||||
if criteria.Operator == "between" {
|
||||
start, err := findShard(criteria.From, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
last, err := findShard(criteria.To, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if last < start {
|
||||
start, last = last, start
|
||||
}
|
||||
return makeList(start, last+1), nil
|
||||
}
|
||||
}
|
||||
return makeList(0, len(tabletKeys)), nil
|
||||
}
|
||||
|
||||
func getRoutingPlan(statement sqlparser.Statement) (plan *RoutingPlan, err error) {
|
||||
plan = &RoutingPlan{}
|
||||
if ins, ok := statement.(*sqlparser.Insert); ok {
|
||||
if sel, ok := ins.Rows.(sqlparser.SelectStatement); ok {
|
||||
return getRoutingPlan(sel)
|
||||
}
|
||||
plan.criteria, err = routingAnalyzeValues(ins.Rows.(sqlparser.Values))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
var where *sqlparser.Where
|
||||
switch stmt := statement.(type) {
|
||||
case *sqlparser.Select:
|
||||
where = stmt.Where
|
||||
case *sqlparser.Update:
|
||||
where = stmt.Where
|
||||
case *sqlparser.Delete:
|
||||
where = stmt.Where
|
||||
}
|
||||
if where != nil {
|
||||
plan.criteria = routingAnalyzeBoolean(where.Expr)
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func routingAnalyzeValues(vals sqlparser.Values) (sqlparser.Values, error) {
|
||||
// Analyze first value of every item in the list
|
||||
for i := 0; i < len(vals); i++ {
|
||||
switch tuple := vals[i].(type) {
|
||||
case sqlparser.ValTuple:
|
||||
result := routingAnalyzeValue(tuple[0])
|
||||
if result != VALUE_NODE {
|
||||
return nil, fmt.Errorf("insert is too complex")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("insert is too complex")
|
||||
}
|
||||
}
|
||||
return vals, nil
|
||||
}
|
||||
|
||||
func routingAnalyzeBoolean(node sqlparser.BoolExpr) sqlparser.BoolExpr {
|
||||
switch node := node.(type) {
|
||||
case *sqlparser.AndExpr:
|
||||
left := routingAnalyzeBoolean(node.Left)
|
||||
right := routingAnalyzeBoolean(node.Right)
|
||||
if left != nil && right != nil {
|
||||
return nil
|
||||
} else if left != nil {
|
||||
return left
|
||||
} else {
|
||||
return right
|
||||
}
|
||||
case *sqlparser.ParenBoolExpr:
|
||||
return routingAnalyzeBoolean(node.Expr)
|
||||
case *sqlparser.ComparisonExpr:
|
||||
switch {
|
||||
case sqlparser.StringIn(node.Operator, "=", "<", ">", "<=", ">=", "<=>"):
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
right := routingAnalyzeValue(node.Right)
|
||||
if (left == EID_NODE && right == VALUE_NODE) || (left == VALUE_NODE && right == EID_NODE) {
|
||||
return node
|
||||
}
|
||||
case node.Operator == "in":
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
right := routingAnalyzeValue(node.Right)
|
||||
if left == EID_NODE && right == LIST_NODE {
|
||||
return node
|
||||
}
|
||||
}
|
||||
case *sqlparser.RangeCond:
|
||||
if node.Operator != "between" {
|
||||
return nil
|
||||
}
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
from := routingAnalyzeValue(node.From)
|
||||
to := routingAnalyzeValue(node.To)
|
||||
if left == EID_NODE && from == VALUE_NODE && to == VALUE_NODE {
|
||||
return node
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func routingAnalyzeValue(valExpr sqlparser.ValExpr) int {
|
||||
switch node := valExpr.(type) {
|
||||
case *sqlparser.ColName:
|
||||
if string(node.Name) == "entity_id" {
|
||||
return EID_NODE
|
||||
}
|
||||
case sqlparser.ValTuple:
|
||||
for _, n := range node {
|
||||
if routingAnalyzeValue(n) != VALUE_NODE {
|
||||
return OTHER_NODE
|
||||
}
|
||||
}
|
||||
return LIST_NODE
|
||||
case sqlparser.StrVal, sqlparser.NumVal, sqlparser.ValArg:
|
||||
return VALUE_NODE
|
||||
}
|
||||
return OTHER_NODE
|
||||
}
|
||||
|
||||
func findShardList(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) ([]int, error) {
|
||||
shardset := make(map[int]bool)
|
||||
switch node := valExpr.(type) {
|
||||
case sqlparser.ValTuple:
|
||||
for _, n := range node {
|
||||
index, err := findShard(n, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shardset[index] = true
|
||||
}
|
||||
}
|
||||
shardlist := make([]int, len(shardset))
|
||||
index := 0
|
||||
for k := range shardset {
|
||||
shardlist[index] = k
|
||||
index++
|
||||
}
|
||||
return shardlist, nil
|
||||
}
|
||||
|
||||
func findInsertShard(vals sqlparser.Values, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
|
||||
index := -1
|
||||
for i := 0; i < len(vals); i++ {
|
||||
first_value_expression := vals[i].(sqlparser.ValTuple)[0]
|
||||
newIndex, err := findShard(first_value_expression, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
if index == -1 {
|
||||
index = newIndex
|
||||
} else if index != newIndex {
|
||||
return -1, fmt.Errorf("insert has multiple shard targets")
|
||||
}
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func findShard(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
|
||||
value, err := getBoundValue(valExpr, bindVariables)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return key.FindShardForValue(value, tabletKeys), nil
|
||||
}
|
||||
|
||||
func getBoundValue(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}) (string, error) {
|
||||
switch node := valExpr.(type) {
|
||||
case sqlparser.ValTuple:
|
||||
if len(node) != 1 {
|
||||
return "", fmt.Errorf("tuples not allowed as insert values")
|
||||
}
|
||||
// TODO: Change parser to create single value tuples into non-tuples.
|
||||
return getBoundValue(node[0], bindVariables)
|
||||
case sqlparser.StrVal:
|
||||
return string(node), nil
|
||||
case sqlparser.NumVal:
|
||||
val, err := strconv.ParseInt(string(node), 10, 64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return key.Uint64Key(val).String(), nil
|
||||
case sqlparser.ValArg:
|
||||
value, err := findBindValue(node, bindVariables)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return key.EncodeValue(value), nil
|
||||
}
|
||||
panic("Unexpected token")
|
||||
}
|
||||
|
||||
func findBindValue(valArg sqlparser.ValArg, bindVariables map[string]interface{}) (interface{}, error) {
|
||||
if bindVariables == nil {
|
||||
return nil, fmt.Errorf("No bind variable for " + string(valArg))
|
||||
}
|
||||
value, ok := bindVariables[string(valArg[1:])]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("No bind variable for " + string(valArg))
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func makeList(start, end int) []int {
|
||||
list := make([]int, end-start)
|
||||
for i := start; i < end; i++ {
|
||||
list[i-start] = i
|
||||
}
|
||||
return list
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package client2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/testfiles"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
)
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
tabletkeys := []key.KeyspaceId{
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x02",
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x04",
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x06",
|
||||
"a",
|
||||
"b",
|
||||
"d",
|
||||
}
|
||||
bindVariables := make(map[string]interface{})
|
||||
bindVariables["id0"] = 0
|
||||
bindVariables["id2"] = 2
|
||||
bindVariables["id3"] = 3
|
||||
bindVariables["id4"] = 4
|
||||
bindVariables["id6"] = 6
|
||||
bindVariables["id8"] = 8
|
||||
bindVariables["ids"] = []interface{}{1, 4}
|
||||
bindVariables["a"] = "a"
|
||||
bindVariables["b"] = "b"
|
||||
bindVariables["c"] = "c"
|
||||
bindVariables["d"] = "d"
|
||||
bindVariables["e"] = "e"
|
||||
for tcase := range iterateFiles("sqlparser_test/routing_cases.txt") {
|
||||
if tcase.output == "" {
|
||||
tcase.output = tcase.input
|
||||
}
|
||||
out, err := GetShardList(tcase.input, bindVariables, tabletkeys)
|
||||
if err != nil {
|
||||
if err.Error() != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.input, err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
sort.Ints(out)
|
||||
outstr := fmt.Sprintf("%v", out)
|
||||
if outstr != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, outstr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(sougou): This is now duplicated in three plcaes. Refactor.
|
||||
type testCase struct {
|
||||
file string
|
||||
lineno int
|
||||
input string
|
||||
output string
|
||||
}
|
||||
|
||||
func iterateFiles(pattern string) (testCaseIterator chan testCase) {
|
||||
names := testfiles.Glob(pattern)
|
||||
testCaseIterator = make(chan testCase)
|
||||
go func() {
|
||||
defer close(testCaseIterator)
|
||||
for _, name := range names {
|
||||
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Could not open file %s", name))
|
||||
}
|
||||
|
||||
r := bufio.NewReader(fd)
|
||||
lineno := 0
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
lines := strings.Split(strings.TrimRight(line, "\n"), "#")
|
||||
lineno++
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
panic(fmt.Sprintf("Error reading file %s: %s", name, err.Error()))
|
||||
}
|
||||
break
|
||||
}
|
||||
input := lines[0]
|
||||
output := ""
|
||||
if len(lines) > 1 {
|
||||
output = lines[1]
|
||||
}
|
||||
if input == "" {
|
||||
continue
|
||||
}
|
||||
testCaseIterator <- testCase{name, lineno, input, output}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return testCaseIterator
|
||||
}
|
|
@ -1,589 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package client2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/db"
|
||||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/vt/client2/tablet"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/zktopo"
|
||||
"github.com/youtube/vitess/go/zk"
|
||||
)
|
||||
|
||||
// The sharded client handles writing to multiple shards across the
|
||||
// database.
|
||||
//
|
||||
// The ShardedConn can handles several separate aspects:
|
||||
// * loading/reloading tablet addresses on demand from zk
|
||||
// * maintaining at most one connection to each tablet as required
|
||||
// * transaction tracking across shards
|
||||
// * preflight checking all transactions before attempting to commit
|
||||
// (reduce partial commit probability)
|
||||
//
|
||||
// NOTE: Queries with aggregate results will not produce expected
|
||||
// results right now. For instance, running a count(*) on a table
|
||||
// across all tablets will return one row per tablet. In the future,
|
||||
// the SQL parser and query engine can handle these more
|
||||
// automatically. For now, clients will have to do the rollup at a
|
||||
// higher level.
|
||||
|
||||
var (
|
||||
ErrNotConnected = fmt.Errorf("vt: not connected")
|
||||
)
|
||||
|
||||
type VtClientError struct {
|
||||
msg string
|
||||
partial bool
|
||||
}
|
||||
|
||||
func (err VtClientError) Error() string {
|
||||
return err.msg
|
||||
}
|
||||
|
||||
func (err VtClientError) Partial() bool {
|
||||
return err.partial
|
||||
}
|
||||
|
||||
// Not thread safe, as per sql package.
|
||||
type ShardedConn struct {
|
||||
ts topo.Server
|
||||
cell string
|
||||
keyspace string
|
||||
tabletType topo.TabletType
|
||||
stream bool // Use streaming RPC
|
||||
|
||||
srvKeyspace *topo.SrvKeyspace
|
||||
// Keep a map per shard mapping tabletType to a real connection.
|
||||
// connByType []map[string]*Conn
|
||||
|
||||
// Sorted list of the max keys for each shard.
|
||||
shardMaxKeys []key.KeyspaceId
|
||||
conns []*tablet.VtConn
|
||||
|
||||
timeout time.Duration // How long should we wait for a given operation?
|
||||
|
||||
// Currently running transaction (or nil if not inside a transaction)
|
||||
currentTransaction *MetaTx
|
||||
}
|
||||
|
||||
// FIXME(msolomon) Normally a connect method would actually connect up
|
||||
// to the appropriate endpoints. In the distributed case, it's unclear
|
||||
// that this is necessary. You have to deal with transient failures
|
||||
// anyway, so the whole system degenerates to managing connections on
|
||||
// demand.
|
||||
func Dial(ts topo.Server, cell, keyspace string, tabletType topo.TabletType, stream bool, timeout time.Duration) (*ShardedConn, error) {
|
||||
sc := &ShardedConn{
|
||||
ts: ts,
|
||||
cell: cell,
|
||||
keyspace: keyspace,
|
||||
tabletType: tabletType,
|
||||
stream: stream,
|
||||
}
|
||||
err := sc.readKeyspace()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sc, nil
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) Close() error {
|
||||
if sc.conns == nil {
|
||||
return nil
|
||||
}
|
||||
if sc.currentTransaction != nil {
|
||||
sc.rollback()
|
||||
}
|
||||
|
||||
for _, conn := range sc.conns {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
sc.conns = nil
|
||||
sc.srvKeyspace = nil
|
||||
sc.shardMaxKeys = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) readKeyspace() error {
|
||||
sc.Close()
|
||||
var err error
|
||||
sc.srvKeyspace, err = sc.ts.GetSrvKeyspace(sc.cell, sc.keyspace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vt: GetSrvKeyspace failed %v", err)
|
||||
}
|
||||
|
||||
sc.conns = make([]*tablet.VtConn, len(sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences))
|
||||
sc.shardMaxKeys = make([]key.KeyspaceId, len(sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences))
|
||||
|
||||
for i, shardReference := range sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences {
|
||||
sc.shardMaxKeys[i] = shardReference.KeyRange.End
|
||||
}
|
||||
|
||||
// Disabled for now.
|
||||
// sc.connByType = make([]map[string]*Conn, len(sc.srvKeyspace.ShardReferences))
|
||||
// for i := 0; i < len(sc.connByType); i++ {
|
||||
// sc.connByType[i] = make(map[string]*Conn, 8)
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
// A "transaction" that may be across and thus, not transactional at
|
||||
// this point.
|
||||
type MetaTx struct {
|
||||
// The connections involved in this transaction, in the order they
|
||||
// were added to the transaction.
|
||||
shardedConn *ShardedConn
|
||||
conns []*tablet.VtConn
|
||||
}
|
||||
|
||||
// makes sure the given transaction was issued a Begin() call
|
||||
func (tx *MetaTx) begin(conn *tablet.VtConn) (err error) {
|
||||
for _, v := range tx.conns {
|
||||
if v == conn {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Begin()
|
||||
if err != nil {
|
||||
// the caller will need to take care of the rollback,
|
||||
// and therefore issue a rollback on all pre-existing
|
||||
// transactions
|
||||
return err
|
||||
}
|
||||
tx.conns = append(tx.conns, conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *MetaTx) Commit() (err error) {
|
||||
if tx.shardedConn.currentTransaction == nil {
|
||||
return tablet.ErrBadRollback
|
||||
}
|
||||
|
||||
commit := true
|
||||
for _, conn := range tx.conns {
|
||||
if commit {
|
||||
if err = conn.Commit(); err != nil {
|
||||
commit = false
|
||||
}
|
||||
}
|
||||
if !commit {
|
||||
conn.Rollback()
|
||||
}
|
||||
}
|
||||
tx.shardedConn.currentTransaction = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (tx *MetaTx) Rollback() error {
|
||||
if tx.shardedConn.currentTransaction == nil {
|
||||
return tablet.ErrBadRollback
|
||||
}
|
||||
var someErr error
|
||||
for _, conn := range tx.conns {
|
||||
if err := conn.Rollback(); err != nil {
|
||||
someErr = err
|
||||
}
|
||||
}
|
||||
tx.shardedConn.currentTransaction = nil
|
||||
return someErr
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) Begin() (db.Tx, error) {
|
||||
if sc.srvKeyspace == nil {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
if sc.currentTransaction != nil {
|
||||
return nil, tablet.ErrNoNestedTxn
|
||||
}
|
||||
tx := &MetaTx{sc, make([]*tablet.VtConn, 0, 32)}
|
||||
sc.currentTransaction = tx
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) rollback() error {
|
||||
if sc.currentTransaction == nil {
|
||||
return tablet.ErrBadRollback
|
||||
}
|
||||
var someErr error
|
||||
for _, conn := range sc.conns {
|
||||
if conn.TransactionId != 0 {
|
||||
if err := conn.Rollback(); err != nil {
|
||||
someErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
sc.currentTransaction = nil
|
||||
return someErr
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) Exec(query string, bindVars map[string]interface{}) (db.Result, error) {
|
||||
if sc.srvKeyspace == nil {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
shards, err := GetShardList(query, bindVars, sc.shardMaxKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sc.stream {
|
||||
return sc.execOnShardsStream(query, bindVars, shards)
|
||||
}
|
||||
return sc.execOnShards(query, bindVars, shards)
|
||||
}
|
||||
|
||||
// FIXME(msolomon) define key interface "Keyer" or force a concrete type?
|
||||
func (sc *ShardedConn) ExecWithKey(query string, bindVars map[string]interface{}, keyVal interface{}) (db.Result, error) {
|
||||
shardIdx, err := key.FindShardForKey(keyVal, sc.shardMaxKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sc.stream {
|
||||
return sc.execOnShardsStream(query, bindVars, []int{shardIdx})
|
||||
}
|
||||
return sc.execOnShards(query, bindVars, []int{shardIdx})
|
||||
}
|
||||
|
||||
type tabletResult struct {
|
||||
error
|
||||
*tablet.Result
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) execOnShards(query string, bindVars map[string]interface{}, shards []int) (metaResult *tablet.Result, err error) {
|
||||
rchan := make(chan tabletResult, len(shards))
|
||||
for _, shardIdx := range shards {
|
||||
go func(shardIdx int) {
|
||||
qr, err := sc.execOnShard(query, bindVars, shardIdx)
|
||||
if err != nil {
|
||||
rchan <- tabletResult{error: err}
|
||||
} else {
|
||||
rchan <- tabletResult{Result: qr.(*tablet.Result)}
|
||||
}
|
||||
}(shardIdx)
|
||||
}
|
||||
|
||||
results := make([]tabletResult, len(shards))
|
||||
rowCount := int64(0)
|
||||
rowsAffected := int64(0)
|
||||
lastInsertId := int64(0)
|
||||
var hasError error
|
||||
for i := range results {
|
||||
results[i] = <-rchan
|
||||
if results[i].error != nil {
|
||||
hasError = results[i].error
|
||||
continue
|
||||
}
|
||||
affected, _ := results[i].RowsAffected()
|
||||
insertId, _ := results[i].LastInsertId()
|
||||
rowsAffected += affected
|
||||
if insertId > 0 {
|
||||
if lastInsertId == 0 {
|
||||
lastInsertId = insertId
|
||||
}
|
||||
// FIXME(msolomon) issue an error when you have multiple last inserts?
|
||||
}
|
||||
rowCount += results[i].RowsRetrieved()
|
||||
}
|
||||
|
||||
// FIXME(msolomon) allow partial result set?
|
||||
if hasError != nil {
|
||||
return nil, fmt.Errorf("vt: partial result set (%v)", hasError)
|
||||
}
|
||||
|
||||
for _, tr := range results {
|
||||
if tr.error != nil {
|
||||
return nil, tr.error
|
||||
}
|
||||
// FIXME(msolomon) This error message should be a const. Should this
|
||||
// be deferred until we get a next query?
|
||||
if tr.error != nil && tr.error.Error() == "retry: unavailable" {
|
||||
sc.readKeyspace()
|
||||
}
|
||||
}
|
||||
|
||||
var fields []mproto.Field
|
||||
if len(results) > 0 {
|
||||
fields = results[0].Fields()
|
||||
}
|
||||
|
||||
// check the schemas all match (both names and types)
|
||||
if len(results) > 1 {
|
||||
firstFields := results[0].Fields()
|
||||
for _, r := range results[1:] {
|
||||
fields := r.Fields()
|
||||
if len(fields) != len(firstFields) {
|
||||
return nil, fmt.Errorf("vt: column count mismatch: %v != %v", len(firstFields), len(fields))
|
||||
}
|
||||
for i, name := range fields {
|
||||
if name.Name != firstFields[i].Name {
|
||||
return nil, fmt.Errorf("vt: column[%v] name mismatch: %v != %v", i, name.Name, firstFields[i].Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine results.
|
||||
metaResult = tablet.NewResult(rowCount, rowsAffected, lastInsertId, fields)
|
||||
curIndex := 0
|
||||
rows := metaResult.Rows()
|
||||
for _, tr := range results {
|
||||
for _, row := range tr.Rows() {
|
||||
rows[curIndex] = row
|
||||
curIndex++
|
||||
}
|
||||
}
|
||||
|
||||
return metaResult, nil
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) execOnShard(query string, bindVars map[string]interface{}, shardIdx int) (db.Result, error) {
|
||||
if sc.conns[shardIdx] == nil {
|
||||
conn, err := sc.dial(shardIdx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sc.conns[shardIdx] = conn
|
||||
}
|
||||
conn := sc.conns[shardIdx]
|
||||
|
||||
// if we haven't started the transaction on that shard and need to, now is the time
|
||||
if sc.currentTransaction != nil {
|
||||
err := sc.currentTransaction.begin(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Retries should have already taken place inside the tablet connection.
|
||||
// At this point, all that's left are more sinister failures.
|
||||
// FIXME(msolomon) reload just this shard unless the failure pertains to
|
||||
// needing to reload the entire keyspace.
|
||||
return conn.Exec(query, bindVars)
|
||||
}
|
||||
|
||||
// when doing a streaming query, we send this structure back
|
||||
type streamTabletResult struct {
|
||||
error
|
||||
row []interface{}
|
||||
}
|
||||
|
||||
// our streaming result, just aggregates from all streaming results
|
||||
// it implements both driver.Result and driver.Rows
|
||||
type multiStreamResult struct {
|
||||
cols []string
|
||||
|
||||
// results flow through this, maybe with errors
|
||||
rows chan streamTabletResult
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
// driver.Result interface
|
||||
func (*multiStreamResult) LastInsertId() (int64, error) {
|
||||
return 0, tablet.ErrNoLastInsertId
|
||||
}
|
||||
|
||||
func (*multiStreamResult) RowsAffected() (int64, error) {
|
||||
return 0, tablet.ErrNoRowsAffected
|
||||
}
|
||||
|
||||
// driver.Rows interface
|
||||
func (sr *multiStreamResult) Columns() []string {
|
||||
return sr.cols
|
||||
}
|
||||
|
||||
func (sr *multiStreamResult) Close() error {
|
||||
close(sr.rows)
|
||||
return nil
|
||||
}
|
||||
|
||||
// read from the stream and gets the next value
|
||||
// if one of the go routines returns an error, we want to save it and return it
|
||||
// eventually. (except if it's EOF, then we just know that routine is done)
|
||||
func (sr *multiStreamResult) Next() (row []interface{}) {
|
||||
for {
|
||||
str, ok := <-sr.rows
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if str.error != nil {
|
||||
sr.err = str.error
|
||||
continue
|
||||
}
|
||||
return str.row
|
||||
}
|
||||
}
|
||||
|
||||
func (sr *multiStreamResult) Err() error {
|
||||
return sr.err
|
||||
}
|
||||
|
||||
func (sc *ShardedConn) execOnShardsStream(query string, bindVars map[string]interface{}, shards []int) (msr *multiStreamResult, err error) {
|
||||
// we synchronously do the exec on each shard
|
||||
// so we can get the Columns from the first one
|
||||
// and check the others match them
|
||||
var cols []string
|
||||
qrs := make([]db.Result, len(shards))
|
||||
for i, shardIdx := range shards {
|
||||
qr, err := sc.execOnShard(query, bindVars, shardIdx)
|
||||
if err != nil {
|
||||
// FIXME(alainjobart) if the first queries went through
|
||||
// we need to cancel them
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// we know the result is a tablet.StreamResult,
|
||||
// and we use it as a driver.Rows
|
||||
qrs[i] = qr.(db.Result)
|
||||
|
||||
// save the columns or check they match
|
||||
if i == 0 {
|
||||
cols = qrs[i].Columns()
|
||||
} else {
|
||||
ncols := qrs[i].Columns()
|
||||
if len(ncols) != len(cols) {
|
||||
return nil, fmt.Errorf("vt: column count mismatch: %v != %v", len(ncols), len(cols))
|
||||
}
|
||||
for i, name := range cols {
|
||||
if name != ncols[i] {
|
||||
return nil, fmt.Errorf("vt: column[%v] name mismatch: %v != %v", i, name, ncols[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// now we create the result, its channel, and run background
|
||||
// routines to stream results
|
||||
msr = &multiStreamResult{cols: cols, rows: make(chan streamTabletResult, 10*len(shards))}
|
||||
var wg sync.WaitGroup
|
||||
for i, shardIdx := range shards {
|
||||
wg.Add(1)
|
||||
go func(i, shardIdx int) {
|
||||
defer wg.Done()
|
||||
for row := qrs[i].Next(); row != nil; row = qrs[i].Next() {
|
||||
msr.rows <- streamTabletResult{row: row}
|
||||
}
|
||||
if err := qrs[i].Err(); err != nil {
|
||||
msr.rows <- streamTabletResult{error: err}
|
||||
}
|
||||
}(i, shardIdx)
|
||||
}
|
||||
|
||||
// Close channel once all data has been sent
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(msr.rows)
|
||||
}()
|
||||
|
||||
return msr, nil
|
||||
}
|
||||
|
||||
/*
|
||||
type ClientQuery struct {
|
||||
Sql string
|
||||
BindVariables map[string]interface{}
|
||||
}
|
||||
|
||||
// FIXME(msolomon) There are multiple options for an efficient ExecMulti.
|
||||
// * Use a special stmt object, buffer all statements, connections, etc and send when it's ready.
|
||||
// * Take a list of (sql, bind) pairs and just send that - have to parse and route that anyway.
|
||||
// * Probably need separate support for the a MultiTx too.
|
||||
func (sc *ShardedConn) ExecuteBatch(queryList []ClientQuery, keyVal interface{}) (*tabletserver.QueryResult, error) {
|
||||
shardIdx, err := key.FindShardForKey(keyVal, sc.shardMaxKeys)
|
||||
shards := []int{shardIdx}
|
||||
|
||||
if err = sc.tabletPrepare(shardIdx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqs := make([]tabletserver.Query, len(queryList))
|
||||
for i, cq := range queryList {
|
||||
reqs[i] = tabletserver.Query{
|
||||
Sql: cq.Sql,
|
||||
BindVariables: cq.BindVariables,
|
||||
TransactionId: sc.conns[shardIdx].TransactionId,
|
||||
SessionId: sc.conns[shardIdx].SessionId,
|
||||
}
|
||||
}
|
||||
res := new(tabletserver.QueryResult)
|
||||
err = sc.conns[shardIdx].Call("SqlQuery.ExecuteBatch", reqs, res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
*/
|
||||
|
||||
func (sc *ShardedConn) dial(shardIdx int) (conn *tablet.VtConn, err error) {
|
||||
shardReference := &(sc.srvKeyspace.Partitions[sc.tabletType].ShardReferences[shardIdx])
|
||||
addrs, err := sc.ts.GetEndPoints(sc.cell, sc.keyspace, shardReference.Name, sc.tabletType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vt: GetEndPoints failed %v", err)
|
||||
}
|
||||
|
||||
srvs, err := topo.SrvEntries(addrs, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to connect to any address.
|
||||
for _, srv := range srvs {
|
||||
name := topo.SrvAddr(srv) + "/" + sc.keyspace + "/" + shardReference.Name
|
||||
conn, err = tablet.DialVtdb(name, sc.stream, tablet.DefaultTimeout)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type sDriver struct {
|
||||
ts topo.Server
|
||||
stream bool
|
||||
}
|
||||
|
||||
// for direct zk connection: vtzk://host:port/cell/keyspace/tabletType
|
||||
// we always use a MetaConn, host and port are ignored.
|
||||
// the driver name dictates if we streaming or not
|
||||
func (driver *sDriver) Open(name string) (sc db.Conn, err error) {
|
||||
if !strings.HasPrefix(name, "vtzk://") {
|
||||
// add a default protocol talking to zk
|
||||
name = "vtzk://" + name
|
||||
}
|
||||
u, err := url.Parse(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbi, tabletType := path.Split(u.Path)
|
||||
dbi = strings.Trim(dbi, "/")
|
||||
tabletType = strings.Trim(tabletType, "/")
|
||||
cell, keyspace := path.Split(dbi)
|
||||
cell = strings.Trim(cell, "/")
|
||||
keyspace = strings.Trim(keyspace, "/")
|
||||
return Dial(driver.ts, cell, keyspace, topo.TabletType(tabletType), driver.stream, tablet.DefaultTimeout)
|
||||
}
|
||||
|
||||
func RegisterShardedDrivers() {
|
||||
// default topo server
|
||||
ts := topo.GetServer()
|
||||
db.Register("vtdb", &sDriver{ts, false})
|
||||
db.Register("vtdb-streaming", &sDriver{ts, true})
|
||||
|
||||
// forced zk topo server
|
||||
zconn := zk.NewMetaConn()
|
||||
zkts := zktopo.NewServer(zconn)
|
||||
db.Register("vtdb-zk", &sDriver{zkts, false})
|
||||
db.Register("vtdb-zk-streaming", &sDriver{zkts, true})
|
||||
}
|
|
@ -1,332 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tablet is an API compliant to the requirements of database/sql
|
||||
// Open expects name to be "hostname:port/keyspace/shard"
|
||||
// For query arguments, we assume place-holders in the query string
|
||||
// in the form of :v0, :v1, etc.
|
||||
package tablet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/db"
|
||||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/netutil"
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoNestedTxn = errors.New("vt: no nested transactions")
|
||||
ErrBadCommit = errors.New("vt: commit without corresponding begin")
|
||||
ErrBadRollback = errors.New("vt: rollback without corresponding begin")
|
||||
ErrNoLastInsertId = errors.New("vt: no LastInsertId available after streaming statement")
|
||||
ErrNoRowsAffected = errors.New("vt: no RowsAffected available after streaming statement")
|
||||
ErrFieldLengthMismatch = errors.New("vt: no RowsAffected available after streaming statement")
|
||||
)
|
||||
|
||||
type TabletError struct {
|
||||
err error
|
||||
addr string
|
||||
}
|
||||
|
||||
func (te TabletError) Error() string {
|
||||
return fmt.Sprintf("vt: client error on %v %v", te.addr, te.err)
|
||||
}
|
||||
|
||||
// Not thread safe, as per sql package.
|
||||
type Conn struct {
|
||||
dbi *url.URL
|
||||
stream bool
|
||||
tabletConn tabletconn.TabletConn
|
||||
TransactionId int64
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
type Tx struct {
|
||||
conn *Conn
|
||||
}
|
||||
|
||||
type StreamResult struct {
|
||||
errFunc tabletconn.ErrFunc
|
||||
sr <-chan *mproto.QueryResult
|
||||
columns *mproto.QueryResult
|
||||
// current result and index on it
|
||||
qr *mproto.QueryResult
|
||||
index int
|
||||
err error
|
||||
}
|
||||
|
||||
func (conn *Conn) keyspace() string {
|
||||
return strings.Split(conn.dbi.Path, "/")[1]
|
||||
}
|
||||
|
||||
func (conn *Conn) shard() string {
|
||||
return strings.Split(conn.dbi.Path, "/")[2]
|
||||
}
|
||||
|
||||
// parseDbi parses the dbi and a URL. The dbi may or may not contain
|
||||
// the scheme part.
|
||||
func parseDbi(dbi string) (*url.URL, error) {
|
||||
if !strings.HasPrefix(dbi, "vttp://") {
|
||||
dbi = "vttp://" + dbi
|
||||
}
|
||||
return url.Parse(dbi)
|
||||
}
|
||||
|
||||
func DialTablet(dbi string, stream bool, timeout time.Duration) (conn *Conn, err error) {
|
||||
conn = new(Conn)
|
||||
if conn.dbi, err = parseDbi(dbi); err != nil {
|
||||
return
|
||||
}
|
||||
conn.stream = stream
|
||||
conn.timeout = timeout
|
||||
if err = conn.dial(); err != nil {
|
||||
return nil, conn.fmtErr(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Format error for exported methods to give callers more information.
|
||||
func (conn *Conn) fmtErr(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return TabletError{err, conn.dbi.Host}
|
||||
}
|
||||
|
||||
func (conn *Conn) dial() (err error) {
|
||||
// build the endpoint in the right format
|
||||
host, port, err := netutil.SplitHostPort(conn.dbi.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endPoint := topo.EndPoint{
|
||||
Host: host,
|
||||
NamedPortMap: map[string]int{
|
||||
"vt": port,
|
||||
},
|
||||
}
|
||||
|
||||
// and dial
|
||||
tabletConn, err := tabletconn.GetDialer()(context.TODO(), endPoint, conn.keyspace(), conn.shard(), conn.timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.tabletConn = tabletConn
|
||||
return
|
||||
}
|
||||
|
||||
func (conn *Conn) Close() error {
|
||||
conn.tabletConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) Exec(query string, bindVars map[string]interface{}) (db.Result, error) {
|
||||
if conn.stream {
|
||||
sr, errFunc, err := conn.tabletConn.StreamExecute(context.TODO(), query, bindVars, conn.TransactionId)
|
||||
if err != nil {
|
||||
return nil, conn.fmtErr(err)
|
||||
}
|
||||
// read the columns, or grab the error
|
||||
cols, ok := <-sr
|
||||
if !ok {
|
||||
return nil, conn.fmtErr(errFunc())
|
||||
}
|
||||
return &StreamResult{errFunc, sr, cols, nil, 0, nil}, nil
|
||||
}
|
||||
|
||||
qr, err := conn.tabletConn.Execute(context.TODO(), query, bindVars, conn.TransactionId)
|
||||
if err != nil {
|
||||
return nil, conn.fmtErr(err)
|
||||
}
|
||||
return &Result{qr, 0, nil}, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) Begin() (db.Tx, error) {
|
||||
if conn.TransactionId != 0 {
|
||||
return &Tx{}, ErrNoNestedTxn
|
||||
}
|
||||
if transactionId, err := conn.tabletConn.Begin(context.TODO()); err != nil {
|
||||
return &Tx{}, conn.fmtErr(err)
|
||||
} else {
|
||||
conn.TransactionId = transactionId
|
||||
}
|
||||
return &Tx{conn}, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) Commit() error {
|
||||
if conn.TransactionId == 0 {
|
||||
return ErrBadCommit
|
||||
}
|
||||
// NOTE(msolomon) Unset the transaction_id irrespective of the RPC's
|
||||
// response. The intent of commit is that no more statements can be
|
||||
// made on this transaction, so we guarantee that. Transient errors
|
||||
// between the db and the client shouldn't affect this part of the
|
||||
// bookkeeping. According to the Go Driver API, this will not be
|
||||
// called concurrently. Defer this because we this affects the
|
||||
// session referenced in the request.
|
||||
defer func() { conn.TransactionId = 0 }()
|
||||
return conn.fmtErr(conn.tabletConn.Commit(context.TODO(), conn.TransactionId))
|
||||
}
|
||||
|
||||
func (conn *Conn) Rollback() error {
|
||||
if conn.TransactionId == 0 {
|
||||
return ErrBadRollback
|
||||
}
|
||||
// See note in Commit about the behavior of TransactionId.
|
||||
defer func() { conn.TransactionId = 0 }()
|
||||
return conn.fmtErr(conn.tabletConn.Rollback(context.TODO(), conn.TransactionId))
|
||||
}
|
||||
|
||||
// driver.Tx interface (forwarded to Conn)
|
||||
func (tx *Tx) Commit() error {
|
||||
return tx.conn.Commit()
|
||||
}
|
||||
|
||||
func (tx *Tx) Rollback() error {
|
||||
return tx.conn.Rollback()
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
qr *mproto.QueryResult
|
||||
index int
|
||||
err error
|
||||
}
|
||||
|
||||
// TODO(mberlin): Populate flags here as well (e.g. to correctly identify unsigned integer type)?
|
||||
func NewResult(rowCount, rowsAffected, insertId int64, fields []mproto.Field) *Result {
|
||||
return &Result{
|
||||
qr: &mproto.QueryResult{
|
||||
Rows: make([][]sqltypes.Value, int(rowCount)),
|
||||
Fields: fields,
|
||||
RowsAffected: uint64(rowsAffected),
|
||||
InsertId: uint64(insertId),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (result *Result) RowsRetrieved() int64 {
|
||||
return int64(len(result.qr.Rows))
|
||||
}
|
||||
|
||||
func (result *Result) LastInsertId() (int64, error) {
|
||||
return int64(result.qr.InsertId), nil
|
||||
}
|
||||
|
||||
func (result *Result) RowsAffected() (int64, error) {
|
||||
return int64(result.qr.RowsAffected), nil
|
||||
}
|
||||
|
||||
// driver.Rows interface
|
||||
func (result *Result) Columns() []string {
|
||||
cols := make([]string, len(result.qr.Fields))
|
||||
for i, f := range result.qr.Fields {
|
||||
cols[i] = f.Name
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (result *Result) Rows() [][]sqltypes.Value {
|
||||
return result.qr.Rows
|
||||
}
|
||||
|
||||
// FIXME(msolomon) This should be intependent of the mysql module.
|
||||
func (result *Result) Fields() []mproto.Field {
|
||||
return result.qr.Fields
|
||||
}
|
||||
|
||||
func (result *Result) Close() error {
|
||||
result.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (result *Result) Next() (row []interface{}) {
|
||||
if result.index >= len(result.qr.Rows) {
|
||||
return nil
|
||||
}
|
||||
row = make([]interface{}, len(result.qr.Rows[result.index]))
|
||||
for i, v := range result.qr.Rows[result.index] {
|
||||
var err error
|
||||
row[i], err = mproto.Convert(result.qr.Fields[i].Type, v)
|
||||
if err != nil {
|
||||
panic(err) // unexpected
|
||||
}
|
||||
}
|
||||
result.index++
|
||||
return row
|
||||
}
|
||||
|
||||
func (result *Result) Err() error {
|
||||
return result.err
|
||||
}
|
||||
|
||||
// driver.Result interface
|
||||
func (*StreamResult) LastInsertId() (int64, error) {
|
||||
return 0, ErrNoLastInsertId
|
||||
}
|
||||
|
||||
func (*StreamResult) RowsAffected() (int64, error) {
|
||||
return 0, ErrNoRowsAffected
|
||||
}
|
||||
|
||||
// driver.Rows interface
|
||||
func (sr *StreamResult) Columns() (cols []string) {
|
||||
cols = make([]string, len(sr.columns.Fields))
|
||||
for i, f := range sr.columns.Fields {
|
||||
cols[i] = f.Name
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (*StreamResult) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sr *StreamResult) Next() (row []interface{}) {
|
||||
if sr.qr == nil {
|
||||
// we need to read the next record that may contain
|
||||
// multiple rows
|
||||
qr, ok := <-sr.sr
|
||||
if !ok {
|
||||
if sr.errFunc() != nil {
|
||||
log.Warningf("vt: error reading the next value %v", sr.errFunc())
|
||||
sr.err = sr.errFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
sr.qr = qr
|
||||
sr.index = 0
|
||||
}
|
||||
|
||||
row = make([]interface{}, len(sr.qr.Rows[sr.index]))
|
||||
for i, v := range sr.qr.Rows[sr.index] {
|
||||
var err error
|
||||
row[i], err = mproto.Convert(sr.columns.Fields[i].Type, v)
|
||||
if err != nil {
|
||||
panic(err) // unexpected
|
||||
}
|
||||
}
|
||||
|
||||
sr.index++
|
||||
if sr.index == len(sr.qr.Rows) {
|
||||
// we reached the end of our rows, nil it so next run
|
||||
// will fetch the next one
|
||||
sr.qr = nil
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
func (sr *StreamResult) Err() error {
|
||||
return sr.err
|
||||
}
|
|
@ -1,178 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tablet implements some additional error handling logic to
|
||||
// make the client more robust in the face of transient problems with
|
||||
// easy solutions.
|
||||
package tablet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/db"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrTypeFatal = 1 //errors.New("vt: fatal: reresolve endpoint")
|
||||
ErrTypeRetry = 2 //errors.New("vt: retry: reconnect endpoint")
|
||||
ErrTypeApp = 3 //errors.New("vt: app level error")
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultReconnectDelay = 2 * time.Millisecond
|
||||
DefaultMaxAttempts = 2
|
||||
DefaultTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var zeroTime time.Time
|
||||
|
||||
// Layer some logic on top of the basic tablet protocol to support
|
||||
// fast retry when we can.
|
||||
|
||||
type VtConn struct {
|
||||
Conn
|
||||
maxAttempts int // How many times should try each retriable operation?
|
||||
timeFailed time.Time // This is the time a client transitioned from presumable health to failure.
|
||||
reconnectDelay time.Duration
|
||||
}
|
||||
|
||||
// How long should we wait to try to recover?
|
||||
// FIXME(msolomon) not sure if maxAttempts is still useful
|
||||
func (vtc *VtConn) recoveryTimeout() time.Duration {
|
||||
return vtc.timeout * 2
|
||||
}
|
||||
|
||||
func (vtc *VtConn) handleErr(err error) (int, error) {
|
||||
now := time.Now()
|
||||
if vtc.timeFailed.IsZero() {
|
||||
vtc.timeFailed = now
|
||||
} else if now.Sub(vtc.timeFailed) > vtc.recoveryTimeout() {
|
||||
vtc.Close()
|
||||
return ErrTypeFatal, fmt.Errorf("vt: max recovery time exceeded: %v", err)
|
||||
}
|
||||
|
||||
errType := ErrTypeApp
|
||||
if tabletErr, ok := err.(TabletError); ok {
|
||||
msg := strings.ToLower(tabletErr.err.Error())
|
||||
if strings.HasPrefix(msg, "fatal") {
|
||||
errType = ErrTypeFatal
|
||||
} else if strings.HasPrefix(msg, "retry") {
|
||||
errType = ErrTypeRetry
|
||||
}
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
|
||||
errType = ErrTypeRetry
|
||||
}
|
||||
|
||||
if errType == ErrTypeRetry && vtc.TransactionId != 0 {
|
||||
errType = ErrTypeApp
|
||||
err = fmt.Errorf("vt: cannot retry within a transaction: %v", err)
|
||||
time.Sleep(vtc.reconnectDelay)
|
||||
vtc.Close()
|
||||
dialErr := vtc.dial()
|
||||
log.Warningf("vt: redial error %v", dialErr)
|
||||
}
|
||||
|
||||
return errType, err
|
||||
}
|
||||
|
||||
func (vtc *VtConn) Exec(query string, bindVars map[string]interface{}) (db.Result, error) {
|
||||
attempt := 0
|
||||
for {
|
||||
result, err := vtc.Conn.Exec(query, bindVars)
|
||||
if err == nil {
|
||||
vtc.timeFailed = zeroTime
|
||||
return result, nil
|
||||
}
|
||||
|
||||
errType, err := vtc.handleErr(err)
|
||||
if errType != ErrTypeRetry {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
attempt++
|
||||
if attempt > vtc.maxAttempts {
|
||||
return nil, fmt.Errorf("vt: max recovery attempts exceeded: %v", err)
|
||||
}
|
||||
vtc.Close()
|
||||
time.Sleep(vtc.reconnectDelay)
|
||||
if err := vtc.dial(); err == nil {
|
||||
break
|
||||
}
|
||||
log.Warningf("vt: error dialing on exec %v", vtc.Conn.dbi.Host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vtc *VtConn) Begin() (db.Tx, error) {
|
||||
attempt := 0
|
||||
for {
|
||||
tx, err := vtc.Conn.Begin()
|
||||
if err == nil {
|
||||
vtc.timeFailed = zeroTime
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
errType, err := vtc.handleErr(err)
|
||||
if errType != ErrTypeRetry {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
attempt++
|
||||
if attempt > vtc.maxAttempts {
|
||||
return nil, fmt.Errorf("vt: max recovery attempts exceeded: %v", err)
|
||||
}
|
||||
vtc.Close()
|
||||
time.Sleep(vtc.reconnectDelay)
|
||||
if err := vtc.dial(); err == nil {
|
||||
break
|
||||
}
|
||||
log.Warningf("vt: error dialing on begin %v", vtc.Conn.dbi.Host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vtc *VtConn) Commit() (err error) {
|
||||
if err = vtc.Conn.Commit(); err == nil {
|
||||
vtc.timeFailed = zeroTime
|
||||
return nil
|
||||
}
|
||||
|
||||
// Not much we can do at this point, just annotate the error and return.
|
||||
_, err = vtc.handleErr(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func DialVtdb(dbi string, stream bool, timeout time.Duration) (*VtConn, error) {
|
||||
url, err := parseDbi(dbi)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := &VtConn{
|
||||
Conn: Conn{dbi: url, stream: stream, timeout: timeout},
|
||||
maxAttempts: DefaultMaxAttempts,
|
||||
reconnectDelay: DefaultReconnectDelay,
|
||||
}
|
||||
|
||||
if err := conn.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type vDriver struct {
|
||||
stream bool
|
||||
}
|
||||
|
||||
func (driver *vDriver) Open(name string) (db.Conn, error) {
|
||||
return DialVtdb(name, driver.stream, DefaultTimeout)
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.Register("vttablet", &vDriver{})
|
||||
db.Register("vttablet-streaming", &vDriver{true})
|
||||
}
|
|
@ -1,106 +0,0 @@
|
|||
// Copyright 2013, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package concurrency
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/youtube/vitess/go/sync2"
|
||||
)
|
||||
|
||||
// ResourceConstraint combines 3 different features:
|
||||
// - a WaitGroup to wait for all tasks to be done
|
||||
// - a Semaphore to control concurrency
|
||||
// - an ErrorRecorder
|
||||
type ResourceConstraint struct {
|
||||
semaphore *sync2.Semaphore
|
||||
wg sync.WaitGroup
|
||||
FirstErrorRecorder
|
||||
}
|
||||
|
||||
// NewResourceConstraint creates a ResourceConstraint with
|
||||
// max concurrency.
|
||||
func NewResourceConstraint(max int) *ResourceConstraint {
|
||||
return &ResourceConstraint{semaphore: sync2.NewSemaphore(max, 0)}
|
||||
}
|
||||
|
||||
func (rc *ResourceConstraint) Add(n int) {
|
||||
rc.wg.Add(n)
|
||||
}
|
||||
|
||||
func (rc *ResourceConstraint) Done() {
|
||||
rc.wg.Done()
|
||||
}
|
||||
|
||||
// Wait waits for the WG and returns the firstError we encountered, or nil
|
||||
func (rc *ResourceConstraint) Wait() error {
|
||||
rc.wg.Wait()
|
||||
return rc.Error()
|
||||
}
|
||||
|
||||
// Acquire will wait until we have a resource to use
|
||||
func (rc *ResourceConstraint) Acquire() {
|
||||
rc.semaphore.Acquire()
|
||||
}
|
||||
|
||||
func (rc *ResourceConstraint) Release() {
|
||||
rc.semaphore.Release()
|
||||
}
|
||||
|
||||
func (rc *ResourceConstraint) ReleaseAndDone() {
|
||||
rc.Release()
|
||||
rc.Done()
|
||||
}
|
||||
|
||||
// MultiResourceConstraint combines 3 different features:
|
||||
// - a WaitGroup to wait for all tasks to be done
|
||||
// - a Semaphore map to control multiple concurrencies
|
||||
// - an ErrorRecorder
|
||||
type MultiResourceConstraint struct {
|
||||
semaphoreMap map[string]*sync2.Semaphore
|
||||
wg sync.WaitGroup
|
||||
FirstErrorRecorder
|
||||
}
|
||||
|
||||
func NewMultiResourceConstraint(semaphoreMap map[string]*sync2.Semaphore) *MultiResourceConstraint {
|
||||
return &MultiResourceConstraint{semaphoreMap: semaphoreMap}
|
||||
}
|
||||
|
||||
func (mrc *MultiResourceConstraint) Add(n int) {
|
||||
mrc.wg.Add(n)
|
||||
}
|
||||
|
||||
func (mrc *MultiResourceConstraint) Done() {
|
||||
mrc.wg.Done()
|
||||
}
|
||||
|
||||
// Returns the firstError we encountered, or nil
|
||||
func (mrc *MultiResourceConstraint) Wait() error {
|
||||
mrc.wg.Wait()
|
||||
return mrc.Error()
|
||||
}
|
||||
|
||||
// Acquire will wait until we have a resource to use
|
||||
func (mrc *MultiResourceConstraint) Acquire(name string) {
|
||||
s, ok := mrc.semaphoreMap[name]
|
||||
if !ok {
|
||||
panic(fmt.Errorf("MultiResourceConstraint: No resource named %v in semaphore map", name))
|
||||
}
|
||||
s.Acquire()
|
||||
}
|
||||
|
||||
func (mrc *MultiResourceConstraint) Release(name string) {
|
||||
s, ok := mrc.semaphoreMap[name]
|
||||
if !ok {
|
||||
panic(fmt.Errorf("MultiResourceConstraint: No resource named %v in semaphore map", name))
|
||||
}
|
||||
s.Release()
|
||||
}
|
||||
|
||||
func (mrc *MultiResourceConstraint) ReleaseAndDone(name string) {
|
||||
mrc.Release(name)
|
||||
mrc.Done()
|
||||
}
|
|
@ -107,7 +107,7 @@ func (cp *ConnectionPool) Get(timeout time.Duration) (PoolConnection, error) {
|
|||
ctx := context.Background()
|
||||
if timeout != 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(context.Background(), timeout)
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
r, err := p.Get(ctx)
|
||||
|
@ -117,20 +117,6 @@ func (cp *ConnectionPool) Get(timeout time.Duration) (PoolConnection, error) {
|
|||
return r.(PoolConnection), nil
|
||||
}
|
||||
|
||||
// TryGet returns a connection, or nil.
|
||||
// You must call Recycle on the PoolConnection once done.
|
||||
func (cp *ConnectionPool) TryGet() (PoolConnection, error) {
|
||||
p := cp.pool()
|
||||
if p == nil {
|
||||
return nil, ErrConnPoolClosed
|
||||
}
|
||||
r, err := p.TryGet()
|
||||
if err != nil || r == nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.(PoolConnection), nil
|
||||
}
|
||||
|
||||
// Put puts a connection into the pool.
|
||||
func (cp *ConnectionPool) Put(conn PoolConnection) {
|
||||
p := cp.pool()
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/youtube/vitess/go/jscfg"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestSplitCellPath(t *testing.T) {
|
||||
|
@ -78,14 +79,15 @@ func TestHandlePathKeyspace(t *testing.T) {
|
|||
shard := &topo.Shard{}
|
||||
want := jscfg.ToJSON(keyspace)
|
||||
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, cells)
|
||||
if err := ts.CreateKeyspace("test_keyspace", keyspace); err != nil {
|
||||
if err := ts.CreateKeyspace(ctx, "test_keyspace", keyspace); err != nil {
|
||||
t.Fatalf("CreateKeyspace error: %v", err)
|
||||
}
|
||||
if err := ts.CreateShard("test_keyspace", "10-20", shard); err != nil {
|
||||
if err := ts.CreateShard(ctx, "test_keyspace", "10-20", shard); err != nil {
|
||||
t.Fatalf("CreateShard error: %v", err)
|
||||
}
|
||||
if err := ts.CreateShard("test_keyspace", "20-30", shard); err != nil {
|
||||
if err := ts.CreateShard(ctx, "test_keyspace", "20-30", shard); err != nil {
|
||||
t.Fatalf("CreateShard error: %v", err)
|
||||
}
|
||||
|
||||
|
@ -114,11 +116,12 @@ func TestHandlePathShard(t *testing.T) {
|
|||
shard := &topo.Shard{}
|
||||
want := jscfg.ToJSON(shard)
|
||||
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, cells)
|
||||
if err := ts.CreateKeyspace("test_keyspace", keyspace); err != nil {
|
||||
if err := ts.CreateKeyspace(ctx, "test_keyspace", keyspace); err != nil {
|
||||
t.Fatalf("CreateKeyspace error: %v", err)
|
||||
}
|
||||
if err := ts.CreateShard("test_keyspace", "-80", shard); err != nil {
|
||||
if err := ts.CreateShard(ctx, "test_keyspace", "-80", shard); err != nil {
|
||||
t.Fatalf("CreateShard error: %v", err)
|
||||
}
|
||||
|
||||
|
@ -150,8 +153,9 @@ func TestHandlePathTablet(t *testing.T) {
|
|||
}
|
||||
want := jscfg.ToJSON(tablet)
|
||||
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, cells)
|
||||
if err := ts.CreateTablet(tablet); err != nil {
|
||||
if err := ts.CreateTablet(ctx, tablet); err != nil {
|
||||
t.Fatalf("CreateTablet error: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,10 +14,11 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/concurrency"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/topo/events"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// CreateKeyspace implements topo.Server.
|
||||
func (s *Server) CreateKeyspace(keyspace string, value *topo.Keyspace) error {
|
||||
func (s *Server) CreateKeyspace(ctx context.Context, keyspace string, value *topo.Keyspace) error {
|
||||
data := jscfg.ToJSON(value)
|
||||
global := s.getGlobal()
|
||||
|
||||
|
@ -44,7 +45,7 @@ func (s *Server) CreateKeyspace(keyspace string, value *topo.Keyspace) error {
|
|||
}
|
||||
|
||||
// UpdateKeyspace implements topo.Server.
|
||||
func (s *Server) UpdateKeyspace(ki *topo.KeyspaceInfo, existingVersion int64) (int64, error) {
|
||||
func (s *Server) UpdateKeyspace(ctx context.Context, ki *topo.KeyspaceInfo, existingVersion int64) (int64, error) {
|
||||
data := jscfg.ToJSON(ki.Keyspace)
|
||||
|
||||
resp, err := s.getGlobal().CompareAndSwap(keyspaceFilePath(ki.KeyspaceName()),
|
||||
|
@ -64,7 +65,7 @@ func (s *Server) UpdateKeyspace(ki *topo.KeyspaceInfo, existingVersion int64) (i
|
|||
}
|
||||
|
||||
// GetKeyspace implements topo.Server.
|
||||
func (s *Server) GetKeyspace(keyspace string) (*topo.KeyspaceInfo, error) {
|
||||
func (s *Server) GetKeyspace(ctx context.Context, keyspace string) (*topo.KeyspaceInfo, error) {
|
||||
resp, err := s.getGlobal().Get(keyspaceFilePath(keyspace), false /* sort */, false /* recursive */)
|
||||
if err != nil {
|
||||
return nil, convertError(err)
|
||||
|
@ -82,7 +83,7 @@ func (s *Server) GetKeyspace(keyspace string) (*topo.KeyspaceInfo, error) {
|
|||
}
|
||||
|
||||
// GetKeyspaces implements topo.Server.
|
||||
func (s *Server) GetKeyspaces() ([]string, error) {
|
||||
func (s *Server) GetKeyspaces(ctx context.Context) ([]string, error) {
|
||||
resp, err := s.getGlobal().Get(keyspacesDirPath, true /* sort */, false /* recursive */)
|
||||
if err != nil {
|
||||
err = convertError(err)
|
||||
|
@ -95,8 +96,8 @@ func (s *Server) GetKeyspaces() ([]string, error) {
|
|||
}
|
||||
|
||||
// DeleteKeyspaceShards implements topo.Server.
|
||||
func (s *Server) DeleteKeyspaceShards(keyspace string) error {
|
||||
shards, err := s.GetShardNames(keyspace)
|
||||
func (s *Server) DeleteKeyspaceShards(ctx context.Context, keyspace string) error {
|
||||
shards, err := s.GetShardNames(ctx, keyspace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -185,7 +185,7 @@ func (s *Server) LockSrvShardForAction(ctx context.Context, cellName, keyspace,
|
|||
}
|
||||
|
||||
// UnlockSrvShardForAction implements topo.Server.
|
||||
func (s *Server) UnlockSrvShardForAction(cellName, keyspace, shard, actionPath, results string) error {
|
||||
func (s *Server) UnlockSrvShardForAction(ctx context.Context, cellName, keyspace, shard, actionPath, results string) error {
|
||||
log.Infof("results of %v: %v", actionPath, results)
|
||||
|
||||
cell, err := s.getCell(cellName)
|
||||
|
@ -204,7 +204,7 @@ func (s *Server) LockKeyspaceForAction(ctx context.Context, keyspace, contents s
|
|||
}
|
||||
|
||||
// UnlockKeyspaceForAction implements topo.Server.
|
||||
func (s *Server) UnlockKeyspaceForAction(keyspace, actionPath, results string) error {
|
||||
func (s *Server) UnlockKeyspaceForAction(ctx context.Context, keyspace, actionPath, results string) error {
|
||||
log.Infof("results of %v: %v", actionPath, results)
|
||||
|
||||
return unlock(s.getGlobal(), keyspaceDirPath(keyspace), actionPath,
|
||||
|
@ -218,7 +218,7 @@ func (s *Server) LockShardForAction(ctx context.Context, keyspace, shard, conten
|
|||
}
|
||||
|
||||
// UnlockShardForAction implements topo.Server.
|
||||
func (s *Server) UnlockShardForAction(keyspace, shard, actionPath, results string) error {
|
||||
func (s *Server) UnlockShardForAction(ctx context.Context, keyspace, shard, actionPath, results string) error {
|
||||
log.Infof("results of %v: %v", actionPath, results)
|
||||
|
||||
return unlock(s.getGlobal(), shardDirPath(keyspace, shard), actionPath,
|
||||
|
|
|
@ -10,10 +10,11 @@ import (
|
|||
|
||||
"github.com/youtube/vitess/go/jscfg"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// UpdateShardReplicationFields implements topo.Server.
|
||||
func (s *Server) UpdateShardReplicationFields(cell, keyspace, shard string, updateFunc func(*topo.ShardReplication) error) error {
|
||||
func (s *Server) UpdateShardReplicationFields(ctx context.Context, cell, keyspace, shard string, updateFunc func(*topo.ShardReplication) error) error {
|
||||
var sri *topo.ShardReplicationInfo
|
||||
var version int64
|
||||
var err error
|
||||
|
@ -82,7 +83,7 @@ func (s *Server) createShardReplication(sri *topo.ShardReplicationInfo) (int64,
|
|||
}
|
||||
|
||||
// GetShardReplication implements topo.Server.
|
||||
func (s *Server) GetShardReplication(cell, keyspace, shard string) (*topo.ShardReplicationInfo, error) {
|
||||
func (s *Server) GetShardReplication(ctx context.Context, cell, keyspace, shard string) (*topo.ShardReplicationInfo, error) {
|
||||
sri, _, err := s.getShardReplication(cell, keyspace, shard)
|
||||
return sri, err
|
||||
}
|
||||
|
@ -110,7 +111,7 @@ func (s *Server) getShardReplication(cellName, keyspace, shard string) (*topo.Sh
|
|||
}
|
||||
|
||||
// DeleteShardReplication implements topo.Server.
|
||||
func (s *Server) DeleteShardReplication(cellName, keyspace, shard string) error {
|
||||
func (s *Server) DeleteShardReplication(ctx context.Context, cellName, keyspace, shard string) error {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Server is the implementation of topo.Server for etcd.
|
||||
|
@ -52,7 +53,7 @@ func (s *Server) Close() {
|
|||
}
|
||||
|
||||
// GetKnownCells implements topo.Server.
|
||||
func (s *Server) GetKnownCells() ([]string, error) {
|
||||
func (s *Server) GetKnownCells(ctx context.Context) ([]string, error) {
|
||||
resp, err := s.getGlobal().Get(cellsDirPath, true /* sort */, false /* recursive */)
|
||||
if err != nil {
|
||||
return nil, convertError(err)
|
||||
|
|
|
@ -31,73 +31,83 @@ func newTestServer(t *testing.T, cells []string) *Server {
|
|||
}
|
||||
|
||||
func TestKeyspace(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckKeyspace(t, ts)
|
||||
test.CheckKeyspace(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestShard(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckShard(context.Background(), t, ts)
|
||||
test.CheckShard(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestTablet(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckTablet(context.Background(), t, ts)
|
||||
test.CheckTablet(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestShardReplication(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckShardReplication(t, ts)
|
||||
test.CheckShardReplication(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestServingGraph(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckServingGraph(context.Background(), t, ts)
|
||||
test.CheckServingGraph(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestWatchEndPoints(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckWatchEndPoints(context.Background(), t, ts)
|
||||
test.CheckWatchEndPoints(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestKeyspaceLock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckKeyspaceLock(t, ts)
|
||||
test.CheckKeyspaceLock(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestShardLock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if testing.Short() {
|
||||
t.Skip("skipping wait-based test in short mode.")
|
||||
}
|
||||
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckShardLock(t, ts)
|
||||
test.CheckShardLock(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestSrvShardLock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if testing.Short() {
|
||||
t.Skip("skipping wait-based test in short mode.")
|
||||
}
|
||||
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckSrvShardLock(t, ts)
|
||||
test.CheckSrvShardLock(ctx, t, ts)
|
||||
}
|
||||
|
||||
func TestVSchema(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if testing.Short() {
|
||||
t.Skip("skipping wait-based test in short mode.")
|
||||
}
|
||||
|
||||
ts := newTestServer(t, []string{"test"})
|
||||
defer ts.Close()
|
||||
test.CheckVSchema(t, ts)
|
||||
test.CheckVSchema(ctx, t, ts)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/jscfg"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// WatchSleepDuration is how many seconds interval to poll for in case
|
||||
|
@ -22,7 +23,7 @@ import (
|
|||
var WatchSleepDuration = 30 * time.Second
|
||||
|
||||
// GetSrvTabletTypesPerShard implements topo.Server.
|
||||
func (s *Server) GetSrvTabletTypesPerShard(cellName, keyspace, shard string) ([]topo.TabletType, error) {
|
||||
func (s *Server) GetSrvTabletTypesPerShard(ctx context.Context, cellName, keyspace, shard string) ([]topo.TabletType, error) {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -44,7 +45,7 @@ func (s *Server) GetSrvTabletTypesPerShard(cellName, keyspace, shard string) ([]
|
|||
}
|
||||
|
||||
// UpdateEndPoints implements topo.Server.
|
||||
func (s *Server) UpdateEndPoints(cellName, keyspace, shard string, tabletType topo.TabletType, addrs *topo.EndPoints) error {
|
||||
func (s *Server) UpdateEndPoints(ctx context.Context, cellName, keyspace, shard string, tabletType topo.TabletType, addrs *topo.EndPoints) error {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -71,7 +72,7 @@ func (s *Server) updateEndPoints(cellName, keyspace, shard string, tabletType to
|
|||
}
|
||||
|
||||
// GetEndPoints implements topo.Server.
|
||||
func (s *Server) GetEndPoints(cell, keyspace, shard string, tabletType topo.TabletType) (*topo.EndPoints, error) {
|
||||
func (s *Server) GetEndPoints(ctx context.Context, cell, keyspace, shard string, tabletType topo.TabletType) (*topo.EndPoints, error) {
|
||||
value, _, err := s.getEndPoints(cell, keyspace, shard, tabletType)
|
||||
return value, err
|
||||
}
|
||||
|
@ -100,7 +101,7 @@ func (s *Server) getEndPoints(cellName, keyspace, shard string, tabletType topo.
|
|||
}
|
||||
|
||||
// DeleteEndPoints implements topo.Server.
|
||||
func (s *Server) DeleteEndPoints(cellName, keyspace, shard string, tabletType topo.TabletType) error {
|
||||
func (s *Server) DeleteEndPoints(ctx context.Context, cellName, keyspace, shard string, tabletType topo.TabletType) error {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -111,7 +112,7 @@ func (s *Server) DeleteEndPoints(cellName, keyspace, shard string, tabletType to
|
|||
}
|
||||
|
||||
// UpdateSrvShard implements topo.Server.
|
||||
func (s *Server) UpdateSrvShard(cellName, keyspace, shard string, srvShard *topo.SrvShard) error {
|
||||
func (s *Server) UpdateSrvShard(ctx context.Context, cellName, keyspace, shard string, srvShard *topo.SrvShard) error {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -124,7 +125,7 @@ func (s *Server) UpdateSrvShard(cellName, keyspace, shard string, srvShard *topo
|
|||
}
|
||||
|
||||
// GetSrvShard implements topo.Server.
|
||||
func (s *Server) GetSrvShard(cellName, keyspace, shard string) (*topo.SrvShard, error) {
|
||||
func (s *Server) GetSrvShard(ctx context.Context, cellName, keyspace, shard string) (*topo.SrvShard, error) {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -146,7 +147,7 @@ func (s *Server) GetSrvShard(cellName, keyspace, shard string) (*topo.SrvShard,
|
|||
}
|
||||
|
||||
// DeleteSrvShard implements topo.Server.
|
||||
func (s *Server) DeleteSrvShard(cellName, keyspace, shard string) error {
|
||||
func (s *Server) DeleteSrvShard(ctx context.Context, cellName, keyspace, shard string) error {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -157,7 +158,7 @@ func (s *Server) DeleteSrvShard(cellName, keyspace, shard string) error {
|
|||
}
|
||||
|
||||
// UpdateSrvKeyspace implements topo.Server.
|
||||
func (s *Server) UpdateSrvKeyspace(cellName, keyspace string, srvKeyspace *topo.SrvKeyspace) error {
|
||||
func (s *Server) UpdateSrvKeyspace(ctx context.Context, cellName, keyspace string, srvKeyspace *topo.SrvKeyspace) error {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -170,7 +171,7 @@ func (s *Server) UpdateSrvKeyspace(cellName, keyspace string, srvKeyspace *topo.
|
|||
}
|
||||
|
||||
// GetSrvKeyspace implements topo.Server.
|
||||
func (s *Server) GetSrvKeyspace(cellName, keyspace string) (*topo.SrvKeyspace, error) {
|
||||
func (s *Server) GetSrvKeyspace(ctx context.Context, cellName, keyspace string) (*topo.SrvKeyspace, error) {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -192,7 +193,7 @@ func (s *Server) GetSrvKeyspace(cellName, keyspace string) (*topo.SrvKeyspace, e
|
|||
}
|
||||
|
||||
// GetSrvKeyspaceNames implements topo.Server.
|
||||
func (s *Server) GetSrvKeyspaceNames(cellName string) ([]string, error) {
|
||||
func (s *Server) GetSrvKeyspaceNames(ctx context.Context, cellName string) ([]string, error) {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -206,7 +207,7 @@ func (s *Server) GetSrvKeyspaceNames(cellName string) ([]string, error) {
|
|||
}
|
||||
|
||||
// UpdateTabletEndpoint implements topo.Server.
|
||||
func (s *Server) UpdateTabletEndpoint(cell, keyspace, shard string, tabletType topo.TabletType, addr *topo.EndPoint) error {
|
||||
func (s *Server) UpdateTabletEndpoint(ctx context.Context, cell, keyspace, shard string, tabletType topo.TabletType, addr *topo.EndPoint) error {
|
||||
for {
|
||||
addrs, version, err := s.getEndPoints(cell, keyspace, shard, tabletType)
|
||||
if err == topo.ErrNoNode {
|
||||
|
@ -239,7 +240,7 @@ func (s *Server) UpdateTabletEndpoint(cell, keyspace, shard string, tabletType t
|
|||
}
|
||||
|
||||
// WatchEndPoints is part of the topo.Server interface
|
||||
func (s *Server) WatchEndPoints(cellName, keyspace, shard string, tabletType topo.TabletType) (<-chan *topo.EndPoints, chan<- struct{}, error) {
|
||||
func (s *Server) WatchEndPoints(ctx context.Context, cellName, keyspace, shard string, tabletType topo.TabletType) (<-chan *topo.EndPoints, chan<- struct{}, error) {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("WatchEndPoints cannot get cell: %v", err)
|
||||
|
|
|
@ -12,10 +12,11 @@ import (
|
|||
"github.com/youtube/vitess/go/jscfg"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/topo/events"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// CreateShard implements topo.Server.
|
||||
func (s *Server) CreateShard(keyspace, shard string, value *topo.Shard) error {
|
||||
func (s *Server) CreateShard(ctx context.Context, keyspace, shard string, value *topo.Shard) error {
|
||||
data := jscfg.ToJSON(value)
|
||||
global := s.getGlobal()
|
||||
|
||||
|
@ -42,7 +43,7 @@ func (s *Server) CreateShard(keyspace, shard string, value *topo.Shard) error {
|
|||
}
|
||||
|
||||
// UpdateShard implements topo.Server.
|
||||
func (s *Server) UpdateShard(si *topo.ShardInfo, existingVersion int64) (int64, error) {
|
||||
func (s *Server) UpdateShard(ctx context.Context, si *topo.ShardInfo, existingVersion int64) (int64, error) {
|
||||
data := jscfg.ToJSON(si.Shard)
|
||||
|
||||
resp, err := s.getGlobal().CompareAndSwap(shardFilePath(si.Keyspace(), si.ShardName()),
|
||||
|
@ -62,13 +63,13 @@ func (s *Server) UpdateShard(si *topo.ShardInfo, existingVersion int64) (int64,
|
|||
}
|
||||
|
||||
// ValidateShard implements topo.Server.
|
||||
func (s *Server) ValidateShard(keyspace, shard string) error {
|
||||
_, err := s.GetShard(keyspace, shard)
|
||||
func (s *Server) ValidateShard(ctx context.Context, keyspace, shard string) error {
|
||||
_, err := s.GetShard(ctx, keyspace, shard)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetShard implements topo.Server.
|
||||
func (s *Server) GetShard(keyspace, shard string) (*topo.ShardInfo, error) {
|
||||
func (s *Server) GetShard(ctx context.Context, keyspace, shard string) (*topo.ShardInfo, error) {
|
||||
resp, err := s.getGlobal().Get(shardFilePath(keyspace, shard), false /* sort */, false /* recursive */)
|
||||
if err != nil {
|
||||
return nil, convertError(err)
|
||||
|
@ -86,7 +87,7 @@ func (s *Server) GetShard(keyspace, shard string) (*topo.ShardInfo, error) {
|
|||
}
|
||||
|
||||
// GetShardNames implements topo.Server.
|
||||
func (s *Server) GetShardNames(keyspace string) ([]string, error) {
|
||||
func (s *Server) GetShardNames(ctx context.Context, keyspace string) ([]string, error) {
|
||||
resp, err := s.getGlobal().Get(shardsDirPath(keyspace), true /* sort */, false /* recursive */)
|
||||
if err != nil {
|
||||
return nil, convertError(err)
|
||||
|
@ -95,7 +96,7 @@ func (s *Server) GetShardNames(keyspace string) ([]string, error) {
|
|||
}
|
||||
|
||||
// DeleteShard implements topo.Server.
|
||||
func (s *Server) DeleteShard(keyspace, shard string) error {
|
||||
func (s *Server) DeleteShard(ctx context.Context, keyspace, shard string) error {
|
||||
_, err := s.getGlobal().Delete(shardDirPath(keyspace, shard), true /* recursive */)
|
||||
if err != nil {
|
||||
return convertError(err)
|
||||
|
|
|
@ -12,10 +12,11 @@ import (
|
|||
"github.com/youtube/vitess/go/jscfg"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/topo/events"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// CreateTablet implements topo.Server.
|
||||
func (s *Server) CreateTablet(tablet *topo.Tablet) error {
|
||||
func (s *Server) CreateTablet(ctx context.Context, tablet *topo.Tablet) error {
|
||||
cell, err := s.getCell(tablet.Alias.Cell)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -35,7 +36,7 @@ func (s *Server) CreateTablet(tablet *topo.Tablet) error {
|
|||
}
|
||||
|
||||
// UpdateTablet implements topo.Server.
|
||||
func (s *Server) UpdateTablet(ti *topo.TabletInfo, existingVersion int64) (int64, error) {
|
||||
func (s *Server) UpdateTablet(ctx context.Context, ti *topo.TabletInfo, existingVersion int64) (int64, error) {
|
||||
cell, err := s.getCell(ti.Alias.Cell)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
|
@ -59,18 +60,18 @@ func (s *Server) UpdateTablet(ti *topo.TabletInfo, existingVersion int64) (int64
|
|||
}
|
||||
|
||||
// UpdateTabletFields implements topo.Server.
|
||||
func (s *Server) UpdateTabletFields(tabletAlias topo.TabletAlias, updateFunc func(*topo.Tablet) error) error {
|
||||
func (s *Server) UpdateTabletFields(ctx context.Context, tabletAlias topo.TabletAlias, updateFunc func(*topo.Tablet) error) error {
|
||||
var ti *topo.TabletInfo
|
||||
var err error
|
||||
|
||||
for {
|
||||
if ti, err = s.GetTablet(tabletAlias); err != nil {
|
||||
if ti, err = s.GetTablet(ctx, tabletAlias); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = updateFunc(ti.Tablet); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = s.UpdateTablet(ti, ti.Version()); err != topo.ErrBadVersion {
|
||||
if _, err = s.UpdateTablet(ctx, ti, ti.Version()); err != topo.ErrBadVersion {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -86,14 +87,14 @@ func (s *Server) UpdateTabletFields(tabletAlias topo.TabletAlias, updateFunc fun
|
|||
}
|
||||
|
||||
// DeleteTablet implements topo.Server.
|
||||
func (s *Server) DeleteTablet(tabletAlias topo.TabletAlias) error {
|
||||
func (s *Server) DeleteTablet(ctx context.Context, tabletAlias topo.TabletAlias) error {
|
||||
cell, err := s.getCell(tabletAlias.Cell)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the keyspace and shard names for the TabletChange event.
|
||||
ti, tiErr := s.GetTablet(tabletAlias)
|
||||
ti, tiErr := s.GetTablet(ctx, tabletAlias)
|
||||
|
||||
_, err = cell.Delete(tabletDirPath(tabletAlias.String()), true /* recursive */)
|
||||
if err != nil {
|
||||
|
@ -116,13 +117,13 @@ func (s *Server) DeleteTablet(tabletAlias topo.TabletAlias) error {
|
|||
}
|
||||
|
||||
// ValidateTablet implements topo.Server.
|
||||
func (s *Server) ValidateTablet(tabletAlias topo.TabletAlias) error {
|
||||
_, err := s.GetTablet(tabletAlias)
|
||||
func (s *Server) ValidateTablet(ctx context.Context, tabletAlias topo.TabletAlias) error {
|
||||
_, err := s.GetTablet(ctx, tabletAlias)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTablet implements topo.Server.
|
||||
func (s *Server) GetTablet(tabletAlias topo.TabletAlias) (*topo.TabletInfo, error) {
|
||||
func (s *Server) GetTablet(ctx context.Context, tabletAlias topo.TabletAlias) (*topo.TabletInfo, error) {
|
||||
cell, err := s.getCell(tabletAlias.Cell)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -145,7 +146,7 @@ func (s *Server) GetTablet(tabletAlias topo.TabletAlias) (*topo.TabletInfo, erro
|
|||
}
|
||||
|
||||
// GetTabletsByCell implements topo.Server.
|
||||
func (s *Server) GetTabletsByCell(cellName string) ([]topo.TabletAlias, error) {
|
||||
func (s *Server) GetTabletsByCell(ctx context.Context, cellName string) ([]topo.TabletAlias, error) {
|
||||
cell, err := s.getCell(cellName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -3,6 +3,7 @@ package etcdtopo
|
|||
import (
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/planbuilder"
|
||||
"golang.org/x/net/context"
|
||||
// vindexes needs to be imported so that they register
|
||||
// themselves against vtgate/planbuilder. This will allow
|
||||
// us to sanity check the schema being uploaded.
|
||||
|
@ -14,7 +15,7 @@ This file contains the vschema management code for etcdtopo.Server
|
|||
*/
|
||||
|
||||
// SaveVSchema saves the JSON vschema into the topo.
|
||||
func (s *Server) SaveVSchema(vschema string) error {
|
||||
func (s *Server) SaveVSchema(ctx context.Context, vschema string) error {
|
||||
_, err := planbuilder.NewSchema([]byte(vschema))
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -28,7 +29,7 @@ func (s *Server) SaveVSchema(vschema string) error {
|
|||
}
|
||||
|
||||
// GetVSchema fetches the JSON vschema from the topo.
|
||||
func (s *Server) GetVSchema() (string, error) {
|
||||
func (s *Server) GetVSchema(ctx context.Context) (string, error) {
|
||||
resp, err := s.getGlobal().Get(vschemaPath, false /* sort */, false /* recursive */)
|
||||
if err != nil {
|
||||
err = convertError(err)
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/concurrency"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -26,11 +25,11 @@ func init() {
|
|||
type Reporter interface {
|
||||
// Report returns the replication delay gathered by this
|
||||
// module (or 0 if it thinks it's not behind), assuming that
|
||||
// its tablet type is TabletType, and that its query service
|
||||
// it is a slave type or not, and that its query service
|
||||
// should be running or not. If Report returns an error it
|
||||
// implies that the tablet is in a bad shape and not able to
|
||||
// handle queries.
|
||||
Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (replicationDelay time.Duration, err error)
|
||||
Report(isSlaveType, shouldQueryServiceBeRunning bool) (replicationDelay time.Duration, err error)
|
||||
|
||||
// HTMLName returns a displayable name for the module.
|
||||
// Can be used to be displayed in the status page.
|
||||
|
@ -38,11 +37,11 @@ type Reporter interface {
|
|||
}
|
||||
|
||||
// FunctionReporter is a function that may act as a Reporter.
|
||||
type FunctionReporter func(topo.TabletType, bool) (time.Duration, error)
|
||||
type FunctionReporter func(bool, bool) (time.Duration, error)
|
||||
|
||||
// Report implements Reporter.Report
|
||||
func (fc FunctionReporter) Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
|
||||
return fc(tabletType, shouldQueryServiceBeRunning)
|
||||
func (fc FunctionReporter) Report(isSlaveType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
|
||||
return fc(isSlaveType, shouldQueryServiceBeRunning)
|
||||
}
|
||||
|
||||
// HTMLName implements Reporter.HTMLName
|
||||
|
@ -71,7 +70,7 @@ func NewAggregator() *Aggregator {
|
|||
// The returned replication delay will be the highest of all the replication
|
||||
// delays returned by the Reporter implementations (although typically
|
||||
// only one implementation will actually return a meaningful one).
|
||||
func (ag *Aggregator) Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
|
||||
func (ag *Aggregator) Report(isSlaveType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
rec concurrency.AllErrorRecorder
|
||||
|
@ -83,7 +82,7 @@ func (ag *Aggregator) Report(tabletType topo.TabletType, shouldQueryServiceBeRun
|
|||
wg.Add(1)
|
||||
go func(name string, rep Reporter) {
|
||||
defer wg.Done()
|
||||
replicationDelay, err := rep.Report(tabletType, shouldQueryServiceBeRunning)
|
||||
replicationDelay, err := rep.Report(isSlaveType, shouldQueryServiceBeRunning)
|
||||
if err != nil {
|
||||
rec.RecordError(fmt.Errorf("%v: %v", name, err))
|
||||
return
|
||||
|
|
|
@ -4,23 +4,21 @@ import (
|
|||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
)
|
||||
|
||||
func TestReporters(t *testing.T) {
|
||||
|
||||
ag := NewAggregator()
|
||||
|
||||
ag.Register("a", FunctionReporter(func(topo.TabletType, bool) (time.Duration, error) {
|
||||
ag.Register("a", FunctionReporter(func(bool, bool) (time.Duration, error) {
|
||||
return 10 * time.Second, nil
|
||||
}))
|
||||
|
||||
ag.Register("b", FunctionReporter(func(topo.TabletType, bool) (time.Duration, error) {
|
||||
ag.Register("b", FunctionReporter(func(bool, bool) (time.Duration, error) {
|
||||
return 5 * time.Second, nil
|
||||
}))
|
||||
|
||||
delay, err := ag.Report(topo.TYPE_REPLICA, true)
|
||||
delay, err := ag.Report(true, true)
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -29,10 +27,10 @@ func TestReporters(t *testing.T) {
|
|||
t.Errorf("delay=%v, want 10s", delay)
|
||||
}
|
||||
|
||||
ag.Register("c", FunctionReporter(func(topo.TabletType, bool) (time.Duration, error) {
|
||||
ag.Register("c", FunctionReporter(func(bool, bool) (time.Duration, error) {
|
||||
return 0, errors.New("e error")
|
||||
}))
|
||||
if _, err := ag.Report(topo.TYPE_REPLICA, false); err == nil {
|
||||
if _, err := ag.Report(true, false); err == nil {
|
||||
t.Errorf("ag.Run: expected error")
|
||||
}
|
||||
|
||||
|
|
|
@ -116,9 +116,14 @@ func (hook *Hook) Execute() (result *HookResult) {
|
|||
// Execute an optional hook, returns a printable error
|
||||
func (hook *Hook) ExecuteOptional() error {
|
||||
hr := hook.Execute()
|
||||
if hr.ExitStatus == HOOK_DOES_NOT_EXIST {
|
||||
switch hr.ExitStatus {
|
||||
case HOOK_DOES_NOT_EXIST:
|
||||
log.Infof("%v hook doesn't exist", hook.Name)
|
||||
} else if hr.ExitStatus != HOOK_SUCCESS {
|
||||
case HOOK_VTROOT_ERROR:
|
||||
log.Infof("VTROOT not set, so %v hook doesn't exist", hook.Name)
|
||||
case HOOK_SUCCESS:
|
||||
// nothing to do here
|
||||
default:
|
||||
return fmt.Errorf("%v hook failed(%v): %v", hook.Name, hr.ExitStatus, hr.Stderr)
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -0,0 +1,557 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mysqlctl
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
|
||||
"github.com/youtube/vitess/go/cgzip"
|
||||
"github.com/youtube/vitess/go/sync2"
|
||||
"github.com/youtube/vitess/go/vt/concurrency"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl/backupstorage"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
|
||||
)
|
||||
|
||||
// This file handles the backup and restore related code
|
||||
|
||||
const (
|
||||
// the three bases for files to restore
|
||||
backupInnodbDataHomeDir = "InnoDBData"
|
||||
backupInnodbLogGroupHomeDir = "InnoDBLog"
|
||||
backupData = "Data"
|
||||
|
||||
// the manifest file name
|
||||
backupManifest = "MANIFEST"
|
||||
)
|
||||
|
||||
const (
|
||||
// slaveStartDeadline is the deadline for starting a slave
|
||||
slaveStartDeadline = 30
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoBackup is returned when there is no backup
|
||||
ErrNoBackup = errors.New("no available backup")
|
||||
)
|
||||
|
||||
// FileEntry is one file to backup
|
||||
type FileEntry struct {
|
||||
// Base is one of:
|
||||
// - backupInnodbDataHomeDir for files that go into Mycnf.InnodbDataHomeDir
|
||||
// - backupInnodbLogGroupHomeDir for files that go into Mycnf.InnodbLogGroupHomeDir
|
||||
// - backupData for files that go into Mycnf.DataDir
|
||||
Base string
|
||||
|
||||
// Name is the file name, relative to Base
|
||||
Name string
|
||||
|
||||
// Hash is the hash of the gzip compressed data stored in the
|
||||
// BackupStorage.
|
||||
Hash string
|
||||
}
|
||||
|
||||
func (fe *FileEntry) open(cnf *Mycnf, readOnly bool) (*os.File, error) {
|
||||
// find the root to use
|
||||
var root string
|
||||
switch fe.Base {
|
||||
case backupInnodbDataHomeDir:
|
||||
root = cnf.InnodbDataHomeDir
|
||||
case backupInnodbLogGroupHomeDir:
|
||||
root = cnf.InnodbLogGroupHomeDir
|
||||
case backupData:
|
||||
root = cnf.DataDir
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown base: %v", fe.Base)
|
||||
}
|
||||
|
||||
// and open the file
|
||||
name := path.Join(root, fe.Name)
|
||||
var fd *os.File
|
||||
var err error
|
||||
if readOnly {
|
||||
fd, err = os.Open(name)
|
||||
} else {
|
||||
dir := path.Dir(name)
|
||||
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return nil, fmt.Errorf("cannot create destination directory %v: %v", dir, err)
|
||||
}
|
||||
fd, err = os.Create(name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot open source file %v: %v", name, err)
|
||||
}
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
// BackupManifest represents the backup. It lists all the files, and
|
||||
// the ReplicationPosition that the backup was taken at.
|
||||
type BackupManifest struct {
|
||||
// FileEntries contains all the files in the backup
|
||||
FileEntries []FileEntry
|
||||
|
||||
// ReplicationPosition is the position at which the backup was taken
|
||||
ReplicationPosition proto.ReplicationPosition
|
||||
}
|
||||
|
||||
// isDbDir returns true if the given directory contains a DB
|
||||
func isDbDir(p string) bool {
|
||||
// db.opt is there
|
||||
if _, err := os.Stat(path.Join(p, "db.opt")); err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Look for at least one .frm file
|
||||
fis, err := ioutil.ReadDir(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, fi := range fis {
|
||||
if strings.HasSuffix(fi.Name(), ".frm") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func addDirectory(fes []FileEntry, base string, baseDir string, subDir string) ([]FileEntry, error) {
|
||||
p := path.Join(baseDir, subDir)
|
||||
|
||||
fis, err := ioutil.ReadDir(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fi := range fis {
|
||||
fes = append(fes, FileEntry{
|
||||
Base: base,
|
||||
Name: path.Join(subDir, fi.Name()),
|
||||
})
|
||||
}
|
||||
return fes, nil
|
||||
}
|
||||
|
||||
func findFilesTobackup(cnf *Mycnf) ([]FileEntry, error) {
|
||||
var err error
|
||||
var result []FileEntry
|
||||
|
||||
// first add inno db files
|
||||
result, err = addDirectory(result, backupInnodbDataHomeDir, cnf.InnodbDataHomeDir, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err = addDirectory(result, backupInnodbLogGroupHomeDir, cnf.InnodbLogGroupHomeDir, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// then add DB directories
|
||||
fis, err := ioutil.ReadDir(cnf.DataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, fi := range fis {
|
||||
p := path.Join(cnf.DataDir, fi.Name())
|
||||
if isDbDir(p) {
|
||||
result, err = addDirectory(result, backupData, cnf.DataDir, fi.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Backup is the main entry point for a backup:
|
||||
// - uses the BackupStorage service to store a new backup
|
||||
// - shuts down Mysqld during the backup
|
||||
// - remember if we were replicating, restore the exact same state
|
||||
func Backup(mysqld MysqlDaemon, logger logutil.Logger, bucket, name string, backupConcurrency int, hookExtraEnv map[string]string) error {
|
||||
|
||||
// start the backup with the BackupStorage
|
||||
bs, err := backupstorage.GetBackupStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bh, err := bs.StartBackup(bucket, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("StartBackup failed: %v", err)
|
||||
}
|
||||
|
||||
if err = backup(mysqld, logger, bh, backupConcurrency, hookExtraEnv); err != nil {
|
||||
if abortErr := bh.AbortBackup(); abortErr != nil {
|
||||
logger.Errorf("failed to abort backup: %v", abortErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return bh.EndBackup()
|
||||
}
|
||||
|
||||
func backup(mysqld MysqlDaemon, logger logutil.Logger, bh backupstorage.BackupHandle, backupConcurrency int, hookExtraEnv map[string]string) error {
|
||||
|
||||
// save initial state so we can restore
|
||||
slaveStartRequired := false
|
||||
sourceIsMaster := false
|
||||
readOnly := true
|
||||
var replicationPosition proto.ReplicationPosition
|
||||
|
||||
// see if we need to restart replication after backup
|
||||
logger.Infof("getting current replication status")
|
||||
slaveStatus, err := mysqld.SlaveStatus()
|
||||
switch err {
|
||||
case nil:
|
||||
slaveStartRequired = slaveStatus.SlaveRunning()
|
||||
case ErrNotSlave:
|
||||
// keep going if we're the master, might be a degenerate case
|
||||
sourceIsMaster = true
|
||||
default:
|
||||
return fmt.Errorf("cannot get slave status: %v", err)
|
||||
}
|
||||
|
||||
// get the read-only flag
|
||||
readOnly, err = mysqld.IsReadOnly()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get read only status: %v", err)
|
||||
}
|
||||
|
||||
// get the replication position
|
||||
if sourceIsMaster {
|
||||
if !readOnly {
|
||||
logger.Infof("turning master read-onyl before backup")
|
||||
if err = mysqld.SetReadOnly(true); err != nil {
|
||||
return fmt.Errorf("cannot get read only status: %v", err)
|
||||
}
|
||||
}
|
||||
replicationPosition, err = mysqld.MasterPosition()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get master position: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err = StopSlave(mysqld, hookExtraEnv); err != nil {
|
||||
return fmt.Errorf("cannot stop slave: %v", err)
|
||||
}
|
||||
var slaveStatus proto.ReplicationStatus
|
||||
slaveStatus, err = mysqld.SlaveStatus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get slave status: %v", err)
|
||||
}
|
||||
replicationPosition = slaveStatus.Position
|
||||
}
|
||||
logger.Infof("using replication position: %v", replicationPosition)
|
||||
|
||||
// shutdown mysqld
|
||||
if err = mysqld.Shutdown(true, MysqlWaitTime); err != nil {
|
||||
return fmt.Errorf("cannot shutdown mysqld: %v", err)
|
||||
}
|
||||
|
||||
// get the files to backup
|
||||
fes, err := findFilesTobackup(mysqld.Cnf())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot find files to backup: %v", err)
|
||||
}
|
||||
logger.Infof("found %v files to backup", len(fes))
|
||||
|
||||
// backup everything
|
||||
if err := backupFiles(mysqld, logger, bh, fes, replicationPosition, backupConcurrency); err != nil {
|
||||
return fmt.Errorf("cannot backup files: %v", err)
|
||||
}
|
||||
|
||||
// Try to restart mysqld
|
||||
if err := mysqld.Start(MysqlWaitTime); err != nil {
|
||||
return fmt.Errorf("cannot restart mysqld: %v", err)
|
||||
}
|
||||
|
||||
// Restore original mysqld state that we saved above.
|
||||
if slaveStartRequired {
|
||||
logger.Infof("restarting mysql replication")
|
||||
if err := StartSlave(mysqld, hookExtraEnv); err != nil {
|
||||
return fmt.Errorf("cannot restart slave: %v", err)
|
||||
}
|
||||
|
||||
// this should be quick, but we might as well just wait
|
||||
if err := WaitForSlaveStart(mysqld, slaveStartDeadline); err != nil {
|
||||
return fmt.Errorf("slave is not restarting: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// And set read-only mode
|
||||
logger.Infof("resetting mysqld read-only to %v", readOnly)
|
||||
if err := mysqld.SetReadOnly(readOnly); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func backupFiles(mysqld MysqlDaemon, logger logutil.Logger, bh backupstorage.BackupHandle, fes []FileEntry, replicationPosition proto.ReplicationPosition, backupConcurrency int) (err error) {
|
||||
sema := sync2.NewSemaphore(backupConcurrency, 0)
|
||||
rec := concurrency.AllErrorRecorder{}
|
||||
wg := sync.WaitGroup{}
|
||||
for i, fe := range fes {
|
||||
wg.Add(1)
|
||||
go func(i int, fe FileEntry) {
|
||||
defer wg.Done()
|
||||
|
||||
// wait until we are ready to go, skip if we already
|
||||
// encountered an error
|
||||
sema.Acquire()
|
||||
defer sema.Release()
|
||||
if rec.HasErrors() {
|
||||
return
|
||||
}
|
||||
|
||||
// open the source file for reading
|
||||
source, err := fe.open(mysqld.Cnf(), true)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
// open the destination file for writing, and a buffer
|
||||
name := fmt.Sprintf("%v", i)
|
||||
wc, err := bh.AddFile(name)
|
||||
if err != nil {
|
||||
rec.RecordError(fmt.Errorf("cannot add file: %v", err))
|
||||
return
|
||||
}
|
||||
defer func() { rec.RecordError(wc.Close()) }()
|
||||
dst := bufio.NewWriterSize(wc, 2*1024*1024)
|
||||
|
||||
// create the hasher and the tee on top
|
||||
hasher := newHasher()
|
||||
tee := io.MultiWriter(dst, hasher)
|
||||
|
||||
// create the gzip compression filter
|
||||
gzip, err := cgzip.NewWriterLevel(tee, cgzip.Z_BEST_SPEED)
|
||||
if err != nil {
|
||||
rec.RecordError(fmt.Errorf("cannot create gziper: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// copy from the source file to gzip to tee to output file and hasher
|
||||
_, err = io.Copy(gzip, source)
|
||||
if err != nil {
|
||||
rec.RecordError(fmt.Errorf("cannot copy data: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// close gzip to flush it, after that the hash is good
|
||||
if err = gzip.Close(); err != nil {
|
||||
rec.RecordError(fmt.Errorf("cannot close gzip: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// flush the buffer to finish writing, save the hash
|
||||
rec.RecordError(dst.Flush())
|
||||
fes[i].Hash = hasher.HashString()
|
||||
}(i, fe)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if rec.HasErrors() {
|
||||
return rec.Error()
|
||||
}
|
||||
|
||||
// open the MANIFEST
|
||||
wc, err := bh.AddFile(backupManifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add %v to backup: %v", backupManifest, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := wc.Close(); err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
// JSON-encode and write the MANIFEST
|
||||
bm := &BackupManifest{
|
||||
FileEntries: fes,
|
||||
ReplicationPosition: replicationPosition,
|
||||
}
|
||||
data, err := json.MarshalIndent(bm, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot JSON encode %v: %v", backupManifest, err)
|
||||
}
|
||||
if _, err := wc.Write([]byte(data)); err != nil {
|
||||
return fmt.Errorf("cannot write %v: %v", backupManifest, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkNoDB makes sure there is no vt_ db already there. Used by Restore,
|
||||
// we do not wnat to destroy an existing DB.
|
||||
func checkNoDB(mysqld MysqlDaemon) error {
|
||||
qr, err := mysqld.FetchSuperQuery("SHOW DATABASES")
|
||||
if err != nil {
|
||||
return fmt.Errorf("checkNoDB failed: %v", err)
|
||||
}
|
||||
|
||||
for _, row := range qr.Rows {
|
||||
if strings.HasPrefix(row[0].String(), "vt_") {
|
||||
dbName := row[0].String()
|
||||
tableQr, err := mysqld.FetchSuperQuery("SHOW TABLES FROM " + dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checkNoDB failed: %v", err)
|
||||
} else if len(tableQr.Rows) == 0 {
|
||||
// no tables == empty db, all is well
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("checkNoDB failed, found active db %v", dbName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreFiles will copy all the files from the BackupStorage to the
|
||||
// right place
|
||||
func restoreFiles(cnf *Mycnf, bh backupstorage.BackupHandle, fes []FileEntry, restoreConcurrency int) error {
|
||||
sema := sync2.NewSemaphore(restoreConcurrency, 0)
|
||||
rec := concurrency.AllErrorRecorder{}
|
||||
wg := sync.WaitGroup{}
|
||||
for i, fe := range fes {
|
||||
wg.Add(1)
|
||||
go func(i int, fe FileEntry) {
|
||||
defer wg.Done()
|
||||
|
||||
// wait until we are ready to go, skip if we already
|
||||
// encountered an error
|
||||
sema.Acquire()
|
||||
defer sema.Release()
|
||||
if rec.HasErrors() {
|
||||
return
|
||||
}
|
||||
|
||||
// open the source file for reading
|
||||
name := fmt.Sprintf("%v", i)
|
||||
source, err := bh.ReadFile(name)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
// open the destination file for writing
|
||||
dstFile, err := fe.open(cnf, false)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
}
|
||||
defer func() { rec.RecordError(dstFile.Close()) }()
|
||||
|
||||
// create a buffering output
|
||||
dst := bufio.NewWriterSize(dstFile, 2*1024*1024)
|
||||
|
||||
// create hash to write the compressed data to
|
||||
hasher := newHasher()
|
||||
|
||||
// create a Tee: we split the input into the hasher
|
||||
// and into the gunziper
|
||||
tee := io.TeeReader(source, hasher)
|
||||
|
||||
// create the uncompresser
|
||||
gz, err := cgzip.NewReader(tee)
|
||||
if err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
}
|
||||
defer func() { rec.RecordError(gz.Close()) }()
|
||||
|
||||
// copy the data. Will also write to the hasher
|
||||
if _, err = io.Copy(dst, gz); err != nil {
|
||||
rec.RecordError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// check the hash
|
||||
hash := hasher.HashString()
|
||||
if hash != fe.Hash {
|
||||
rec.RecordError(fmt.Errorf("hash mismatch for %v, got %v expected %v", fe.Name, hash, fe.Hash))
|
||||
return
|
||||
}
|
||||
|
||||
// flush the buffer
|
||||
rec.RecordError(dst.Flush())
|
||||
}(i, fe)
|
||||
}
|
||||
wg.Wait()
|
||||
return rec.Error()
|
||||
}
|
||||
|
||||
// Restore is the main entry point for backup restore. If there is no
|
||||
// appropriate backup on the BackupStorage, Restore logs an error
|
||||
// and returns ErrNoBackup. Any other error is returned.
|
||||
func Restore(mysqld MysqlDaemon, bucket string, restoreConcurrency int, hookExtraEnv map[string]string) (proto.ReplicationPosition, error) {
|
||||
// find the right backup handle: most recent one, with a MANIFEST
|
||||
log.Infof("Restore: looking for a suitable backup to restore")
|
||||
bs, err := backupstorage.GetBackupStorage()
|
||||
if err != nil {
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
bhs, err := bs.ListBackups(bucket)
|
||||
if err != nil {
|
||||
return proto.ReplicationPosition{}, fmt.Errorf("ListBackups failed: %v", err)
|
||||
}
|
||||
toRestore := len(bhs) - 1
|
||||
var bh backupstorage.BackupHandle
|
||||
var bm BackupManifest
|
||||
for toRestore >= 0 {
|
||||
bh = bhs[toRestore]
|
||||
if rc, err := bh.ReadFile(backupManifest); err == nil {
|
||||
dec := json.NewDecoder(rc)
|
||||
err := dec.Decode(&bm)
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
log.Warningf("Possibly incomplete backup %v in bucket %v on BackupStorage (cannot JSON decode MANIFEST: %v)", bh.Name(), bucket, err)
|
||||
} else {
|
||||
log.Infof("Restore: found backup %v %v to restore with %v files", bh.Bucket(), bh.Name(), len(bm.FileEntries))
|
||||
break
|
||||
}
|
||||
} else {
|
||||
log.Warningf("Possibly incomplete backup %v in bucket %v on BackupStorage (cannot read MANIFEST)", bh.Name(), bucket)
|
||||
}
|
||||
toRestore--
|
||||
}
|
||||
if toRestore < 0 {
|
||||
log.Errorf("No backup to restore on BackupStorage for bucket %v", bucket)
|
||||
return proto.ReplicationPosition{}, ErrNoBackup
|
||||
}
|
||||
|
||||
log.Infof("Restore: checking no existing data is present")
|
||||
if err := checkNoDB(mysqld); err != nil {
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
|
||||
log.Infof("Restore: shutdown mysqld")
|
||||
if err := mysqld.Shutdown(true, MysqlWaitTime); err != nil {
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
|
||||
log.Infof("Restore: copying all files")
|
||||
if err := restoreFiles(mysqld.Cnf(), bh, bm.FileEntries, restoreConcurrency); err != nil {
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
|
||||
log.Infof("Restore: restart mysqld")
|
||||
if err := mysqld.Start(MysqlWaitTime); err != nil {
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
|
||||
return bm.ReplicationPosition, nil
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mysqlctl
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFindFilesToBackup(t *testing.T) {
|
||||
root, err := ioutil.TempDir("", "backuptest")
|
||||
if err != nil {
|
||||
t.Fatalf("os.TempDir failed: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(root)
|
||||
|
||||
// Initialize the fake mysql root directories
|
||||
innodbDataDir := path.Join(root, "innodb_data")
|
||||
innodbLogDir := path.Join(root, "innodb_log")
|
||||
dataDir := path.Join(root, "data")
|
||||
dataDbDir := path.Join(dataDir, "vt_db")
|
||||
extraDir := path.Join(dataDir, "extra_dir")
|
||||
outsideDbDir := path.Join(root, "outside_db")
|
||||
for _, s := range []string{innodbDataDir, innodbLogDir, dataDbDir, extraDir, outsideDbDir} {
|
||||
if err := os.MkdirAll(s, os.ModePerm); err != nil {
|
||||
t.Fatalf("failed to create directory %v: %v", s, err)
|
||||
}
|
||||
}
|
||||
if err := ioutil.WriteFile(path.Join(innodbDataDir, "innodb_data_1"), []byte("innodb data 1 contents"), os.ModePerm); err != nil {
|
||||
t.Fatalf("failed to write file innodb_data_1: %v", err)
|
||||
}
|
||||
if err := ioutil.WriteFile(path.Join(innodbLogDir, "innodb_log_1"), []byte("innodb log 1 contents"), os.ModePerm); err != nil {
|
||||
t.Fatalf("failed to write file innodb_log_1: %v", err)
|
||||
}
|
||||
if err := ioutil.WriteFile(path.Join(dataDbDir, "db.opt"), []byte("db opt file"), os.ModePerm); err != nil {
|
||||
t.Fatalf("failed to write file db.opt: %v", err)
|
||||
}
|
||||
if err := ioutil.WriteFile(path.Join(extraDir, "extra.stuff"), []byte("extra file"), os.ModePerm); err != nil {
|
||||
t.Fatalf("failed to write file extra.stuff: %v", err)
|
||||
}
|
||||
if err := ioutil.WriteFile(path.Join(outsideDbDir, "table1.frm"), []byte("frm file"), os.ModePerm); err != nil {
|
||||
t.Fatalf("failed to write file table1.opt: %v", err)
|
||||
}
|
||||
if err := os.Symlink(outsideDbDir, path.Join(dataDir, "vt_symlink")); err != nil {
|
||||
t.Fatalf("failed to symlink vt_symlink: %v", err)
|
||||
}
|
||||
|
||||
cnf := &Mycnf{
|
||||
InnodbDataHomeDir: innodbDataDir,
|
||||
InnodbLogGroupHomeDir: innodbLogDir,
|
||||
DataDir: dataDir,
|
||||
}
|
||||
|
||||
result, err := findFilesTobackup(cnf)
|
||||
if err != nil {
|
||||
t.Fatalf("findFilesTobackup failed: %v", err)
|
||||
}
|
||||
sort.Sort(forTest(result))
|
||||
t.Logf("findFilesTobackup returned: %v", result)
|
||||
expected := []FileEntry{
|
||||
FileEntry{
|
||||
Base: "Data",
|
||||
Name: "vt_db/db.opt",
|
||||
},
|
||||
FileEntry{
|
||||
Base: "Data",
|
||||
Name: "vt_symlink/table1.frm",
|
||||
},
|
||||
FileEntry{
|
||||
Base: "InnoDBData",
|
||||
Name: "innodb_data_1",
|
||||
},
|
||||
FileEntry{
|
||||
Base: "InnoDBLog",
|
||||
Name: "innodb_log_1",
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Fatalf("got wrong list of FileEntry %v, expected %v", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
type forTest []FileEntry
|
||||
|
||||
func (f forTest) Len() int { return len(f) }
|
||||
func (f forTest) Swap(i, j int) { f[i], f[j] = f[j], f[i] }
|
||||
func (f forTest) Less(i, j int) bool { return f[i].Base+f[i].Name < f[j].Base+f[j].Name }
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package backupstorage contains the interface and file system implementation
|
||||
// of the backup system.
|
||||
package backupstorage
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
// BackupStorageImplementation is the implementation to use
|
||||
// for BackupStorage. Exported for test purposes.
|
||||
BackupStorageImplementation = flag.String("backup_storage_implementation", "", "which implementation to use for the backup storage feature")
|
||||
)
|
||||
|
||||
// BackupHandle describes an individual backup.
|
||||
type BackupHandle interface {
|
||||
// Bucket is the location of the backup. Will contain keyspace/shard.
|
||||
Bucket() string
|
||||
|
||||
// Name is the individual name of the backup. Will contain
|
||||
// tabletAlias-timestamp.
|
||||
Name() string
|
||||
|
||||
// AddFile opens a new file to be added to the backup.
|
||||
// Only works for read-write backups (created by StartBackup).
|
||||
// filename is guaranteed to only contain alphanumerical
|
||||
// characters and hyphens.
|
||||
// It should be thread safe, it is possible to call AddFile in
|
||||
// multiple go routines once a backup has been started.
|
||||
AddFile(filename string) (io.WriteCloser, error)
|
||||
|
||||
// EndBackup stops and closes a backup. The contents should be kept.
|
||||
// Only works for read-write backups (created by StartBackup).
|
||||
EndBackup() error
|
||||
|
||||
// AbortBackup stops a backup, and removes the contents that
|
||||
// have been copied already. It is called if an error occurs
|
||||
// while the backup is being taken, and the backup cannot be finished.
|
||||
// Only works for read-write backups (created by StartBackup).
|
||||
AbortBackup() error
|
||||
|
||||
// ReadFile starts reading a file from a backup.
|
||||
// Only works for read-only backups (created by ListBackups).
|
||||
ReadFile(filename string) (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
// BackupStorage is the interface to the storage system
|
||||
type BackupStorage interface {
|
||||
// ListBackups returns all the backups in a bucket. The
|
||||
// returned backups are read-only (ReadFile can be called, but
|
||||
// AddFile/EndBackup/AbortBackup cannot).
|
||||
// The backups are string-sorted by Name(), ascending (ends up
|
||||
// being the oldest backup first).
|
||||
ListBackups(bucket string) ([]BackupHandle, error)
|
||||
|
||||
// StartBackup creates a new backup with the given name. If a
|
||||
// backup with the same name already exists, it's an error.
|
||||
// The returned backup is read-write
|
||||
// (AddFile/EndBackup/AbortBackup cann all be called, not
|
||||
// ReadFile)
|
||||
StartBackup(bucket, name string) (BackupHandle, error)
|
||||
|
||||
// RemoveBackup removes all the data associated with a backup.
|
||||
// It will not appear in ListBackups after RemoveBackup succeeds.
|
||||
RemoveBackup(bucket, name string) error
|
||||
}
|
||||
|
||||
// BackupStorageMap contains the registered implementations for BackupStorage
|
||||
var BackupStorageMap = make(map[string]BackupStorage)
|
||||
|
||||
// GetBackupStorage returns the current BackupStorage implementation.
|
||||
// Should be called after flags have been initialized.
|
||||
func GetBackupStorage() (BackupStorage, error) {
|
||||
bs, ok := BackupStorageMap[*BackupStorageImplementation]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no registered implementation of BackupStorage")
|
||||
}
|
||||
return bs, nil
|
||||
}
|
|
@ -1,474 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mysqlctl
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/ioutil2"
|
||||
"github.com/youtube/vitess/go/vt/hook"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
|
||||
)
|
||||
|
||||
// These methods deal with cloning a running instance of mysql.
|
||||
|
||||
const (
|
||||
maxLagSeconds = 5
|
||||
)
|
||||
|
||||
const (
|
||||
// slaveStartDeadline is the deadline for starting a slave
|
||||
slaveStartDeadline = 30
|
||||
)
|
||||
|
||||
const (
|
||||
// SnapshotManifestFile is the file name for the snapshot manifest.
|
||||
SnapshotManifestFile = "snapshot_manifest.json"
|
||||
|
||||
// SnapshotURLPath is the URL where to find the snapshot manifest.
|
||||
SnapshotURLPath = "/snapshot"
|
||||
)
|
||||
|
||||
// Validate that this instance is a reasonable source of data.
|
||||
func (mysqld *Mysqld) validateCloneSource(serverMode bool, hookExtraEnv map[string]string) error {
|
||||
// NOTE(msolomon) Removing this check for now - I don't see the value of validating this.
|
||||
// // needs to be master, or slave that's not too far behind
|
||||
// slaveStatus, err := mysqld.slaveStatus()
|
||||
// if err != nil {
|
||||
// if err != ErrNotSlave {
|
||||
// return fmt.Errorf("mysqlctl: validateCloneSource failed, %v", err)
|
||||
// }
|
||||
// } else {
|
||||
// lagSeconds, _ := strconv.Atoi(slaveStatus["seconds_behind_master"])
|
||||
// if lagSeconds > maxLagSeconds {
|
||||
// return fmt.Errorf("mysqlctl: validateCloneSource failed, lag_seconds exceed maximum tolerance (%v)", lagSeconds)
|
||||
// }
|
||||
// }
|
||||
|
||||
// make sure we can write locally
|
||||
if err := mysqld.ValidateSnapshotPath(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// run a hook to check local things
|
||||
// FIXME(alainjobart) What other parameters do we have to
|
||||
// provide? dbname, host, socket?
|
||||
params := make([]string, 0, 1)
|
||||
if serverMode {
|
||||
params = append(params, "--server-mode")
|
||||
}
|
||||
h := hook.NewHook("preflight_snapshot", params)
|
||||
h.ExtraEnv = hookExtraEnv
|
||||
if err := h.ExecuteOptional(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// FIXME(msolomon) check free space based on an estimate of the current
|
||||
// size of the db files.
|
||||
// Also, check that we aren't already cloning/compressing or acting as a
|
||||
// source. Mysqld being down isn't enough, presumably that will be
|
||||
// restarted as soon as the snapshot is taken.
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateCloneTarget makes sure this mysql daemon is a valid target
|
||||
// for a clone.
|
||||
func (mysqld *Mysqld) ValidateCloneTarget(hookExtraEnv map[string]string) error {
|
||||
// run a hook to check local things
|
||||
h := hook.NewSimpleHook("preflight_restore")
|
||||
h.ExtraEnv = hookExtraEnv
|
||||
if err := h.ExecuteOptional(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
qr, err := mysqld.fetchSuperQuery("SHOW DATABASES")
|
||||
if err != nil {
|
||||
return fmt.Errorf("mysqlctl: ValidateCloneTarget failed, %v", err)
|
||||
}
|
||||
|
||||
for _, row := range qr.Rows {
|
||||
if strings.HasPrefix(row[0].String(), "vt_") {
|
||||
dbName := row[0].String()
|
||||
tableQr, err := mysqld.fetchSuperQuery("SHOW TABLES FROM " + dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mysqlctl: ValidateCloneTarget failed, %v", err)
|
||||
} else if len(tableQr.Rows) == 0 {
|
||||
// no tables == empty db, all is well
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("mysqlctl: ValidateCloneTarget failed, found active db %v", dbName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func findFilesToServe(srcDir, dstDir string, compress bool) ([]string, []string, error) {
|
||||
fiList, err := ioutil.ReadDir(srcDir)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
sources := make([]string, 0, len(fiList))
|
||||
destinations := make([]string, 0, len(fiList))
|
||||
for _, fi := range fiList {
|
||||
if !fi.IsDir() {
|
||||
srcPath := path.Join(srcDir, fi.Name())
|
||||
var dstPath string
|
||||
if compress {
|
||||
dstPath = path.Join(dstDir, fi.Name()+".gz")
|
||||
} else {
|
||||
dstPath = path.Join(dstDir, fi.Name())
|
||||
}
|
||||
sources = append(sources, srcPath)
|
||||
destinations = append(destinations, dstPath)
|
||||
}
|
||||
}
|
||||
return sources, destinations, nil
|
||||
}
|
||||
|
||||
func (mysqld *Mysqld) createSnapshot(logger logutil.Logger, concurrency int, serverMode bool) ([]SnapshotFile, error) {
|
||||
sources := make([]string, 0, 128)
|
||||
destinations := make([]string, 0, 128)
|
||||
|
||||
// clean out and start fresh
|
||||
logger.Infof("removing previous snapshots: %v", mysqld.SnapshotDir)
|
||||
if err := os.RemoveAll(mysqld.SnapshotDir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// FIXME(msolomon) innodb paths must match patterns in mycnf -
|
||||
// probably belongs as a derived path.
|
||||
type snapPair struct{ srcDir, dstDir string }
|
||||
dps := []snapPair{
|
||||
{mysqld.config.InnodbDataHomeDir, path.Join(mysqld.SnapshotDir, innodbDataSubdir)},
|
||||
{mysqld.config.InnodbLogGroupHomeDir, path.Join(mysqld.SnapshotDir, innodbLogSubdir)},
|
||||
}
|
||||
|
||||
dataDirEntries, err := ioutil.ReadDir(mysqld.config.DataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, de := range dataDirEntries {
|
||||
dbDirPath := path.Join(mysqld.config.DataDir, de.Name())
|
||||
// If this is not a directory, try to eval it as a syslink.
|
||||
if !de.IsDir() {
|
||||
dbDirPath, err = filepath.EvalSymlinks(dbDirPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
de, err = os.Stat(dbDirPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if de.IsDir() {
|
||||
// Copy anything that defines a db.opt file - that includes empty databases.
|
||||
_, err := os.Stat(path.Join(dbDirPath, "db.opt"))
|
||||
if err == nil {
|
||||
dps = append(dps, snapPair{dbDirPath, path.Join(mysqld.SnapshotDir, dataDir, de.Name())})
|
||||
} else {
|
||||
// Look for at least one .frm file
|
||||
dbDirEntries, err := ioutil.ReadDir(dbDirPath)
|
||||
if err == nil {
|
||||
for _, dbEntry := range dbDirEntries {
|
||||
if strings.HasSuffix(dbEntry.Name(), ".frm") {
|
||||
dps = append(dps, snapPair{dbDirPath, path.Join(mysqld.SnapshotDir, dataDir, de.Name())})
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, dp := range dps {
|
||||
if err := os.MkdirAll(dp.dstDir, 0775); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s, d, err := findFilesToServe(dp.srcDir, dp.dstDir, !serverMode); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
sources = append(sources, s...)
|
||||
destinations = append(destinations, d...)
|
||||
}
|
||||
}
|
||||
|
||||
return newSnapshotFiles(sources, destinations, mysqld.SnapshotDir, concurrency, !serverMode)
|
||||
}
|
||||
|
||||
// CreateSnapshot runs on the machine acting as the source for the clone.
|
||||
//
|
||||
// Check master/slave status and determine restore needs.
|
||||
// If this instance is a slave, stop replication, otherwise place in read-only mode.
|
||||
// Record replication position.
|
||||
// Shutdown mysql
|
||||
// Check paths for storing data
|
||||
//
|
||||
// Depending on the serverMode flag, we do the following:
|
||||
// serverMode = false:
|
||||
// Compress /vt/vt_[0-9a-f]+/data/vt_.+
|
||||
// Compute hash (of compressed files, as we serve .gz files here)
|
||||
// Place in /vt/clone_src where they will be served by http server (not rpc)
|
||||
// Restart mysql
|
||||
// serverMode = true:
|
||||
// Make symlinks for /vt/vt_[0-9a-f]+/data/vt_.+ to innodb files
|
||||
// Compute hash (of uncompressed files, as we serve uncompressed files)
|
||||
// Place symlinks in /vt/clone_src where they will be served by http server
|
||||
// Leave mysql stopped, return slaveStartRequired, readOnly
|
||||
func (mysqld *Mysqld) CreateSnapshot(logger logutil.Logger, dbName, sourceAddr string, allowHierarchicalReplication bool, concurrency int, serverMode bool, hookExtraEnv map[string]string) (snapshotManifestURLPath string, slaveStartRequired, readOnly bool, err error) {
|
||||
if dbName == "" {
|
||||
return "", false, false, errors.New("CreateSnapshot failed: no database name provided")
|
||||
}
|
||||
|
||||
if err = mysqld.validateCloneSource(serverMode, hookExtraEnv); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// save initial state so we can restore on Start()
|
||||
slaveStartRequired = false
|
||||
sourceIsMaster := false
|
||||
readOnly = true
|
||||
|
||||
slaveStatus, err := mysqld.SlaveStatus()
|
||||
if err == nil {
|
||||
slaveStartRequired = slaveStatus.SlaveRunning()
|
||||
} else if err == ErrNotSlave {
|
||||
sourceIsMaster = true
|
||||
} else {
|
||||
// If we can't get any data, just fail.
|
||||
return
|
||||
}
|
||||
|
||||
readOnly, err = mysqld.IsReadOnly()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Stop sources of writes so we can get a consistent replication position.
|
||||
// If the source is a slave use the master replication position
|
||||
// unless we are allowing hierarchical replicas.
|
||||
masterAddr := ""
|
||||
var replicationPosition proto.ReplicationPosition
|
||||
if sourceIsMaster {
|
||||
if err = mysqld.SetReadOnly(true); err != nil {
|
||||
return
|
||||
}
|
||||
replicationPosition, err = mysqld.MasterPosition()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
masterAddr = mysqld.IPAddr()
|
||||
} else {
|
||||
if err = mysqld.StopSlave(hookExtraEnv); err != nil {
|
||||
return
|
||||
}
|
||||
var slaveStatus *proto.ReplicationStatus
|
||||
slaveStatus, err = mysqld.SlaveStatus()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
replicationPosition = slaveStatus.Position
|
||||
|
||||
// We are a slave, check our replication strategy before
|
||||
// choosing the master address.
|
||||
if allowHierarchicalReplication {
|
||||
masterAddr = mysqld.IPAddr()
|
||||
} else {
|
||||
masterAddr, err = mysqld.GetMasterAddr()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = mysqld.Shutdown(true, MysqlWaitTime); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var smFile string
|
||||
dataFiles, snapshotErr := mysqld.createSnapshot(logger, concurrency, serverMode)
|
||||
if snapshotErr != nil {
|
||||
logger.Errorf("CreateSnapshot failed: %v", snapshotErr)
|
||||
} else {
|
||||
var sm *SnapshotManifest
|
||||
sm, snapshotErr = newSnapshotManifest(sourceAddr, mysqld.IPAddr(),
|
||||
masterAddr, dbName, dataFiles, replicationPosition, proto.ReplicationPosition{})
|
||||
if snapshotErr != nil {
|
||||
logger.Errorf("CreateSnapshot failed: %v", snapshotErr)
|
||||
} else {
|
||||
smFile = path.Join(mysqld.SnapshotDir, SnapshotManifestFile)
|
||||
if snapshotErr = writeJSON(smFile, sm); snapshotErr != nil {
|
||||
logger.Errorf("CreateSnapshot failed: %v", snapshotErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// restore our state if required
|
||||
if serverMode && snapshotErr == nil {
|
||||
logger.Infof("server mode snapshot worked, not restarting mysql")
|
||||
} else {
|
||||
if err = mysqld.SnapshotSourceEnd(slaveStartRequired, readOnly, false /*deleteSnapshot*/, hookExtraEnv); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if snapshotErr != nil {
|
||||
return "", slaveStartRequired, readOnly, snapshotErr
|
||||
}
|
||||
relative, err := filepath.Rel(mysqld.SnapshotDir, smFile)
|
||||
if err != nil {
|
||||
return "", slaveStartRequired, readOnly, nil
|
||||
}
|
||||
return path.Join(SnapshotURLPath, relative), slaveStartRequired, readOnly, nil
|
||||
}
|
||||
|
||||
// SnapshotSourceEnd removes the current snapshot, and restarts mysqld.
|
||||
func (mysqld *Mysqld) SnapshotSourceEnd(slaveStartRequired, readOnly, deleteSnapshot bool, hookExtraEnv map[string]string) error {
|
||||
if deleteSnapshot {
|
||||
// clean out our files
|
||||
log.Infof("removing snapshot links: %v", mysqld.SnapshotDir)
|
||||
if err := os.RemoveAll(mysqld.SnapshotDir); err != nil {
|
||||
log.Warningf("failed to remove old snapshot: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Try to restart mysqld
|
||||
if err := mysqld.Start(MysqlWaitTime); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore original mysqld state that we saved above.
|
||||
if slaveStartRequired {
|
||||
if err := mysqld.StartSlave(hookExtraEnv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// this should be quick, but we might as well just wait
|
||||
if err := mysqld.WaitForSlaveStart(slaveStartDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// And set read-only mode
|
||||
if err := mysqld.SetReadOnly(readOnly); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeJSON(filename string, x interface{}) error {
|
||||
data, err := json.MarshalIndent(x, " ", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil2.WriteFileAtomic(filename, data, 0660)
|
||||
}
|
||||
|
||||
// ReadSnapshotManifest reads and unpacks a SnapshotManifest
|
||||
func ReadSnapshotManifest(filename string) (*SnapshotManifest, error) {
|
||||
data, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sm := new(SnapshotManifest)
|
||||
if err = json.Unmarshal(data, sm); err != nil {
|
||||
return nil, fmt.Errorf("ReadSnapshotManifest failed: %v %v", filename, err)
|
||||
}
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// RestoreFromSnapshot runs on the presumably empty machine acting as
|
||||
// the target in the create replica action.
|
||||
//
|
||||
// validate target (self)
|
||||
// shutdown_mysql()
|
||||
// create temp data directory /vt/target/vt_<keyspace>
|
||||
// copy compressed data files via HTTP
|
||||
// verify hash of compressed files
|
||||
// uncompress into /vt/vt_<target-uid>/data/vt_<keyspace>
|
||||
// start_mysql()
|
||||
// clean up compressed files
|
||||
func (mysqld *Mysqld) RestoreFromSnapshot(logger logutil.Logger, snapshotManifest *SnapshotManifest, fetchConcurrency, fetchRetryCount int, dontWaitForSlaveStart bool, hookExtraEnv map[string]string) error {
|
||||
if snapshotManifest == nil {
|
||||
return errors.New("RestoreFromSnapshot: nil snapshotManifest")
|
||||
}
|
||||
|
||||
logger.Infof("ValidateCloneTarget")
|
||||
if err := mysqld.ValidateCloneTarget(hookExtraEnv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("Shutdown mysqld")
|
||||
if err := mysqld.Shutdown(true, MysqlWaitTime); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("Fetch snapshot")
|
||||
if err := mysqld.fetchSnapshot(snapshotManifest, fetchConcurrency, fetchRetryCount); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("Restart mysqld")
|
||||
if err := mysqld.Start(MysqlWaitTime); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmdList, err := mysqld.StartReplicationCommands(snapshotManifest.ReplicationStatus)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mysqld.ExecuteSuperQueryList(cmdList); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !dontWaitForSlaveStart {
|
||||
if err := mysqld.WaitForSlaveStart(slaveStartDeadline); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
h := hook.NewSimpleHook("postflight_restore")
|
||||
h.ExtraEnv = hookExtraEnv
|
||||
if err := h.ExecuteOptional(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mysqld *Mysqld) fetchSnapshot(snapshotManifest *SnapshotManifest, fetchConcurrency, fetchRetryCount int) error {
|
||||
replicaDbPath := path.Join(mysqld.config.DataDir, snapshotManifest.DbName)
|
||||
|
||||
cleanDirs := []string{mysqld.SnapshotDir, replicaDbPath,
|
||||
mysqld.config.InnodbDataHomeDir, mysqld.config.InnodbLogGroupHomeDir}
|
||||
|
||||
// clean out and start fresh
|
||||
// FIXME(msolomon) this might be changed to allow partial recovery, but at that point
|
||||
// we are starting to reimplement rsync.
|
||||
for _, dir := range cleanDirs {
|
||||
if err := os.RemoveAll(dir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0775); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fetchFiles(snapshotManifest, mysqld.TabletDir, fetchConcurrency, fetchRetryCount)
|
||||
}
|
|
@ -0,0 +1,141 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package filebackupstorage implements the BacksupStorage interface
|
||||
// for a local filesystem (which can be an NFS mount).
|
||||
package filebackupstorage
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl/backupstorage"
|
||||
)
|
||||
|
||||
var (
|
||||
// FileBackupStorageRoot is where the backups will go.
|
||||
// Exported for test purposes.
|
||||
FileBackupStorageRoot = flag.String("file_backup_storage_root", "", "root directory for the file backup storage")
|
||||
)
|
||||
|
||||
// FileBackupHandle implements BackupHandle for local file system.
|
||||
type FileBackupHandle struct {
|
||||
fbs *FileBackupStorage
|
||||
bucket string
|
||||
name string
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
// Bucket is part of the BackupHandle interface
|
||||
func (fbh *FileBackupHandle) Bucket() string {
|
||||
return fbh.bucket
|
||||
}
|
||||
|
||||
// Name is part of the BackupHandle interface
|
||||
func (fbh *FileBackupHandle) Name() string {
|
||||
return fbh.name
|
||||
}
|
||||
|
||||
// AddFile is part of the BackupHandle interface
|
||||
func (fbh *FileBackupHandle) AddFile(filename string) (io.WriteCloser, error) {
|
||||
if fbh.readOnly {
|
||||
return nil, fmt.Errorf("AddFile cannot be called on read-only backup")
|
||||
}
|
||||
p := path.Join(*FileBackupStorageRoot, fbh.bucket, fbh.name, filename)
|
||||
return os.Create(p)
|
||||
}
|
||||
|
||||
// EndBackup is part of the BackupHandle interface
|
||||
func (fbh *FileBackupHandle) EndBackup() error {
|
||||
if fbh.readOnly {
|
||||
return fmt.Errorf("EndBackup cannot be called on read-only backup")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AbortBackup is part of the BackupHandle interface
|
||||
func (fbh *FileBackupHandle) AbortBackup() error {
|
||||
if fbh.readOnly {
|
||||
return fmt.Errorf("AbortBackup cannot be called on read-only backup")
|
||||
}
|
||||
return fbh.fbs.RemoveBackup(fbh.bucket, fbh.name)
|
||||
}
|
||||
|
||||
// ReadFile is part of the BackupHandle interface
|
||||
func (fbh *FileBackupHandle) ReadFile(filename string) (io.ReadCloser, error) {
|
||||
if !fbh.readOnly {
|
||||
return nil, fmt.Errorf("ReadFile cannot be called on read-write backup")
|
||||
}
|
||||
p := path.Join(*FileBackupStorageRoot, fbh.bucket, fbh.name, filename)
|
||||
return os.Open(p)
|
||||
}
|
||||
|
||||
// FileBackupStorage implements BackupStorage for local file system.
|
||||
type FileBackupStorage struct{}
|
||||
|
||||
// ListBackups is part of the BackupStorage interface
|
||||
func (fbs *FileBackupStorage) ListBackups(bucket string) ([]backupstorage.BackupHandle, error) {
|
||||
// ReadDir already sorts the results
|
||||
p := path.Join(*FileBackupStorageRoot, bucket)
|
||||
fi, err := ioutil.ReadDir(p)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]backupstorage.BackupHandle, 0, len(fi))
|
||||
for _, info := range fi {
|
||||
if !info.IsDir() {
|
||||
continue
|
||||
}
|
||||
if info.Name() == "." || info.Name() == ".." {
|
||||
continue
|
||||
}
|
||||
result = append(result, &FileBackupHandle{
|
||||
fbs: fbs,
|
||||
bucket: bucket,
|
||||
name: info.Name(),
|
||||
readOnly: true,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// StartBackup is part of the BackupStorage interface
|
||||
func (fbs *FileBackupStorage) StartBackup(bucket, name string) (backupstorage.BackupHandle, error) {
|
||||
// make sure the bucket directory exists
|
||||
p := path.Join(*FileBackupStorageRoot, bucket)
|
||||
if err := os.MkdirAll(p, os.ModePerm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// creates the backup directory
|
||||
p = path.Join(p, name)
|
||||
if err := os.Mkdir(p, os.ModePerm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &FileBackupHandle{
|
||||
fbs: fbs,
|
||||
bucket: bucket,
|
||||
name: name,
|
||||
readOnly: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RemoveBackup is part of the BackupStorage interface
|
||||
func (fbs *FileBackupStorage) RemoveBackup(bucket, name string) error {
|
||||
p := path.Join(*FileBackupStorageRoot, bucket, name)
|
||||
return os.RemoveAll(p)
|
||||
}
|
||||
|
||||
func init() {
|
||||
backupstorage.BackupStorageMap["file"] = &FileBackupStorage{}
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filebackupstorage
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// This file tests the file BackupStorage engine.
|
||||
|
||||
// Note this is a very generic test for BackupStorage implementations,
|
||||
// we test the interface only. But making it a generic test library is
|
||||
// more cumbersome, we'll do that when we have an actual need for
|
||||
// another BackupStorage implementation.
|
||||
|
||||
// setupFileBackupStorage creates a temporary directory, and
|
||||
// returns a FileBackupStorage based on it
|
||||
func setupFileBackupStorage(t *testing.T) *FileBackupStorage {
|
||||
root, err := ioutil.TempDir("", "fbstest")
|
||||
if err != nil {
|
||||
t.Fatalf("os.TempDir failed: %v", err)
|
||||
}
|
||||
*FileBackupStorageRoot = root
|
||||
return &FileBackupStorage{}
|
||||
}
|
||||
|
||||
// cleanupFileBackupStorage removes the entire directory
|
||||
func cleanupFileBackupStorage(fbs *FileBackupStorage) {
|
||||
os.RemoveAll(*FileBackupStorageRoot)
|
||||
}
|
||||
|
||||
func TestListBackups(t *testing.T) {
|
||||
fbs := setupFileBackupStorage(t)
|
||||
defer cleanupFileBackupStorage(fbs)
|
||||
|
||||
// verify we have no entry now
|
||||
bucket := "keyspace/shard"
|
||||
bhs, err := fbs.ListBackups(bucket)
|
||||
if err != nil {
|
||||
t.Fatalf("ListBackups on empty fbs failed: %v", err)
|
||||
}
|
||||
if len(bhs) != 0 {
|
||||
t.Fatalf("ListBackups on empty fbs returned results: %#v", bhs)
|
||||
}
|
||||
|
||||
// add one empty backup
|
||||
firstBackup := "cell-0001-2015-01-14-10-00-00"
|
||||
bh, err := fbs.StartBackup(bucket, firstBackup)
|
||||
if err != nil {
|
||||
t.Fatalf("fbs.StartBackup failed: %v", err)
|
||||
}
|
||||
if err := bh.EndBackup(); err != nil {
|
||||
t.Fatalf("bh.EndBackup failed: %v", err)
|
||||
}
|
||||
|
||||
// verify we have one entry now
|
||||
bhs, err = fbs.ListBackups(bucket)
|
||||
if err != nil {
|
||||
t.Fatalf("ListBackups on empty fbs failed: %v", err)
|
||||
}
|
||||
if len(bhs) != 1 ||
|
||||
bhs[0].Bucket() != bucket ||
|
||||
bhs[0].Name() != firstBackup {
|
||||
t.Fatalf("ListBackups with one backup returned wrong results: %#v", bhs)
|
||||
}
|
||||
|
||||
// add another one, with earlier date
|
||||
secondBackup := "cell-0001-2015-01-12-10-00-00"
|
||||
bh, err = fbs.StartBackup(bucket, secondBackup)
|
||||
if err != nil {
|
||||
t.Fatalf("fbs.StartBackup failed: %v", err)
|
||||
}
|
||||
if err := bh.EndBackup(); err != nil {
|
||||
t.Fatalf("bh.EndBackup failed: %v", err)
|
||||
}
|
||||
|
||||
// verify we have two sorted entries now
|
||||
bhs, err = fbs.ListBackups(bucket)
|
||||
if err != nil {
|
||||
t.Fatalf("ListBackups on empty fbs failed: %v", err)
|
||||
}
|
||||
if len(bhs) != 2 ||
|
||||
bhs[0].Bucket() != bucket ||
|
||||
bhs[0].Name() != secondBackup ||
|
||||
bhs[1].Bucket() != bucket ||
|
||||
bhs[1].Name() != firstBackup {
|
||||
t.Fatalf("ListBackups with two backups returned wrong results: %#v", bhs)
|
||||
}
|
||||
|
||||
// remove a backup, back to one
|
||||
if err := fbs.RemoveBackup(bucket, secondBackup); err != nil {
|
||||
t.Fatalf("RemoveBackup failed: %v", err)
|
||||
}
|
||||
bhs, err = fbs.ListBackups(bucket)
|
||||
if err != nil {
|
||||
t.Fatalf("ListBackups after deletion failed: %v", err)
|
||||
}
|
||||
if len(bhs) != 1 ||
|
||||
bhs[0].Bucket() != bucket ||
|
||||
bhs[0].Name() != firstBackup {
|
||||
t.Fatalf("ListBackups after deletion returned wrong results: %#v", bhs)
|
||||
}
|
||||
|
||||
// add a backup but abort it, should stay at one
|
||||
bh, err = fbs.StartBackup(bucket, secondBackup)
|
||||
if err != nil {
|
||||
t.Fatalf("fbs.StartBackup failed: %v", err)
|
||||
}
|
||||
if err := bh.AbortBackup(); err != nil {
|
||||
t.Fatalf("bh.AbortBackup failed: %v", err)
|
||||
}
|
||||
bhs, err = fbs.ListBackups(bucket)
|
||||
if err != nil {
|
||||
t.Fatalf("ListBackups after abort failed: %v", err)
|
||||
}
|
||||
if len(bhs) != 1 ||
|
||||
bhs[0].Bucket() != bucket ||
|
||||
bhs[0].Name() != firstBackup {
|
||||
t.Fatalf("ListBackups after abort returned wrong results: %#v", bhs)
|
||||
}
|
||||
|
||||
// check we cannot chaneg a backup we listed
|
||||
if _, err := bhs[0].AddFile("test"); err == nil {
|
||||
t.Fatalf("was able to AddFile to read-only backup")
|
||||
}
|
||||
if err := bhs[0].EndBackup(); err == nil {
|
||||
t.Fatalf("was able to EndBackup a read-only backup")
|
||||
}
|
||||
if err := bhs[0].AbortBackup(); err == nil {
|
||||
t.Fatalf("was able to AbortBackup a read-only backup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileContents(t *testing.T) {
|
||||
fbs := setupFileBackupStorage(t)
|
||||
defer cleanupFileBackupStorage(fbs)
|
||||
|
||||
bucket := "keyspace/shard"
|
||||
name := "cell-0001-2015-01-14-10-00-00"
|
||||
filename1 := "file1"
|
||||
contents1 := "contents of the first file"
|
||||
|
||||
// start a backup, add a file
|
||||
bh, err := fbs.StartBackup(bucket, name)
|
||||
if err != nil {
|
||||
t.Fatalf("fbs.StartBackup failed: %v", err)
|
||||
}
|
||||
wc, err := bh.AddFile(filename1)
|
||||
if err != nil {
|
||||
t.Fatalf("bh.AddFile failed: %v", err)
|
||||
}
|
||||
if _, err := wc.Write([]byte(contents1)); err != nil {
|
||||
t.Fatalf("wc.Write failed: %v", err)
|
||||
}
|
||||
if err := wc.Close(); err != nil {
|
||||
t.Fatalf("wc.Close failed: %v", err)
|
||||
}
|
||||
|
||||
// test we can't read back on read-write backup
|
||||
if _, err := bh.ReadFile(filename1); err == nil {
|
||||
t.Fatalf("was able to ReadFile to read-write backup")
|
||||
}
|
||||
|
||||
// and close
|
||||
if err := bh.EndBackup(); err != nil {
|
||||
t.Fatalf("bh.EndBackup failed: %v", err)
|
||||
}
|
||||
|
||||
// re-read the file
|
||||
bhs, err := fbs.ListBackups(bucket)
|
||||
if err != nil || len(bhs) != 1 {
|
||||
t.Fatalf("ListBackups after abort returned wrong return: %v %v", err, bhs)
|
||||
}
|
||||
rc, err := bhs[0].ReadFile(filename1)
|
||||
if err != nil {
|
||||
t.Fatalf("bhs[0].ReadFile failed: %v", err)
|
||||
}
|
||||
buf := make([]byte, len(contents1)+10)
|
||||
if n, err := rc.Read(buf); (err != nil && err != io.EOF) || n != len(contents1) {
|
||||
t.Fatalf("rc.Read returned wrong result: %v %#v", err, n)
|
||||
}
|
||||
if err := rc.Close(); err != nil {
|
||||
t.Fatalf("rc.Close failed: %v", err)
|
||||
}
|
||||
}
|
|
@ -5,25 +5,13 @@
|
|||
package mysqlctl
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
// "crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash"
|
||||
// "hash/crc64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/cgzip"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
|
||||
)
|
||||
|
||||
// Use this to simulate failures in tests
|
||||
|
@ -75,452 +63,3 @@ func newHasher() *hasher {
|
|||
func (h *hasher) HashString() string {
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// SnapshotFile describes a file to serve.
|
||||
// 'Path' is the path component of the URL. SnapshotManifest.Addr is
|
||||
// the host+port component of the URL.
|
||||
// If path ends in '.gz', it is compressed.
|
||||
// Size and Hash are computed on the Path itself
|
||||
// if TableName is set, this file belongs to that table
|
||||
type SnapshotFile struct {
|
||||
Path string
|
||||
Size int64
|
||||
Hash string
|
||||
TableName string
|
||||
}
|
||||
|
||||
type SnapshotFiles []SnapshotFile
|
||||
|
||||
// sort.Interface
|
||||
// we sort by descending file size
|
||||
func (s SnapshotFiles) Len() int { return len(s) }
|
||||
func (s SnapshotFiles) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s SnapshotFiles) Less(i, j int) bool { return s[i].Size > s[j].Size }
|
||||
|
||||
// This function returns the local file used to store the SnapshotFile,
|
||||
// relative to the basePath.
|
||||
// for instance, if the source path is something like:
|
||||
// /vt/snapshot/vt_0000062344/data/vt_snapshot_test-MA,Mw/vt_insert_test.csv.gz
|
||||
// we will get everything starting with 'data/...', append it to basepath,
|
||||
// and remove the .gz extension. So with basePath=myPath, it will return:
|
||||
// myPath/data/vt_snapshot_test-MA,Mw/vt_insert_test.csv
|
||||
func (dataFile *SnapshotFile) getLocalFilename(basePath string) string {
|
||||
filename := path.Join(basePath, dataFile.Path)
|
||||
// trim compression extension if present
|
||||
if strings.HasSuffix(filename, ".gz") {
|
||||
filename = filename[:len(filename)-3]
|
||||
}
|
||||
return filename
|
||||
}
|
||||
|
||||
// newSnapshotFile behavior depends on the compress flag:
|
||||
// - if compress is true , it compresses a single file with gzip, and
|
||||
// computes the hash on the compressed version.
|
||||
// - if compress is false, just symlinks and computes the hash on the file
|
||||
// The source file is always left intact.
|
||||
// The path of the returned SnapshotFile will be relative
|
||||
// to root.
|
||||
func newSnapshotFile(srcPath, dstPath, root string, compress bool) (*SnapshotFile, error) {
|
||||
// open the source file
|
||||
srcFile, err := os.OpenFile(srcPath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
src := bufio.NewReaderSize(srcFile, 2*1024*1024)
|
||||
|
||||
var hash string
|
||||
var size int64
|
||||
if compress {
|
||||
log.Infof("newSnapshotFile: starting to compress %v into %v", srcPath, dstPath)
|
||||
|
||||
// open the temporary destination file
|
||||
dir, filePrefix := path.Split(dstPath)
|
||||
dstFile, err := ioutil.TempFile(dir, filePrefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
// try to close and delete the file. in the
|
||||
// success case, the file will already be
|
||||
// closed and renamed, so all of this would
|
||||
// fail anyway, no biggie
|
||||
dstFile.Close()
|
||||
os.Remove(dstFile.Name())
|
||||
}()
|
||||
dst := bufio.NewWriterSize(dstFile, 2*1024*1024)
|
||||
|
||||
// create the hasher and the tee on top
|
||||
hasher := newHasher()
|
||||
tee := io.MultiWriter(dst, hasher)
|
||||
|
||||
// create the gzip compression filter
|
||||
gzip, err := cgzip.NewWriterLevel(tee, cgzip.Z_BEST_SPEED)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// copy from the file to gzip to tee to output file and hasher
|
||||
_, err = io.Copy(gzip, src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// close gzip to flush it
|
||||
if err = gzip.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// close dst manually to flush all buffers to disk
|
||||
dst.Flush()
|
||||
dstFile.Close()
|
||||
hash = hasher.HashString()
|
||||
|
||||
// atomically move completed compressed file
|
||||
err = os.Rename(dstFile.Name(), dstPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// and get its size
|
||||
fi, err := os.Stat(dstPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size = fi.Size()
|
||||
} else {
|
||||
log.Infof("newSnapshotFile: starting to hash and symlinking %v to %v", srcPath, dstPath)
|
||||
|
||||
// get the hash
|
||||
hasher := newHasher()
|
||||
_, err = io.Copy(hasher, src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hash = hasher.HashString()
|
||||
|
||||
// do the symlink
|
||||
err = os.Symlink(srcPath, dstPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// and get the size
|
||||
fi, err := os.Stat(srcPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size = fi.Size()
|
||||
}
|
||||
|
||||
log.Infof("clone data ready %v:%v", dstPath, hash)
|
||||
relativeDst, err := filepath.Rel(root, dstPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SnapshotFile{relativeDst, size, hash, ""}, nil
|
||||
}
|
||||
|
||||
// newSnapshotFiles processes multiple files in parallel. The Paths of
|
||||
// the returned SnapshotFiles will be relative to root.
|
||||
// - if compress is true, we compress the files and compute the hash on
|
||||
// the compressed version.
|
||||
// - if compress is false, we symlink the files, and compute the hash on
|
||||
// the original version.
|
||||
func newSnapshotFiles(sources, destinations []string, root string, concurrency int, compress bool) ([]SnapshotFile, error) {
|
||||
if len(sources) != len(destinations) || len(sources) == 0 {
|
||||
return nil, fmt.Errorf("programming error: bad array lengths: %v %v", len(sources), len(destinations))
|
||||
}
|
||||
|
||||
workQueue := make(chan int, len(sources))
|
||||
for i := 0; i < len(sources); i++ {
|
||||
workQueue <- i
|
||||
}
|
||||
close(workQueue)
|
||||
|
||||
snapshotFiles := make([]SnapshotFile, len(sources))
|
||||
resultQueue := make(chan error, len(sources))
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
for i := range workQueue {
|
||||
sf, err := newSnapshotFile(sources[i], destinations[i], root, compress)
|
||||
if err == nil {
|
||||
snapshotFiles[i] = *sf
|
||||
}
|
||||
resultQueue <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var err error
|
||||
for i := 0; i < len(sources); i++ {
|
||||
if compressErr := <-resultQueue; compressErr != nil {
|
||||
err = compressErr
|
||||
}
|
||||
}
|
||||
|
||||
// clean up files if we had an error
|
||||
// FIXME(alainjobart) it seems extreme to delete all files if
|
||||
// the last one failed. Since we only move the file into
|
||||
// its destination when it worked, we could assume if the file
|
||||
// already exists it's good, and re-compute its hash.
|
||||
if err != nil {
|
||||
log.Infof("Error happened, deleting all the files we already compressed")
|
||||
for _, dest := range destinations {
|
||||
os.Remove(dest)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return snapshotFiles, nil
|
||||
}
|
||||
|
||||
// a SnapshotManifest describes multiple SnapshotFiles and where
|
||||
// to get them from.
|
||||
type SnapshotManifest struct {
|
||||
Addr string // this is the address of the tabletserver, not mysql
|
||||
|
||||
DbName string
|
||||
Files SnapshotFiles
|
||||
|
||||
ReplicationStatus *proto.ReplicationStatus
|
||||
MasterPosition proto.ReplicationPosition
|
||||
}
|
||||
|
||||
func newSnapshotManifest(addr, mysqlAddr, masterAddr, dbName string, files []SnapshotFile, pos, masterPos proto.ReplicationPosition) (*SnapshotManifest, error) {
|
||||
nrs, err := proto.NewReplicationStatus(masterAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs := &SnapshotManifest{
|
||||
Addr: addr,
|
||||
DbName: dbName,
|
||||
Files: files,
|
||||
ReplicationStatus: nrs,
|
||||
MasterPosition: masterPos,
|
||||
}
|
||||
sort.Sort(rs.Files)
|
||||
rs.ReplicationStatus.Position = pos
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// fetchFile fetches data from the web server. It then sends it to a
|
||||
// tee, which on one side has an hash checksum reader, and on the other
|
||||
// a gunzip reader writing to a file. It will compare the hash
|
||||
// checksum after the copy is done.
|
||||
func fetchFile(srcUrl, srcHash, dstFilename string) error {
|
||||
log.Infof("fetchFile: starting to fetch %v from %v", dstFilename, srcUrl)
|
||||
|
||||
// open the URL
|
||||
req, err := http.NewRequest("GET", srcUrl, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("NewRequest failed for %v: %v", srcUrl, err)
|
||||
}
|
||||
// we set the 'gzip' encoding ourselves so the library doesn't
|
||||
// do it for us and ends up using go gzip (we want to use our own
|
||||
// cgzip which is much faster)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("failed fetching %v: %v", srcUrl, resp.Status)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// see if we need some uncompression
|
||||
var reader io.Reader = resp.Body
|
||||
ce := resp.Header.Get("Content-Encoding")
|
||||
if ce != "" {
|
||||
if ce == "gzip" {
|
||||
gz, err := cgzip.NewReader(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gz.Close()
|
||||
reader = gz
|
||||
} else {
|
||||
return fmt.Errorf("unsupported Content-Encoding: %v", ce)
|
||||
}
|
||||
}
|
||||
|
||||
return uncompressAndCheck(reader, srcHash, dstFilename, strings.HasSuffix(srcUrl, ".gz"))
|
||||
}
|
||||
|
||||
// uncompressAndCheck uses the provided reader to read data, and then
|
||||
// sends it to a tee, which on one side has an hash checksum reader,
|
||||
// and on the other a gunzip reader writing to a file. It will
|
||||
// compare the hash checksum after the copy is done.
|
||||
func uncompressAndCheck(reader io.Reader, srcHash, dstFilename string, needsUncompress bool) error {
|
||||
// create destination directory
|
||||
dir, filePrefix := path.Split(dstFilename)
|
||||
if dirErr := os.MkdirAll(dir, 0775); dirErr != nil {
|
||||
return dirErr
|
||||
}
|
||||
|
||||
// create a temporary file to uncompress to
|
||||
dstFile, err := ioutil.TempFile(dir, filePrefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// try to close and delete the file.
|
||||
// in the success case, the file will already be closed
|
||||
// and renamed, so all of this would fail anyway, no biggie
|
||||
dstFile.Close()
|
||||
os.Remove(dstFile.Name())
|
||||
}()
|
||||
|
||||
// create a buffering output
|
||||
dst := bufio.NewWriterSize(dstFile, 2*1024*1024)
|
||||
|
||||
// create hash to write the compressed data to
|
||||
hasher := newHasher()
|
||||
|
||||
// create a Tee: we split the HTTP input into the hasher
|
||||
// and into the gunziper
|
||||
tee := io.TeeReader(reader, hasher)
|
||||
|
||||
// create the uncompresser
|
||||
var decompressor io.Reader
|
||||
if needsUncompress {
|
||||
gz, err := cgzip.NewReader(tee)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gz.Close()
|
||||
decompressor = gz
|
||||
} else {
|
||||
decompressor = tee
|
||||
}
|
||||
|
||||
// see if we need to introduce failures
|
||||
if simulateFailures {
|
||||
failureCounter++
|
||||
if failureCounter%10 == 0 {
|
||||
return fmt.Errorf("Simulated error")
|
||||
}
|
||||
}
|
||||
|
||||
// copy the data. Will also write to the hasher
|
||||
if _, err = io.Copy(dst, decompressor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check the hash
|
||||
hash := hasher.HashString()
|
||||
if srcHash != hash {
|
||||
return fmt.Errorf("hash mismatch for %v, %v != %v", dstFilename, srcHash, hash)
|
||||
}
|
||||
|
||||
// we're good
|
||||
log.Infof("processed snapshot file: %v", dstFilename)
|
||||
dst.Flush()
|
||||
dstFile.Close()
|
||||
|
||||
// atomically move uncompressed file
|
||||
if err := os.Chmod(dstFile.Name(), 0664); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(dstFile.Name(), dstFilename)
|
||||
}
|
||||
|
||||
// fetchFileWithRetry fetches data from the web server, retrying a few
|
||||
// times.
|
||||
func fetchFileWithRetry(srcUrl, srcHash, dstFilename string, fetchRetryCount int) (err error) {
|
||||
for i := 0; i < fetchRetryCount; i++ {
|
||||
err = fetchFile(srcUrl, srcHash, dstFilename)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
log.Warningf("fetching snapshot file %v failed (try=%v): %v", dstFilename, i, err)
|
||||
}
|
||||
|
||||
log.Errorf("fetching snapshot file %v failed too many times", dstFilename)
|
||||
return err
|
||||
}
|
||||
|
||||
// FIXME(msolomon) Should we add deadlines? What really matters more
|
||||
// than a deadline is probably a sense of progress, more like a
|
||||
// "progress timeout" - how long will we wait if there is no change in
|
||||
// received bytes.
|
||||
// FIXME(alainjobart) support fetching files in chunks: create a new
|
||||
// struct fileChunk {
|
||||
// snapshotFile *SnapshotFile
|
||||
// relatedChunks []*fileChunk
|
||||
// start,end uint64
|
||||
// observedCrc32 uint32
|
||||
// }
|
||||
// Create a slice of fileChunk objects, populate it:
|
||||
// For files smaller than <threshold>, create one fileChunk
|
||||
// For files bigger than <threshold>, create N fileChunks
|
||||
// (the first one has the list of all the others)
|
||||
// Fetch them all:
|
||||
// - change the workqueue to have indexes on the fileChunk slice
|
||||
// - compute the crc32 while fetching, but don't compare right away
|
||||
// Collect results the same way, write observedCrc32 in the fileChunk
|
||||
// For each fileChunk, compare checksum:
|
||||
// - if single file, compare snapshotFile.hash with observedCrc32
|
||||
// - if multiple chunks and first chunk, merge observedCrc32, and compare
|
||||
func fetchFiles(snapshotManifest *SnapshotManifest, destinationPath string, fetchConcurrency, fetchRetryCount int) (err error) {
|
||||
// create a workQueue, a resultQueue, and the go routines
|
||||
// to process entries out of workQueue into resultQueue
|
||||
// the mutex protects the error response
|
||||
workQueue := make(chan SnapshotFile, len(snapshotManifest.Files))
|
||||
resultQueue := make(chan error, len(snapshotManifest.Files))
|
||||
mutex := sync.Mutex{}
|
||||
for i := 0; i < fetchConcurrency; i++ {
|
||||
go func() {
|
||||
for sf := range workQueue {
|
||||
// if someone else errored out, we skip our job
|
||||
mutex.Lock()
|
||||
previousError := err
|
||||
mutex.Unlock()
|
||||
if previousError != nil {
|
||||
resultQueue <- previousError
|
||||
continue
|
||||
}
|
||||
|
||||
// do our fetch, save the error
|
||||
filename := sf.getLocalFilename(destinationPath)
|
||||
furl := "http://" + snapshotManifest.Addr + path.Join(SnapshotURLPath, sf.Path)
|
||||
fetchErr := fetchFileWithRetry(furl, sf.Hash, filename, fetchRetryCount)
|
||||
if fetchErr != nil {
|
||||
mutex.Lock()
|
||||
err = fetchErr
|
||||
mutex.Unlock()
|
||||
}
|
||||
resultQueue <- fetchErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// add the jobs (writing on the channel will block if the queue
|
||||
// is full, no big deal)
|
||||
jobCount := 0
|
||||
for _, fi := range snapshotManifest.Files {
|
||||
workQueue <- fi
|
||||
jobCount++
|
||||
}
|
||||
close(workQueue)
|
||||
|
||||
// read the responses (we guarantee one response per job)
|
||||
for i := 0; i < jobCount; i++ {
|
||||
<-resultQueue
|
||||
}
|
||||
|
||||
// clean up files if we had an error
|
||||
// FIXME(alainjobart) it seems extreme to delete all files if
|
||||
// the last one failed. Maybe we shouldn't, and if a file already
|
||||
// exists, we hash it before retransmitting.
|
||||
if err != nil {
|
||||
log.Infof("Error happened, deleting all the files we already got")
|
||||
for _, fi := range snapshotManifest.Files {
|
||||
filename := fi.getLocalFilename(destinationPath)
|
||||
os.Remove(filename)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/health"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
)
|
||||
|
||||
// mysqlReplicationLag implements health.Reporter
|
||||
|
@ -15,8 +14,8 @@ type mysqlReplicationLag struct {
|
|||
}
|
||||
|
||||
// Report is part of the health.Reporter interface
|
||||
func (mrl *mysqlReplicationLag) Report(tabletType topo.TabletType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
|
||||
if !topo.IsSlaveType(tabletType) {
|
||||
func (mrl *mysqlReplicationLag) Report(isSlaveType, shouldQueryServiceBeRunning bool) (time.Duration, error) {
|
||||
if !isSlaveType {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/sqldb"
|
||||
"github.com/youtube/vitess/go/stats"
|
||||
"github.com/youtube/vitess/go/vt/dbconnpool"
|
||||
|
@ -19,21 +20,23 @@ import (
|
|||
|
||||
// MysqlDaemon is the interface we use for abstracting Mysqld.
|
||||
type MysqlDaemon interface {
|
||||
// GetMasterAddr returns the mysql master address, as shown by
|
||||
// 'show slave status'.
|
||||
GetMasterAddr() (string, error)
|
||||
// Cnf returns the underlying mycnf
|
||||
Cnf() *Mycnf
|
||||
|
||||
// methods related to mysql running or not
|
||||
Start(mysqlWaitTime time.Duration) error
|
||||
Shutdown(waitForMysqld bool, mysqlWaitTime time.Duration) error
|
||||
|
||||
// GetMysqlPort returns the current port mysql is listening on.
|
||||
GetMysqlPort() (int, error)
|
||||
|
||||
// replication related methods
|
||||
StartSlave(hookExtraEnv map[string]string) error
|
||||
StopSlave(hookExtraEnv map[string]string) error
|
||||
SlaveStatus() (*proto.ReplicationStatus, error)
|
||||
SlaveStatus() (proto.ReplicationStatus, error)
|
||||
|
||||
// reparenting related methods
|
||||
ResetReplicationCommands() ([]string, error)
|
||||
MasterPosition() (proto.ReplicationPosition, error)
|
||||
IsReadOnly() (bool, error)
|
||||
SetReadOnly(on bool) error
|
||||
StartReplicationCommands(status *proto.ReplicationStatus) ([]string, error)
|
||||
SetMasterCommands(masterHost string, masterPort int) ([]string, error)
|
||||
|
@ -52,21 +55,42 @@ type MysqlDaemon interface {
|
|||
|
||||
// Schema related methods
|
||||
GetSchema(dbName string, tables, excludeTables []string, includeViews bool) (*proto.SchemaDefinition, error)
|
||||
PreflightSchemaChange(dbName string, change string) (*proto.SchemaChangeResult, error)
|
||||
ApplySchemaChange(dbName string, change *proto.SchemaChange) (*proto.SchemaChangeResult, error)
|
||||
|
||||
// GetAppConnection returns a app connection to be able to talk to the database.
|
||||
GetAppConnection() (dbconnpool.PoolConnection, error)
|
||||
// GetDbaConnection returns a dba connection.
|
||||
GetDbaConnection() (*dbconnpool.DBConnection, error)
|
||||
// query execution methods
|
||||
|
||||
// ExecuteSuperQueryList executes a list of queries, no result
|
||||
ExecuteSuperQueryList(queryList []string) error
|
||||
|
||||
// FetchSuperQuery executes one query, returns the result
|
||||
FetchSuperQuery(query string) (*mproto.QueryResult, error)
|
||||
|
||||
// NewSlaveConnection returns a SlaveConnection to the database.
|
||||
NewSlaveConnection() (*SlaveConnection, error)
|
||||
|
||||
// EnableBinlogPlayback enables playback of binlog events
|
||||
EnableBinlogPlayback() error
|
||||
|
||||
// DisableBinlogPlayback disable playback of binlog events
|
||||
DisableBinlogPlayback() error
|
||||
|
||||
// Close will close this instance of Mysqld. It will wait for all dba
|
||||
// queries to be finished.
|
||||
Close()
|
||||
}
|
||||
|
||||
// FakeMysqlDaemon implements MysqlDaemon and allows the user to fake
|
||||
// everything.
|
||||
type FakeMysqlDaemon struct {
|
||||
// MasterAddr will be returned by GetMasterAddr(). Set to "" to return
|
||||
// ErrNotSlave, or to "ERROR" to return an error.
|
||||
MasterAddr string
|
||||
// Mycnf will be returned by Cnf()
|
||||
Mycnf *Mycnf
|
||||
|
||||
// Running is used by Start / Shutdown
|
||||
Running bool
|
||||
|
||||
// MysqlPort will be returned by GetMysqlPort(). Set to -1 to
|
||||
// return an error.
|
||||
|
@ -77,9 +101,6 @@ type FakeMysqlDaemon struct {
|
|||
// test owner responsability to have these two match)
|
||||
Replicating bool
|
||||
|
||||
// CurrentSlaveStatus is returned by SlaveStatus
|
||||
CurrentSlaveStatus *proto.ReplicationStatus
|
||||
|
||||
// ResetReplicationResult is returned by ResetReplication
|
||||
ResetReplicationResult []string
|
||||
|
||||
|
@ -87,8 +108,15 @@ type FakeMysqlDaemon struct {
|
|||
ResetReplicationError error
|
||||
|
||||
// CurrentMasterPosition is returned by MasterPosition
|
||||
// and SlaveStatus
|
||||
CurrentMasterPosition proto.ReplicationPosition
|
||||
|
||||
// CurrentMasterHost is returned by SlaveStatus
|
||||
CurrentMasterHost string
|
||||
|
||||
// CurrentMasterport is returned by SlaveStatus
|
||||
CurrentMasterPort int
|
||||
|
||||
// ReadOnly is the current value of the flag
|
||||
ReadOnly bool
|
||||
|
||||
|
@ -120,12 +148,17 @@ type FakeMysqlDaemon struct {
|
|||
// PromoteSlaveResult is returned by PromoteSlave
|
||||
PromoteSlaveResult proto.ReplicationPosition
|
||||
|
||||
// Schema that will be returned by GetSchema. If nil we'll
|
||||
// Schema will be returned by GetSchema. If nil we'll
|
||||
// return an error.
|
||||
Schema *proto.SchemaDefinition
|
||||
|
||||
// DbaConnectionFactory is the factory for making fake dba connections
|
||||
DbaConnectionFactory func() (dbconnpool.PoolConnection, error)
|
||||
// PreflightSchemaChangeResult will be returned by PreflightSchemaChange.
|
||||
// If nil we'll return an error.
|
||||
PreflightSchemaChangeResult *proto.SchemaChangeResult
|
||||
|
||||
// ApplySchemaChangeResult will be returned by ApplySchemaChange.
|
||||
// If nil we'll return an error.
|
||||
ApplySchemaChangeResult *proto.SchemaChangeResult
|
||||
|
||||
// DbAppConnectionFactory is the factory for making fake db app connections
|
||||
DbAppConnectionFactory func() (dbconnpool.PoolConnection, error)
|
||||
|
@ -141,17 +174,43 @@ type FakeMysqlDaemon struct {
|
|||
// ExpectedExecuteSuperQueryCurrent is the current index of the queries
|
||||
// we expect
|
||||
ExpectedExecuteSuperQueryCurrent int
|
||||
|
||||
// FetchSuperQueryResults is used by FetchSuperQuery
|
||||
FetchSuperQueryMap map[string]*mproto.QueryResult
|
||||
|
||||
// BinlogPlayerEnabled is used by {Enable,Disable}BinlogPlayer
|
||||
BinlogPlayerEnabled bool
|
||||
}
|
||||
|
||||
// GetMasterAddr is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) GetMasterAddr() (string, error) {
|
||||
if fmd.MasterAddr == "" {
|
||||
return "", ErrNotSlave
|
||||
// NewFakeMysqlDaemon returns a FakeMysqlDaemon where mysqld appears
|
||||
// to be running
|
||||
func NewFakeMysqlDaemon() *FakeMysqlDaemon {
|
||||
return &FakeMysqlDaemon{
|
||||
Running: true,
|
||||
}
|
||||
if fmd.MasterAddr == "ERROR" {
|
||||
return "", fmt.Errorf("FakeMysqlDaemon.GetMasterAddr returns an error")
|
||||
}
|
||||
|
||||
// Cnf is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) Cnf() *Mycnf {
|
||||
return fmd.Mycnf
|
||||
}
|
||||
|
||||
// Start is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) Start(mysqlWaitTime time.Duration) error {
|
||||
if fmd.Running {
|
||||
return fmt.Errorf("fake mysql daemon already running")
|
||||
}
|
||||
return fmd.MasterAddr, nil
|
||||
fmd.Running = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) Shutdown(waitForMysqld bool, mysqlWaitTime time.Duration) error {
|
||||
if !fmd.Running {
|
||||
return fmt.Errorf("fake mysql daemon not running")
|
||||
}
|
||||
fmd.Running = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMysqlPort is part of the MysqlDaemon interface
|
||||
|
@ -162,24 +221,15 @@ func (fmd *FakeMysqlDaemon) GetMysqlPort() (int, error) {
|
|||
return fmd.MysqlPort, nil
|
||||
}
|
||||
|
||||
// StartSlave is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) StartSlave(hookExtraEnv map[string]string) error {
|
||||
fmd.Replicating = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopSlave is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) StopSlave(hookExtraEnv map[string]string) error {
|
||||
fmd.Replicating = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// SlaveStatus is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) SlaveStatus() (*proto.ReplicationStatus, error) {
|
||||
if fmd.CurrentSlaveStatus == nil {
|
||||
return nil, fmt.Errorf("no slave status defined")
|
||||
}
|
||||
return fmd.CurrentSlaveStatus, nil
|
||||
func (fmd *FakeMysqlDaemon) SlaveStatus() (proto.ReplicationStatus, error) {
|
||||
return proto.ReplicationStatus{
|
||||
Position: fmd.CurrentMasterPosition,
|
||||
SlaveIORunning: fmd.Replicating,
|
||||
SlaveSQLRunning: fmd.Replicating,
|
||||
MasterHost: fmd.CurrentMasterHost,
|
||||
MasterPort: fmd.CurrentMasterPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ResetReplicationCommands is part of the MysqlDaemon interface
|
||||
|
@ -192,6 +242,11 @@ func (fmd *FakeMysqlDaemon) MasterPosition() (proto.ReplicationPosition, error)
|
|||
return fmd.CurrentMasterPosition, nil
|
||||
}
|
||||
|
||||
// IsReadOnly is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) IsReadOnly() (bool, error) {
|
||||
return fmd.ReadOnly, nil
|
||||
}
|
||||
|
||||
// SetReadOnly is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) SetReadOnly(on bool) error {
|
||||
fmd.ReadOnly = on
|
||||
|
@ -261,10 +316,58 @@ func (fmd *FakeMysqlDaemon) ExecuteSuperQueryList(queryList []string) error {
|
|||
if expected != query {
|
||||
return fmt.Errorf("wrong query for ExecuteSuperQueryList: expected %v got %v", expected, query)
|
||||
}
|
||||
|
||||
// intercept some queries to update our status
|
||||
switch query {
|
||||
case SqlStartSlave:
|
||||
fmd.Replicating = true
|
||||
case SqlStopSlave:
|
||||
fmd.Replicating = false
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchSuperQuery returns the results from the map, if any
|
||||
func (fmd *FakeMysqlDaemon) FetchSuperQuery(query string) (*mproto.QueryResult, error) {
|
||||
if fmd.FetchSuperQueryMap == nil {
|
||||
return nil, fmt.Errorf("unexpected query: %v", query)
|
||||
}
|
||||
|
||||
qr, ok := fmd.FetchSuperQueryMap[query]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected query: %v", query)
|
||||
}
|
||||
return qr, nil
|
||||
}
|
||||
|
||||
// NewSlaveConnection is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) NewSlaveConnection() (*SlaveConnection, error) {
|
||||
panic(fmt.Errorf("not implemented on FakeMysqlDaemon"))
|
||||
}
|
||||
|
||||
// EnableBinlogPlayback is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) EnableBinlogPlayback() error {
|
||||
if fmd.BinlogPlayerEnabled {
|
||||
return fmt.Errorf("binlog player already enabled")
|
||||
}
|
||||
fmd.BinlogPlayerEnabled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableBinlogPlayback disable playback of binlog events
|
||||
func (fmd *FakeMysqlDaemon) DisableBinlogPlayback() error {
|
||||
if fmd.BinlogPlayerEnabled {
|
||||
return fmt.Errorf("binlog player already disabled")
|
||||
}
|
||||
fmd.BinlogPlayerEnabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) Close() {
|
||||
}
|
||||
|
||||
// CheckSuperQueryList returns an error if all the queries we expected
|
||||
// haven't been seen.
|
||||
func (fmd *FakeMysqlDaemon) CheckSuperQueryList() error {
|
||||
|
@ -282,6 +385,22 @@ func (fmd *FakeMysqlDaemon) GetSchema(dbName string, tables, excludeTables []str
|
|||
return fmd.Schema.FilterTables(tables, excludeTables, includeViews)
|
||||
}
|
||||
|
||||
// PreflightSchemaChange is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) PreflightSchemaChange(dbName string, change string) (*proto.SchemaChangeResult, error) {
|
||||
if fmd.PreflightSchemaChangeResult == nil {
|
||||
return nil, fmt.Errorf("no preflight result defined")
|
||||
}
|
||||
return fmd.PreflightSchemaChangeResult, nil
|
||||
}
|
||||
|
||||
// ApplySchemaChange is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) ApplySchemaChange(dbName string, change *proto.SchemaChange) (*proto.SchemaChangeResult, error) {
|
||||
if fmd.ApplySchemaChangeResult == nil {
|
||||
return nil, fmt.Errorf("no apply schema defined")
|
||||
}
|
||||
return fmd.ApplySchemaChangeResult, nil
|
||||
}
|
||||
|
||||
// GetAppConnection is part of the MysqlDaemon interface
|
||||
func (fmd *FakeMysqlDaemon) GetAppConnection() (dbconnpool.PoolConnection, error) {
|
||||
if fmd.DbAppConnectionFactory == nil {
|
||||
|
|
|
@ -30,7 +30,7 @@ type MysqlFlavor interface {
|
|||
MasterPosition(mysqld *Mysqld) (proto.ReplicationPosition, error)
|
||||
|
||||
// SlaveStatus returns the ReplicationStatus of a slave.
|
||||
SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error)
|
||||
SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error)
|
||||
|
||||
// ResetReplicationCommands returns the commands to completely reset
|
||||
// replication on the host.
|
||||
|
@ -63,7 +63,7 @@ type MysqlFlavor interface {
|
|||
// SendBinlogDumpCommand sends the flavor-specific version of
|
||||
// the COM_BINLOG_DUMP command to start dumping raw binlog
|
||||
// events over a slave connection, starting at a given GTID.
|
||||
SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error
|
||||
SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error
|
||||
|
||||
// MakeBinlogEvent takes a raw packet from the MySQL binlog
|
||||
// stream connection and returns a BinlogEvent through which
|
||||
|
@ -120,7 +120,7 @@ func (mysqld *Mysqld) detectFlavor() (MysqlFlavor, error) {
|
|||
|
||||
// If no environment variable set, fall back to auto-detect.
|
||||
log.Infof("MYSQL_FLAVOR empty or unset, attempting to auto-detect...")
|
||||
qr, err := mysqld.fetchSuperQuery("SELECT VERSION()")
|
||||
qr, err := mysqld.FetchSuperQuery("SELECT VERSION()")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't SELECT VERSION(): %v", err)
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ func (*mariaDB10) VersionMatch(version string) bool {
|
|||
|
||||
// MasterPosition implements MysqlFlavor.MasterPosition().
|
||||
func (flavor *mariaDB10) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPosition, err error) {
|
||||
qr, err := mysqld.fetchSuperQuery("SELECT @@GLOBAL.gtid_binlog_pos")
|
||||
qr, err := mysqld.FetchSuperQuery("SELECT @@GLOBAL.gtid_binlog_pos")
|
||||
if err != nil {
|
||||
return rp, err
|
||||
}
|
||||
|
@ -40,16 +40,16 @@ func (flavor *mariaDB10) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPos
|
|||
}
|
||||
|
||||
// SlaveStatus implements MysqlFlavor.SlaveStatus().
|
||||
func (flavor *mariaDB10) SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error) {
|
||||
func (flavor *mariaDB10) SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error) {
|
||||
fields, err := mysqld.fetchSuperQueryMap("SHOW ALL SLAVES STATUS")
|
||||
if err != nil {
|
||||
return nil, ErrNotSlave
|
||||
return proto.ReplicationStatus{}, ErrNotSlave
|
||||
}
|
||||
status := parseSlaveStatus(fields)
|
||||
|
||||
status.Position, err = flavor.ParseReplicationPosition(fields["Gtid_Slave_Pos"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SlaveStatus can't parse MariaDB GTID (Gtid_Slave_Pos: %#v): %v", fields["Gtid_Slave_Pos"], err)
|
||||
return proto.ReplicationStatus{}, fmt.Errorf("SlaveStatus can't parse MariaDB GTID (Gtid_Slave_Pos: %#v): %v", fields["Gtid_Slave_Pos"], err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ func (*mariaDB10) WaitMasterPos(mysqld *Mysqld, targetPos proto.ReplicationPosit
|
|||
}
|
||||
|
||||
log.Infof("Waiting for minimum replication position with query: %v", query)
|
||||
qr, err := mysqld.fetchSuperQuery(query)
|
||||
qr, err := mysqld.FetchSuperQuery(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("MASTER_GTID_WAIT() failed: %v", err)
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ func (*mariaDB10) ParseReplicationPosition(s string) (proto.ReplicationPosition,
|
|||
}
|
||||
|
||||
// SendBinlogDumpCommand implements MysqlFlavor.SendBinlogDumpCommand().
|
||||
func (*mariaDB10) SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error {
|
||||
func (*mariaDB10) SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error {
|
||||
const ComBinlogDump = 0x12
|
||||
|
||||
// Tell the server that we understand GTIDs by setting our slave capability
|
||||
|
|
|
@ -30,7 +30,7 @@ func (*mysql56) VersionMatch(version string) bool {
|
|||
|
||||
// MasterPosition implements MysqlFlavor.MasterPosition().
|
||||
func (flavor *mysql56) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPosition, err error) {
|
||||
qr, err := mysqld.fetchSuperQuery("SELECT @@GLOBAL.gtid_executed")
|
||||
qr, err := mysqld.FetchSuperQuery("SELECT @@GLOBAL.gtid_executed")
|
||||
if err != nil {
|
||||
return rp, err
|
||||
}
|
||||
|
@ -41,16 +41,16 @@ func (flavor *mysql56) MasterPosition(mysqld *Mysqld) (rp proto.ReplicationPosit
|
|||
}
|
||||
|
||||
// SlaveStatus implements MysqlFlavor.SlaveStatus().
|
||||
func (flavor *mysql56) SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error) {
|
||||
func (flavor *mysql56) SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error) {
|
||||
fields, err := mysqld.fetchSuperQueryMap("SHOW SLAVE STATUS")
|
||||
if err != nil {
|
||||
return nil, ErrNotSlave
|
||||
return proto.ReplicationStatus{}, ErrNotSlave
|
||||
}
|
||||
status := parseSlaveStatus(fields)
|
||||
|
||||
status.Position, err = flavor.ParseReplicationPosition(fields["Executed_Gtid_Set"])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SlaveStatus can't parse MySQL 5.6 GTID (Executed_Gtid_Set: %#v): %v", fields["Executed_Gtid_Set"], err)
|
||||
return proto.ReplicationStatus{}, fmt.Errorf("SlaveStatus can't parse MySQL 5.6 GTID (Executed_Gtid_Set: %#v): %v", fields["Executed_Gtid_Set"], err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func (*mysql56) WaitMasterPos(mysqld *Mysqld, targetPos proto.ReplicationPositio
|
|||
query = fmt.Sprintf("SELECT WAIT_UNTIL_SQL_THREAD_AFTER_GTIDS('%s', %v)", targetPos, int(waitTimeout.Seconds()))
|
||||
|
||||
log.Infof("Waiting for minimum replication position with query: %v", query)
|
||||
qr, err := mysqld.fetchSuperQuery(query)
|
||||
qr, err := mysqld.FetchSuperQuery(query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("WAIT_UNTIL_SQL_THREAD_AFTER_GTIDS() failed: %v", err)
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func (*mysql56) ParseReplicationPosition(s string) (proto.ReplicationPosition, e
|
|||
}
|
||||
|
||||
// SendBinlogDumpCommand implements MysqlFlavor.SendBinlogDumpCommand().
|
||||
func (flavor *mysql56) SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error {
|
||||
func (flavor *mysql56) SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error {
|
||||
const ComBinlogDumpGTID = 0x1E // COM_BINLOG_DUMP_GTID
|
||||
|
||||
gtidSet, ok := startPos.GTIDSet.(proto.Mysql56GTIDSet)
|
||||
|
|
|
@ -24,7 +24,7 @@ func (fakeMysqlFlavor) MakeBinlogEvent(buf []byte) blproto.BinlogEvent { return
|
|||
func (fakeMysqlFlavor) ParseReplicationPosition(string) (proto.ReplicationPosition, error) {
|
||||
return proto.ReplicationPosition{}, nil
|
||||
}
|
||||
func (fakeMysqlFlavor) SendBinlogDumpCommand(mysqld *Mysqld, conn *SlaveConnection, startPos proto.ReplicationPosition) error {
|
||||
func (fakeMysqlFlavor) SendBinlogDumpCommand(conn *SlaveConnection, startPos proto.ReplicationPosition) error {
|
||||
return nil
|
||||
}
|
||||
func (fakeMysqlFlavor) WaitMasterPos(mysqld *Mysqld, targetPos proto.ReplicationPosition, waitTimeout time.Duration) error {
|
||||
|
@ -33,7 +33,9 @@ func (fakeMysqlFlavor) WaitMasterPos(mysqld *Mysqld, targetPos proto.Replication
|
|||
func (fakeMysqlFlavor) MasterPosition(mysqld *Mysqld) (proto.ReplicationPosition, error) {
|
||||
return proto.ReplicationPosition{}, nil
|
||||
}
|
||||
func (fakeMysqlFlavor) SlaveStatus(mysqld *Mysqld) (*proto.ReplicationStatus, error) { return nil, nil }
|
||||
func (fakeMysqlFlavor) SlaveStatus(mysqld *Mysqld) (proto.ReplicationStatus, error) {
|
||||
return proto.ReplicationStatus{}, nil
|
||||
}
|
||||
func (fakeMysqlFlavor) StartReplicationCommands(params *sqldb.ConnParams, status *proto.ReplicationStatus) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -42,8 +42,6 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
// TODO(aaijazi): for reasons I don't understand, the dba pool size needs to be fairly large (15+)
|
||||
// for test/clone.py to pass.
|
||||
dbaPoolSize = flag.Int("dba_pool_size", 20, "Size of the connection pool for dba connections")
|
||||
dbaIdleTimeout = flag.Duration("dba_idle_timeout", time.Minute, "Idle timeout for dba connections")
|
||||
appPoolSize = flag.Int("app_pool_size", 40, "Size of the connection pool for app connections")
|
||||
|
@ -402,7 +400,7 @@ func (mysqld *Mysqld) initConfig(root string) error {
|
|||
|
||||
func (mysqld *Mysqld) createDirs() error {
|
||||
log.Infof("creating directory %s", mysqld.TabletDir)
|
||||
if err := os.MkdirAll(mysqld.TabletDir, 0775); err != nil {
|
||||
if err := os.MkdirAll(mysqld.TabletDir, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, dir := range TopLevelDirs() {
|
||||
|
@ -412,7 +410,7 @@ func (mysqld *Mysqld) createDirs() error {
|
|||
}
|
||||
for _, dir := range mysqld.config.directoryList() {
|
||||
log.Infof("creating directory %s", dir)
|
||||
if err := os.MkdirAll(dir, 0775); err != nil {
|
||||
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
// FIXME(msolomon) validate permissions?
|
||||
|
@ -435,14 +433,14 @@ func (mysqld *Mysqld) createTopDir(dir string) error {
|
|||
if os.IsNotExist(err) {
|
||||
topdir := path.Join(mysqld.TabletDir, dir)
|
||||
log.Infof("creating directory %s", topdir)
|
||||
return os.MkdirAll(topdir, 0775)
|
||||
return os.MkdirAll(topdir, os.ModePerm)
|
||||
}
|
||||
return err
|
||||
}
|
||||
linkto := path.Join(target, vtname)
|
||||
source := path.Join(mysqld.TabletDir, dir)
|
||||
log.Infof("creating directory %s", linkto)
|
||||
err = os.MkdirAll(linkto, 0775)
|
||||
err = os.MkdirAll(linkto, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,11 +8,12 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
|
||||
)
|
||||
|
||||
func (mysqld *Mysqld) GetPermissions() (*proto.Permissions, error) {
|
||||
// GetPermissions lists the permissions on the mysqld
|
||||
func GetPermissions(mysqld MysqlDaemon) (*proto.Permissions, error) {
|
||||
permissions := &proto.Permissions{}
|
||||
|
||||
// get Users
|
||||
qr, err := mysqld.fetchSuperQuery("SELECT * FROM mysql.user")
|
||||
qr, err := mysqld.FetchSuperQuery("SELECT * FROM mysql.user")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -21,7 +22,7 @@ func (mysqld *Mysqld) GetPermissions() (*proto.Permissions, error) {
|
|||
}
|
||||
|
||||
// get Dbs
|
||||
qr, err = mysqld.fetchSuperQuery("SELECT * FROM mysql.db")
|
||||
qr, err = mysqld.FetchSuperQuery("SELECT * FROM mysql.db")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -30,7 +31,7 @@ func (mysqld *Mysqld) GetPermissions() (*proto.Permissions, error) {
|
|||
}
|
||||
|
||||
// get Hosts
|
||||
qr, err = mysqld.fetchSuperQuery("SELECT * FROM mysql.host")
|
||||
qr, err = mysqld.FetchSuperQuery("SELECT * FROM mysql.host")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/youtube/vitess/go/jscfg"
|
||||
|
@ -17,41 +16,38 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
TABLE_BASE_TABLE = "BASE TABLE"
|
||||
TABLE_VIEW = "VIEW"
|
||||
// TableBaseTable indicates the table type is a base table.
|
||||
TableBaseTable = "BASE TABLE"
|
||||
// TableView indicates the table type is a view.
|
||||
TableView = "VIEW"
|
||||
)
|
||||
|
||||
// TableDefinition contains all schema information about a table.
|
||||
type TableDefinition struct {
|
||||
Name string // the table name
|
||||
Schema string // the SQL to run to create the table
|
||||
Columns []string // the columns in the order that will be used to dump and load the data
|
||||
PrimaryKeyColumns []string // the columns used by the primary key, in order
|
||||
Type string // TABLE_BASE_TABLE or TABLE_VIEW
|
||||
Type string // TableBaseTable or TableView
|
||||
DataLength uint64 // how much space the data file takes.
|
||||
RowCount uint64 // how many rows in the table (may
|
||||
// be approximate count)
|
||||
}
|
||||
|
||||
// helper methods for sorting
|
||||
// TableDefinitions is a list of TableDefinition.
|
||||
type TableDefinitions []*TableDefinition
|
||||
|
||||
// Len returns TableDefinitions length.
|
||||
func (tds TableDefinitions) Len() int {
|
||||
return len(tds)
|
||||
}
|
||||
|
||||
// Swap used for sorting TableDefinitions.
|
||||
func (tds TableDefinitions) Swap(i, j int) {
|
||||
tds[i], tds[j] = tds[j], tds[i]
|
||||
}
|
||||
|
||||
// sort by reverse DataLength
|
||||
type ByReverseDataLength struct {
|
||||
TableDefinitions
|
||||
}
|
||||
|
||||
func (bdl ByReverseDataLength) Less(i, j int) bool {
|
||||
return bdl.TableDefinitions[j].DataLength < bdl.TableDefinitions[i].DataLength
|
||||
}
|
||||
|
||||
// SchemaDefinition defines schema for a certain database.
|
||||
type SchemaDefinition struct {
|
||||
// the 'CREATE DATABASE...' statement, with db name as {{.DatabaseName}}
|
||||
DatabaseSchema string
|
||||
|
@ -67,10 +63,6 @@ func (sd *SchemaDefinition) String() string {
|
|||
return jscfg.ToJSON(sd)
|
||||
}
|
||||
|
||||
func (sd *SchemaDefinition) SortByReverseDataLength() {
|
||||
sort.Sort(ByReverseDataLength{sd.TableDefinitions})
|
||||
}
|
||||
|
||||
// FilterTables returns a copy which includes only
|
||||
// whitelisted tables (tables), no blacklisted tables (excludeTables) and optionally views (includeViews).
|
||||
func (sd *SchemaDefinition) FilterTables(tables, excludeTables []string, includeViews bool) (*SchemaDefinition, error) {
|
||||
|
@ -126,7 +118,7 @@ func (sd *SchemaDefinition) FilterTables(tables, excludeTables []string, include
|
|||
continue
|
||||
}
|
||||
|
||||
if !includeViews && table.Type == TABLE_VIEW {
|
||||
if !includeViews && table.Type == TableView {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -141,6 +133,8 @@ func (sd *SchemaDefinition) FilterTables(tables, excludeTables []string, include
|
|||
return ©, nil
|
||||
}
|
||||
|
||||
// GenerateSchemaVersion return a unique schema version string based on
|
||||
// its TableDefinitions.
|
||||
func (sd *SchemaDefinition) GenerateSchemaVersion() {
|
||||
hasher := md5.New()
|
||||
for _, td := range sd.TableDefinitions {
|
||||
|
@ -151,6 +145,7 @@ func (sd *SchemaDefinition) GenerateSchemaVersion() {
|
|||
sd.Version = hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// GetTable returns TableDefinition for a given table name.
|
||||
func (sd *SchemaDefinition) GetTable(table string) (td *TableDefinition, ok bool) {
|
||||
for _, td := range sd.TableDefinitions {
|
||||
if td.Name == table {
|
||||
|
@ -170,7 +165,7 @@ func (sd *SchemaDefinition) ToSQLStrings() []string {
|
|||
sqlStrings = append(sqlStrings, sd.DatabaseSchema)
|
||||
|
||||
for _, td := range sd.TableDefinitions {
|
||||
if td.Type == TABLE_VIEW {
|
||||
if td.Type == TableView {
|
||||
createViewSql = append(createViewSql, td.Schema)
|
||||
} else {
|
||||
lines := strings.Split(td.Schema, "\n")
|
||||
|
@ -186,9 +181,16 @@ func (sd *SchemaDefinition) ToSQLStrings() []string {
|
|||
return append(sqlStrings, createViewSql...)
|
||||
}
|
||||
|
||||
// generates a report on what's different between two SchemaDefinition
|
||||
// for now, we skip the VIEW entirely.
|
||||
// DiffSchema generates a report on what's different between two SchemaDefinitions
|
||||
// including views.
|
||||
func DiffSchema(leftName string, left *SchemaDefinition, rightName string, right *SchemaDefinition, er concurrency.ErrorRecorder) {
|
||||
if left == nil && right == nil {
|
||||
return
|
||||
}
|
||||
if left == nil || right == nil {
|
||||
er.RecordError(fmt.Errorf("%v and %v are different, %s: %v, %s: %v", leftName, rightName, leftName, left, rightName, right))
|
||||
return
|
||||
}
|
||||
if left.DatabaseSchema != right.DatabaseSchema {
|
||||
er.RecordError(fmt.Errorf("%v and %v don't agree on database creation command:\n%v\n differs from:\n%v", leftName, rightName, left.DatabaseSchema, right.DatabaseSchema))
|
||||
}
|
||||
|
@ -196,16 +198,6 @@ func DiffSchema(leftName string, left *SchemaDefinition, rightName string, right
|
|||
leftIndex := 0
|
||||
rightIndex := 0
|
||||
for leftIndex < len(left.TableDefinitions) && rightIndex < len(right.TableDefinitions) {
|
||||
// skip views
|
||||
if left.TableDefinitions[leftIndex].Type == TABLE_VIEW {
|
||||
leftIndex++
|
||||
continue
|
||||
}
|
||||
if right.TableDefinitions[rightIndex].Type == TABLE_VIEW {
|
||||
rightIndex++
|
||||
continue
|
||||
}
|
||||
|
||||
// extra table on the left side
|
||||
if left.TableDefinitions[leftIndex].Name < right.TableDefinitions[rightIndex].Name {
|
||||
er.RecordError(fmt.Errorf("%v has an extra table named %v", leftName, left.TableDefinitions[leftIndex].Name))
|
||||
|
@ -224,34 +216,46 @@ func DiffSchema(leftName string, left *SchemaDefinition, rightName string, right
|
|||
if left.TableDefinitions[leftIndex].Schema != right.TableDefinitions[rightIndex].Schema {
|
||||
er.RecordError(fmt.Errorf("%v and %v disagree on schema for table %v:\n%v\n differs from:\n%v", leftName, rightName, left.TableDefinitions[leftIndex].Name, left.TableDefinitions[leftIndex].Schema, right.TableDefinitions[rightIndex].Schema))
|
||||
}
|
||||
|
||||
if left.TableDefinitions[leftIndex].Type != right.TableDefinitions[rightIndex].Type {
|
||||
er.RecordError(fmt.Errorf("%v and %v disagree on table type for table %v:\n%v\n differs from:\n%v", leftName, rightName, left.TableDefinitions[leftIndex].Name, left.TableDefinitions[leftIndex].Type, right.TableDefinitions[rightIndex].Type))
|
||||
}
|
||||
|
||||
leftIndex++
|
||||
rightIndex++
|
||||
}
|
||||
|
||||
for leftIndex < len(left.TableDefinitions) {
|
||||
if left.TableDefinitions[leftIndex].Type == TABLE_BASE_TABLE {
|
||||
if left.TableDefinitions[leftIndex].Type == TableBaseTable {
|
||||
er.RecordError(fmt.Errorf("%v has an extra table named %v", leftName, left.TableDefinitions[leftIndex].Name))
|
||||
}
|
||||
if left.TableDefinitions[leftIndex].Type == TableView {
|
||||
er.RecordError(fmt.Errorf("%v has an extra view named %v", leftName, left.TableDefinitions[leftIndex].Name))
|
||||
}
|
||||
leftIndex++
|
||||
}
|
||||
for rightIndex < len(right.TableDefinitions) {
|
||||
if right.TableDefinitions[rightIndex].Type == TABLE_BASE_TABLE {
|
||||
if right.TableDefinitions[rightIndex].Type == TableBaseTable {
|
||||
er.RecordError(fmt.Errorf("%v has an extra table named %v", rightName, right.TableDefinitions[rightIndex].Name))
|
||||
}
|
||||
if right.TableDefinitions[rightIndex].Type == TableView {
|
||||
er.RecordError(fmt.Errorf("%v has an extra view named %v", rightName, right.TableDefinitions[rightIndex].Name))
|
||||
}
|
||||
rightIndex++
|
||||
}
|
||||
}
|
||||
|
||||
// DiffSchemaToArray diffs two schemas and return the schema diffs if there is any.
|
||||
func DiffSchemaToArray(leftName string, left *SchemaDefinition, rightName string, right *SchemaDefinition) (result []string) {
|
||||
er := concurrency.AllErrorRecorder{}
|
||||
DiffSchema(leftName, left, rightName, right, &er)
|
||||
if er.HasErrors() {
|
||||
return er.ErrorStrings()
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SchemaChange contains all necessary information to apply a schema change.
|
||||
type SchemaChange struct {
|
||||
Sql string
|
||||
Force bool
|
||||
|
@ -260,6 +264,8 @@ type SchemaChange struct {
|
|||
AfterSchema *SchemaDefinition
|
||||
}
|
||||
|
||||
// SchemaChangeResult contains before and after table schemas for
|
||||
// a schema change sql.
|
||||
type SchemaChangeResult struct {
|
||||
BeforeSchema *SchemaDefinition
|
||||
AfterSchema *SchemaDefinition
|
||||
|
|
|
@ -6,6 +6,7 @@ package proto
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
@ -13,12 +14,12 @@ import (
|
|||
var basicTable1 = &TableDefinition{
|
||||
Name: "table1",
|
||||
Schema: "table schema 1",
|
||||
Type: TABLE_BASE_TABLE,
|
||||
Type: TableBaseTable,
|
||||
}
|
||||
var basicTable2 = &TableDefinition{
|
||||
Name: "table2",
|
||||
Schema: "table schema 2",
|
||||
Type: TABLE_BASE_TABLE,
|
||||
Type: TableBaseTable,
|
||||
}
|
||||
|
||||
var table3 = &TableDefinition{
|
||||
|
@ -26,19 +27,19 @@ var table3 = &TableDefinition{
|
|||
Schema: "CREATE TABLE `table3` (\n" +
|
||||
"id bigint not null,\n" +
|
||||
") Engine=InnoDB",
|
||||
Type: TABLE_BASE_TABLE,
|
||||
Type: TableBaseTable,
|
||||
}
|
||||
|
||||
var view1 = &TableDefinition{
|
||||
Name: "view1",
|
||||
Schema: "view schema 1",
|
||||
Type: TABLE_VIEW,
|
||||
Type: TableView,
|
||||
}
|
||||
|
||||
var view2 = &TableDefinition{
|
||||
Name: "view2",
|
||||
Schema: "view schema 2",
|
||||
Type: TABLE_VIEW,
|
||||
Type: TableView,
|
||||
}
|
||||
|
||||
func TestToSQLStrings(t *testing.T) {
|
||||
|
@ -152,30 +153,89 @@ func TestSchemaDiff(t *testing.T) {
|
|||
&TableDefinition{
|
||||
Name: "table1",
|
||||
Schema: "schema1",
|
||||
Type: TABLE_BASE_TABLE,
|
||||
Type: TableBaseTable,
|
||||
},
|
||||
&TableDefinition{
|
||||
Name: "table2",
|
||||
Schema: "schema2",
|
||||
Type: TABLE_BASE_TABLE,
|
||||
Type: TableBaseTable,
|
||||
},
|
||||
},
|
||||
}
|
||||
testDiff(t, sd1, sd1, "sd1", "sd2", []string{})
|
||||
|
||||
sd2 := &SchemaDefinition{TableDefinitions: make([]*TableDefinition, 0, 2)}
|
||||
|
||||
sd3 := &SchemaDefinition{
|
||||
TableDefinitions: []*TableDefinition{
|
||||
&TableDefinition{
|
||||
Name: "table2",
|
||||
Schema: "schema2",
|
||||
Type: TableBaseTable,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sd4 := &SchemaDefinition{
|
||||
TableDefinitions: []*TableDefinition{
|
||||
&TableDefinition{
|
||||
Name: "table2",
|
||||
Schema: "table2",
|
||||
Type: TableView,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
sd5 := &SchemaDefinition{
|
||||
TableDefinitions: []*TableDefinition{
|
||||
&TableDefinition{
|
||||
Name: "table2",
|
||||
Schema: "table2",
|
||||
Type: TableBaseTable,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testDiff(t, sd1, sd1, "sd1", "sd2", []string{})
|
||||
|
||||
testDiff(t, sd2, sd2, "sd2", "sd2", []string{})
|
||||
|
||||
// two schemas are considered the same if both nil
|
||||
testDiff(t, nil, nil, "sd1", "sd2", nil)
|
||||
|
||||
testDiff(t, sd1, nil, "sd1", "sd2", []string{
|
||||
fmt.Sprintf("sd1 and sd2 are different, sd1: %v, sd2: null", sd1),
|
||||
})
|
||||
|
||||
testDiff(t, sd1, sd3, "sd1", "sd3", []string{
|
||||
"sd1 has an extra table named table1",
|
||||
})
|
||||
|
||||
testDiff(t, sd3, sd1, "sd3", "sd1", []string{
|
||||
"sd1 has an extra table named table1",
|
||||
})
|
||||
|
||||
testDiff(t, sd2, sd4, "sd2", "sd4", []string{
|
||||
"sd4 has an extra view named table2",
|
||||
})
|
||||
|
||||
testDiff(t, sd4, sd2, "sd4", "sd2", []string{
|
||||
"sd4 has an extra view named table2",
|
||||
})
|
||||
|
||||
testDiff(t, sd4, sd5, "sd4", "sd5", []string{
|
||||
fmt.Sprintf("sd4 and sd5 disagree on table type for table table2:\nVIEW\n differs from:\nBASE TABLE"),
|
||||
})
|
||||
|
||||
sd1.DatabaseSchema = "CREATE DATABASE {{.DatabaseName}}"
|
||||
sd2.DatabaseSchema = "DONT CREATE DATABASE {{.DatabaseName}}"
|
||||
testDiff(t, sd1, sd2, "sd1", "sd2", []string{"sd1 and sd2 don't agree on database creation command:\nCREATE DATABASE {{.DatabaseName}}\n differs from:\nDONT CREATE DATABASE {{.DatabaseName}}", "sd1 has an extra table named table1", "sd1 has an extra table named table2"})
|
||||
sd2.DatabaseSchema = "CREATE DATABASE {{.DatabaseName}}"
|
||||
testDiff(t, sd2, sd1, "sd2", "sd1", []string{"sd1 has an extra table named table1", "sd1 has an extra table named table2"})
|
||||
|
||||
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table1", Schema: "schema1", Type: TABLE_BASE_TABLE})
|
||||
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table1", Schema: "schema1", Type: TableBaseTable})
|
||||
testDiff(t, sd1, sd2, "sd1", "sd2", []string{"sd1 has an extra table named table2"})
|
||||
|
||||
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table2", Schema: "schema3", Type: TABLE_BASE_TABLE})
|
||||
sd2.TableDefinitions = append(sd2.TableDefinitions, &TableDefinition{Name: "table2", Schema: "schema3", Type: TableBaseTable})
|
||||
testDiff(t, sd1, sd2, "sd1", "sd2", []string{"sd1 and sd2 disagree on schema for table table2:\nschema2\n differs from:\nschema3"})
|
||||
}
|
||||
|
||||
|
|
|
@ -33,8 +33,8 @@ func (mysqld *Mysqld) ExecuteSuperQueryList(queryList []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// fetchSuperQuery returns the results of executing a query as a super user.
|
||||
func (mysqld *Mysqld) fetchSuperQuery(query string) (*mproto.QueryResult, error) {
|
||||
// FetchSuperQuery returns the results of executing a query as a super user.
|
||||
func (mysqld *Mysqld) FetchSuperQuery(query string) (*mproto.QueryResult, error) {
|
||||
conn, connErr := mysqld.dbaPool.Get(0)
|
||||
if connErr != nil {
|
||||
return nil, connErr
|
||||
|
@ -51,7 +51,7 @@ func (mysqld *Mysqld) fetchSuperQuery(query string) (*mproto.QueryResult, error)
|
|||
// fetchSuperQueryMap returns a map from column names to cell data for a query
|
||||
// that should return exactly 1 row.
|
||||
func (mysqld *Mysqld) fetchSuperQueryMap(query string) (map[string]string, error) {
|
||||
qr, err := mysqld.fetchSuperQuery(query)
|
||||
qr, err := mysqld.FetchSuperQuery(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,6 +69,9 @@ func (mysqld *Mysqld) fetchSuperQueryMap(query string) (map[string]string, error
|
|||
return rowMap, nil
|
||||
}
|
||||
|
||||
const masterPasswordStart = " MASTER_PASSWORD = '"
|
||||
const masterPasswordEnd = "',\n"
|
||||
|
||||
func redactMasterPassword(input string) string {
|
||||
i := strings.Index(input, masterPasswordStart)
|
||||
if i == -1 {
|
||||
|
|
|
@ -52,7 +52,7 @@ func queryReparentJournal(timeCreatedNS int64) string {
|
|||
// the row in the reparent_journal table.
|
||||
func (mysqld *Mysqld) WaitForReparentJournal(ctx context.Context, timeCreatedNS int64) error {
|
||||
for {
|
||||
qr, err := mysqld.fetchSuperQuery(queryReparentJournal(timeCreatedNS))
|
||||
qr, err := mysqld.FetchSuperQuery(queryReparentJournal(timeCreatedNS))
|
||||
if err == nil && len(qr.Rows) == 1 {
|
||||
// we have the row, we're done
|
||||
return nil
|
||||
|
@ -82,12 +82,10 @@ func (mysqld *Mysqld) DemoteMaster() (rp proto.ReplicationPosition, err error) {
|
|||
return mysqld.MasterPosition()
|
||||
}
|
||||
|
||||
// PromoteSlave will promote a slave to be the new master
|
||||
// PromoteSlave will promote a slave to be the new master.
|
||||
func (mysqld *Mysqld) PromoteSlave(hookExtraEnv map[string]string) (proto.ReplicationPosition, error) {
|
||||
// stop replication for good
|
||||
if err := mysqld.StopSlave(hookExtraEnv); err != nil {
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
// we handle replication, just stop it
|
||||
cmds := []string{SqlStopSlave}
|
||||
|
||||
// Promote to master.
|
||||
flavor, err := mysqld.flavor()
|
||||
|
@ -95,7 +93,7 @@ func (mysqld *Mysqld) PromoteSlave(hookExtraEnv map[string]string) (proto.Replic
|
|||
err = fmt.Errorf("PromoteSlave needs flavor: %v", err)
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
cmds := flavor.PromoteSlaveCommands()
|
||||
cmds = append(cmds, flavor.PromoteSlaveCommands()...)
|
||||
if err := mysqld.ExecuteSuperQueryList(cmds); err != nil {
|
||||
return proto.ReplicationPosition{}, err
|
||||
}
|
||||
|
|
|
@ -12,8 +12,6 @@ import (
|
|||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
@ -30,8 +28,13 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/mysqlctl/proto"
|
||||
)
|
||||
|
||||
var masterPasswordStart = " MASTER_PASSWORD = '"
|
||||
var masterPasswordEnd = "',\n"
|
||||
const (
|
||||
// SqlStartSlave is the SQl command issued to start MySQL replication
|
||||
SqlStartSlave = "START SLAVE"
|
||||
|
||||
// SqlStopSlave is the SQl command issued to stop MySQL replication
|
||||
SqlStopSlave = "STOP SLAVE"
|
||||
)
|
||||
|
||||
func fillStringTemplate(tmpl string, vars interface{}) (string, error) {
|
||||
myTemplate := template.Must(template.New("").Parse(tmpl))
|
||||
|
@ -69,8 +72,8 @@ func changeMasterArgs(params *sqldb.ConnParams, masterHost string, masterPort in
|
|||
}
|
||||
|
||||
// parseSlaveStatus parses the common fields of SHOW SLAVE STATUS.
|
||||
func parseSlaveStatus(fields map[string]string) *proto.ReplicationStatus {
|
||||
status := &proto.ReplicationStatus{
|
||||
func parseSlaveStatus(fields map[string]string) proto.ReplicationStatus {
|
||||
status := proto.ReplicationStatus{
|
||||
MasterHost: fields["Master_Host"],
|
||||
SlaveIORunning: fields["Slave_IO_Running"] == "Yes",
|
||||
SlaveSQLRunning: fields["Slave_SQL_Running"] == "Yes",
|
||||
|
@ -84,8 +87,9 @@ func parseSlaveStatus(fields map[string]string) *proto.ReplicationStatus {
|
|||
return status
|
||||
}
|
||||
|
||||
// WaitForSlaveStart waits a slave until given deadline passed
|
||||
func (mysqld *Mysqld) WaitForSlaveStart(slaveStartDeadline int) error {
|
||||
// WaitForSlaveStart waits until the deadline for replication to start.
|
||||
// This validates the current master is correct and can be connected to.
|
||||
func WaitForSlaveStart(mysqld MysqlDaemon, slaveStartDeadline int) error {
|
||||
var rowMap map[string]string
|
||||
for slaveWait := 0; slaveWait < slaveStartDeadline; slaveWait++ {
|
||||
status, err := mysqld.SlaveStatus()
|
||||
|
@ -112,9 +116,9 @@ func (mysqld *Mysqld) WaitForSlaveStart(slaveStartDeadline int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// StartSlave starts a slave
|
||||
func (mysqld *Mysqld) StartSlave(hookExtraEnv map[string]string) error {
|
||||
if err := mysqld.ExecuteSuperQuery("START SLAVE"); err != nil {
|
||||
// StartSlave starts a slave on the provided MysqldDaemon
|
||||
func StartSlave(md MysqlDaemon, hookExtraEnv map[string]string) error {
|
||||
if err := md.ExecuteSuperQueryList([]string{SqlStartSlave}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -123,29 +127,20 @@ func (mysqld *Mysqld) StartSlave(hookExtraEnv map[string]string) error {
|
|||
return h.ExecuteOptional()
|
||||
}
|
||||
|
||||
// StopSlave stops a slave
|
||||
func (mysqld *Mysqld) StopSlave(hookExtraEnv map[string]string) error {
|
||||
// StopSlave stops a slave on the provided MysqldDaemon
|
||||
func StopSlave(md MysqlDaemon, hookExtraEnv map[string]string) error {
|
||||
h := hook.NewSimpleHook("preflight_stop_slave")
|
||||
h.ExtraEnv = hookExtraEnv
|
||||
if err := h.ExecuteOptional(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mysqld.ExecuteSuperQuery("STOP SLAVE")
|
||||
}
|
||||
|
||||
// GetMasterAddr returns master address
|
||||
func (mysqld *Mysqld) GetMasterAddr() (string, error) {
|
||||
slaveStatus, err := mysqld.SlaveStatus()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return slaveStatus.MasterAddr(), nil
|
||||
return md.ExecuteSuperQueryList([]string{SqlStopSlave})
|
||||
}
|
||||
|
||||
// GetMysqlPort returns mysql port
|
||||
func (mysqld *Mysqld) GetMysqlPort() (int, error) {
|
||||
qr, err := mysqld.fetchSuperQuery("SHOW VARIABLES LIKE 'port'")
|
||||
qr, err := mysqld.FetchSuperQuery("SHOW VARIABLES LIKE 'port'")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -161,7 +156,7 @@ func (mysqld *Mysqld) GetMysqlPort() (int, error) {
|
|||
|
||||
// IsReadOnly return true if the instance is read only
|
||||
func (mysqld *Mysqld) IsReadOnly() (bool, error) {
|
||||
qr, err := mysqld.fetchSuperQuery("SHOW VARIABLES LIKE 'read_only'")
|
||||
qr, err := mysqld.FetchSuperQuery("SHOW VARIABLES LIKE 'read_only'")
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
@ -202,10 +197,10 @@ func (mysqld *Mysqld) WaitMasterPos(targetPos proto.ReplicationPosition, waitTim
|
|||
}
|
||||
|
||||
// SlaveStatus returns the slave replication statuses
|
||||
func (mysqld *Mysqld) SlaveStatus() (*proto.ReplicationStatus, error) {
|
||||
func (mysqld *Mysqld) SlaveStatus() (proto.ReplicationStatus, error) {
|
||||
flavor, err := mysqld.flavor()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SlaveStatus needs flavor: %v", err)
|
||||
return proto.ReplicationStatus{}, fmt.Errorf("SlaveStatus needs flavor: %v", err)
|
||||
}
|
||||
return flavor.SlaveStatus(mysqld)
|
||||
}
|
||||
|
@ -250,43 +245,6 @@ func (mysqld *Mysqld) SetMasterCommands(masterHost string, masterPort int) ([]st
|
|||
return flavor.SetMasterCommands(¶ms, masterHost, masterPort, int(masterConnectRetry.Seconds()))
|
||||
}
|
||||
|
||||
// WaitForSlave waits for a slave if its lag is larger than given maxLag
|
||||
func (mysqld *Mysqld) WaitForSlave(maxLag int) (err error) {
|
||||
// FIXME(msolomon) verify that slave started based on show slave status;
|
||||
var rowMap map[string]string
|
||||
for {
|
||||
rowMap, err = mysqld.fetchSuperQueryMap("SHOW SLAVE STATUS")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if rowMap["Seconds_Behind_Master"] == "NULL" {
|
||||
break
|
||||
} else {
|
||||
lag, err := strconv.Atoi(rowMap["Seconds_Behind_Master"])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if lag < maxLag {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
errorKeys := []string{"Last_Error", "Last_IO_Error", "Last_SQL_Error"}
|
||||
errs := make([]string, 0, len(errorKeys))
|
||||
for _, key := range errorKeys {
|
||||
if rowMap[key] != "" {
|
||||
errs = append(errs, key+": "+rowMap[key])
|
||||
}
|
||||
}
|
||||
if len(errs) != 0 {
|
||||
return errors.New(strings.Join(errs, ", "))
|
||||
}
|
||||
return errors.New("replication stopped, it will never catch up")
|
||||
}
|
||||
|
||||
// ResetReplicationCommands returns the commands to run to reset all
|
||||
// replication for this host.
|
||||
func (mysqld *Mysqld) ResetReplicationCommands() ([]string, error) {
|
||||
|
@ -319,8 +277,8 @@ const (
|
|||
)
|
||||
|
||||
// FindSlaves gets IP addresses for all currently connected slaves.
|
||||
func (mysqld *Mysqld) FindSlaves() ([]string, error) {
|
||||
qr, err := mysqld.fetchSuperQuery("SHOW PROCESSLIST")
|
||||
func FindSlaves(mysqld MysqlDaemon) ([]string, error) {
|
||||
qr, err := mysqld.FetchSuperQuery("SHOW PROCESSLIST")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -339,25 +297,9 @@ func (mysqld *Mysqld) FindSlaves() ([]string, error) {
|
|||
return addrs, nil
|
||||
}
|
||||
|
||||
// ValidateSnapshotPath is a helper function to make sure we can write to the local snapshot area, before we actually do any action
|
||||
// (can be used for both partial and full snapshots)
|
||||
func (mysqld *Mysqld) ValidateSnapshotPath() error {
|
||||
_path := path.Join(mysqld.SnapshotDir, "validate_test")
|
||||
if err := os.RemoveAll(_path); err != nil {
|
||||
return fmt.Errorf("ValidateSnapshotPath: Cannot validate snapshot directory: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(_path, 0775); err != nil {
|
||||
return fmt.Errorf("ValidateSnapshotPath: Cannot validate snapshot directory: %v", err)
|
||||
}
|
||||
if err := os.RemoveAll(_path); err != nil {
|
||||
return fmt.Errorf("ValidateSnapshotPath: Cannot validate snapshot directory: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitBlpPosition will wait for the filtered replication to reach at least
|
||||
// the provided position.
|
||||
func (mysqld *Mysqld) WaitBlpPosition(bp *blproto.BlpPosition, waitTimeout time.Duration) error {
|
||||
func WaitBlpPosition(mysqld MysqlDaemon, bp *blproto.BlpPosition, waitTimeout time.Duration) error {
|
||||
timeOut := time.Now().Add(waitTimeout)
|
||||
for {
|
||||
if time.Now().After(timeOut) {
|
||||
|
@ -365,7 +307,7 @@ func (mysqld *Mysqld) WaitBlpPosition(bp *blproto.BlpPosition, waitTimeout time.
|
|||
}
|
||||
|
||||
cmd := binlogplayer.QueryBlpCheckpoint(bp.Uid)
|
||||
qr, err := mysqld.fetchSuperQuery(cmd)
|
||||
qr, err := mysqld.FetchSuperQuery(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче