зеркало из https://github.com/github/vitess-gh.git
Merge branch 'suguwork' into sequence
This commit is contained in:
Коммит
bed67e8fff
|
@ -39,7 +39,6 @@ cache:
|
|||
- $HOME/gopath/dist/grpc/usr/local
|
||||
- $HOME/gopath/dist/py-mock-1.0.1/.build_finished
|
||||
- $HOME/gopath/dist/py-mock-1.0.1/lib/python2.7/site-packages
|
||||
- $HOME/gopath/dist/py-vt-bson-0.3.2/lib/python2.7/site-packages
|
||||
- $HOME/gopath/dist/vt-zookeeper-3.4.6/.build_finished
|
||||
- $HOME/gopath/dist/vt-zookeeper-3.4.6/include
|
||||
- $HOME/gopath/dist/vt-zookeeper-3.4.6/lib
|
||||
|
|
5
Makefile
5
Makefile
|
@ -8,7 +8,7 @@ MAKEFLAGS = -s
|
|||
# Since we are not using this Makefile for compilation, limiting parallelism will not increase build time.
|
||||
.NOTPARALLEL:
|
||||
|
||||
.PHONY: all build test clean unit_test unit_test_cover unit_test_race integration_test bson proto site_test site_integration_test docker_bootstrap docker_test docker_unit_test java_test php_test reshard_tests
|
||||
.PHONY: all build test clean unit_test unit_test_cover unit_test_race integration_test proto site_test site_integration_test docker_bootstrap docker_test docker_unit_test java_test php_test reshard_tests
|
||||
|
||||
all: build test
|
||||
|
||||
|
@ -87,9 +87,6 @@ php_test:
|
|||
godep go install ./go/cmd/vtgateclienttest
|
||||
phpunit php/tests
|
||||
|
||||
bson:
|
||||
go generate ./go/...
|
||||
|
||||
# This rule rebuilds all the go files from the proto definitions for gRPC.
|
||||
# 1. list all proto files.
|
||||
# 2. remove 'proto/' prefix and '.proto' suffix.
|
||||
|
|
10
bootstrap.sh
10
bootstrap.sh
|
@ -181,16 +181,6 @@ else
|
|||
echo "Libs:" "$($VT_MYSQL_ROOT/bin/mysql_config --libs_r)" >> $VTROOT/lib/gomysql.pc
|
||||
fi
|
||||
|
||||
# install bson
|
||||
bson_dist=$VTROOT/dist/py-vt-bson-0.3.2
|
||||
if [ -f $bson_dist/lib/python2.7/site-packages/bson/__init__.py ]; then
|
||||
echo "skipping bson python build"
|
||||
else
|
||||
cd $VTTOP/third_party/py/bson-0.3.2 && \
|
||||
python ./setup.py install --prefix=$bson_dist && \
|
||||
rm -r build
|
||||
fi
|
||||
|
||||
# install mock
|
||||
mock_dist=$VTROOT/dist/py-mock-1.0.1
|
||||
if [ -f $mock_dist/.build_finished ]; then
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
type MyType struct {
|
||||
Cust1 Custom1
|
||||
Cust2 *Custom1
|
||||
Cust3 pkg.Custom2
|
||||
Cust4 *pkg.Custom2
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
type MyType int
|
|
@ -1,17 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
type MyType struct {
|
||||
Map map[string]string
|
||||
MapBytes map[string][]byte
|
||||
MapPtr map[string]*string
|
||||
MapSlice map[string][]string
|
||||
MapMap map[string]map[string]int64
|
||||
MapCustom map[string]Custom
|
||||
MapCustomPtr map[string]*Custom
|
||||
CustomMap map[Custom]string
|
||||
MapExternal map[pkg.Custom]string
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
func foo() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
type MyType struct {
|
||||
Val int64
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
type MyType struct {
|
||||
Public int64
|
||||
private int64
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
type MyType struct {
|
||||
Ptr *int64
|
||||
PtrPtr **int64
|
||||
PtrBytes *[]byte
|
||||
PtrSlice *[]int64
|
||||
PtrMap *map[string]int64
|
||||
PtrCustom *Custom
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import "time"
|
||||
|
||||
type MyType struct {
|
||||
Float64 float64
|
||||
String string
|
||||
Bool bool
|
||||
Int64 int64
|
||||
Int32 int32
|
||||
Int int
|
||||
Uint64 uint64
|
||||
Uint32 uint32
|
||||
Uint uint
|
||||
Bytes []byte
|
||||
Time time.Time
|
||||
Interface interface{}
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
type MyType struct {
|
||||
Slice []string
|
||||
SliceBytes [][]byte
|
||||
SlicePtr []*string
|
||||
SliceSlice [][]string
|
||||
SliceMap []map[string]int64
|
||||
SliceCustom []Custom
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
type MyType struct {
|
||||
local int `bson:"Local"`
|
||||
Local2 int `bson:"Local1"`
|
||||
}
|
|
@ -1,75 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
myType.Cust1.MarshalBson(buf, "Cust1")
|
||||
// *Custom1
|
||||
if myType.Cust2 == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "Cust2")
|
||||
} else {
|
||||
(*myType.Cust2).MarshalBson(buf, "Cust2")
|
||||
}
|
||||
myType.Cust3.MarshalBson(buf, "Cust3")
|
||||
// *pkg.Custom2
|
||||
if myType.Cust4 == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "Cust4")
|
||||
} else {
|
||||
(*myType.Cust4).MarshalBson(buf, "Cust4")
|
||||
}
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Cust1":
|
||||
myType.Cust1.UnmarshalBson(buf, kind)
|
||||
case "Cust2":
|
||||
// *Custom1
|
||||
if kind != bson.Null {
|
||||
myType.Cust2 = new(Custom1)
|
||||
(*myType.Cust2).UnmarshalBson(buf, kind)
|
||||
}
|
||||
case "Cust3":
|
||||
myType.Cust3.UnmarshalBson(buf, kind)
|
||||
case "Cust4":
|
||||
// *pkg.Custom2
|
||||
if kind != bson.Null {
|
||||
myType.Cust4 = new(pkg.Custom2)
|
||||
(*myType.Cust4).UnmarshalBson(buf, kind)
|
||||
}
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
if key == "" {
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
defer lenWriter.Close()
|
||||
key = bson.MAGICTAG
|
||||
}
|
||||
bson.EncodeInt(buf, key, int(myType))
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
if kind == bson.EOO {
|
||||
bson.Next(buf, 4)
|
||||
kind = bson.NextByte(buf)
|
||||
bson.ReadCString(buf)
|
||||
}
|
||||
*myType = MyType(bson.DecodeInt(buf, kind))
|
||||
}
|
|
@ -1,320 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
// map[string]string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "Map")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v1 := range myType.Map {
|
||||
bson.EncodeString(buf, _k, _v1)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[string][]byte
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "MapBytes")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v2 := range myType.MapBytes {
|
||||
bson.EncodeBinary(buf, _k, _v2)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[string]*string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "MapPtr")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v3 := range myType.MapPtr {
|
||||
// *string
|
||||
if _v3 == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, _k)
|
||||
} else {
|
||||
bson.EncodeString(buf, _k, (*_v3))
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[string][]string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "MapSlice")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v4 := range myType.MapSlice {
|
||||
// []string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, _k)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v5 := range _v4 {
|
||||
bson.EncodeString(buf, bson.Itoa(_i), _v5)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[string]map[string]int64
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "MapMap")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v6 := range myType.MapMap {
|
||||
// map[string]int64
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, _k)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v7 := range _v6 {
|
||||
bson.EncodeInt64(buf, _k, _v7)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[string]Custom
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "MapCustom")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v8 := range myType.MapCustom {
|
||||
_v8.MarshalBson(buf, _k)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[string]*Custom
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "MapCustomPtr")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v9 := range myType.MapCustomPtr {
|
||||
// *Custom
|
||||
if _v9 == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, _k)
|
||||
} else {
|
||||
(*_v9).MarshalBson(buf, _k)
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[Custom]string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "CustomMap")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v10 := range myType.CustomMap {
|
||||
bson.EncodeString(buf, string(_k), _v10)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// map[pkg.Custom]string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "MapExternal")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v11 := range myType.MapExternal {
|
||||
bson.EncodeString(buf, string(_k), _v11)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Map":
|
||||
// map[string]string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.Map", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.Map = make(map[string]string)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v1 string
|
||||
_v1 = bson.DecodeString(buf, kind)
|
||||
myType.Map[_k] = _v1
|
||||
}
|
||||
}
|
||||
case "MapBytes":
|
||||
// map[string][]byte
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.MapBytes", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.MapBytes = make(map[string][]byte)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v2 []byte
|
||||
_v2 = bson.DecodeBinary(buf, kind)
|
||||
myType.MapBytes[_k] = _v2
|
||||
}
|
||||
}
|
||||
case "MapPtr":
|
||||
// map[string]*string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.MapPtr", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.MapPtr = make(map[string]*string)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v3 *string
|
||||
// *string
|
||||
if kind != bson.Null {
|
||||
_v3 = new(string)
|
||||
(*_v3) = bson.DecodeString(buf, kind)
|
||||
}
|
||||
myType.MapPtr[_k] = _v3
|
||||
}
|
||||
}
|
||||
case "MapSlice":
|
||||
// map[string][]string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.MapSlice", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.MapSlice = make(map[string][]string)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v4 []string
|
||||
// []string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for _v4", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
_v4 = make([]string, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v5 string
|
||||
_v5 = bson.DecodeString(buf, kind)
|
||||
_v4 = append(_v4, _v5)
|
||||
}
|
||||
}
|
||||
myType.MapSlice[_k] = _v4
|
||||
}
|
||||
}
|
||||
case "MapMap":
|
||||
// map[string]map[string]int64
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.MapMap", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.MapMap = make(map[string]map[string]int64)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v6 map[string]int64
|
||||
// map[string]int64
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for _v6", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
_v6 = make(map[string]int64)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v7 int64
|
||||
_v7 = bson.DecodeInt64(buf, kind)
|
||||
_v6[_k] = _v7
|
||||
}
|
||||
}
|
||||
myType.MapMap[_k] = _v6
|
||||
}
|
||||
}
|
||||
case "MapCustom":
|
||||
// map[string]Custom
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.MapCustom", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.MapCustom = make(map[string]Custom)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v8 Custom
|
||||
_v8.UnmarshalBson(buf, kind)
|
||||
myType.MapCustom[_k] = _v8
|
||||
}
|
||||
}
|
||||
case "MapCustomPtr":
|
||||
// map[string]*Custom
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.MapCustomPtr", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.MapCustomPtr = make(map[string]*Custom)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v9 *Custom
|
||||
// *Custom
|
||||
if kind != bson.Null {
|
||||
_v9 = new(Custom)
|
||||
(*_v9).UnmarshalBson(buf, kind)
|
||||
}
|
||||
myType.MapCustomPtr[_k] = _v9
|
||||
}
|
||||
}
|
||||
case "CustomMap":
|
||||
// map[Custom]string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.CustomMap", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.CustomMap = make(map[Custom]string)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := Custom(bson.ReadCString(buf))
|
||||
var _v10 string
|
||||
_v10 = bson.DecodeString(buf, kind)
|
||||
myType.CustomMap[_k] = _v10
|
||||
}
|
||||
}
|
||||
case "MapExternal":
|
||||
// map[pkg.Custom]string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.MapExternal", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.MapExternal = make(map[pkg.Custom]string)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := pkg.Custom(bson.ReadCString(buf))
|
||||
var _v11 string
|
||||
_v11 = bson.DecodeString(buf, kind)
|
||||
myType.MapExternal[_k] = _v11
|
||||
}
|
||||
}
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
bson.EncodeInt64(buf, "Val", myType.Val)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Val":
|
||||
myType.Val = bson.DecodeInt64(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
bson.EncodeInt64(buf, "Public", myType.Public)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Public":
|
||||
myType.Public = bson.DecodeInt64(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,168 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
// *int64
|
||||
if myType.Ptr == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "Ptr")
|
||||
} else {
|
||||
bson.EncodeInt64(buf, "Ptr", (*myType.Ptr))
|
||||
}
|
||||
// **int64
|
||||
if myType.PtrPtr == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "PtrPtr")
|
||||
} else {
|
||||
// *int64
|
||||
if (*myType.PtrPtr) == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "PtrPtr")
|
||||
} else {
|
||||
bson.EncodeInt64(buf, "PtrPtr", (*(*myType.PtrPtr)))
|
||||
}
|
||||
}
|
||||
// *[]byte
|
||||
if myType.PtrBytes == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "PtrBytes")
|
||||
} else {
|
||||
bson.EncodeBinary(buf, "PtrBytes", (*myType.PtrBytes))
|
||||
}
|
||||
// *[]int64
|
||||
if myType.PtrSlice == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "PtrSlice")
|
||||
} else {
|
||||
// []int64
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "PtrSlice")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v1 := range *myType.PtrSlice {
|
||||
bson.EncodeInt64(buf, bson.Itoa(_i), _v1)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
}
|
||||
// *map[string]int64
|
||||
if myType.PtrMap == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "PtrMap")
|
||||
} else {
|
||||
// map[string]int64
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, "PtrMap")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v2 := range *myType.PtrMap {
|
||||
bson.EncodeInt64(buf, _k, _v2)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
}
|
||||
// *Custom
|
||||
if myType.PtrCustom == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, "PtrCustom")
|
||||
} else {
|
||||
(*myType.PtrCustom).MarshalBson(buf, "PtrCustom")
|
||||
}
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Ptr":
|
||||
// *int64
|
||||
if kind != bson.Null {
|
||||
myType.Ptr = new(int64)
|
||||
(*myType.Ptr) = bson.DecodeInt64(buf, kind)
|
||||
}
|
||||
case "PtrPtr":
|
||||
// **int64
|
||||
if kind != bson.Null {
|
||||
myType.PtrPtr = new(*int64)
|
||||
// *int64
|
||||
if kind != bson.Null {
|
||||
(*myType.PtrPtr) = new(int64)
|
||||
(*(*myType.PtrPtr)) = bson.DecodeInt64(buf, kind)
|
||||
}
|
||||
}
|
||||
case "PtrBytes":
|
||||
// *[]byte
|
||||
if kind != bson.Null {
|
||||
myType.PtrBytes = new([]byte)
|
||||
(*myType.PtrBytes) = bson.DecodeBinary(buf, kind)
|
||||
}
|
||||
case "PtrSlice":
|
||||
// *[]int64
|
||||
if kind != bson.Null {
|
||||
myType.PtrSlice = new([]int64)
|
||||
// []int64
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for (*myType.PtrSlice)", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
(*myType.PtrSlice) = make([]int64, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v1 int64
|
||||
_v1 = bson.DecodeInt64(buf, kind)
|
||||
(*myType.PtrSlice) = append((*myType.PtrSlice), _v1)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "PtrMap":
|
||||
// *map[string]int64
|
||||
if kind != bson.Null {
|
||||
myType.PtrMap = new(map[string]int64)
|
||||
// map[string]int64
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for (*myType.PtrMap)", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
(*myType.PtrMap) = make(map[string]int64)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v2 int64
|
||||
_v2 = bson.DecodeInt64(buf, kind)
|
||||
(*myType.PtrMap)[_k] = _v2
|
||||
}
|
||||
}
|
||||
}
|
||||
case "PtrCustom":
|
||||
// *Custom
|
||||
if kind != bson.Null {
|
||||
myType.PtrCustom = new(Custom)
|
||||
(*myType.PtrCustom).UnmarshalBson(buf, kind)
|
||||
}
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
bson.EncodeFloat64(buf, "Float64", myType.Float64)
|
||||
bson.EncodeString(buf, "String", myType.String)
|
||||
bson.EncodeBool(buf, "Bool", myType.Bool)
|
||||
bson.EncodeInt64(buf, "Int64", myType.Int64)
|
||||
bson.EncodeInt32(buf, "Int32", myType.Int32)
|
||||
bson.EncodeInt(buf, "Int", myType.Int)
|
||||
bson.EncodeUint64(buf, "Uint64", myType.Uint64)
|
||||
bson.EncodeUint32(buf, "Uint32", myType.Uint32)
|
||||
bson.EncodeUint(buf, "Uint", myType.Uint)
|
||||
bson.EncodeBinary(buf, "Bytes", myType.Bytes)
|
||||
bson.EncodeTime(buf, "Time", myType.Time)
|
||||
bson.EncodeInterface(buf, "Interface", myType.Interface)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Float64":
|
||||
myType.Float64 = bson.DecodeFloat64(buf, kind)
|
||||
case "String":
|
||||
myType.String = bson.DecodeString(buf, kind)
|
||||
case "Bool":
|
||||
myType.Bool = bson.DecodeBool(buf, kind)
|
||||
case "Int64":
|
||||
myType.Int64 = bson.DecodeInt64(buf, kind)
|
||||
case "Int32":
|
||||
myType.Int32 = bson.DecodeInt32(buf, kind)
|
||||
case "Int":
|
||||
myType.Int = bson.DecodeInt(buf, kind)
|
||||
case "Uint64":
|
||||
myType.Uint64 = bson.DecodeUint64(buf, kind)
|
||||
case "Uint32":
|
||||
myType.Uint32 = bson.DecodeUint32(buf, kind)
|
||||
case "Uint":
|
||||
myType.Uint = bson.DecodeUint(buf, kind)
|
||||
case "Bytes":
|
||||
myType.Bytes = bson.DecodeBinary(buf, kind)
|
||||
case "Time":
|
||||
myType.Time = bson.DecodeTime(buf, kind)
|
||||
case "Interface":
|
||||
myType.Interface = bson.DecodeInterface(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,240 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
// []string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "Slice")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v1 := range myType.Slice {
|
||||
bson.EncodeString(buf, bson.Itoa(_i), _v1)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// [][]byte
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "SliceBytes")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v2 := range myType.SliceBytes {
|
||||
bson.EncodeBinary(buf, bson.Itoa(_i), _v2)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// []*string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "SlicePtr")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v3 := range myType.SlicePtr {
|
||||
// *string
|
||||
if _v3 == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, bson.Itoa(_i))
|
||||
} else {
|
||||
bson.EncodeString(buf, bson.Itoa(_i), (*_v3))
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// [][]string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "SliceSlice")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v4 := range myType.SliceSlice {
|
||||
// []string
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, bson.Itoa(_i))
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v5 := range _v4 {
|
||||
bson.EncodeString(buf, bson.Itoa(_i), _v5)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// []map[string]int64
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "SliceMap")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v6 := range myType.SliceMap {
|
||||
// map[string]int64
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, bson.Itoa(_i))
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, _v7 := range _v6 {
|
||||
bson.EncodeInt64(buf, _k, _v7)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
// []Custom
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "SliceCustom")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v8 := range myType.SliceCustom {
|
||||
_v8.MarshalBson(buf, bson.Itoa(_i))
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Slice":
|
||||
// []string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.Slice", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.Slice = make([]string, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v1 string
|
||||
_v1 = bson.DecodeString(buf, kind)
|
||||
myType.Slice = append(myType.Slice, _v1)
|
||||
}
|
||||
}
|
||||
case "SliceBytes":
|
||||
// [][]byte
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.SliceBytes", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.SliceBytes = make([][]byte, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v2 []byte
|
||||
_v2 = bson.DecodeBinary(buf, kind)
|
||||
myType.SliceBytes = append(myType.SliceBytes, _v2)
|
||||
}
|
||||
}
|
||||
case "SlicePtr":
|
||||
// []*string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.SlicePtr", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.SlicePtr = make([]*string, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v3 *string
|
||||
// *string
|
||||
if kind != bson.Null {
|
||||
_v3 = new(string)
|
||||
(*_v3) = bson.DecodeString(buf, kind)
|
||||
}
|
||||
myType.SlicePtr = append(myType.SlicePtr, _v3)
|
||||
}
|
||||
}
|
||||
case "SliceSlice":
|
||||
// [][]string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.SliceSlice", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.SliceSlice = make([][]string, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v4 []string
|
||||
// []string
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for _v4", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
_v4 = make([]string, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v5 string
|
||||
_v5 = bson.DecodeString(buf, kind)
|
||||
_v4 = append(_v4, _v5)
|
||||
}
|
||||
}
|
||||
myType.SliceSlice = append(myType.SliceSlice, _v4)
|
||||
}
|
||||
}
|
||||
case "SliceMap":
|
||||
// []map[string]int64
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.SliceMap", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.SliceMap = make([]map[string]int64, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v6 map[string]int64
|
||||
// map[string]int64
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for _v6", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
_v6 = make(map[string]int64)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := bson.ReadCString(buf)
|
||||
var _v7 int64
|
||||
_v7 = bson.DecodeInt64(buf, kind)
|
||||
_v6[_k] = _v7
|
||||
}
|
||||
}
|
||||
myType.SliceMap = append(myType.SliceMap, _v6)
|
||||
}
|
||||
}
|
||||
case "SliceCustom":
|
||||
// []Custom
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for myType.SliceCustom", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
myType.SliceCustom = make([]Custom, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v8 Custom
|
||||
_v8.UnmarshalBson(buf, kind)
|
||||
myType.SliceCustom = append(myType.SliceCustom, _v8)
|
||||
}
|
||||
}
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mytype
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes MyType.
|
||||
func (myType *MyType) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
bson.EncodeInt(buf, "Local", myType.local)
|
||||
bson.EncodeInt(buf, "Local1", myType.Local2)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into MyType.
|
||||
func (myType *MyType) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for MyType", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Local":
|
||||
myType.local = bson.DecodeInt(buf, kind)
|
||||
case "Local1":
|
||||
myType.Local2 = bson.DecodeInt(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -47,13 +47,14 @@ Vitess client libraries follow these core principles:
|
|||
* Each client library should support language-specific, idiomatic
|
||||
constructs to simplify application development in that language.
|
||||
* Client libraries should integrate with the following language-specific
|
||||
database drivers, though this support is not yet provided:
|
||||
* Go: [database/sql package](http://golang.org/pkg/database/sql/)
|
||||
database drivers, though this support is not yet provided in some cases:
|
||||
* Go: [database/sql package](http://golang.org/pkg/database/sql/) (done)
|
||||
* Java: [JDBC](https://docs.oracle.com/javase/tutorial/jdbc/index.html)
|
||||
compliance
|
||||
compliance (in progress)
|
||||
* PHP: [PHP Data Objects \(PDO\)](http://php.net/manual/en/intro.pdo.php)
|
||||
compliance
|
||||
compliance (in progress)
|
||||
* Python: [DB API](https://www.python.org/dev/peps/pep-0249/) compliance
|
||||
(done)
|
||||
* Libraries provide a thin wrapper around the proto3 service definitions.
|
||||
Those wrappers could be extended with adapters to higher level libraries
|
||||
like SQLAlchemy (Python) or JDBC (Java), with other object-based helper
|
||||
|
@ -107,4 +108,4 @@ test your application against an actual instance.
|
|||
|
||||
### Python
|
||||
|
||||
* [Python client](https://github.com/youtube/vitess/blob/master/py/vtdb/vtgatev2.py)
|
||||
* [Python client](https://github.com/youtube/vitess/blob/master/py/vtdb/vtgate_client.py)
|
||||
|
|
|
@ -255,7 +255,6 @@ In addition, Vitess requires the software and libraries listed below.
|
|||
# skipping zookeeper build
|
||||
# go install golang.org/x/tools/cmd/cover ...
|
||||
# Found MariaDB installation in ...
|
||||
# skipping bson python build
|
||||
# creating git pre-commit hooks
|
||||
#
|
||||
# source dev.env in your shell before building
|
||||
|
|
|
@ -8,10 +8,6 @@ Life of A Query
|
|||
* [TopoServer](#toposerver)
|
||||
* [Streaming Query](#streaming-query)
|
||||
* [Scatter Query](#scatter-query)
|
||||
* [Misc](#misc)
|
||||
* [Rpc Server Code Path (VtGate)](#rpc-server-code-path-vtgate)
|
||||
* [VtGate to VtTablet Code Path](#vtgate-to-vttablet-code-path)
|
||||
* [VtTablet to MySQL Code Path](#vttablet-to-mysql-code-path)
|
||||
|
||||
A query means a request for information from database and it involves four components in the case of Vitess, including the client application, VtGate, VtTablet and MySQL instance. This doc explains the interaction which happens between and within components.
|
||||
|
||||
|
@ -21,11 +17,11 @@ At a very high level, as the graph shows, first the client sends a query to VtGa
|
|||
|
||||
## From Client to VtGate
|
||||
|
||||
A client application first sends a bson rpc with an embedded sql query to VtGate. VtGate's rpc server unmarshals this rpc request, calls the appropriate VtGate method and return its result back to client. VtGate has an rpc server that listens to localhost:port/\_bson\_rpc\_ for http requests and localhost:port/\_bson\_rpc\_/auth for https requests.
|
||||
A client application first sends an rpc with an embedded sql query to VtGate. VtGate's rpc server unmarshals this rpc request, calls the appropriate VtGate method and return its result back to client.
|
||||
|
||||
![](https://raw.githubusercontent.com/youtube/vitess/master/doc/life_of_a_query_client_to_vtgate.png)
|
||||
|
||||
VtGate keeps an in-memory table that stores all available rpc methods for each service, e.g. VtGate uses "VTGate" as its service name and most of its methods defined in [go/vt/vtgate/vtgate.go](../go/vt/vtgate/vtgate.go) are used to serve rpc request [go/rpcplus/server.go](../go/rpcplus/server.go).
|
||||
VtGate keeps an in-memory table that stores all available rpc methods for each service, e.g. VtGate uses "VTGate" as its service name and most of its methods defined in [go/vt/vtgate/vtgate.go](../go/vt/vtgate/vtgate.go) are used to serve rpc request.
|
||||
|
||||
## From VtGate to VtTablet
|
||||
|
||||
|
@ -58,73 +54,3 @@ Generally speaking, a streaming query means query results will be returned as a
|
|||
## Scatter Query
|
||||
|
||||
A scatter query, as its name indicates, will hit multiple shards. In Vitess, a scatter query is recognized once VtGate determines a query needs to hit multiple VtTablets. VtGate then sends the query to these VtTablets, assembles the result after receiving all responses and returns the combined result to the client.
|
||||
|
||||
## Misc
|
||||
|
||||
### Rpc Server Code Path (VtGate)
|
||||
|
||||
Init an rpc server
|
||||
|
||||
```
|
||||
go/cmd/vtgate/vtgate.go: main() ->
|
||||
go/vt/servenv/servenv.go: RunDefault() -> // use the port specified in command line "--port"
|
||||
go/vt/servenv/run.go: Run(port int) ->
|
||||
go/vt/servenv/rpc.go: ServeRPC() -> // set up rpc server
|
||||
go/rpcwrap/bsonrpc/codec.go: ServeRPC() -> // set up bson rpc server
|
||||
go/rpcwrap/rpcwrap.go: ServeRPC("bson", NewServerCodec) -> // common code to register rpc server
|
||||
```
|
||||
|
||||
ServeRPC("bson", NewServerCodec) registers an rpcHandler instance whose ServeHTTP(http.ResponseWriter, *http.Request) will be called for every http request
|
||||
|
||||
The rpc server handles the http request:
|
||||
|
||||
```
|
||||
go/rpcwrap/rpcwrap.go rpcHandler.ServeHTTP ->
|
||||
go/rpcwrap/rpcwrap.go rpcHandler.server.ServeCodecWithContext ->
|
||||
go/rpcplus/server.go Server.ServeCodecWithContext(context.Context, ServerCodec) (note: rpcHandler uses a global DefaultServer instance defined in the server.go) ->
|
||||
go/rpcplus/server.go Server.readRequest(ServeCodec) will use a given codec to extract (service, methodType, request, request arguments, reply value, keep reading)
|
||||
```
|
||||
|
||||
Finally we do "service.call(..)" with parameters provided in the request. In the current setup, service.call will always call some method in VtGate (go/vt/vtgate/vtgate.go).
|
||||
|
||||
### VtGate to VtTablet Code Path
|
||||
|
||||
Here is the code path for a query with keyspace id.
|
||||
|
||||
```
|
||||
go/vt/vtgate/vtgate.go VTGate.ExecuteKeyspaceIds(context.Context, *proto.KeyspaceIdQuery, *proto.QueryResult) ->
|
||||
go/vt/vtgate/resolver.go resolver.ExecuteKeyspaceIds(context.Context, *proto.KeyspaceIdQuery) ->
|
||||
go/vt/vtgate/resolver.go resolver.Execute ->
|
||||
go/vt/vtgate/scatter_conn.go ScatterConn.Execute ->
|
||||
go/vt/vtgate/scatter_conn.go ScatterConn.multiGo ->
|
||||
go/vt/vtgate/scatter_conn.go ScatterConn.getConnection ->
|
||||
go/vt/vtgate/shard_conn.go ShardConn.Execute ->
|
||||
go/vt/vtgate/shard_conn.go ShardConn.withRetry ->
|
||||
go/vt/vtgate/shard_conn.go ShardConn.getConn ->
|
||||
go/vt/tabletserver/tabletconn/tablet_conn.go tabletconn.GetDialer ->
|
||||
go/vt/tabletserver/tabletconn/tablet_conn.go tabletconn.TabletConn.Execute ->
|
||||
go/vt/tabletserver/gorpctabletconn/conn.go TabletBson.Execute ->
|
||||
go/vt/tabletserver/gorpctabletconn/conn.go TabletBson.rpcClient.Call ->
|
||||
go/rpcplus/client.go rpcplus.Client.Call ->
|
||||
go/rpcplus/client.go rpcplus.Client.Go ->
|
||||
go/rpcplus/client.go rpcplus.Client.send
|
||||
```
|
||||
|
||||
### VtTablet to MySQL Code Path
|
||||
|
||||
Here is the code path for a select query.
|
||||
|
||||
```
|
||||
go/vt/tabletserver/sqlquery.go SqlQuery.Execute ->
|
||||
go/vt/tabletserver/query_executor.go QueryExecutor.Execute ->
|
||||
go/vt/tabletserver/query_executor.go QueryExecutor.execSelect ->
|
||||
go/vt/tabletserver/request_context.go RequestContext.getConn -> // QueryExecutor composes a RequestContext
|
||||
go/vt/tabletserver/request_context.go RequestContext.fullFetch ->
|
||||
go/vt/tabletserver/request_context.go RequestContext.execSQL ->
|
||||
go/vt/tabletserver/request_context.go RequestContext.execSQLNoPanic ->
|
||||
go/vt/tabletserver/request_context.go RequestContext.execSQLOnce ->
|
||||
go/vt/dbconnpool/connection_pool.go PoolConnection.ExecuteFetch (current implementation is in DBConnection) ->
|
||||
go/vt/dbconnpool/connection.go PooledConnection.DBConnection.ExecuteFetch ->
|
||||
go/mysql/mysql.go mysql.Connection.ExecuteFetch ->
|
||||
go/mysql/mysql.go mysql.Connection.fetchAll
|
||||
```
|
||||
|
|
|
@ -28,7 +28,7 @@ Vitess improves a vanilla MySQL implementation in several ways:
|
|||
<tbody>
|
||||
<tr>
|
||||
<td>Every MySQL connection has a memory overhead that ranges between 256KB and almost 3MB, depending on which MySQL release you're using. As your user base grows, you need to add RAM to support additional connections, but the RAM does not contribute to faster queries. In addition, there is a significant CPU cost associated with obtaining the connections.</td>
|
||||
<td>Vitess' BSON-based protocol creates very lightweight connections that are around 32KB. Vitess' connection pooling feature uses Go's concurrency support to map these lightweight connections to a small pool of MySQL connections. As such, Vitess can easily handle thousands of connections.</td>
|
||||
<td>Vitess' gRPC-based protocol creates very lightweight connections. Vitess' connection pooling feature uses Go's concurrency support to map these lightweight connections to a small pool of MySQL connections. As such, Vitess can easily handle thousands of connections.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Poorly written queries, such as those that don't set a LIMIT, can negatively impact database performance for all users.</td>
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -1 +1 @@
|
|||
<mxGraphModel dx="894" dy="566" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" fold="1" page="1" pageScale="1" pageWidth="826" pageHeight="1169" style="default-style2" math="0"><root><mxCell id="0"/><mxCell id="1" parent="0"/><mxCell id="3" value="VtGate<div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div>" style="shape=ext;rounded=1;html=1;whiteSpace=wrap;dashed=1;dashPattern=1 4;" parent="1" vertex="1"><mxGeometry x="55" y="60" width="250" height="340" as="geometry"/></mxCell><mxCell id="2" value="Invoke VTGate.Execute*" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="120" y="80" width="120" height="40" as="geometry"/></mxCell><mxCell id="4" value="VtTablet<div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div>" style="shape=ext;rounded=1;html=1;whiteSpace=wrap;dashed=1;dashPattern=1 4;" parent="1" vertex="1"><mxGeometry x="380" y="60" width="250" height="340" as="geometry"/></mxCell><mxCell id="6" value="Http Server" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="445" y="105" width="120" height="60" as="geometry"/></mxCell><mxCell id="7" value="Extract request parameters" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="445" y="188" width="120" height="60" as="geometry"/></mxCell><mxCell id="8" value="Launch a go routine and call SqlQuery.Execute" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="445" y="275" width="120" height="60" as="geometry"/></mxCell><mxCell id="11" value="Send queries to multiple shards in parallel<div>ScatterConn.multiGo</div>" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="120" y="138" width="120" height="65" as="geometry"/></mxCell><mxCell id="13" value="&nbsp;Randomly connect to a desired tablet for each shard.<div>ShardConn.getConn</div>" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="95" y="220" width="170" height="50" as="geometry"/></mxCell><mxCell id="14" value="Send a bson rpc to VtTablet" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="95" y="290" width="170" height="50" as="geometry"/></mxCell><mxCell id="15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="2" target="11" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="16" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="11" target="13" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="17" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="13" target="14" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="19" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=1;exitY=0.5;entryX=0;entryY=0.5" parent="1" source="14" target="6" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="20" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="6" target="7" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="21" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1" parent="1" source="7" target="8" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell></root></mxGraphModel>
|
||||
<mxGraphModel dx="894" dy="566" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" fold="1" page="1" pageScale="1" pageWidth="826" pageHeight="1169" style="default-style2" math="0"><root><mxCell id="0"/><mxCell id="1" parent="0"/><mxCell id="3" value="VtGate<div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div>" style="shape=ext;rounded=1;html=1;whiteSpace=wrap;dashed=1;dashPattern=1 4;" parent="1" vertex="1"><mxGeometry x="55" y="60" width="250" height="340" as="geometry"/></mxCell><mxCell id="2" value="Invoke VTGate.Execute*" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="120" y="80" width="120" height="40" as="geometry"/></mxCell><mxCell id="4" value="VtTablet<div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div><div><br></div>" style="shape=ext;rounded=1;html=1;whiteSpace=wrap;dashed=1;dashPattern=1 4;" parent="1" vertex="1"><mxGeometry x="380" y="60" width="250" height="340" as="geometry"/></mxCell><mxCell id="6" value="Http Server" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="445" y="105" width="120" height="60" as="geometry"/></mxCell><mxCell id="7" value="Extract request parameters" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="445" y="188" width="120" height="60" as="geometry"/></mxCell><mxCell id="8" value="Launch a go routine and call SqlQuery.Execute" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="445" y="275" width="120" height="60" as="geometry"/></mxCell><mxCell id="11" value="Send queries to multiple shards in parallel<div>ScatterConn.multiGo</div>" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="120" y="138" width="120" height="65" as="geometry"/></mxCell><mxCell id="13" value="&nbsp;Randomly connect to a desired tablet for each shard.<div>ShardConn.getConn</div>" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="95" y="220" width="170" height="50" as="geometry"/></mxCell><mxCell id="14" value="Send an rpc to VtTablet" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1"><mxGeometry x="95" y="290" width="170" height="50" as="geometry"/></mxCell><mxCell id="15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="2" target="11" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="16" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="11" target="13" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="17" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="13" target="14" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="19" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=1;exitY=0.5;entryX=0;entryY=0.5" parent="1" source="14" target="6" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="20" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1;entryX=0.5;entryY=0" parent="1" source="6" target="7" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell><mxCell id="21" style="edgeStyle=orthogonalEdgeStyle;rounded=0;html=1;exitX=0.5;exitY=1" parent="1" source="7" target="8" edge="1"><mxGeometry relative="1" as="geometry"/></mxCell></root></mxGraphModel>
|
||||
|
|
|
@ -66,7 +66,7 @@ ENV GOTOP $VTTOP/go
|
|||
ENV PYTOP $VTTOP/py
|
||||
ENV VTDATAROOT $VTROOT/vtdataroot
|
||||
ENV VTPORTSTART 15000
|
||||
ENV PYTHONPATH $VTROOT/dist/py-mock-1.0.1/lib/python2.7/site-packages:$VTROOT/dist/py-vt-bson-0.3.2/lib/python2.7/site-packages:$VTROOT/py-vtdb
|
||||
ENV PYTHONPATH $VTROOT/dist/py-mock-1.0.1/lib/python2.7/site-packages:$VTROOT/py-vtdb
|
||||
ENV GOBIN $VTROOT/bin
|
||||
ENV GOPATH $VTROOT
|
||||
ENV PATH $VTROOT/bin:$VTROOT/dist/maven/bin:$PATH
|
||||
|
|
|
@ -7,13 +7,14 @@ import struct
|
|||
import hashlib
|
||||
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
|
||||
from vtdb import vtgate_client
|
||||
|
||||
# Register gRPC protocol.
|
||||
from vtdb import grpc_vtgate_client # pylint: disable=unused-import
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# conn is the connection to vtgate.
|
||||
conn = None
|
||||
|
||||
|
@ -71,7 +72,7 @@ def list_guestbook(page):
|
|||
keyspace_ids=[keyspace_id])
|
||||
|
||||
cursor.execute(
|
||||
'SELECT message FROM messages WHERE page=%(page)s'
|
||||
'SELECT message FROM messages WHERE page=:page'
|
||||
' ORDER BY time_created_ns',
|
||||
{'page': page})
|
||||
entries = [row[0] for row in cursor.fetchall()]
|
||||
|
@ -92,7 +93,7 @@ def add_entry(page, value):
|
|||
cursor.begin()
|
||||
cursor.execute(
|
||||
'INSERT INTO messages (page, time_created_ns, keyspace_id, message)'
|
||||
' VALUES (%(page)s, %(time_created_ns)s, %(keyspace_id)s, %(message)s)',
|
||||
' VALUES (:page, :time_created_ns, :keyspace_id, :message)',
|
||||
{
|
||||
'page': page,
|
||||
'time_created_ns': int(time.time() * 1e9),
|
||||
|
@ -104,7 +105,7 @@ def add_entry(page, value):
|
|||
# Read the list back from master (critical read) because it's
|
||||
# important that the user sees their own addition immediately.
|
||||
cursor.execute(
|
||||
'SELECT message FROM messages WHERE page=%(page)s'
|
||||
'SELECT message FROM messages WHERE page=:page'
|
||||
' ORDER BY time_created_ns',
|
||||
{'page': page})
|
||||
entries = [row[0] for row in cursor.fetchall()]
|
||||
|
|
|
@ -43,7 +43,7 @@ spec:
|
|||
-port 15001
|
||||
-grpc_port 15991
|
||||
-tablet_protocol grpc
|
||||
-service_map 'bsonrpc-vt-vtgateservice,grpc-vtgateservice'
|
||||
-service_map 'grpc-vtgateservice'
|
||||
-cells_to_watch {{cell}}
|
||||
-tablet_types_to_wait MASTER,REPLICA
|
||||
-gateway_implementation discoverygateway
|
||||
|
|
|
@ -43,7 +43,7 @@ spec:
|
|||
-port 15001
|
||||
-grpc_port 15991
|
||||
-tablet_protocol grpc
|
||||
-service_map 'bsonrpc-vt-vtgateservice,grpc-vtgateservice'
|
||||
-service_map 'grpc-vtgateservice'
|
||||
-cells_to_watch {{cell}}
|
||||
-tablet_types_to_wait MASTER,REPLICA
|
||||
-gateway_implementation discoverygateway
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
set -e
|
||||
|
||||
cell='test'
|
||||
web_port=15001 # This is also the bsonrpc port.
|
||||
web_port=15001
|
||||
grpc_port=15991
|
||||
|
||||
script_root=`dirname "${BASH_SOURCE}"`
|
||||
|
@ -21,7 +21,7 @@ $VTROOT/bin/vtgate \
|
|||
-tablet_types_to_wait MASTER,REPLICA \
|
||||
-gateway_implementation discoverygateway \
|
||||
-tablet_protocol grpc \
|
||||
-service_map 'bsonrpc-vt-vtgateservice,grpc-vtgateservice' \
|
||||
-service_map 'grpc-vtgateservice' \
|
||||
-pid_file $VTDATAROOT/tmp/vtgate.pid \
|
||||
> $VTDATAROOT/tmp/vtgate.out 2>&1 &
|
||||
|
||||
|
|
|
@ -1,457 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
type alltypes struct {
|
||||
Bytes []byte
|
||||
Float64 float64
|
||||
String string
|
||||
Bool bool
|
||||
Time time.Time
|
||||
Int64 int64
|
||||
Int32 int32
|
||||
Int int
|
||||
Uint64 uint64
|
||||
Uint32 uint32
|
||||
Uint uint
|
||||
Strings []string
|
||||
Nil interface{}
|
||||
}
|
||||
|
||||
func (a *alltypes) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
VerifyObject(kind)
|
||||
Next(buf, 4)
|
||||
|
||||
kind = NextByte(buf)
|
||||
for kind != EOO {
|
||||
key := ReadCString(buf)
|
||||
switch key {
|
||||
case "Bytes":
|
||||
verifyKind("Bytes", Binary, kind)
|
||||
a.Bytes = DecodeBinary(buf, kind)
|
||||
case "Float64":
|
||||
verifyKind("Float64", Number, kind)
|
||||
a.Float64 = DecodeFloat64(buf, kind)
|
||||
case "String":
|
||||
verifyKind("String", Binary, kind)
|
||||
// Put an easter egg here to verify the function is called
|
||||
a.String = DecodeString(buf, kind) + "1"
|
||||
case "Bool":
|
||||
verifyKind("Bool", Boolean, kind)
|
||||
a.Bool = DecodeBool(buf, kind)
|
||||
case "Time":
|
||||
verifyKind("Time", Datetime, kind)
|
||||
a.Time = DecodeTime(buf, kind)
|
||||
case "Int32":
|
||||
verifyKind("Int32", Int, kind)
|
||||
a.Int32 = DecodeInt32(buf, kind)
|
||||
case "Int":
|
||||
verifyKind("Int", Long, kind)
|
||||
a.Int = DecodeInt(buf, kind)
|
||||
case "Int64":
|
||||
verifyKind("Int64", Long, kind)
|
||||
a.Int64 = DecodeInt64(buf, kind)
|
||||
case "Uint64":
|
||||
verifyKind("Uint64", Long, kind)
|
||||
a.Uint64 = DecodeUint64(buf, kind)
|
||||
case "Uint32":
|
||||
verifyKind("Uint32", Long, kind)
|
||||
a.Uint32 = DecodeUint32(buf, kind)
|
||||
case "Uint":
|
||||
verifyKind("Uint", Long, kind)
|
||||
a.Uint = DecodeUint(buf, kind)
|
||||
case "Strings":
|
||||
verifyKind("Strings", Array, kind)
|
||||
a.Strings = DecodeStringArray(buf, kind)
|
||||
case "Nil":
|
||||
verifyKind("Nil", Null, kind)
|
||||
default:
|
||||
Skip(buf, kind)
|
||||
}
|
||||
kind = NextByte(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyKind(tag string, want, got byte) {
|
||||
if want != got {
|
||||
panic(NewBsonError("Decode %s, kind is %v, want %v", tag, got, want))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(sougou): Revisit usefulness of this test
|
||||
func TestUnmarshalUtil(t *testing.T) {
|
||||
a := alltypes{
|
||||
Bytes: []byte("bytes"),
|
||||
Float64: float64(64),
|
||||
String: "string",
|
||||
Bool: true,
|
||||
Time: time.Unix(1136243045, 0).UTC(),
|
||||
Int64: int64(-0x8000000000000000),
|
||||
Int32: int32(-0x80000000),
|
||||
Int: int(-0x80000000),
|
||||
Uint64: uint64(0xFFFFFFFFFFFFFFFF),
|
||||
Uint32: uint32(0xFFFFFFFF),
|
||||
Uint: uint(0xFFFFFFFF),
|
||||
Strings: []string{"a", "b"},
|
||||
Nil: nil,
|
||||
}
|
||||
got := verifyMarshal(t, a)
|
||||
var out alltypes
|
||||
verifyUnmarshal(t, got, &out)
|
||||
// Verify easter egg
|
||||
if out.String != "string1" {
|
||||
t.Errorf("got %s, want %s", out.String, "string1")
|
||||
}
|
||||
out.String = "string"
|
||||
if !reflect.DeepEqual(a, out) {
|
||||
t.Errorf("got\n%+v, want\n%+v", out, a)
|
||||
}
|
||||
|
||||
b := alltypes{Bytes: []byte(""), Strings: []string{"a"}}
|
||||
got = verifyMarshal(t, b)
|
||||
var outb alltypes
|
||||
verifyUnmarshal(t, got, &outb)
|
||||
if outb.Bytes == nil || len(outb.Bytes) != 0 {
|
||||
t.Errorf("got %q, want nil", string(outb.Bytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypes(t *testing.T) {
|
||||
in := map[string]interface{}{
|
||||
"bytes": []byte("bytes"),
|
||||
"float64": float64(64),
|
||||
"string": "string",
|
||||
"bool": true,
|
||||
"time": time.Unix(1136243045, 0).UTC(),
|
||||
"int64": int64(-0x8000000000000000),
|
||||
"int32": int32(-0x80000000),
|
||||
"int": int(-0x80000000),
|
||||
"uint64": uint64(0xFFFFFFFFFFFFFFFF),
|
||||
"uint32": uint32(0xFFFFFFFF),
|
||||
"uint": uint(0xFFFFFFFF),
|
||||
"slice": []interface{}{1, nil},
|
||||
"nil": nil,
|
||||
}
|
||||
marshalled := verifyMarshal(t, in)
|
||||
got := make(map[string]interface{})
|
||||
verifyUnmarshal(t, marshalled, &got)
|
||||
|
||||
want := map[string]interface{}{
|
||||
"bytes": []byte("bytes"),
|
||||
"float64": float64(64),
|
||||
"string": []byte("string"),
|
||||
"bool": true,
|
||||
"time": time.Unix(1136243045, 0).UTC(),
|
||||
"int64": int64(-0x8000000000000000),
|
||||
"int32": int32(-0x80000000),
|
||||
"int": int64(-0x80000000),
|
||||
"uint64": int64(-1),
|
||||
"uint32": int64(0xFFFFFFFF),
|
||||
"uint": int64(0xFFFFFFFF),
|
||||
"slice": []interface{}{int64(1), nil},
|
||||
"nil": nil,
|
||||
}
|
||||
// We do the range so the errors are more precise.
|
||||
for k, v := range got {
|
||||
if !reflect.DeepEqual(v, want[k]) {
|
||||
t.Errorf("got \n%+v, want \n%+v", v, want[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// test that we are calling the right encoding method
|
||||
// if we use the reflection code, this will fail as reflection
|
||||
// cannot access the non-exported field
|
||||
type PrivateStruct struct {
|
||||
veryPrivate uint64
|
||||
}
|
||||
|
||||
func (ps *PrivateStruct) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
EncodeOptionalPrefix(buf, Object, key)
|
||||
lenWriter := NewLenWriter(buf)
|
||||
|
||||
EncodeUint64(buf, "Type", ps.veryPrivate)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
func (ps *PrivateStruct) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
VerifyObject(kind)
|
||||
Next(buf, 4)
|
||||
|
||||
for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) {
|
||||
key := ReadCString(buf)
|
||||
switch key {
|
||||
case "Type":
|
||||
verifyKind("Type", Long, kind)
|
||||
ps.veryPrivate = DecodeUint64(buf, kind)
|
||||
default:
|
||||
Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// an array can use non-pointers for custom marshaler
|
||||
type PrivateStructList struct {
|
||||
List []PrivateStruct
|
||||
}
|
||||
|
||||
// the map has to be using pointers, so the custom marshaler is used
|
||||
type PrivateStructMap struct {
|
||||
Map map[string]*PrivateStruct
|
||||
}
|
||||
|
||||
type PrivateStructStruct struct {
|
||||
Inner *PrivateStruct
|
||||
}
|
||||
|
||||
func TestCustomStruct(t *testing.T) {
|
||||
// This should use the custom marshaler & unmarshaler
|
||||
s := PrivateStruct{1}
|
||||
got := verifyMarshal(t, &s)
|
||||
want := "\x13\x00\x00\x00\x12Type\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", string(got), want)
|
||||
}
|
||||
var s2 PrivateStruct
|
||||
verifyUnmarshal(t, got, &s2)
|
||||
if s2 != s {
|
||||
t.Errorf("got \n%+v, want \n%+v", s2, s)
|
||||
}
|
||||
|
||||
// This should use the custom marshaler & unmarshaler
|
||||
sl := PrivateStructList{make([]PrivateStruct, 1)}
|
||||
sl.List[0] = s
|
||||
got = verifyMarshal(t, &sl)
|
||||
want = "&\x00\x00\x00\x04List\x00\x1b\x00\x00\x00\x030\x00\x13\x00\x00\x00\x12Type\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", string(got), want)
|
||||
}
|
||||
var sl2 PrivateStructList
|
||||
verifyUnmarshal(t, got, &sl2)
|
||||
if !reflect.DeepEqual(sl2, sl) {
|
||||
t.Errorf("got \n%+v, want \n%+v", sl2, sl)
|
||||
}
|
||||
|
||||
// This should use the custom marshaler & unmarshaler
|
||||
smp := make(map[string]*PrivateStruct)
|
||||
smp["first"] = &s
|
||||
got = verifyMarshal(t, smp)
|
||||
want = "\x1f\x00\x00\x00\x03first\x00\x13\x00\x00\x00\x12Type\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", string(got), want)
|
||||
}
|
||||
smp2 := make(map[string]*PrivateStruct)
|
||||
verifyUnmarshal(t, got, &smp2)
|
||||
if !reflect.DeepEqual(smp2, smp) {
|
||||
t.Errorf("got \n%+v, want \n%+v", smp2, smp)
|
||||
}
|
||||
|
||||
// This should not use the custom unmarshaler
|
||||
sm := make(map[string]PrivateStruct)
|
||||
sm["first"] = s
|
||||
sm2 := make(map[string]PrivateStruct)
|
||||
verifyUnmarshal(t, got, &sm2)
|
||||
if reflect.DeepEqual(sm2, sm) {
|
||||
t.Errorf("got \n%+v, want \n%+v", sm2, sm)
|
||||
}
|
||||
|
||||
// This should not use the custom marshaler
|
||||
got = verifyMarshal(t, sm)
|
||||
want = "\x11\x00\x00\x00\x03first\x00\x05\x00\x00\x00\x00\x00"
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", string(got), want)
|
||||
}
|
||||
|
||||
// This should not use the custom marshaler (or crash)
|
||||
nilinner := PrivateStructStruct{}
|
||||
got = verifyMarshal(t, &nilinner)
|
||||
want = "\f\x00\x00\x00\nInner\x00\x00"
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", string(got), want)
|
||||
}
|
||||
}
|
||||
|
||||
type HasPrivate struct {
|
||||
private string
|
||||
Public string
|
||||
}
|
||||
|
||||
func TestIgnorePrivateFields(t *testing.T) {
|
||||
v := HasPrivate{private: "private", Public: "public"}
|
||||
marshaled := verifyMarshal(t, v)
|
||||
unmarshaled := new(HasPrivate)
|
||||
Unmarshal(marshaled, unmarshaled)
|
||||
if unmarshaled.Public != "Public" && unmarshaled.private != "" {
|
||||
t.Errorf("private fields were not ignored: %+v", unmarshaled)
|
||||
}
|
||||
}
|
||||
|
||||
type LotsMoreFields struct {
|
||||
CommonField1 string
|
||||
ExtraField1 float64
|
||||
ExtraField2 string
|
||||
ExtraField3 HasPrivate
|
||||
ExtraField4 []string
|
||||
CommonField2 string
|
||||
ExtraField5 []byte
|
||||
ExtraField6 bool
|
||||
ExtraField7 time.Time
|
||||
ExtraField8 *int
|
||||
ExtraField9 int32
|
||||
ExtraField10 int64
|
||||
ExtraField11 uint64
|
||||
}
|
||||
|
||||
type LotsFewerFields struct {
|
||||
CommonField1 string
|
||||
CommonField2 string
|
||||
}
|
||||
|
||||
func TestSkipUnknownFields(t *testing.T) {
|
||||
v := LotsMoreFields{
|
||||
CommonField1: "value1",
|
||||
ExtraField1: 1.0,
|
||||
ExtraField2: "abcd",
|
||||
ExtraField3: HasPrivate{private: "private", Public: "public"},
|
||||
ExtraField4: []string{"s1", "s2"},
|
||||
CommonField2: "value3",
|
||||
ExtraField5: []byte("abcd"),
|
||||
ExtraField6: true,
|
||||
ExtraField7: time.Now(),
|
||||
}
|
||||
marshaled := verifyMarshal(t, v)
|
||||
unmarshaled := LotsFewerFields{}
|
||||
verifyUnmarshal(t, marshaled, &unmarshaled)
|
||||
want := LotsFewerFields{
|
||||
CommonField1: "value1",
|
||||
CommonField2: "value3",
|
||||
}
|
||||
if unmarshaled != want {
|
||||
t.Errorf("got \n%+v, want \n%+v", unmarshaled, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeFieldNil(t *testing.T) {
|
||||
buf := bytes2.NewChunkedWriter(DefaultBufferSize)
|
||||
EncodeField(buf, "Val", nil)
|
||||
got := string(buf.Bytes())
|
||||
want := "\nVal\x00"
|
||||
if got != want {
|
||||
t.Errorf("nil encode: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
err := MarshalToStream(buf, 1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want := "\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
got := buf.String()
|
||||
if got != want {
|
||||
t.Errorf("got \n%q, want %q", got, want)
|
||||
}
|
||||
readbuf := bytes.NewBuffer(buf.Bytes())
|
||||
var out int64
|
||||
err = UnmarshalFromStream(readbuf, &out)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if out != 1 {
|
||||
t.Errorf("got %d, want 1", out)
|
||||
}
|
||||
err = MarshalToStream(buf, make(chan int))
|
||||
want = "unexpected type chan int"
|
||||
got = err.Error()
|
||||
if got != want {
|
||||
t.Errorf("got \n%q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
var testMap map[string]interface{}
|
||||
var testBlob []byte
|
||||
|
||||
func init() {
|
||||
testMap = map[string]interface{}{
|
||||
"bytes": []byte("bytes"),
|
||||
"float64": float64(64),
|
||||
"string": "string",
|
||||
"bool": true,
|
||||
"time": time.Unix(1136243045, 0),
|
||||
"int64": int64(-0x8000000000000000),
|
||||
"int32": int32(-0x80000000),
|
||||
"int": int(-0x80000000),
|
||||
"uint64": uint64(0xFFFFFFFFFFFFFFFF),
|
||||
"uint32": uint32(0xFFFFFFFF),
|
||||
"uint": uint(0xFFFFFFFF),
|
||||
"slice": []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, nil},
|
||||
"nil": nil,
|
||||
}
|
||||
testBlob, _ = Marshal(testMap)
|
||||
}
|
||||
|
||||
func BenchmarkMarshal(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := Marshal(testMap)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmarshal(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := make(map[string]interface{})
|
||||
err := Unmarshal(testBlob, &v)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeField(b *testing.B) {
|
||||
values := []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := bytes2.NewChunkedWriter(2048)
|
||||
EncodeField(buf, "Val", values)
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeInterface(b *testing.B) {
|
||||
values := []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := bytes2.NewChunkedWriter(2048)
|
||||
EncodeInterface(buf, "Val", values)
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func verifyMarshal(t *testing.T, val interface{}) []byte {
|
||||
got, err := Marshal(val)
|
||||
if err != nil {
|
||||
t.Errorf("Marshal error for %+v: %v\n", val, err)
|
||||
}
|
||||
return got
|
||||
}
|
||||
|
||||
func verifyUnmarshal(t *testing.T, buf []byte, val interface{}) {
|
||||
if err := Unmarshal(buf, val); err != nil {
|
||||
t.Errorf("Unmarshal error for %+v: %v\n", val, err)
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package bson implements encoding and decoding of BSON objects.
|
||||
package bson
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pack is the BSON binary packing protocol.
|
||||
// It's little endian.
|
||||
var Pack = binary.LittleEndian
|
||||
|
||||
var (
|
||||
timeType = reflect.TypeOf(time.Time{})
|
||||
bytesType = reflect.TypeOf([]byte(nil))
|
||||
)
|
||||
|
||||
// Words size in bytes.
|
||||
const (
|
||||
WORD32 = 4
|
||||
WORD64 = 8
|
||||
)
|
||||
|
||||
const (
|
||||
EOO = 0x00
|
||||
Number = 0x01
|
||||
String = 0x02
|
||||
Object = 0x03
|
||||
Array = 0x04
|
||||
Binary = 0x05
|
||||
Undefined = 0x06 // deprecated
|
||||
OID = 0x07 // unsupported
|
||||
Boolean = 0x08
|
||||
Datetime = 0x09
|
||||
Null = 0x0A
|
||||
Regex = 0x0B // unsupported
|
||||
Ref = 0x0C // deprecated
|
||||
Code = 0x0D // unsupported
|
||||
Symbol = 0x0E // unsupported
|
||||
CodeWithScope = 0x0F // unsupported
|
||||
Int = 0x10
|
||||
Timestamp = 0x11 // unsupported
|
||||
Long = 0x12
|
||||
Ulong = 0x3F // nonstandard extension
|
||||
MinKey = 0xFF // unsupported
|
||||
MaxKey = 0x7F // unsupported
|
||||
)
|
||||
|
||||
const (
|
||||
// MAGICTAG is the tag used to embed simple types inside
|
||||
// a bson document.
|
||||
MAGICTAG = "_Val_"
|
||||
)
|
||||
|
||||
type BsonError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func NewBsonError(format string, args ...interface{}) BsonError {
|
||||
return BsonError{fmt.Sprintf(format, args...)}
|
||||
}
|
||||
|
||||
func (err BsonError) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func handleError(err *error) {
|
||||
if x := recover(); x != nil {
|
||||
*err = x.(BsonError)
|
||||
}
|
||||
}
|
|
@ -1,249 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
const (
|
||||
bsonValNil = "\nVal\x00"
|
||||
bsonValBytes = "\x05Val\x00\x04\x00\x00\x00\x00test"
|
||||
bsonValInt64 = "\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00"
|
||||
bsonValInt32 = "\x10Val\x00\x01\x00\x00\x00"
|
||||
bsonValUint64 = "\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00"
|
||||
bsonValFloat64 = "\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?"
|
||||
bsonValBool = "\bVal\x00\x01"
|
||||
bsonValMap = "\x03Val\x00\x13\x00\x00\x00\x12Val1\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
bsonValSlice = "\x04Val\x00\x10\x00\x00\x00\x120\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
bsonValTime = "\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00"
|
||||
)
|
||||
|
||||
var interfaceMarshalCases = []struct {
|
||||
desc string
|
||||
in interface{}
|
||||
out string
|
||||
}{
|
||||
{"nil", nil, bsonValNil},
|
||||
{"string", "test", bsonValBytes},
|
||||
{"[]byte", []byte("test"), bsonValBytes},
|
||||
{"int64", int64(1), bsonValInt64},
|
||||
{"int32", int32(1), bsonValInt32},
|
||||
{"int", int(1), bsonValInt64},
|
||||
{"uint64", uint64(1), bsonValUint64},
|
||||
{"uint32", uint32(1), bsonValUint64},
|
||||
{"uint", uint(1), bsonValUint64},
|
||||
{"float64", float64(1.0), bsonValFloat64},
|
||||
{"bool", true, bsonValBool},
|
||||
{"nil map", map[string]interface{}(nil), bsonValNil},
|
||||
{"map", map[string]interface{}{"Val1": 1}, bsonValMap},
|
||||
{"nil slice", []interface{}(nil), bsonValNil},
|
||||
{"slice", []interface{}{1}, bsonValSlice},
|
||||
{"time", time.Unix(1136243045, 0).UTC(), bsonValTime},
|
||||
}
|
||||
|
||||
func TestInterfaceMarshal(t *testing.T) {
|
||||
for _, tcase := range interfaceMarshalCases {
|
||||
buf := bytes2.NewChunkedWriter(DefaultBufferSize)
|
||||
EncodeInterface(buf, "Val", tcase.in)
|
||||
got := string(buf.Bytes())
|
||||
if got != tcase.out {
|
||||
t.Errorf("%s: got \n%q, want \n%q", tcase.desc, got, tcase.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceMarshalFailure(t *testing.T) {
|
||||
want := "don't know how to marshal chan int"
|
||||
func() {
|
||||
defer func() {
|
||||
if x := recover(); x != nil {
|
||||
got := x.(BsonError).Error()
|
||||
if got != want {
|
||||
t.Errorf("got %s, want %s", got, want)
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
buf := bytes2.NewChunkedWriter(DefaultBufferSize)
|
||||
EncodeInterface(buf, "Val", make(chan int))
|
||||
t.Errorf("got no error, want %s", want)
|
||||
}()
|
||||
}
|
||||
|
||||
const (
|
||||
bsonString = "\x05\x00\x00\x00test\x00"
|
||||
bsonBinary = "\x04\x00\x00\x00\x00test"
|
||||
bsonInt = "\x01\x00\x00\x00"
|
||||
bsonLong = "\x01\x00\x00\x00\x00\x00\x00\x00"
|
||||
bsonNumber = "\x00\x00\x00\x00\x00\x00\xf0?"
|
||||
bsonDatetime = "\x88\xf2\\\x8d\b\x01\x00\x00"
|
||||
bsonBoolean = "\x01"
|
||||
bsonObject = "\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00"
|
||||
bsonObjectNull = "\v\x00\x00\x00\nVal2\x00\x00"
|
||||
bsonArray = "\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00"
|
||||
bsonArrayNull = "\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
bsonStringArray = "\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00"
|
||||
)
|
||||
|
||||
func stringDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeString(buf, kind) }
|
||||
func binaryDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeBinary(buf, kind) }
|
||||
func int64Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInt64(buf, kind) }
|
||||
func int32Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInt32(buf, kind) }
|
||||
func intDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInt(buf, kind) }
|
||||
func uint64Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeUint64(buf, kind) }
|
||||
func uint32Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeUint32(buf, kind) }
|
||||
func uintDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeUint(buf, kind) }
|
||||
func float64Decoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeFloat64(buf, kind) }
|
||||
func boolDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeBool(buf, kind) }
|
||||
func timeDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeTime(buf, kind) }
|
||||
func interfaceDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeInterface(buf, kind) }
|
||||
func mapDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeMap(buf, kind) }
|
||||
func arrayDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeArray(buf, kind) }
|
||||
func skipDecoder(buf *bytes.Buffer, kind byte) interface{} { Skip(buf, kind); return nil }
|
||||
func stringArrayDecoder(buf *bytes.Buffer, kind byte) interface{} { return DecodeStringArray(buf, kind) }
|
||||
|
||||
var customUnmarshalCases = []struct {
|
||||
desc string
|
||||
in string
|
||||
kind byte
|
||||
decoder func(buf *bytes.Buffer, kind byte) interface{}
|
||||
out interface{}
|
||||
}{
|
||||
{"String->string", bsonString, String, stringDecoder, "test"},
|
||||
{"Binary->string", bsonBinary, Binary, stringDecoder, "test"},
|
||||
{"Null->string", "", Null, stringDecoder, ""},
|
||||
{"String->bytes", bsonString, String, binaryDecoder, []byte("test")},
|
||||
{"Binary->bytes", bsonBinary, Binary, binaryDecoder, []byte("test")},
|
||||
{"Null->bytes", "", Null, binaryDecoder, []byte(nil)},
|
||||
{"Int->int64", bsonInt, Int, int64Decoder, int64(1)},
|
||||
{"Long->int64", bsonLong, Long, int64Decoder, int64(1)},
|
||||
{"Ulong->int64", bsonLong, Ulong, int64Decoder, int64(1)},
|
||||
{"Null->int64", "", Null, int64Decoder, int64(0)},
|
||||
{"Int->int32", bsonInt, Int, int32Decoder, int32(1)},
|
||||
{"Null->int32", "", Null, int32Decoder, int32(0)},
|
||||
{"Int->int", bsonInt, Int, intDecoder, int(1)},
|
||||
{"Long->int", bsonLong, Long, intDecoder, int(1)},
|
||||
{"Ulong->int", bsonLong, Ulong, intDecoder, int(1)},
|
||||
{"Null->int", "", Null, intDecoder, int(0)},
|
||||
{"Int->uint64", bsonInt, Int, uint64Decoder, uint64(1)},
|
||||
{"Long->uint64", bsonLong, Long, uint64Decoder, uint64(1)},
|
||||
{"Ulong->uint64", bsonLong, Ulong, uint64Decoder, uint64(1)},
|
||||
{"Null->uint64", "", Null, uint64Decoder, uint64(0)},
|
||||
{"Int->uint32", bsonInt, Int, uint32Decoder, uint32(1)},
|
||||
{"Ulong->uint32", bsonLong, Ulong, uint32Decoder, uint32(1)},
|
||||
{"Null->uint32", "", Null, uint32Decoder, uint32(0)},
|
||||
{"Int->uint", bsonInt, Int, uintDecoder, uint(1)},
|
||||
{"Long->uint", bsonLong, Long, uintDecoder, uint(1)},
|
||||
{"Ulong->uint", bsonLong, Ulong, uintDecoder, uint(1)},
|
||||
{"Null->uint", "", Null, uintDecoder, uint(0)},
|
||||
{"Number->float64", bsonNumber, Number, float64Decoder, float64(1.0)},
|
||||
{"Null->float64", "", Null, float64Decoder, float64(0.0)},
|
||||
{"Boolean->bool", bsonBoolean, Boolean, boolDecoder, true},
|
||||
{"Null->bool", "", Null, boolDecoder, false},
|
||||
{"Datetime->time.Time", bsonDatetime, Datetime, timeDecoder, time.Unix(1136243045, 0).UTC()},
|
||||
{"Null->time.Time", "", Null, timeDecoder, time.Time{}},
|
||||
{"Number->interface{}", bsonNumber, Number, interfaceDecoder, float64(1.0)},
|
||||
{"String->interface{}", bsonString, String, interfaceDecoder, "test"},
|
||||
{"Object->interface{}", bsonObject, Object, interfaceDecoder, map[string]interface{}{"Val2": []byte("test")}},
|
||||
{"Object->interface{} with null element", bsonObjectNull, Object, interfaceDecoder, map[string]interface{}{"Val2": nil}},
|
||||
{"Array->interface{}", bsonArray, Array, interfaceDecoder, []interface{}{[]byte("test")}},
|
||||
{"Array->interface{} with null element", bsonArrayNull, Array, interfaceDecoder, []interface{}{nil, int64(1)}},
|
||||
{"Binary->interface{}", bsonBinary, Binary, interfaceDecoder, []byte("test")},
|
||||
{"Boolean->interface{}", bsonBoolean, Boolean, interfaceDecoder, true},
|
||||
{"Datetime->interface{}", bsonDatetime, Datetime, interfaceDecoder, time.Unix(1136243045, 0).UTC()},
|
||||
{"Int->interface{}", bsonInt, Int, interfaceDecoder, int32(1)},
|
||||
{"Long->interface{}", bsonLong, Long, interfaceDecoder, int64(1)},
|
||||
{"Ulong->interface{}", bsonLong, Ulong, interfaceDecoder, uint64(1)},
|
||||
{"Null->interface{}", "", Null, interfaceDecoder, nil},
|
||||
{"Null->map[string]interface{}", "", Null, mapDecoder, map[string]interface{}(nil)},
|
||||
{"Null->[]interface{}", "", Null, arrayDecoder, []interface{}(nil)},
|
||||
{"Number->Skip", bsonNumber, Number, skipDecoder, nil},
|
||||
{"String->Skip", bsonString, String, skipDecoder, nil},
|
||||
{"Object->Skip", bsonObject, Object, skipDecoder, nil},
|
||||
{"Object->Skip with null element", bsonObjectNull, Object, skipDecoder, nil},
|
||||
{"Array->Skip", bsonArray, Array, skipDecoder, nil},
|
||||
{"Array->Skip with null element", bsonArrayNull, Array, skipDecoder, nil},
|
||||
{"Binary->Skip", bsonBinary, Binary, skipDecoder, nil},
|
||||
{"Boolean->Skip", bsonBoolean, Boolean, skipDecoder, nil},
|
||||
{"Datetime->Skip", bsonDatetime, Datetime, skipDecoder, nil},
|
||||
{"Int->Skip", bsonInt, Int, skipDecoder, nil},
|
||||
{"Long->Skip", bsonLong, Long, skipDecoder, nil},
|
||||
{"Ulong->Skip", bsonLong, Ulong, skipDecoder, nil},
|
||||
{"Null->Skip", "", Null, skipDecoder, nil},
|
||||
{"Null->map[string]interface{}", "", Null, mapDecoder, map[string]interface{}(nil)},
|
||||
{"Null->[]interface{}", "", Null, arrayDecoder, []interface{}(nil)},
|
||||
{"Array->[]string", bsonStringArray, Array, stringArrayDecoder, []string{"test1", "test2"}},
|
||||
{"Null->[]string", "", Null, stringArrayDecoder, []string(nil)},
|
||||
}
|
||||
|
||||
func TestCustomUnmarshal(t *testing.T) {
|
||||
for _, tcase := range customUnmarshalCases {
|
||||
buf := bytes.NewBuffer([]byte(tcase.in))
|
||||
got := tcase.decoder(buf, tcase.kind)
|
||||
if !reflect.DeepEqual(got, tcase.out) {
|
||||
t.Errorf("%s: received: %v, want %v", tcase.desc, got, tcase.out)
|
||||
}
|
||||
if buf.Len() != 0 {
|
||||
t.Errorf("%s: %d unread bytes from %q, want 0", tcase.desc, buf.Len(), tcase.in)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var customUnmarshalFailureCases = []struct {
|
||||
typ string
|
||||
decoder func(buf *bytes.Buffer, kind byte) interface{}
|
||||
valid []byte
|
||||
}{
|
||||
{"string", stringDecoder, []byte{String, Binary, Null}},
|
||||
{"[]byte", binaryDecoder, []byte{String, Binary, Null}},
|
||||
{"int64", int64Decoder, []byte{Int, Long, Ulong, Null}},
|
||||
{"int32", int32Decoder, []byte{Int, Long, Null}},
|
||||
{"int", intDecoder, []byte{Int, Long, Ulong, Null}},
|
||||
{"uint64", uint64Decoder, []byte{Int, Long, Ulong, Null}},
|
||||
{"uint32", uint32Decoder, []byte{Int, Long, Ulong, Null}},
|
||||
{"uint", uintDecoder, []byte{Int, Long, Ulong, Null}},
|
||||
{"float64", float64Decoder, []byte{Number, Null}},
|
||||
{"bool", boolDecoder, []byte{Boolean, Int, Long, Ulong, Null}},
|
||||
{"time.Time", timeDecoder, []byte{Datetime, Null}},
|
||||
{"interface{}", interfaceDecoder, []byte{Number, String, Object, Array, Binary, Boolean, Datetime, Null, Int, Long, Ulong}},
|
||||
{"map", mapDecoder, []byte{Object, Null}},
|
||||
{"slice", arrayDecoder, []byte{Array, Null}},
|
||||
{"[]string", stringArrayDecoder, []byte{Array, Null}},
|
||||
{"skip", skipDecoder, []byte{Number, String, Object, Array, Binary, Boolean, Datetime, Null, Int, Long, Ulong}},
|
||||
}
|
||||
|
||||
func TestCustomUnmarshalFailures(t *testing.T) {
|
||||
allKinds := []byte{EOO, Number, String, Object, Array, Binary, Boolean, Datetime, Null, Int, Long, Ulong}
|
||||
for _, tcase := range customUnmarshalFailureCases {
|
||||
for _, kind := range allKinds {
|
||||
want := fmt.Sprintf("unexpected kind %v for %s", kind, tcase.typ)
|
||||
func() {
|
||||
defer func() {
|
||||
if x := recover(); x != nil {
|
||||
got := x.(BsonError).Error()
|
||||
if got != want {
|
||||
t.Errorf("got %s, want %s", got, want)
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
for _, valid := range tcase.valid {
|
||||
if kind == valid {
|
||||
return
|
||||
}
|
||||
}
|
||||
tcase.decoder(nil, kind)
|
||||
t.Errorf("got no error, want %s", want)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,381 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
// LenWriter records the current write position on the buffer
|
||||
// and can later be used to record the number of bytes written
|
||||
// in conformance to BSON spec
|
||||
type LenWriter struct {
|
||||
buf *bytes2.ChunkedWriter
|
||||
off int
|
||||
b []byte
|
||||
}
|
||||
|
||||
// NewLenWriter returns a LenWriter that reserves the
|
||||
// bytes buf so they can store the length later.
|
||||
func NewLenWriter(buf *bytes2.ChunkedWriter) LenWriter {
|
||||
off := buf.Len()
|
||||
b := buf.Reserve(WORD32)
|
||||
return LenWriter{buf, off, b}
|
||||
}
|
||||
|
||||
// Close closes the current object being encoded by
|
||||
// writing bson's EOO byte and recording the length.
|
||||
func (lw LenWriter) Close() {
|
||||
lw.buf.WriteByte(EOO)
|
||||
Pack.PutUint32(lw.b, uint32(lw.buf.Len()-lw.off))
|
||||
}
|
||||
|
||||
// Marshaler is the interface that needs to be
|
||||
// satisfied by types that want to implement a custom
|
||||
// marshaler.
|
||||
// When being invoked as a top level object, key will
|
||||
// be "". In such cases, MarshalBson must not encode
|
||||
// any prefix.
|
||||
type Marshaler interface {
|
||||
MarshalBson(buf *bytes2.ChunkedWriter, key string)
|
||||
}
|
||||
|
||||
func canMarshal(val reflect.Value) Marshaler {
|
||||
// Check the Marshaler interface on T.
|
||||
if marshaler, ok := val.Interface().(Marshaler); ok {
|
||||
// Don't call custom marshaler for nil values.
|
||||
switch val.Kind() {
|
||||
case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice:
|
||||
if val.IsNil() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return marshaler
|
||||
}
|
||||
// Check the Marshaler interface on *T.
|
||||
if val.CanAddr() {
|
||||
if marshaler, ok := val.Addr().Interface().(Marshaler); ok {
|
||||
return marshaler
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultBufferSize is the default allocation size for ChunkedWriter.
|
||||
const DefaultBufferSize = 1024
|
||||
|
||||
// MarshalToStream marshals val into writer.
|
||||
func MarshalToStream(writer io.Writer, val interface{}) (err error) {
|
||||
buf := bytes2.NewChunkedWriter(DefaultBufferSize)
|
||||
if err = MarshalToBuffer(buf, val); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buf.WriteTo(writer)
|
||||
return err
|
||||
}
|
||||
|
||||
// Marshal marshals val into encoded.
|
||||
func Marshal(val interface{}) (encoded []byte, err error) {
|
||||
buf := bytes2.NewChunkedWriter(DefaultBufferSize)
|
||||
err = MarshalToBuffer(buf, val)
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
// MarshalToBuffer marshals val into buf. This is the most efficient
|
||||
// function to use, especially when marshaling large nested objects.
|
||||
func MarshalToBuffer(buf *bytes2.ChunkedWriter, val interface{}) (err error) {
|
||||
defer handleError(&err)
|
||||
if val == nil {
|
||||
return NewBsonError("cannot marshal nil")
|
||||
}
|
||||
|
||||
v := reflect.Indirect(reflect.ValueOf(val))
|
||||
if marshaler := canMarshal(v); marshaler != nil {
|
||||
marshaler.MarshalBson(buf, "")
|
||||
return
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.String,
|
||||
reflect.Int64, reflect.Int32, reflect.Int,
|
||||
reflect.Uint64, reflect.Uint32, reflect.Uint,
|
||||
reflect.Float64, reflect.Bool:
|
||||
EncodeSimple(buf, v.Interface())
|
||||
case reflect.Struct:
|
||||
if v.Type() == timeType {
|
||||
EncodeSimple(buf, v.Interface())
|
||||
} else {
|
||||
encodeStructContent(buf, v)
|
||||
}
|
||||
case reflect.Map:
|
||||
encodeMapContent(buf, v)
|
||||
case reflect.Slice, reflect.Array:
|
||||
if v.Type() == bytesType {
|
||||
EncodeSimple(buf, v.Interface())
|
||||
} else {
|
||||
encodeSliceContent(buf, v)
|
||||
}
|
||||
default:
|
||||
return NewBsonError("unexpected type %v", v.Type())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeSimple marshals simple objects that cannot be
|
||||
// encoded as a top level bson document.
|
||||
func EncodeSimple(buf *bytes2.ChunkedWriter, val interface{}) {
|
||||
lenWriter := NewLenWriter(buf)
|
||||
EncodeField(buf, MAGICTAG, val)
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// EncodeField encodes val using the supplied key as embedded tag.
|
||||
// Unlike EncodeInterface, EncodeField can handle complex objects
|
||||
// like structs, pointers, etc. But it is slower.
|
||||
func EncodeField(buf *bytes2.ChunkedWriter, key string, val interface{}) {
|
||||
encodeField(buf, key, reflect.ValueOf(val))
|
||||
}
|
||||
|
||||
func encodeField(buf *bytes2.ChunkedWriter, key string, val reflect.Value) {
|
||||
// nil interfaces show up as invalid
|
||||
if !val.IsValid() {
|
||||
EncodePrefix(buf, Null, key)
|
||||
return
|
||||
}
|
||||
if marshaler := canMarshal(val); marshaler != nil {
|
||||
marshaler.MarshalBson(buf, key)
|
||||
return
|
||||
}
|
||||
|
||||
switch val.Kind() {
|
||||
case reflect.String:
|
||||
EncodeString(buf, key, val.String())
|
||||
case reflect.Int64:
|
||||
EncodeInt64(buf, key, val.Int())
|
||||
case reflect.Int32:
|
||||
EncodeInt32(buf, key, int32(val.Int()))
|
||||
case reflect.Int:
|
||||
EncodeInt(buf, key, int(val.Int()))
|
||||
case reflect.Uint64:
|
||||
EncodeUint64(buf, key, uint64(val.Uint()))
|
||||
case reflect.Uint32:
|
||||
EncodeUint32(buf, key, uint32(val.Uint()))
|
||||
case reflect.Uint:
|
||||
EncodeUint(buf, key, uint(val.Uint()))
|
||||
case reflect.Float64:
|
||||
EncodeFloat64(buf, key, val.Float())
|
||||
case reflect.Bool:
|
||||
EncodeBool(buf, key, val.Bool())
|
||||
case reflect.Struct:
|
||||
if val.Type() == timeType {
|
||||
EncodeTime(buf, key, val.Interface().(time.Time))
|
||||
} else {
|
||||
encodeStruct(buf, key, val)
|
||||
}
|
||||
case reflect.Map:
|
||||
encodeMap(buf, key, val)
|
||||
case reflect.Slice:
|
||||
if val.Type() == bytesType {
|
||||
EncodeBinary(buf, key, val.Interface().([]byte))
|
||||
} else {
|
||||
encodeSlice(buf, key, val)
|
||||
}
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
if val.IsNil() {
|
||||
EncodePrefix(buf, Null, key)
|
||||
} else {
|
||||
encodeField(buf, key, val.Elem())
|
||||
}
|
||||
default:
|
||||
panic(NewBsonError("don't know how to marshal %v", val.Type()))
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeOptionalPrefix encodes the key as prefix if it's not empty.
|
||||
// If it is empty, then it's a no-op, with the assumption that
|
||||
// it's a top level object.
|
||||
func EncodeOptionalPrefix(buf *bytes2.ChunkedWriter, etype byte, key string) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
EncodePrefix(buf, etype, key)
|
||||
}
|
||||
|
||||
// EncodePrefix encodes key as prefix for the next object or value.
|
||||
func EncodePrefix(buf *bytes2.ChunkedWriter, etype byte, key string) {
|
||||
b := buf.Reserve(len(key) + 2)
|
||||
b[0] = etype
|
||||
copy(b[1:], key)
|
||||
b[len(b)-1] = 0
|
||||
}
|
||||
|
||||
// EncodeString encodes a string.
|
||||
func EncodeString(buf *bytes2.ChunkedWriter, key string, val string) {
|
||||
// Encode strings as binary; go strings are not necessarily unicode
|
||||
EncodePrefix(buf, Binary, key)
|
||||
putUint32(buf, uint32(len(val)))
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(val)
|
||||
}
|
||||
|
||||
// EncodeBinary encodes a []byte as binary.
|
||||
func EncodeBinary(buf *bytes2.ChunkedWriter, key string, val []byte) {
|
||||
EncodePrefix(buf, Binary, key)
|
||||
putUint32(buf, uint32(len(val)))
|
||||
buf.WriteByte(0)
|
||||
buf.Write(val)
|
||||
}
|
||||
|
||||
// EncodeInt64 encodes an int64.
|
||||
func EncodeInt64(buf *bytes2.ChunkedWriter, key string, val int64) {
|
||||
EncodePrefix(buf, Long, key)
|
||||
putUint64(buf, uint64(val))
|
||||
}
|
||||
|
||||
// EncodeInt32 encodes an int32.
|
||||
func EncodeInt32(buf *bytes2.ChunkedWriter, key string, val int32) {
|
||||
EncodePrefix(buf, Int, key)
|
||||
putUint32(buf, uint32(val))
|
||||
}
|
||||
|
||||
// EncodeInt encodes an int.
|
||||
func EncodeInt(buf *bytes2.ChunkedWriter, key string, val int) {
|
||||
EncodeInt64(buf, key, int64(val))
|
||||
}
|
||||
|
||||
// EncodeUint64 encodes an uint64.
|
||||
func EncodeUint64(buf *bytes2.ChunkedWriter, key string, val uint64) {
|
||||
EncodePrefix(buf, Long, key)
|
||||
putUint64(buf, val)
|
||||
}
|
||||
|
||||
// EncodeUint32 encodes an uint32.
|
||||
func EncodeUint32(buf *bytes2.ChunkedWriter, key string, val uint32) {
|
||||
EncodeUint64(buf, key, uint64(val))
|
||||
}
|
||||
|
||||
// EncodeUint encodes an uint.
|
||||
func EncodeUint(buf *bytes2.ChunkedWriter, key string, val uint) {
|
||||
EncodeUint64(buf, key, uint64(val))
|
||||
}
|
||||
|
||||
// EncodeFloat64 encodes a float64.
|
||||
func EncodeFloat64(buf *bytes2.ChunkedWriter, key string, val float64) {
|
||||
EncodePrefix(buf, Number, key)
|
||||
bits := math.Float64bits(val)
|
||||
putUint64(buf, bits)
|
||||
}
|
||||
|
||||
// EncodeBool encodes a bool.
|
||||
func EncodeBool(buf *bytes2.ChunkedWriter, key string, val bool) {
|
||||
EncodePrefix(buf, Boolean, key)
|
||||
if val {
|
||||
buf.WriteByte(1)
|
||||
} else {
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeTime encodes a time.Time.
|
||||
func EncodeTime(buf *bytes2.ChunkedWriter, key string, val time.Time) {
|
||||
EncodePrefix(buf, Datetime, key)
|
||||
mtime := val.UnixNano() / 1e6
|
||||
putUint64(buf, uint64(mtime))
|
||||
}
|
||||
|
||||
func encodeStruct(buf *bytes2.ChunkedWriter, key string, val reflect.Value) {
|
||||
EncodePrefix(buf, Object, key)
|
||||
encodeStructContent(buf, val)
|
||||
}
|
||||
|
||||
func encodeStructContent(buf *bytes2.ChunkedWriter, val reflect.Value) {
|
||||
lenWriter := NewLenWriter(buf)
|
||||
t := val.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
key := t.Field(i).Name
|
||||
|
||||
// NOTE(szopa): Ignore private fields (copied from
|
||||
// encoding/json). Yes, it feels like a hack.
|
||||
if t.Field(i).PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
encodeField(buf, key, val.Field(i))
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
func encodeMap(buf *bytes2.ChunkedWriter, key string, val reflect.Value) {
|
||||
EncodePrefix(buf, Object, key)
|
||||
encodeMapContent(buf, val)
|
||||
}
|
||||
|
||||
// a map seems to lose the 'CanAddr' property. So if we want
|
||||
// to use a custom marshaler with a struct pointer receiver, like:
|
||||
// func (ps *PrivateStruct) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
// the map has to be using pointers, i.e:
|
||||
// map[string]*PrivateStruct
|
||||
// and not:
|
||||
// map[string]PrivateStruct
|
||||
// (see unit test)
|
||||
func encodeMapContent(buf *bytes2.ChunkedWriter, val reflect.Value) {
|
||||
lenWriter := NewLenWriter(buf)
|
||||
mt := val.Type()
|
||||
if mt.Key().Kind() != reflect.String {
|
||||
panic(NewBsonError("can't marshall maps with non-string key types"))
|
||||
}
|
||||
keys := val.MapKeys()
|
||||
for _, k := range keys {
|
||||
key := k.String()
|
||||
encodeField(buf, key, val.MapIndex(k))
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
func encodeSlice(buf *bytes2.ChunkedWriter, key string, val reflect.Value) {
|
||||
EncodePrefix(buf, Array, key)
|
||||
encodeSliceContent(buf, val)
|
||||
}
|
||||
|
||||
func encodeSliceContent(buf *bytes2.ChunkedWriter, val reflect.Value) {
|
||||
lenWriter := NewLenWriter(buf)
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
encodeField(buf, Itoa(i), val.Index(i))
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
func putUint32(buf *bytes2.ChunkedWriter, val uint32) {
|
||||
Pack.PutUint32(buf.Reserve(WORD32), val)
|
||||
}
|
||||
|
||||
func putUint64(buf *bytes2.ChunkedWriter, val uint64) {
|
||||
Pack.PutUint64(buf.Reserve(WORD64), val)
|
||||
}
|
||||
|
||||
var intStrMap [intAliasSize + 1]string
|
||||
|
||||
const (
|
||||
intAliasSize = 1024
|
||||
)
|
||||
|
||||
func init() {
|
||||
for i := 0; i <= intAliasSize; i++ {
|
||||
intStrMap[i] = strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
|
||||
// Itoa is used in code generated by bsongen.
|
||||
func Itoa(i int) string {
|
||||
if i <= intAliasSize {
|
||||
return intStrMap[i]
|
||||
}
|
||||
return strconv.Itoa(i)
|
||||
}
|
|
@ -1,235 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
type String1 string
|
||||
|
||||
func (cs String1) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
// Hardcode value to verify that function is called
|
||||
EncodeString(buf, key, "test")
|
||||
}
|
||||
|
||||
type String2 string
|
||||
|
||||
func (cs *String2) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
// Hardcode value to verify that function is called
|
||||
EncodeString(buf, key, "test")
|
||||
}
|
||||
|
||||
var marshaltest = []struct {
|
||||
desc string
|
||||
in interface{}
|
||||
out string
|
||||
}{{
|
||||
"struct encode",
|
||||
struct{ Val string }{"test"},
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"struct encode nil",
|
||||
struct{ Val *int }{},
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
}, {
|
||||
"struct encode nil interface",
|
||||
struct{ Val interface{} }{},
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
}, {
|
||||
"map encode",
|
||||
map[string]string{"Val": "test"},
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"embedded map encode",
|
||||
struct{ Inner map[string]string }{map[string]string{"Val": "test"}},
|
||||
"\x1f\x00\x00\x00\x03Inner\x00\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
}, {
|
||||
"embedded map encode nil",
|
||||
struct{ Inner map[string]string }{},
|
||||
"\x11\x00\x00\x00\x03Inner\x00\x05\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"slice encode",
|
||||
[]string{"test1", "test2"},
|
||||
"\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00",
|
||||
}, {
|
||||
"embedded slice encode",
|
||||
struct{ Inner []string }{[]string{"test1", "test2"}},
|
||||
"+\x00\x00\x00\x04Inner\x00\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00\x00",
|
||||
}, {
|
||||
"embedded slice encode nil",
|
||||
struct{ Inner []string }{},
|
||||
"\x11\x00\x00\x00\x04Inner\x00\x05\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"array encode",
|
||||
[2]string{"test1", "test2"},
|
||||
"\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00",
|
||||
}, {
|
||||
"string encode",
|
||||
"test",
|
||||
"\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"bytes encode",
|
||||
[]byte("test"),
|
||||
"\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"int64 encode",
|
||||
int64(1),
|
||||
"\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"int32 encode",
|
||||
int32(1),
|
||||
"\x10\x00\x00\x00\x10_Val_\x00\x01\x00\x00\x00\x00",
|
||||
}, {
|
||||
"int encode",
|
||||
int(1),
|
||||
"\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"uint64 encode",
|
||||
uint64(1),
|
||||
"\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"uint32 encode",
|
||||
uint32(1),
|
||||
"\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"uint encode",
|
||||
uint(1),
|
||||
"\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"float encode",
|
||||
float64(1.0),
|
||||
"\x14\x00\x00\x00\x01_Val_\x00\x00\x00\x00\x00\x00\x00\xf0?\x00",
|
||||
}, {
|
||||
"bool encode",
|
||||
true,
|
||||
"\r\x00\x00\x00\b_Val_\x00\x01\x00",
|
||||
}, {
|
||||
"time encode",
|
||||
time.Unix(1136243045, 0).UTC(),
|
||||
"\x14\x00\x00\x00\t_Val_\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00",
|
||||
}, {
|
||||
|
||||
// Following encodes are for reference. They're used for
|
||||
// the decode tests.
|
||||
"embedded Object encode",
|
||||
struct{ Val struct{ Val2 string } }{struct{ Val2 string }{"test"}},
|
||||
"\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
}, {
|
||||
"embedded Object encode nil element",
|
||||
struct{ Val struct{ Val2 *int64 } }{struct{ Val2 *int64 }{nil}},
|
||||
"\x15\x00\x00\x00\x03Val\x00\v\x00\x00\x00\nVal2\x00\x00\x00",
|
||||
}, {
|
||||
"embedded Array encode",
|
||||
struct{ Val []string }{Val: []string{"test"}},
|
||||
"\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
}, {
|
||||
"Array encode nil element",
|
||||
struct{ Val []*int64 }{Val: []*int64{nil, newint64(1)}},
|
||||
"\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"embedded Number encode",
|
||||
struct{ Val float64 }{1.0},
|
||||
"\x12\x00\x00\x00\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?\x00",
|
||||
}, {
|
||||
"embedded Binary encode",
|
||||
struct{ Val string }{"test"},
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"embedded Boolean encode",
|
||||
struct{ Val bool }{true},
|
||||
"\v\x00\x00\x00\bVal\x00\x01\x00",
|
||||
}, {
|
||||
"embedded Datetime encode",
|
||||
struct{ Val time.Time }{time.Unix(1136243045, 0).UTC()},
|
||||
"\x12\x00\x00\x00\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00",
|
||||
}, {
|
||||
"embedded Null encode",
|
||||
struct{ Val *int }{},
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
}, {
|
||||
"embedded Int encode",
|
||||
struct{ Val int32 }{1},
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
}, {
|
||||
"embedded Long encode",
|
||||
struct{ Val int64 }{1},
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"embedded Ulong encode",
|
||||
struct{ Val uint64 }{1},
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
}, {
|
||||
"embedded non-pointer encode with custom marshaler",
|
||||
struct{ Val String1 }{String1("foo")},
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"embedded pointer encode with custom marshaler",
|
||||
struct{ Val *String1 }{func(cs String1) *String1 { return &cs }(String1("foo"))},
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"embedded nil pointer encode with custom marshaler",
|
||||
struct{ Val *String1 }{},
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
}, {
|
||||
"embedded pointer encode with custom pointer marshaler",
|
||||
struct{ Val *String2 }{func(cs String2) *String2 { return &cs }(String2("foo"))},
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"embedded addressable encode with custom pointer marshaler",
|
||||
&struct{ Val String2 }{String2("foo")},
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
}, {
|
||||
"embedded non-addressable encode with custom pointer marshaler",
|
||||
struct{ Val String2 }{String2("foo")},
|
||||
"\x12\x00\x00\x00\x05Val\x00\x03\x00\x00\x00\x00foo\x00",
|
||||
}}
|
||||
|
||||
func TestMarshal(t *testing.T) {
|
||||
for _, tcase := range marshaltest {
|
||||
got := verifyMarshal(t, tcase.in)
|
||||
if string(got) != tcase.out {
|
||||
t.Errorf("%s: encoded: \n%q, want\n%q", tcase.desc, got, tcase.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var marshalErrorCases = []struct {
|
||||
desc string
|
||||
in interface{}
|
||||
out string
|
||||
}{{
|
||||
"nil input",
|
||||
nil,
|
||||
"cannot marshal nil",
|
||||
}, {
|
||||
"chan input",
|
||||
make(chan int),
|
||||
"unexpected type chan int",
|
||||
}, {
|
||||
"embedded chan input",
|
||||
struct{ Val chan int }{},
|
||||
"don't know how to marshal chan int",
|
||||
}, {
|
||||
"map with int key",
|
||||
map[int]int{},
|
||||
"can't marshall maps with non-string key types",
|
||||
}}
|
||||
|
||||
func TestMarshalErrors(t *testing.T) {
|
||||
for _, tcase := range marshalErrorCases {
|
||||
_, err := Marshal(tcase.in)
|
||||
got := ""
|
||||
if err != nil {
|
||||
got = err.Error()
|
||||
}
|
||||
if got != tcase.out {
|
||||
t.Errorf("%s: received: %q, want %q", tcase.desc, got, tcase.out)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Utility functions for custom encoders
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
// EncodeInterface bson encodes an interface{}. Elements
|
||||
// can be basic bson encodable types, or []interface{},
|
||||
// or map[string]interface{}, whose elements have to in
|
||||
// turn be bson encodable.
|
||||
func EncodeInterface(buf *bytes2.ChunkedWriter, key string, val interface{}) {
|
||||
if val == nil {
|
||||
EncodePrefix(buf, Null, key)
|
||||
return
|
||||
}
|
||||
switch val := val.(type) {
|
||||
case string:
|
||||
EncodeString(buf, key, val)
|
||||
case []byte:
|
||||
EncodeBinary(buf, key, val)
|
||||
case int64:
|
||||
EncodeInt64(buf, key, val)
|
||||
case int32:
|
||||
EncodeInt32(buf, key, val)
|
||||
case int:
|
||||
EncodeInt(buf, key, val)
|
||||
case uint64:
|
||||
EncodeUint64(buf, key, val)
|
||||
case uint32:
|
||||
EncodeUint32(buf, key, val)
|
||||
case uint:
|
||||
EncodeUint(buf, key, val)
|
||||
case float64:
|
||||
EncodeFloat64(buf, key, val)
|
||||
case bool:
|
||||
EncodeBool(buf, key, val)
|
||||
case map[string]interface{}:
|
||||
if val == nil {
|
||||
EncodePrefix(buf, Null, key)
|
||||
return
|
||||
}
|
||||
EncodePrefix(buf, Object, key)
|
||||
lenWriter := NewLenWriter(buf)
|
||||
for k, v := range val {
|
||||
EncodeInterface(buf, k, v)
|
||||
}
|
||||
lenWriter.Close()
|
||||
case []interface{}:
|
||||
if val == nil {
|
||||
EncodePrefix(buf, Null, key)
|
||||
return
|
||||
}
|
||||
EncodePrefix(buf, Array, key)
|
||||
lenWriter := NewLenWriter(buf)
|
||||
for i, v := range val {
|
||||
EncodeInterface(buf, Itoa(i), v)
|
||||
}
|
||||
lenWriter.Close()
|
||||
case time.Time:
|
||||
EncodeTime(buf, key, val)
|
||||
default:
|
||||
panic(NewBsonError("don't know how to marshal %T", val))
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeStringArray bson encodes a []string
|
||||
func EncodeStringArray(buf *bytes2.ChunkedWriter, name string, values []string) {
|
||||
if values == nil {
|
||||
EncodePrefix(buf, Null, name)
|
||||
return
|
||||
}
|
||||
EncodePrefix(buf, Array, name)
|
||||
lenWriter := NewLenWriter(buf)
|
||||
for i, val := range values {
|
||||
EncodeString(buf, Itoa(i), val)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
|
@ -1,315 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Unmarshaler is the interface that needs to be satisfied
|
||||
// by types that want to perform custom unmarshaling.
|
||||
// If kind is EOO, then the type is being unmarshalled
|
||||
// as a top level object. Otherwise, it's an embedded
|
||||
// object, and kind will need to be type-checked
|
||||
// before unmarshaling.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalBson(buf *bytes.Buffer, kind byte)
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) canUnMarshal() Unmarshaler {
|
||||
// Don't use custom unmarshalers for map values.
|
||||
// It loses symmetry.
|
||||
if builder.map_.IsValid() {
|
||||
return nil
|
||||
}
|
||||
if builder.val.CanAddr() {
|
||||
if unmarshaler, ok := builder.val.Addr().Interface().(Unmarshaler); ok {
|
||||
return unmarshaler
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmarshal unmarshals b into val.
|
||||
func Unmarshal(b []byte, val interface{}) (err error) {
|
||||
return UnmarshalFromBuffer(bytes.NewBuffer(b), val)
|
||||
}
|
||||
|
||||
// UnmarshalFromStream unmarshals from reader into val.
|
||||
func UnmarshalFromStream(reader io.Reader, val interface{}) (err error) {
|
||||
lenbuf := make([]byte, 4)
|
||||
var n int
|
||||
n, err = io.ReadFull(reader, lenbuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != 4 {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
length := Pack.Uint32(lenbuf)
|
||||
b := make([]byte, length)
|
||||
Pack.PutUint32(b, length)
|
||||
n, err = io.ReadFull(reader, b[4:])
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
if n != int(length-4) {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return UnmarshalFromBuffer(bytes.NewBuffer(b), val)
|
||||
}
|
||||
|
||||
// UnmarshalFromBuffer unmarshals from buf into val.
|
||||
func UnmarshalFromBuffer(buf *bytes.Buffer, val interface{}) (err error) {
|
||||
defer handleError(&err)
|
||||
if val == nil {
|
||||
Skip(buf, Object)
|
||||
return nil
|
||||
}
|
||||
|
||||
if unmarshaler, ok := val.(Unmarshaler); ok {
|
||||
unmarshaler.UnmarshalBson(buf, EOO)
|
||||
return nil
|
||||
}
|
||||
sb, err := topLevelBuilder(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodeDocument(buf, sb, EOO)
|
||||
sb.save()
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeDocument(buf *bytes.Buffer, builder *valueBuilder, kind byte) {
|
||||
if kind != EOO && kind != Object && kind != Array {
|
||||
panic(NewBsonError("unexpected kind: %v", kind))
|
||||
}
|
||||
Next(buf, 4)
|
||||
for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) {
|
||||
b2 := builder.initField(ReadCString(buf), kind)
|
||||
if b2 == nil {
|
||||
Skip(buf, kind)
|
||||
continue
|
||||
}
|
||||
if unmarshaler := b2.canUnMarshal(); unmarshaler != nil {
|
||||
unmarshaler.UnmarshalBson(buf, kind)
|
||||
continue
|
||||
}
|
||||
switch b2.val.Kind() {
|
||||
case reflect.String:
|
||||
b2.setString(DecodeString(buf, kind))
|
||||
case reflect.Int64:
|
||||
b2.setInt(DecodeInt64(buf, kind))
|
||||
case reflect.Int32:
|
||||
b2.setInt(int64(DecodeInt32(buf, kind)))
|
||||
case reflect.Int:
|
||||
b2.setInt(int64(DecodeInt(buf, kind)))
|
||||
case reflect.Uint64:
|
||||
b2.setUint(DecodeUint64(buf, kind))
|
||||
case reflect.Uint32:
|
||||
b2.setUint(uint64(DecodeUint32(buf, kind)))
|
||||
case reflect.Uint:
|
||||
b2.setUint(uint64(DecodeUint(buf, kind)))
|
||||
case reflect.Float64:
|
||||
b2.setFloat(DecodeFloat64(buf, kind))
|
||||
case reflect.Bool:
|
||||
b2.setBool(DecodeBool(buf, kind))
|
||||
case reflect.Struct:
|
||||
if b2.val.Type() == timeType {
|
||||
b2.setTime(DecodeTime(buf, kind))
|
||||
} else {
|
||||
decodeDocument(buf, b2, kind)
|
||||
}
|
||||
case reflect.Map, reflect.Array:
|
||||
decodeDocument(buf, b2, kind)
|
||||
case reflect.Slice:
|
||||
if b2.val.Type() == bytesType {
|
||||
b2.setBytes(DecodeBinary(buf, kind))
|
||||
} else {
|
||||
decodeDocument(buf, b2, kind)
|
||||
}
|
||||
case reflect.Interface:
|
||||
b2.setInterface(DecodeInterface(buf, kind))
|
||||
default:
|
||||
panic(NewBsonError("cannot unmarshal into %v", b2.val.Kind()))
|
||||
}
|
||||
b2.save()
|
||||
}
|
||||
}
|
||||
|
||||
// Maps & interface values will not give you a reference to their underlying object.
|
||||
// You can only update them through their Set methods.
|
||||
type valueBuilder struct {
|
||||
val reflect.Value
|
||||
|
||||
// if map_.IsValid(), write val to map_ using key.
|
||||
map_ reflect.Value
|
||||
key reflect.Value
|
||||
|
||||
// index tracks current index if val is an array.
|
||||
index int
|
||||
}
|
||||
|
||||
// topLevelBuilder returns a valid unmarshalable valueBuilder or an error
|
||||
func topLevelBuilder(val interface{}) (sb *valueBuilder, err error) {
|
||||
ival := reflect.ValueOf(val)
|
||||
if ival.Kind() != reflect.Ptr {
|
||||
return nil, fmt.Errorf("expecting pointer value, received %v", ival.Type())
|
||||
}
|
||||
return newValueBuilder(ival.Elem()), nil
|
||||
}
|
||||
|
||||
// newValuebuilder returns a valueBuilder for val. It perorms all
|
||||
// necessary memory allocations.
|
||||
func newValueBuilder(val reflect.Value) *valueBuilder {
|
||||
for val.Kind() == reflect.Ptr {
|
||||
if val.IsNil() {
|
||||
val.Set(reflect.New(val.Type().Elem()))
|
||||
}
|
||||
val = val.Elem()
|
||||
}
|
||||
switch val.Kind() {
|
||||
case reflect.Map:
|
||||
if val.IsNil() {
|
||||
val.Set(reflect.MakeMap(val.Type()))
|
||||
}
|
||||
case reflect.Slice:
|
||||
if val.IsNil() {
|
||||
val.Set(reflect.MakeSlice(val.Type(), 0, 8))
|
||||
}
|
||||
}
|
||||
return &valueBuilder{val: val}
|
||||
}
|
||||
|
||||
// mapValueBuilder returns a valueBuilder that represents a map value.
|
||||
// You need to call save after building the value to make sure it gets
|
||||
// saved to the map.
|
||||
func mapValueBuilder(typ reflect.Type, map_ reflect.Value, key reflect.Value) *valueBuilder {
|
||||
if typ.Kind() == reflect.Ptr {
|
||||
addr := reflect.New(typ.Elem())
|
||||
map_.SetMapIndex(key, addr)
|
||||
return newValueBuilder(addr.Elem())
|
||||
}
|
||||
builder := newValueBuilder(reflect.New(typ).Elem())
|
||||
builder.map_ = map_
|
||||
builder.key = key
|
||||
return builder
|
||||
}
|
||||
|
||||
// save saves the built value into the map.
|
||||
func (builder *valueBuilder) save() {
|
||||
if builder.map_.IsValid() {
|
||||
builder.map_.SetMapIndex(builder.key, builder.val)
|
||||
}
|
||||
}
|
||||
|
||||
// initField returns a valueBuilder based on the requested key.
|
||||
// If the key is a the magic tag _Val_, it returns itself.
|
||||
// If builder is a struct, it looks for a field of that name.
|
||||
// If builder is a map, it creates an entry for that key.
|
||||
// If buider is an array, it ignores the key and returns the next
|
||||
// element of the array.
|
||||
// If builder is a slice, it returns a newly appended element.
|
||||
// If the key cannot be resolved, it returns null.
|
||||
// If kind is Null, it initializes the field to the zero value.
|
||||
// Otherwise, it allocates memory as needed.
|
||||
func (builder *valueBuilder) initField(k string, kind byte) *valueBuilder {
|
||||
if k == MAGICTAG {
|
||||
if kind == Null {
|
||||
setZero(builder.val)
|
||||
return nil
|
||||
}
|
||||
return builder
|
||||
}
|
||||
switch builder.val.Kind() {
|
||||
case reflect.Struct:
|
||||
t := builder.val.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
if t.Field(i).Name == k {
|
||||
if kind == Null {
|
||||
setZero(builder.val.Field(i))
|
||||
return nil
|
||||
}
|
||||
return newValueBuilder(builder.val.Field(i))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case reflect.Map:
|
||||
t := builder.val.Type()
|
||||
if t.Key().Kind() != reflect.String {
|
||||
panic(NewBsonError("map index is not a string: %s", k))
|
||||
}
|
||||
key := reflect.ValueOf(k)
|
||||
if kind == Null {
|
||||
zero := reflect.Zero(t.Elem())
|
||||
builder.val.SetMapIndex(key, zero)
|
||||
return nil
|
||||
}
|
||||
return mapValueBuilder(t.Elem(), builder.val, key)
|
||||
case reflect.Array:
|
||||
if builder.index >= builder.val.Len() {
|
||||
panic(NewBsonError("array index %v out of bounds", builder.index))
|
||||
}
|
||||
ind := builder.index
|
||||
builder.index++
|
||||
if kind == Null {
|
||||
setZero(builder.val.Index(ind))
|
||||
return nil
|
||||
}
|
||||
return newValueBuilder(builder.val.Index(ind))
|
||||
case reflect.Slice:
|
||||
zero := reflect.Zero(builder.val.Type().Elem())
|
||||
builder.val.Set(reflect.Append(builder.val, zero))
|
||||
if kind == Null {
|
||||
return nil
|
||||
}
|
||||
return newValueBuilder(builder.val.Index(builder.val.Len() - 1))
|
||||
}
|
||||
// Failsafe: this code is actually unreachable.
|
||||
panic(NewBsonError("internal error: unindexable type %v", builder.val.Type()))
|
||||
}
|
||||
|
||||
func setZero(v reflect.Value) {
|
||||
v.Set(reflect.Zero(v.Type()))
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setInt(i int64) {
|
||||
builder.val.SetInt(i)
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setUint(u uint64) {
|
||||
builder.val.SetUint(u)
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setFloat(f float64) {
|
||||
builder.val.SetFloat(f)
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setString(s string) {
|
||||
builder.val.SetString(s)
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setBool(tf bool) {
|
||||
builder.val.SetBool(tf)
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setTime(t time.Time) {
|
||||
builder.val.Set(reflect.ValueOf(t))
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setBytes(b []byte) {
|
||||
builder.val.Set(reflect.ValueOf(b))
|
||||
}
|
||||
|
||||
func (builder *valueBuilder) setInterface(i interface{}) {
|
||||
builder.val.Set(reflect.ValueOf(i))
|
||||
}
|
|
@ -1,485 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newstring(v string) *string { return &v }
|
||||
func newint64(v int64) *int64 { return &v }
|
||||
func newint32(v int32) *int32 { return &v }
|
||||
func newint(v int) *int { return &v }
|
||||
func newuint64(v uint64) *uint64 { return &v }
|
||||
func newuint32(v uint32) *uint32 { return &v }
|
||||
func newuint(v uint) *uint { return &v }
|
||||
func newfloat64(v float64) *float64 { return &v }
|
||||
func newbool(v bool) *bool { return &v }
|
||||
func newtime(v time.Time) *time.Time { return &v }
|
||||
func newinterface(v interface{}) *interface{} { return &v }
|
||||
|
||||
var unmarshaltest = []struct {
|
||||
desc string
|
||||
in string
|
||||
out interface{}
|
||||
want interface{}
|
||||
}{{
|
||||
|
||||
// top level decodes
|
||||
"top level nil decode",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
nil,
|
||||
nil,
|
||||
}, {
|
||||
"top level struct decode",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&struct{ Val string }{},
|
||||
&struct{ Val string }{"test"},
|
||||
}, {
|
||||
"top level map decode",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&map[string]string{},
|
||||
&map[string]string{"Val": "test"},
|
||||
}, {
|
||||
"top level slice decode",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&[]string{},
|
||||
&[]string{"test"},
|
||||
}, {
|
||||
"top level array decode",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&[2]string{},
|
||||
&[2]string{"test", ""},
|
||||
}, {
|
||||
"top level string decode",
|
||||
"\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00",
|
||||
newstring(""),
|
||||
newstring("test"),
|
||||
}, {
|
||||
"top level string decode from Null",
|
||||
"\x0c\x00\x00\x00\n_Val_\x00\x00",
|
||||
newstring("test"),
|
||||
newstring(""),
|
||||
}, {
|
||||
"top level bytes decode",
|
||||
"\x15\x00\x00\x00\x05_Val_\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&[]byte{},
|
||||
&[]byte{'t', 'e', 's', 't'},
|
||||
}, {
|
||||
"top level int64 decode",
|
||||
"\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
newint64(0),
|
||||
newint64(1),
|
||||
}, {
|
||||
"top level int32 decode",
|
||||
"\x10\x00\x00\x00\x10_Val_\x00\x01\x00\x00\x00\x00",
|
||||
newint32(0),
|
||||
newint32(1),
|
||||
}, {
|
||||
"top level int decode",
|
||||
"\x10\x00\x00\x00\x10_Val_\x00\x01\x00\x00\x00\x00",
|
||||
newint(0),
|
||||
newint(1),
|
||||
}, {
|
||||
"top level uint64 decode",
|
||||
"\x14\x00\x00\x00?_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
newuint64(0),
|
||||
newuint64(1),
|
||||
}, {
|
||||
"top level uint32 decode",
|
||||
"\x14\x00\x00\x00?_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
newuint32(0),
|
||||
newuint32(1),
|
||||
}, {
|
||||
"top level uint decode",
|
||||
"\x14\x00\x00\x00\x12_Val_\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
newuint(0),
|
||||
newuint(1),
|
||||
}, {
|
||||
"top level float64 decode",
|
||||
"\x14\x00\x00\x00\x01_Val_\x00\x00\x00\x00\x00\x00\x00\xf0?\x00",
|
||||
newfloat64(0),
|
||||
newfloat64(1.0),
|
||||
}, {
|
||||
"top level bool decode",
|
||||
"\r\x00\x00\x00\b_Val_\x00\x01\x00",
|
||||
newbool(false),
|
||||
newbool(true),
|
||||
}, {
|
||||
"top level time decode",
|
||||
"\x14\x00\x00\x00\t_Val_\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00",
|
||||
newtime(time.Now()),
|
||||
newtime(time.Unix(1136243045, 0).UTC()),
|
||||
}, {
|
||||
"top level interface decode",
|
||||
"\x14\x00\x00\x00\x01_Val_\x00\x00\x00\x00\x00\x00\x00\xf0?\x00",
|
||||
newinterface(nil),
|
||||
newinterface(float64(1.0)),
|
||||
}, {
|
||||
|
||||
// embedded decodes
|
||||
"struct decode from Object",
|
||||
"\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val struct{ Val2 string } }{},
|
||||
&struct{ Val struct{ Val2 string } }{struct{ Val2 string }{"test"}},
|
||||
}, {
|
||||
"struct decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val struct{ Val2 string } }{struct{ Val2 string }{"test"}},
|
||||
&struct{ Val struct{ Val2 string } }{},
|
||||
}, {
|
||||
"map decode from Object",
|
||||
"\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val map[string]string }{},
|
||||
&struct{ Val map[string]string }{map[string]string{"Val2": "test"}},
|
||||
}, {
|
||||
"map decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val map[string]string }{map[string]string{"Val2": "test"}},
|
||||
&struct{ Val map[string]string }{},
|
||||
}, {
|
||||
"map decode from Null element",
|
||||
"\x15\x00\x00\x00\x03Val\x00\v\x00\x00\x00\nVal2\x00\x00\x00",
|
||||
&struct{ Val map[string]string }{},
|
||||
&struct{ Val map[string]string }{map[string]string{"Val2": ""}},
|
||||
}, {
|
||||
"slice decode from Array",
|
||||
"\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val []string }{},
|
||||
&struct{ Val []string }{[]string{"test"}},
|
||||
}, {
|
||||
"slice decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val []string }{[]string{"test"}},
|
||||
&struct{ Val []string }{},
|
||||
}, {
|
||||
"slice decode from Null element",
|
||||
"\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val []*int64 }{},
|
||||
&struct{ Val []*int64 }{[]*int64{nil, newint64(1)}},
|
||||
}, {
|
||||
"array decode from Array",
|
||||
"\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val [2]string }{},
|
||||
&struct{ Val [2]string }{[2]string{"test", ""}},
|
||||
}, {
|
||||
"array decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val [2]string }{[2]string{"test", ""}},
|
||||
&struct{ Val [2]string }{},
|
||||
}, {
|
||||
"array decode from Null element",
|
||||
"\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val [2]*int64 }{},
|
||||
&struct{ Val [2]*int64 }{[2]*int64{nil, newint64(1)}},
|
||||
}, {
|
||||
"string decode from String",
|
||||
"\x13\x00\x00\x00\x02Val\x00\x05\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val string }{},
|
||||
&struct{ Val string }{"test"},
|
||||
}, {
|
||||
"string decode from Binary",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&struct{ Val string }{},
|
||||
&struct{ Val string }{"test"},
|
||||
}, {
|
||||
"string decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val string }{"test"},
|
||||
&struct{ Val string }{},
|
||||
}, {
|
||||
"bytes decode from String",
|
||||
"\x13\x00\x00\x00\x02Val\x00\x05\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val []byte }{},
|
||||
&struct{ Val []byte }{[]byte("test")},
|
||||
}, {
|
||||
"bytes decode from Binary",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&struct{ Val []byte }{},
|
||||
&struct{ Val []byte }{[]byte("test")},
|
||||
}, {
|
||||
"bytes decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val []byte }{[]byte("test")},
|
||||
&struct{ Val []byte }{},
|
||||
}, {
|
||||
"int64 decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val int64 }{},
|
||||
&struct{ Val int64 }{1},
|
||||
}, {
|
||||
"int64 decode from Long",
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val int64 }{},
|
||||
&struct{ Val int64 }{1},
|
||||
}, {
|
||||
"int64 decode from Ulong",
|
||||
"\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val int64 }{},
|
||||
&struct{ Val int64 }{1},
|
||||
}, {
|
||||
"int64 decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val int64 }{1},
|
||||
&struct{ Val int64 }{},
|
||||
}, {
|
||||
"int32 decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val int32 }{},
|
||||
&struct{ Val int32 }{1},
|
||||
}, {
|
||||
"int32 decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val int32 }{1},
|
||||
&struct{ Val int32 }{},
|
||||
}, {
|
||||
"int decode from Long",
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val int }{},
|
||||
&struct{ Val int }{1},
|
||||
}, {
|
||||
"int decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val int }{},
|
||||
&struct{ Val int }{1},
|
||||
}, {
|
||||
"int decode from Ulong",
|
||||
"\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val int }{},
|
||||
&struct{ Val int }{1},
|
||||
}, {
|
||||
"int decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val int }{1},
|
||||
&struct{ Val int }{},
|
||||
}, {
|
||||
"uint64 decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val uint64 }{},
|
||||
&struct{ Val uint64 }{1},
|
||||
}, {
|
||||
"uint64 decode from Long",
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val uint64 }{},
|
||||
&struct{ Val uint64 }{1},
|
||||
}, {
|
||||
"uint64 decode from Ulong",
|
||||
"\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val uint64 }{},
|
||||
&struct{ Val uint64 }{1},
|
||||
}, {
|
||||
"uint64 decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val uint64 }{1},
|
||||
&struct{ Val uint64 }{},
|
||||
}, {
|
||||
"uint32 decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val uint32 }{},
|
||||
&struct{ Val uint32 }{1},
|
||||
}, {
|
||||
"uint32 decode from Ulong",
|
||||
"\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val uint32 }{},
|
||||
&struct{ Val uint32 }{1},
|
||||
}, {
|
||||
"uint32 decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val uint32 }{1},
|
||||
&struct{ Val uint32 }{},
|
||||
}, {
|
||||
"uint decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val uint }{},
|
||||
&struct{ Val uint }{1},
|
||||
}, {
|
||||
"uint decode from Long",
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val uint }{},
|
||||
&struct{ Val uint }{1},
|
||||
}, {
|
||||
"uint decode from Ulong",
|
||||
"\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val uint }{},
|
||||
&struct{ Val uint }{1},
|
||||
}, {
|
||||
"uint decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val uint }{1},
|
||||
&struct{ Val uint }{},
|
||||
}, {
|
||||
"float64 decode from Number",
|
||||
"\x12\x00\x00\x00\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?\x00",
|
||||
&struct{ Val float64 }{},
|
||||
&struct{ Val float64 }{1.0},
|
||||
}, {
|
||||
"float64 decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val float64 }{1.0},
|
||||
&struct{ Val float64 }{},
|
||||
}, {
|
||||
"bool decode from Boolean",
|
||||
"\v\x00\x00\x00\bVal\x00\x01\x00",
|
||||
&struct{ Val bool }{},
|
||||
&struct{ Val bool }{true},
|
||||
}, {
|
||||
"bool decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val bool }{},
|
||||
&struct{ Val bool }{true},
|
||||
}, {
|
||||
"bool decode from Long",
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val bool }{},
|
||||
&struct{ Val bool }{true},
|
||||
}, {
|
||||
"bool decode from Ulong",
|
||||
"\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val bool }{},
|
||||
&struct{ Val bool }{true},
|
||||
}, {
|
||||
"bool decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val bool }{true},
|
||||
&struct{ Val bool }{},
|
||||
}, {
|
||||
"time decode from Datetime",
|
||||
"\x12\x00\x00\x00\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00",
|
||||
&struct{ Val time.Time }{},
|
||||
&struct{ Val time.Time }{time.Unix(1136243045, 0).UTC()},
|
||||
}, {
|
||||
"time decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val time.Time }{time.Unix(1136243045, 0).UTC()},
|
||||
&struct{ Val time.Time }{},
|
||||
}, {
|
||||
"interface decode from Number",
|
||||
"\x12\x00\x00\x00\x01Val\x00\x00\x00\x00\x00\x00\x00\xf0?\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{float64(1.0)},
|
||||
}, {
|
||||
"interface decode from String",
|
||||
"\x13\x00\x00\x00\x02Val\x00\x05\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{"test"},
|
||||
}, {
|
||||
"interface decode from Binary",
|
||||
"\x13\x00\x00\x00\x05Val\x00\x04\x00\x00\x00\x00test\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{[]byte("test")},
|
||||
}, {
|
||||
"interface decode from Boolean",
|
||||
"\v\x00\x00\x00\bVal\x00\x01\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{true},
|
||||
}, {
|
||||
"interface decode from Datetime",
|
||||
"\x12\x00\x00\x00\tVal\x00\x88\xf2\\\x8d\b\x01\x00\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{time.Unix(1136243045, 0).UTC()},
|
||||
}, {
|
||||
"interface decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{int32(1)},
|
||||
}, {
|
||||
"interface decode from Long",
|
||||
"\x12\x00\x00\x00\x12Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{int64(1)},
|
||||
}, {
|
||||
"interface decode from Ulong",
|
||||
"\x12\x00\x00\x00?Val\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{uint64(1)},
|
||||
}, {
|
||||
"interface decode from Object",
|
||||
"\x1e\x00\x00\x00\x03Val\x00\x14\x00\x00\x00\x05Val2\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{map[string]interface{}{"Val2": []byte("test")}},
|
||||
}, {
|
||||
"interface decode from Object with Null element",
|
||||
"\x15\x00\x00\x00\x03Val\x00\v\x00\x00\x00\nVal2\x00\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{map[string]interface{}{"Val2": nil}},
|
||||
}, {
|
||||
"interface decode from Array",
|
||||
"\x1b\x00\x00\x00\x04Val\x00\x11\x00\x00\x00\x050\x00\x04\x00\x00\x00\x00test\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{[]interface{}{[]byte("test")}},
|
||||
}, {
|
||||
"interface decode from Array null element",
|
||||
"\x1d\x00\x00\x00\x04Val\x00\x13\x00\x00\x00\n0\x00\x121\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
&struct{ Val interface{} }{},
|
||||
&struct{ Val interface{} }{[]interface{}{nil, int64(1)}},
|
||||
}, {
|
||||
"interface decode from Null",
|
||||
"\n\x00\x00\x00\nVal\x00\x00",
|
||||
&struct{ Val interface{} }{uint64(1)},
|
||||
&struct{ Val interface{} }{},
|
||||
}, {
|
||||
"pointer decode from Int",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val *int64 }{},
|
||||
&struct{ Val *int64 }{newint64(1)},
|
||||
}}
|
||||
|
||||
func TestUnmarshal(t *testing.T) {
|
||||
for _, tcase := range unmarshaltest {
|
||||
verifyUnmarshal(t, []byte(tcase.in), tcase.out)
|
||||
if !reflect.DeepEqual(tcase.out, tcase.want) {
|
||||
out := reflect.ValueOf(tcase.out).Elem().Interface()
|
||||
want := reflect.ValueOf(tcase.want).Elem().Interface()
|
||||
t.Errorf("%s: decoded: \n%#v, want\n%#v", tcase.desc, out, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var unmarshalErrorCases = []struct {
|
||||
desc string
|
||||
in string
|
||||
out interface{}
|
||||
want string
|
||||
}{{
|
||||
"non pointer input",
|
||||
"",
|
||||
10,
|
||||
"expecting pointer value, received int",
|
||||
}, {
|
||||
"invalid bson kind",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val struct{ Val2 int } }{},
|
||||
"unexpected kind: 16",
|
||||
}, {
|
||||
"map with int key",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&map[int]int{},
|
||||
"map index is not a string: Val",
|
||||
}, {
|
||||
"small array",
|
||||
"\x1f\x00\x00\x00\x050\x00\x05\x00\x00\x00\x00test1\x051\x00\x05\x00\x00\x00\x00test2\x00",
|
||||
&[1]string{},
|
||||
"array index 1 out of bounds",
|
||||
}, {
|
||||
"chan in struct",
|
||||
"\x0e\x00\x00\x00\x10Val\x00\x01\x00\x00\x00\x00",
|
||||
&struct{ Val chan int }{},
|
||||
"cannot unmarshal into chan",
|
||||
}}
|
||||
|
||||
func TestUnmarshalErrors(t *testing.T) {
|
||||
for _, tcase := range unmarshalErrorCases {
|
||||
err := Unmarshal([]byte(tcase.in), tcase.out)
|
||||
got := ""
|
||||
if err != nil {
|
||||
got = err.Error()
|
||||
}
|
||||
if got != tcase.want {
|
||||
t.Errorf("%s: received: %q, want %q", tcase.desc, got, tcase.want)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,355 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Utility functions for custom decoders
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/hack"
|
||||
)
|
||||
|
||||
// VerifyObject verifies kind to make sure it's
|
||||
// either a top level document (EOO) or an Object.
|
||||
// TODO(sougou): deprecate this function.
|
||||
func VerifyObject(kind byte) {
|
||||
if kind != EOO && kind != Object {
|
||||
panic(NewBsonError("unexpected kind: %v", kind))
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeString decodes a string from buf.
|
||||
// Allowed types: String, Binary, Null,
|
||||
func DecodeString(buf *bytes.Buffer, kind byte) string {
|
||||
switch kind {
|
||||
case String:
|
||||
l := int(Pack.Uint32(Next(buf, 4)))
|
||||
s := Next(buf, l-1)
|
||||
NextByte(buf)
|
||||
return string(s)
|
||||
case Binary:
|
||||
l := int(Pack.Uint32(Next(buf, 4)))
|
||||
NextByte(buf)
|
||||
return string(Next(buf, l))
|
||||
case Null:
|
||||
return ""
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for string", kind))
|
||||
}
|
||||
|
||||
// DecodeBinary decodes a []byte from buf.
|
||||
// Allowed types: String, Binary, Null.
|
||||
func DecodeBinary(buf *bytes.Buffer, kind byte) []byte {
|
||||
switch kind {
|
||||
case String:
|
||||
l := int(Pack.Uint32(Next(buf, 4)))
|
||||
b := Next(buf, l-1)
|
||||
NextByte(buf)
|
||||
return b
|
||||
case Binary:
|
||||
l := int(Pack.Uint32(Next(buf, 4)))
|
||||
NextByte(buf)
|
||||
return Next(buf, l)
|
||||
case Null:
|
||||
return nil
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for []byte", kind))
|
||||
}
|
||||
|
||||
// DecodeInt64 decodes a int64 from buf.
|
||||
// Allowed types: Int, Long, Ulong, Null.
|
||||
func DecodeInt64(buf *bytes.Buffer, kind byte) int64 {
|
||||
switch kind {
|
||||
case Int:
|
||||
return int64(int32(Pack.Uint32(Next(buf, 4))))
|
||||
case Long, Ulong:
|
||||
return int64(Pack.Uint64(Next(buf, 8)))
|
||||
case Null:
|
||||
return 0
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for int64", kind))
|
||||
}
|
||||
|
||||
// DecodeInt32 decodes a int32 from buf.
|
||||
// Allowed types: Int, Long, Null.
|
||||
func DecodeInt32(buf *bytes.Buffer, kind byte) int32 {
|
||||
switch kind {
|
||||
case Int:
|
||||
return int32(Pack.Uint32(Next(buf, 4)))
|
||||
case Long:
|
||||
return int32(Pack.Uint64(Next(buf, 8)))
|
||||
case Null:
|
||||
return 0
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for int32", kind))
|
||||
}
|
||||
|
||||
// DecodeInt decodes a int64 from buf.
|
||||
// Allowed types: Int, Long, Ulong, Null.
|
||||
func DecodeInt(buf *bytes.Buffer, kind byte) int {
|
||||
switch kind {
|
||||
case Int:
|
||||
return int(Pack.Uint32(Next(buf, 4)))
|
||||
case Long, Ulong:
|
||||
return int(Pack.Uint64(Next(buf, 8)))
|
||||
case Null:
|
||||
return 0
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for int", kind))
|
||||
}
|
||||
|
||||
// DecodeUint64 decodes a uint64 from buf.
|
||||
// Allowed types: Int, Long, Ulong, Null.
|
||||
func DecodeUint64(buf *bytes.Buffer, kind byte) uint64 {
|
||||
switch kind {
|
||||
case Int:
|
||||
return uint64(Pack.Uint32(Next(buf, 4)))
|
||||
case Long, Ulong:
|
||||
return Pack.Uint64(Next(buf, 8))
|
||||
case Null:
|
||||
return 0
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for uint64", kind))
|
||||
}
|
||||
|
||||
// DecodeUint32 decodes a uint32 from buf.
|
||||
// Allowed types: Int, Long, Null.
|
||||
func DecodeUint32(buf *bytes.Buffer, kind byte) uint32 {
|
||||
switch kind {
|
||||
case Int:
|
||||
return Pack.Uint32(Next(buf, 4))
|
||||
case Ulong, Long:
|
||||
return uint32(Pack.Uint64(Next(buf, 8)))
|
||||
case Null:
|
||||
return 0
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for uint32", kind))
|
||||
}
|
||||
|
||||
// DecodeUint decodes a uint64 from buf.
|
||||
// Allowed types: Int, Long, Ulong, Null.
|
||||
func DecodeUint(buf *bytes.Buffer, kind byte) uint {
|
||||
switch kind {
|
||||
case Int:
|
||||
return uint(Pack.Uint32(Next(buf, 4)))
|
||||
case Long, Ulong:
|
||||
return uint(Pack.Uint64(Next(buf, 8)))
|
||||
case Null:
|
||||
return 0
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for uint", kind))
|
||||
}
|
||||
|
||||
// DecodeFloat64 decodes a float64 from buf.
|
||||
// Allowed types: Number, Null.
|
||||
func DecodeFloat64(buf *bytes.Buffer, kind byte) float64 {
|
||||
switch kind {
|
||||
case Number:
|
||||
return float64(math.Float64frombits(Pack.Uint64(Next(buf, 8))))
|
||||
case Null:
|
||||
return 0
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for float64", kind))
|
||||
}
|
||||
|
||||
// DecodeBool decodes a bool from buf.
|
||||
// Allowed types: Boolean, Int, Long, Ulong, Null.
|
||||
func DecodeBool(buf *bytes.Buffer, kind byte) bool {
|
||||
switch kind {
|
||||
case Boolean:
|
||||
b, _ := buf.ReadByte()
|
||||
return (b != 0)
|
||||
case Int:
|
||||
return (Pack.Uint32(Next(buf, 4)) != 0)
|
||||
case Long, Ulong:
|
||||
return (Pack.Uint64(Next(buf, 8)) != 0)
|
||||
case Null:
|
||||
return false
|
||||
default:
|
||||
panic(NewBsonError("unexpected kind %v for bool", kind))
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeTime decodes a time.Time from buf.
|
||||
// Allowed types: Datetime, Null.
|
||||
func DecodeTime(buf *bytes.Buffer, kind byte) time.Time {
|
||||
switch kind {
|
||||
case Datetime:
|
||||
ui64 := Pack.Uint64(Next(buf, 8))
|
||||
return time.Unix(0, int64(ui64)*1e6).UTC()
|
||||
case Null:
|
||||
return time.Time{}
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for time.Time", kind))
|
||||
}
|
||||
|
||||
// DecodeInterface decodes the next object into an interface.
|
||||
// Object is decoded as map[string]interface{}.
|
||||
// Array is decoded as []interface{}
|
||||
func DecodeInterface(buf *bytes.Buffer, kind byte) interface{} {
|
||||
switch kind {
|
||||
case Number:
|
||||
return DecodeFloat64(buf, kind)
|
||||
case String:
|
||||
return DecodeString(buf, kind)
|
||||
case Object:
|
||||
return DecodeMap(buf, kind)
|
||||
case Array:
|
||||
return DecodeArray(buf, kind)
|
||||
case Binary:
|
||||
return DecodeBinary(buf, kind)
|
||||
case Boolean:
|
||||
return DecodeBool(buf, kind)
|
||||
case Datetime:
|
||||
return DecodeTime(buf, kind)
|
||||
case Null:
|
||||
return nil
|
||||
case Int:
|
||||
return DecodeInt32(buf, kind)
|
||||
case Long:
|
||||
return DecodeInt64(buf, kind)
|
||||
case Ulong:
|
||||
return DecodeUint64(buf, kind)
|
||||
}
|
||||
panic(NewBsonError("unexpected kind %v for interface{}", kind))
|
||||
}
|
||||
|
||||
// DecodeMap decodes a map[string]interface{} from buf.
|
||||
// Allowed types: Object, Null.
|
||||
func DecodeMap(buf *bytes.Buffer, kind byte) map[string]interface{} {
|
||||
switch kind {
|
||||
case Object:
|
||||
// valid
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
panic(NewBsonError("unexpected kind %v for map", kind))
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
Next(buf, 4)
|
||||
for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) {
|
||||
key := ReadCString(buf)
|
||||
if kind == Null {
|
||||
result[key] = nil
|
||||
continue
|
||||
}
|
||||
result[key] = DecodeInterface(buf, kind)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// DecodeArray decodes a []interface{} from buf.
|
||||
// Allowed types: Array, Null.
|
||||
func DecodeArray(buf *bytes.Buffer, kind byte) []interface{} {
|
||||
switch kind {
|
||||
case Array:
|
||||
// valid
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
panic(NewBsonError("unexpected kind %v for slice", kind))
|
||||
}
|
||||
|
||||
result := make([]interface{}, 0, 8)
|
||||
Next(buf, 4)
|
||||
for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) {
|
||||
ReadCString(buf)
|
||||
if kind == Null {
|
||||
result = append(result, nil)
|
||||
continue
|
||||
}
|
||||
result = append(result, DecodeInterface(buf, kind))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// DecodeStringArray decodes a []string from buf.
|
||||
// Allowed types: Array, Null.
|
||||
func DecodeStringArray(buf *bytes.Buffer, kind byte) []string {
|
||||
switch kind {
|
||||
case Array:
|
||||
// valid
|
||||
case Null:
|
||||
return nil
|
||||
default:
|
||||
panic(NewBsonError("unexpected kind %v for []string", kind))
|
||||
}
|
||||
|
||||
result := make([]string, 0, 8)
|
||||
Next(buf, 4)
|
||||
for kind := NextByte(buf); kind != EOO; kind = NextByte(buf) {
|
||||
if kind != Binary {
|
||||
panic(NewBsonError("unexpected kind %v for string", kind))
|
||||
}
|
||||
SkipIndex(buf)
|
||||
result = append(result, DecodeString(buf, kind))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Skip will skip the next field we don't want to read.
|
||||
func Skip(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case Number, Datetime, Long, Ulong:
|
||||
Next(buf, 8)
|
||||
case String:
|
||||
// length of a string includes the 0 at the end, but not the size
|
||||
l := int(Pack.Uint32(Next(buf, 4)))
|
||||
Next(buf, l)
|
||||
case Object, Array:
|
||||
// the encoded length includes the 4 bytes for the size
|
||||
l := int(Pack.Uint32(Next(buf, 4)))
|
||||
if l < 4 {
|
||||
panic(NewBsonError("Object or Array should at least be 4 bytes long"))
|
||||
}
|
||||
Next(buf, l-4)
|
||||
case Binary:
|
||||
// length of a binary doesn't include the subtype
|
||||
l := int(Pack.Uint32(Next(buf, 4)))
|
||||
Next(buf, l+1)
|
||||
case Boolean:
|
||||
buf.ReadByte()
|
||||
case Int:
|
||||
Next(buf, 4)
|
||||
case Null:
|
||||
// no op
|
||||
default:
|
||||
panic(NewBsonError("unexpected kind %v for skip", kind))
|
||||
}
|
||||
}
|
||||
|
||||
// SkipIndex must be used to skip indexes in arrays.
|
||||
func SkipIndex(buf *bytes.Buffer) {
|
||||
ReadCString(buf)
|
||||
}
|
||||
|
||||
// ReadCString reads the the bson document tag.
|
||||
func ReadCString(buf *bytes.Buffer) string {
|
||||
index := bytes.IndexByte(buf.Bytes(), 0)
|
||||
if index < 0 {
|
||||
panic(NewBsonError("unexpected EOF"))
|
||||
}
|
||||
// Read including null termination, but
|
||||
// return the string without the null.
|
||||
return hack.String(Next(buf, index+1)[:index])
|
||||
}
|
||||
|
||||
// Next returns the next n bytes from buf.
|
||||
func Next(buf *bytes.Buffer, n int) []byte {
|
||||
b := buf.Next(n)
|
||||
if len(b) != n {
|
||||
panic(NewBsonError("unexpected EOF"))
|
||||
}
|
||||
return b[:n:n]
|
||||
}
|
||||
|
||||
// NextByte returns the next byte from buf.
|
||||
func NextByte(buf *bytes.Buffer) byte {
|
||||
return Next(buf, 1)[0]
|
||||
}
|
|
@ -1,616 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
bsongen will generate bson encoders and decoders for
|
||||
a given go type. It uses goimports to fix the import
|
||||
statetments post-generation. It assumes goimports is
|
||||
in the path. If you specify a GOIMPORTS environment
|
||||
variable, it will use that instead.
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/template"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
filename = flag.String("file", "", "input file name")
|
||||
typename = flag.String("type", "", "type to generate code for")
|
||||
outfile = flag.String("o", "", "output file name, default stdout")
|
||||
counter = 0
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
if *filename == "" || *typename == "" {
|
||||
flag.PrintDefaults()
|
||||
return
|
||||
}
|
||||
b, err := ioutil.ReadFile(*filename)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
out, err := generateCode(string(b), *typename)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
return
|
||||
}
|
||||
fout := os.Stdout
|
||||
if *outfile != "" {
|
||||
fout, err = os.Create(*outfile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
return
|
||||
}
|
||||
defer fout.Close()
|
||||
}
|
||||
fmt.Fprintf(fout, "%s", out)
|
||||
}
|
||||
|
||||
// encoderMap and decoderMap specify encoder and decoder
|
||||
// functions that can be directly called to execute the
|
||||
// encode/decode operation.
|
||||
var (
|
||||
encoderMap = map[string]string{
|
||||
"string": "EncodeString",
|
||||
"[]byte": "EncodeBinary",
|
||||
"int64": "EncodeInt64",
|
||||
"int32": "EncodeInt32",
|
||||
"int": "EncodeInt",
|
||||
"uint64": "EncodeUint64",
|
||||
"uint32": "EncodeUint32",
|
||||
"uint": "EncodeUint",
|
||||
"float64": "EncodeFloat64",
|
||||
"bool": "EncodeBool",
|
||||
"interface{}": "EncodeInterface",
|
||||
"time.Time": "EncodeTime",
|
||||
}
|
||||
decoderMap = map[string]string{
|
||||
"string": "DecodeString",
|
||||
"[]byte": "DecodeBinary",
|
||||
"int64": "DecodeInt64",
|
||||
"int32": "DecodeInt32",
|
||||
"int": "DecodeInt",
|
||||
"uint64": "DecodeUint64",
|
||||
"uint32": "DecodeUint32",
|
||||
"uint": "DecodeUint",
|
||||
"float64": "DecodeFloat64",
|
||||
"bool": "DecodeBool",
|
||||
"interface{}": "DecodeInterface",
|
||||
"time.Time": "DecodeTime",
|
||||
}
|
||||
)
|
||||
|
||||
// TypeInfo is the top level struct generated by the buildType
|
||||
// function.
|
||||
type TypeInfo struct {
|
||||
Package string
|
||||
Imports []string
|
||||
Name string
|
||||
Var string
|
||||
Fields []*FieldInfo
|
||||
Type string
|
||||
}
|
||||
|
||||
// Encoder returns the encoder function for the top level type.
|
||||
// This is used only if the top level type is a simple type.
|
||||
func (t *TypeInfo) Encoder() string {
|
||||
return encoderMap[t.Type]
|
||||
}
|
||||
|
||||
// Decoder returns the decoder function for the top level type.
|
||||
// This is used only if the top level type is a simple type.
|
||||
func (t *TypeInfo) Decoder() string {
|
||||
return decoderMap[t.Type]
|
||||
}
|
||||
|
||||
// FieldInfo contains the necessary info to generate the encoder
|
||||
// or decoder code for an individual field. If the field is complex,
|
||||
// then it recursively describes the subcomponents using Subfield.
|
||||
// For example, in the case of a *int, the field will
|
||||
// be a '*' and the Subfield will be 'int'.
|
||||
type FieldInfo struct {
|
||||
Tag string
|
||||
Name string
|
||||
typ string
|
||||
KeyType string
|
||||
Subfield *FieldInfo
|
||||
}
|
||||
|
||||
// IsPointer returns true if the field represents a pointer.
|
||||
func (f *FieldInfo) IsPointer() bool {
|
||||
return f.typ == "*"
|
||||
}
|
||||
|
||||
// IsSlice returns true if the field represents a slice.
|
||||
func (f *FieldInfo) IsSlice() bool {
|
||||
return f.typ == "[]"
|
||||
}
|
||||
|
||||
// IsMap returns true if the field represents a map.
|
||||
func (f *FieldInfo) IsMap() bool {
|
||||
return f.KeyType != ""
|
||||
}
|
||||
|
||||
// IsSimpleMap returns true if the field represents a map
|
||||
// whose key is a string.
|
||||
func (f *FieldInfo) IsSimpleMap() bool {
|
||||
return f.KeyType == "string"
|
||||
}
|
||||
|
||||
// IsCustom returns true if the field is not a simple type
|
||||
// for which encode/decode functions exist.
|
||||
func (f *FieldInfo) IsCustom() bool {
|
||||
if f.IsPointer() || f.IsSlice() || f.IsMap() {
|
||||
return false
|
||||
}
|
||||
return encoderMap[f.typ] == ""
|
||||
}
|
||||
|
||||
// Encoder returns the encoder function for a simple type.
|
||||
func (f *FieldInfo) Encoder() string {
|
||||
return encoderMap[f.typ]
|
||||
}
|
||||
|
||||
// Decoder returns the decoder function for a simple type.
|
||||
func (f *FieldInfo) Decoder() string {
|
||||
return decoderMap[f.typ]
|
||||
}
|
||||
|
||||
// NewType emits the string for instantiating a new variable
|
||||
// of the field type.
|
||||
func (f *FieldInfo) NewType() string {
|
||||
if f.typ != "*" {
|
||||
return ""
|
||||
}
|
||||
typ := ""
|
||||
for field := f.Subfield; field != nil; field = field.Subfield {
|
||||
typ += field.typ
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
||||
// Type emits the string representation of the field type.
|
||||
func (f *FieldInfo) Type() string {
|
||||
typ := f.typ
|
||||
for field := f.Subfield; field != nil; field = field.Subfield {
|
||||
typ += field.typ
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
||||
// buildType looks for the specified type in the ast, and builds
|
||||
// the corresponding TypeInfo for it. It returns an error if
|
||||
// tye type is not found or bson code cannot be generated
|
||||
// for such a type.
|
||||
func buildType(file *ast.File, name string) (*TypeInfo, error) {
|
||||
typeInfo := &TypeInfo{
|
||||
Package: file.Name.Name,
|
||||
}
|
||||
for _, decl := range file.Decls {
|
||||
genDecl, ok := decl.(*ast.GenDecl)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if genDecl.Tok == token.IMPORT {
|
||||
typeInfo.Imports = append(typeInfo.Imports, buildImports(genDecl.Specs)...)
|
||||
continue
|
||||
}
|
||||
if genDecl.Tok != token.TYPE {
|
||||
continue
|
||||
}
|
||||
if len(genDecl.Specs) != 1 {
|
||||
continue
|
||||
}
|
||||
typeSpec, ok := genDecl.Specs[0].(*ast.TypeSpec)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if typeSpec.Name.Name != name {
|
||||
continue
|
||||
}
|
||||
typeInfo.Name = name
|
||||
typeInfo.Var = strings.ToLower(name[:1]) + name[1:]
|
||||
switch spec := typeSpec.Type.(type) {
|
||||
case *ast.StructType:
|
||||
fields, err := buildFields(spec, typeInfo.Var)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
typeInfo.Fields = fields
|
||||
return typeInfo, nil
|
||||
case *ast.Ident:
|
||||
if encoderMap[spec.Name] == "" {
|
||||
return nil, fmt.Errorf("%s is not a struct or a simple type", name)
|
||||
}
|
||||
typeInfo.Type = spec.Name
|
||||
return typeInfo, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%s is not a struct or a simple type", name)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%s not found", name)
|
||||
}
|
||||
|
||||
// buildImports returns the list of imports.
|
||||
func buildImports(importSpecs []ast.Spec) (imports []string) {
|
||||
for _, spec := range importSpecs {
|
||||
importSpec, ok := spec.(*ast.ImportSpec)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var str string
|
||||
if importSpec.Name == nil {
|
||||
str = importSpec.Path.Value
|
||||
} else {
|
||||
str = importSpec.Name.Name + " " + importSpec.Path.Value
|
||||
}
|
||||
imports = append(imports, str)
|
||||
}
|
||||
return imports
|
||||
}
|
||||
|
||||
var (
|
||||
tagRE = regexp.MustCompile(`bson:("[a-zA-Z0-9_]*")`)
|
||||
)
|
||||
|
||||
// buildFields builds the fields of a struct into a list.
|
||||
func buildFields(structType *ast.StructType, varName string) (fields []*FieldInfo, err error) {
|
||||
for _, field := range structType.Fields.List {
|
||||
if field.Names == nil {
|
||||
return nil, fmt.Errorf("anonymous embeds not supported: %+v", field.Type)
|
||||
}
|
||||
for _, name := range field.Names {
|
||||
var tag string
|
||||
if field.Tag != nil {
|
||||
values := tagRE.FindStringSubmatch(field.Tag.Value)
|
||||
if len(values) >= 2 {
|
||||
tag = values[1]
|
||||
}
|
||||
}
|
||||
if tag == "" {
|
||||
if unicode.IsLower(rune(name.Name[0])) {
|
||||
continue
|
||||
}
|
||||
// Use var name as tag.
|
||||
tag = "\"" + name.Name + "\""
|
||||
}
|
||||
fullName := varName + "." + name.Name
|
||||
fieldInfo, err := buildField(field.Type, tag, fullName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fields = append(fields, fieldInfo)
|
||||
}
|
||||
}
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
// buildField builds an individual field of a struct. It populates the info
|
||||
// such that it goes hand-in-hand with the code generation templates. For example,
|
||||
// the tag for an array type is bson.Itoa(_i), because it knows that the template
|
||||
// uses _i as the variable in the for loop.
|
||||
// In the case of maps it uses _k, based on similar knowledge.
|
||||
func buildField(fieldType ast.Expr, tag, name string) (*FieldInfo, error) {
|
||||
switch ident := fieldType.(type) {
|
||||
case *ast.Ident:
|
||||
return &FieldInfo{Tag: tag, Name: name, typ: ident.Name}, nil
|
||||
case *ast.InterfaceType:
|
||||
if ident.Methods.List != nil {
|
||||
goto notSimple
|
||||
}
|
||||
return &FieldInfo{Tag: tag, Name: name, typ: "interface{}"}, nil
|
||||
case *ast.ArrayType:
|
||||
if ident.Len != nil {
|
||||
goto notSimple
|
||||
}
|
||||
innerIdent, ok := ident.Elt.(*ast.Ident)
|
||||
// Treat []byte as simple type.
|
||||
if ok && innerIdent.Name == "byte" {
|
||||
return &FieldInfo{Tag: tag, Name: name, typ: "[]byte"}, nil
|
||||
}
|
||||
// The tag for array elements is an expression based on the
|
||||
// variable used for iteration. In the generator templates,
|
||||
// this is _i. bson.Itoa(_i) returns a string represntation of
|
||||
// this index.
|
||||
subfield, err := buildField(ident.Elt, "bson.Itoa(_i)", newVarName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FieldInfo{Tag: tag, Name: name, typ: "[]", Subfield: subfield}, nil
|
||||
case *ast.StarExpr:
|
||||
// We have to enclose the name in parenthesis to disambiguate
|
||||
// constructs like this: (*a)[i].
|
||||
subfield, err := buildField(ident.X, tag, "(*"+name+")")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FieldInfo{Tag: tag, Name: name, typ: "*", Subfield: subfield}, nil
|
||||
case *ast.MapType:
|
||||
var keytype string
|
||||
switch kt := ident.Key.(type) {
|
||||
case *ast.Ident:
|
||||
keytype = kt.Name
|
||||
case *ast.SelectorExpr:
|
||||
pkg, ok := kt.X.(*ast.Ident)
|
||||
if !ok {
|
||||
goto notSimple
|
||||
}
|
||||
keytype = pkg.Name + "." + kt.Sel.Name
|
||||
}
|
||||
// For map elements, the tag is they key. The template uses _k
|
||||
// as the iteration variable. If the var is not a string, we
|
||||
// assume it's castable to a string.
|
||||
subtag := "_k"
|
||||
if keytype != "string" {
|
||||
subtag = "string(_k)"
|
||||
}
|
||||
subfield, err := buildField(ident.Value, subtag, newVarName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FieldInfo{Tag: tag, Name: name, typ: fmt.Sprintf("map[%s]", keytype), KeyType: keytype, Subfield: subfield}, nil
|
||||
case *ast.SelectorExpr:
|
||||
pkg, ok := ident.X.(*ast.Ident)
|
||||
if !ok {
|
||||
goto notSimple
|
||||
}
|
||||
return &FieldInfo{Tag: tag, Name: name, typ: pkg.Name + "." + ident.Sel.Name}, nil
|
||||
}
|
||||
notSimple:
|
||||
return nil, fmt.Errorf("%+v is not a simple type", fieldType)
|
||||
}
|
||||
|
||||
// newVarName generates a unique variable naume using a simple counter.
|
||||
func newVarName() string {
|
||||
counter++
|
||||
return fmt.Sprintf("_v%d", counter)
|
||||
}
|
||||
|
||||
// generateCode generates the formatted code.
|
||||
func generateCode(in string, typename string) (out []byte, err error) {
|
||||
raw, err := generateRawCode(in, typename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return formatCode(raw)
|
||||
}
|
||||
|
||||
// generateRawCode performs the initial unformatted generation of the code.
|
||||
func generateRawCode(in string, typename string) (out []byte, err error) {
|
||||
counter = 0
|
||||
fset := token.NewFileSet()
|
||||
f, err := parser.ParseFile(fset, "", in, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Uncomment this line for debugging.
|
||||
//ast.Print(fset, f)
|
||||
typeInfo, err := buildType(f, typename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := bytes.NewBuffer(nil)
|
||||
genTmpl := "StructBody"
|
||||
if typeInfo.Type != "" {
|
||||
genTmpl = "SimpleBody"
|
||||
}
|
||||
err = generator.ExecuteTemplate(buf, genTmpl, typeInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// formatCode uses goimports to format the generated code.
|
||||
func formatCode(in []byte) (out []byte, err error) {
|
||||
goimports := os.Getenv("GOIMPORTS")
|
||||
if goimports == "" {
|
||||
goimports = "goimports"
|
||||
}
|
||||
cmd := exec.Command(goimports)
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cmd.Wait()
|
||||
go func() {
|
||||
bytes.NewBuffer(in).WriteTo(stdin)
|
||||
stdin.Close()
|
||||
}()
|
||||
b, err := ioutil.ReadAll(stdout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// generator contains all the templates for generating code. It contains the following templates:
|
||||
// StructBody is the main generator for structs.
|
||||
// For the generating the encoder, StructBody uses the Encoder template.
|
||||
// Encoder calls one of: StarEncoder, SliceEncoder, MapEncoder, CustomEncoder, or SimpleEncoder,
|
||||
// depending on the type of the field.
|
||||
// An Encoder that's not simple, generates code for it type, which eventually calls back Encoder
|
||||
// on its Subfield. This goes recursively until a SimpleEncoder or CustomEncoder is encountered.
|
||||
// Decoder code generation follows a similar flow.
|
||||
// If the TypeInfo is not a struct, then SimpleBody is used instead of StructBody.
|
||||
var generator = template.Must(template.New("Generator").Parse(`
|
||||
{{define "SimpleEncoder"}}bson.{{.Encoder}}(buf, {{.Tag}}, {{.Name}}){{end}}
|
||||
|
||||
{{define "CustomEncoder"}}{{.Name}}.MarshalBson(buf, {{.Tag}}){{end}}
|
||||
|
||||
{{define "StarEncoder"}}// {{.Type}}
|
||||
if {{.Name}} == nil {
|
||||
bson.EncodePrefix(buf, bson.Null, {{.Tag}})
|
||||
} else {
|
||||
{{template "Encoder" .Subfield}}
|
||||
}{{end}}
|
||||
|
||||
{{define "SliceEncoder"}}// {{.Type}}
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, {{.Tag}})
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, {{.Subfield.Name}} := range {{.Name}} {
|
||||
{{template "Encoder" .Subfield}}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}{{end}}
|
||||
|
||||
{{define "MapEncoder"}}// {{.Type}}
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Object, {{.Tag}})
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _k, {{.Subfield.Name}} := range {{.Name}} {
|
||||
{{template "Encoder" .Subfield}}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}{{end}}
|
||||
|
||||
{{define "Encoder"}}{{if .IsPointer}}{{template "StarEncoder" .}}{{else if .IsSlice}}{{template "SliceEncoder" .}}{{else if .IsMap}}{{template "MapEncoder" .}}{{else if .IsCustom}}{{template "CustomEncoder" .}}{{else}}{{template "SimpleEncoder" .}}{{end}}{{end}}
|
||||
|
||||
{{define "SimpleDecoder"}}{{.Name}} = bson.{{.Decoder}}(buf, kind){{end}}
|
||||
|
||||
{{define "CustomDecoder"}}{{.Name}}.UnmarshalBson(buf, kind){{end}}
|
||||
|
||||
{{define "StarDecoder"}}// {{.Type}}
|
||||
if kind != bson.Null {
|
||||
{{.Name}} = new({{.NewType}})
|
||||
{{template "Decoder" .Subfield}}
|
||||
}{{end}}
|
||||
|
||||
{{define "SliceDecoder"}}// {{.Type}}
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for {{.Name}}", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
{{.Name}} = make({{.Type}}, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var {{.Subfield.Name}} {{.Subfield.Type}}
|
||||
{{template "Decoder" .Subfield}}
|
||||
{{.Name}} = append({{.Name}}, {{.Subfield.Name}})
|
||||
}
|
||||
}{{end}}
|
||||
|
||||
{{define "MapDecoder"}}// {{.Type}}
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Object {
|
||||
panic(bson.NewBsonError("unexpected kind %v for {{.Name}}", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
{{.Name}} = make({{.Type}})
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
_k := {{if .IsSimpleMap}}bson.ReadCString(buf){{else}}{{.KeyType}}(bson.ReadCString(buf)){{end}}
|
||||
var {{.Subfield.Name}} {{.Subfield.Type}}
|
||||
{{template "Decoder" .Subfield}}
|
||||
{{.Name}}[_k] = {{.Subfield.Name}}
|
||||
}
|
||||
}{{end}}
|
||||
|
||||
{{define "Decoder"}}{{if .IsPointer}}{{template "StarDecoder" .}}{{else if .IsSlice}}{{template "SliceDecoder" .}}{{else if .IsMap}}{{template "MapDecoder" .}}{{else if .IsCustom}}{{template "CustomDecoder" .}}{{else}}{{template "SimpleDecoder" .}}{{end}}{{end}}
|
||||
|
||||
{{define "StructBody"}}// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package {{.Package}}
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
import (
|
||||
{{range .Imports}} {{.}}
|
||||
{{end}}
|
||||
)
|
||||
|
||||
// MarshalBson bson-encodes {{.Name}}.
|
||||
func ({{.Var}} *{{.Name}}) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
{{range .Fields}} {{template "Encoder" .}}
|
||||
{{end}}
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into {{.Name}}.
|
||||
func ({{.Var}} *{{.Name}}) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for {{.Name}}", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
{{range .Fields}} case {{.Tag}}:
|
||||
{{template "Decoder" .}}
|
||||
{{end}} default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
{{end}}
|
||||
|
||||
{{define "SimpleBody"}}// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package {{.Package}}
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
import (
|
||||
{{range .Imports}} {{.}}
|
||||
{{end}}
|
||||
)
|
||||
|
||||
// MarshalBson bson-encodes {{.Name}}.
|
||||
func ({{.Var}} {{.Name}}) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
if key == "" {
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
defer lenWriter.Close()
|
||||
key = bson.MAGICTAG
|
||||
}
|
||||
bson.{{.Encoder}}(buf, key, {{.Type}}({{.Var}}))
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into {{.Name}}.
|
||||
func ({{.Var}} *{{.Name}}) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
if kind == bson.EOO {
|
||||
bson.Next(buf, 4)
|
||||
kind = bson.NextByte(buf)
|
||||
bson.ReadCString(buf)
|
||||
}
|
||||
*{{.Var}} = {{.Name}}(bson.{{.Decoder}}(buf, kind))
|
||||
}
|
||||
{{end}}`))
|
|
@ -1,148 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/testfiles"
|
||||
)
|
||||
|
||||
func TestValidFiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode.")
|
||||
}
|
||||
|
||||
inputs := testfiles.Glob("bson_test/input*.go")
|
||||
for _, input := range inputs {
|
||||
b, err := ioutil.ReadFile(input)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.ReadFile error: %v", err)
|
||||
}
|
||||
want, err := ioutil.ReadFile(strings.Replace(input, "input", "output", 1))
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.ReadFile error: %v", err)
|
||||
}
|
||||
|
||||
out, err := generateCode(string(b), "MyType")
|
||||
if err != nil {
|
||||
t.Fatalf("generateCode error: %v", err)
|
||||
}
|
||||
|
||||
// goimports is flaky. So, let's not test that part.
|
||||
want, err = skipImports(want)
|
||||
if err != nil {
|
||||
t.Fatalf("skipImports error: %v", err)
|
||||
}
|
||||
out, err = skipImports(out)
|
||||
if err != nil {
|
||||
t.Fatalf("skipImports error: %v", err)
|
||||
}
|
||||
|
||||
d, err := diff(want, out)
|
||||
if err != nil {
|
||||
t.Fatalf("diff error: %v", err)
|
||||
}
|
||||
if len(d) != 0 {
|
||||
t.Errorf("Unexpected output for %s:\n%s", input, string(d))
|
||||
if testing.Verbose() {
|
||||
t.Logf("%s:\n%s", input, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// diff copied from gofmt.go
|
||||
func diff(b1, b2 []byte) (data []byte, err error) {
|
||||
f1, err := ioutil.TempFile("", "bsongen")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer os.Remove(f1.Name())
|
||||
defer f1.Close()
|
||||
|
||||
f2, err := ioutil.TempFile("", "bsongen")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer os.Remove(f2.Name())
|
||||
defer f2.Close()
|
||||
|
||||
f1.Write(b1)
|
||||
f2.Write(b2)
|
||||
|
||||
data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
|
||||
if len(data) > 0 {
|
||||
// diff exits with a non-zero status when the files don't match.
|
||||
// Ignore that failure as long as we get output.
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func skipImports(b []byte) ([]byte, error) {
|
||||
begin := bytes.Index(b, []byte("\nimport (\n"))
|
||||
if begin < 0 {
|
||||
return nil, errors.New("couldn't find beginning of import block")
|
||||
}
|
||||
end := bytes.Index(b, []byte("\n)\n"))
|
||||
if end < 0 {
|
||||
return nil, errors.New("couldn't find end of imports block")
|
||||
}
|
||||
return append(b[:begin], b[end+3:]...), nil
|
||||
}
|
||||
|
||||
var invalidInputs = []struct{ title, input, err string }{
|
||||
{
|
||||
"func type",
|
||||
`package a; func MyType(){};`,
|
||||
"MyType not found",
|
||||
}, {
|
||||
"non-struct non-simple top level type",
|
||||
`package a; type MyType Custom;`,
|
||||
"MyType is not a struct or a simple type",
|
||||
}, {
|
||||
// Maybe support this in the future?
|
||||
"map type",
|
||||
`package a; type MyType map[string]Custom;`,
|
||||
"MyType is not a struct or a simple type",
|
||||
}, {
|
||||
// Maybe support this in the future?
|
||||
"slice type",
|
||||
`package a; type MyType []Custom;`,
|
||||
"MyType is not a struct or a simple type",
|
||||
}, {
|
||||
"anonymous embed",
|
||||
`package a; type MyType struct{Custom};`,
|
||||
"anonymous embeds not supported: Custom",
|
||||
}, {
|
||||
"interface with methods",
|
||||
`package a; type MyType struct{Val interface{Custom}};`,
|
||||
"is not a simple type",
|
||||
}, {
|
||||
// Maybe support this in the future?
|
||||
"array",
|
||||
`package a; type MyType struct{Val [5]int};`,
|
||||
"is not a simple type",
|
||||
},
|
||||
}
|
||||
|
||||
func TestInvalidInputs(t *testing.T) {
|
||||
for _, tcase := range invalidInputs {
|
||||
out, err := generateCode(tcase.input, "MyType")
|
||||
if err == nil {
|
||||
t.Errorf("Expecting error for %s:\n%s", tcase.title, string(out))
|
||||
}
|
||||
if !strings.Contains(err.Error(), tcase.err) {
|
||||
t.Errorf("%s: got '%v', error should contain '%s'", tcase.title, err, tcase.err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2013, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the gorpc vtgateservice server
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateservice"
|
||||
)
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the gorpc vtgateconn client
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
|
||||
)
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2015, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the gorpc vtgateconn client
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
|
||||
)
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2013, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the gorpc vtgateservice server
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateservice"
|
||||
)
|
|
@ -5,12 +5,10 @@
|
|||
package goclienttest
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/vterrors"
|
||||
|
@ -234,10 +232,6 @@ func checkError(t *testing.T, err error, query, errStr string, errCode vtrpcpb.E
|
|||
if got, want := vtErr.VtErrorCode(), errCode; got != want {
|
||||
t.Errorf("[%v] error code = %v, want %v", query, got, want)
|
||||
}
|
||||
case rpcplus.ServerError:
|
||||
if !strings.Contains(string(vtErr), errStr) {
|
||||
t.Errorf("[%v] error = %q, want contains %q", query, vtErr, errStr)
|
||||
}
|
||||
default:
|
||||
t.Errorf("[%v] unrecognized error type: %T, error: %#v", query, err, err)
|
||||
return
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gorpcclienttest
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/cmd/vtgateclienttest/goclienttest"
|
||||
"github.com/youtube/vitess/go/cmd/vtgateclienttest/services"
|
||||
"github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcwrap/bsonrpc"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateservice"
|
||||
)
|
||||
|
||||
// TestGoRPCGoClient tests the go client using goRPC
|
||||
func TestGoRPCGoClient(t *testing.T) {
|
||||
service := services.CreateServices()
|
||||
|
||||
// listen on a random port
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot listen: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Create a Go Rpc server and listen on the port
|
||||
server := rpcplus.NewServer()
|
||||
server.Register(gorpcvtgateservice.New(service))
|
||||
|
||||
// create the HTTP server, serve the server from it
|
||||
handler := http.NewServeMux()
|
||||
bsonrpc.ServeCustomRPC(handler, server)
|
||||
httpServer := http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
go httpServer.Serve(listener)
|
||||
|
||||
// and run the test suite
|
||||
goclienttest.TestGoClient(t, "gorpc", listener.Addr().String())
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gorpcclienttest
|
||||
|
||||
import (
|
||||
// import the gorpc client, it will register itself
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
|
||||
)
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright 2013, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
// Imports and register the gorpc vtgateservice server
|
||||
|
||||
import (
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateservice"
|
||||
)
|
|
@ -1,50 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
// DO NOT EDIT.
|
||||
// FILE GENERATED BY BSONGEN.
|
||||
|
||||
// MarshalBson bson-encodes RPCError.
|
||||
func (rPCError *RPCError) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
bson.EncodeInt64(buf, "Code", rPCError.Code)
|
||||
bson.EncodeString(buf, "Message", rPCError.Message)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into RPCError.
|
||||
func (rPCError *RPCError) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for RPCError", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Code":
|
||||
rPCError.Code = bson.DecodeInt64(buf, kind)
|
||||
case "Message":
|
||||
rPCError.Message = bson.DecodeString(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package proto
|
||||
|
||||
// RPCError is the structure that is returned by each RPC call, which contains
|
||||
// the error information for that call.
|
||||
type RPCError struct {
|
||||
Code int64
|
||||
Message string
|
||||
}
|
||||
|
||||
//go:generate bsongen -file $GOFILE -type RPCError -o rpcerror_bson.go
|
|
@ -1,373 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package rpcplus
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/youtube/vitess/go/trace"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// ServerError represents an error that has been returned from
|
||||
// the remote side of the RPC connection.
|
||||
type ServerError string
|
||||
|
||||
func (e ServerError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// ErrShutdown holds the specific error for closing/closed connections
|
||||
var ErrShutdown = errors.New("connection is shut down")
|
||||
|
||||
// Call represents an active RPC.
|
||||
type Call struct {
|
||||
ServiceMethod string // The name of the service and method to call.
|
||||
Args interface{} // The argument to the function (*struct).
|
||||
Reply interface{} // The reply from the function (*struct for single, chan * struct for streaming).
|
||||
Error error // After completion, the error status.
|
||||
Done chan *Call // Strobes when call is complete (nil for streaming RPCs)
|
||||
Stream bool // True for a streaming RPC call, false otherwise
|
||||
Subseq uint64 // The next expected subseq in the packets
|
||||
}
|
||||
|
||||
// Client represents an RPC Client.
|
||||
// There may be multiple outstanding Calls associated
|
||||
// with a single Client, and a Client may be used by
|
||||
// multiple goroutines simultaneously.
|
||||
type Client struct {
|
||||
mutex sync.Mutex // protects pending, seq, request
|
||||
sending sync.Mutex
|
||||
request Request
|
||||
seq uint64
|
||||
codec ClientCodec
|
||||
pending map[uint64]*Call
|
||||
closing bool
|
||||
shutdown bool
|
||||
}
|
||||
|
||||
// A ClientCodec implements writing of RPC requests and
|
||||
// reading of RPC responses for the client side of an RPC session.
|
||||
// The client calls WriteRequest to write a request to the connection
|
||||
// and calls ReadResponseHeader and ReadResponseBody in pairs
|
||||
// to read responses. The client calls Close when finished with the
|
||||
// connection. ReadResponseBody may be called with a nil
|
||||
// argument to force the body of the response to be read and then
|
||||
// discarded.
|
||||
type ClientCodec interface {
|
||||
WriteRequest(*Request, interface{}) error
|
||||
ReadResponseHeader(*Response) error
|
||||
ReadResponseBody(interface{}) error
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (client *Client) send(call *Call) {
|
||||
client.sending.Lock()
|
||||
defer client.sending.Unlock()
|
||||
|
||||
// Register this call.
|
||||
client.mutex.Lock()
|
||||
if client.shutdown {
|
||||
call.Error = ErrShutdown
|
||||
client.mutex.Unlock()
|
||||
call.done()
|
||||
return
|
||||
}
|
||||
seq := client.seq
|
||||
client.seq++
|
||||
client.pending[seq] = call
|
||||
client.mutex.Unlock()
|
||||
|
||||
// Encode and send the request.
|
||||
client.request.Seq = seq
|
||||
client.request.ServiceMethod = call.ServiceMethod
|
||||
err := client.codec.WriteRequest(&client.request, call.Args)
|
||||
if err != nil {
|
||||
client.mutex.Lock()
|
||||
call = client.pending[seq]
|
||||
delete(client.pending, seq)
|
||||
client.mutex.Unlock()
|
||||
if call != nil {
|
||||
call.Error = err
|
||||
call.done()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) input() {
|
||||
var err error
|
||||
var response Response
|
||||
for err == nil {
|
||||
response = Response{}
|
||||
err = client.codec.ReadResponseHeader(&response)
|
||||
if err != nil {
|
||||
if err == io.EOF && !client.closing {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
break
|
||||
}
|
||||
seq := response.Seq
|
||||
client.mutex.Lock()
|
||||
call := client.pending[seq]
|
||||
client.mutex.Unlock()
|
||||
|
||||
switch {
|
||||
case call == nil:
|
||||
// We've got no pending call. That usually means that
|
||||
// WriteRequest partially failed, and call was already
|
||||
// removed; response is a server telling us about an
|
||||
// error reading request body. We should still attempt
|
||||
// to read error body, but there's no one to give it to.
|
||||
err = client.codec.ReadResponseBody(nil)
|
||||
if err != nil {
|
||||
err = errors.New("reading error body: " + err.Error())
|
||||
}
|
||||
case response.Error != "":
|
||||
// We've got an error response. Give this to the request;
|
||||
// any subsequent requests will get the ReadResponseBody
|
||||
// error if there is one.
|
||||
if !(call.Stream && response.Error == lastStreamResponseError) {
|
||||
call.Error = ServerError(response.Error)
|
||||
}
|
||||
err = client.codec.ReadResponseBody(nil)
|
||||
if err != nil {
|
||||
err = errors.New("reading error payload: " + err.Error())
|
||||
}
|
||||
client.done(seq)
|
||||
case call.Stream:
|
||||
// call.Reply is a chan *T2
|
||||
// we need to create a T2 and get a *T2 back
|
||||
value := reflect.New(reflect.TypeOf(call.Reply).Elem().Elem()).Interface()
|
||||
err = client.codec.ReadResponseBody(value)
|
||||
if err != nil {
|
||||
call.Error = errors.New("reading body " + err.Error())
|
||||
} else {
|
||||
// writing on the channel could block forever. For
|
||||
// instance, if a client calls 'close', this might block
|
||||
// forever. the current suggestion is for the
|
||||
// client to drain the receiving channel in that case
|
||||
reflect.ValueOf(call.Reply).Send(reflect.ValueOf(value))
|
||||
}
|
||||
default:
|
||||
err = client.codec.ReadResponseBody(call.Reply)
|
||||
if err != nil {
|
||||
call.Error = errors.New("reading body " + err.Error())
|
||||
}
|
||||
client.done(seq)
|
||||
}
|
||||
}
|
||||
// Terminate pending calls.
|
||||
client.sending.Lock()
|
||||
client.mutex.Lock()
|
||||
client.shutdown = true
|
||||
closing := client.closing
|
||||
for _, call := range client.pending {
|
||||
call.Error = err
|
||||
call.done()
|
||||
}
|
||||
client.mutex.Unlock()
|
||||
client.sending.Unlock()
|
||||
if err != io.EOF && !closing {
|
||||
log.Println("rpc: client protocol error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) done(seq uint64) {
|
||||
client.mutex.Lock()
|
||||
call := client.pending[seq]
|
||||
delete(client.pending, seq)
|
||||
client.mutex.Unlock()
|
||||
|
||||
if call != nil {
|
||||
call.done()
|
||||
}
|
||||
}
|
||||
|
||||
func (call *Call) done() {
|
||||
if call.Stream {
|
||||
// need to close the channel. Client won't be able to read any more.
|
||||
reflect.ValueOf(call.Reply).Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case call.Done <- call:
|
||||
// ok
|
||||
default:
|
||||
// We don't want to block here. It is the caller's responsibility to make
|
||||
// sure the channel has enough buffer space. See comment in Go().
|
||||
log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient returns a new Client to handle requests to the
|
||||
// set of services at the other end of the connection.
|
||||
// It adds a buffer to the write side of the connection so
|
||||
// the header and payload are sent as a unit.
|
||||
func NewClient(conn io.ReadWriteCloser) *Client {
|
||||
encBuf := bufio.NewWriter(conn)
|
||||
client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
|
||||
return NewClientWithCodec(client)
|
||||
}
|
||||
|
||||
// NewClientWithCodec is like NewClient but uses the specified
|
||||
// codec to encode requests and decode responses.
|
||||
func NewClientWithCodec(codec ClientCodec) *Client {
|
||||
client := &Client{
|
||||
codec: codec,
|
||||
pending: make(map[uint64]*Call),
|
||||
}
|
||||
go client.input()
|
||||
return client
|
||||
}
|
||||
|
||||
type gobClientCodec struct {
|
||||
rwc io.ReadWriteCloser
|
||||
dec *gob.Decoder
|
||||
enc *gob.Encoder
|
||||
encBuf *bufio.Writer
|
||||
}
|
||||
|
||||
func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
|
||||
if err = c.enc.Encode(r); err != nil {
|
||||
return
|
||||
}
|
||||
if err = c.enc.Encode(body); err != nil {
|
||||
return
|
||||
}
|
||||
return c.encBuf.Flush()
|
||||
}
|
||||
|
||||
func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
|
||||
return c.dec.Decode(r)
|
||||
}
|
||||
|
||||
func (c *gobClientCodec) ReadResponseBody(body interface{}) error {
|
||||
return c.dec.Decode(body)
|
||||
}
|
||||
|
||||
func (c *gobClientCodec) Close() error {
|
||||
return c.rwc.Close()
|
||||
}
|
||||
|
||||
// DialHTTP connects to an HTTP RPC server at the specified network address
|
||||
// listening on the default HTTP RPC path.
|
||||
func DialHTTP(network, address string) (*Client, error) {
|
||||
return DialHTTPPath(network, address, DefaultRPCPath)
|
||||
}
|
||||
|
||||
// DialHTTPPath connects to an HTTP RPC server
|
||||
// at the specified network address and path.
|
||||
func DialHTTPPath(network, address, path string) (*Client, error) {
|
||||
var err error
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
|
||||
|
||||
// Require successful HTTP response
|
||||
// before switching to RPC protocol.
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
|
||||
if err == nil && resp.Status == connected {
|
||||
return NewClient(conn), nil
|
||||
}
|
||||
if err == nil {
|
||||
err = errors.New("unexpected HTTP response: " + resp.Status)
|
||||
}
|
||||
conn.Close()
|
||||
return nil, &net.OpError{
|
||||
Op: "dial-http",
|
||||
Net: network + " " + address,
|
||||
Addr: nil,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial connects to an RPC server at the specified network address.
|
||||
func Dial(network, address string) (*Client, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClient(conn), nil
|
||||
}
|
||||
|
||||
// Close closes the client connection
|
||||
func (client *Client) Close() error {
|
||||
client.mutex.Lock()
|
||||
if client.shutdown || client.closing {
|
||||
client.mutex.Unlock()
|
||||
return ErrShutdown
|
||||
}
|
||||
client.closing = true
|
||||
client.mutex.Unlock()
|
||||
return client.codec.Close()
|
||||
}
|
||||
|
||||
// Go invokes the function asynchronously. It returns the Call structure representing
|
||||
// the invocation. The done channel will signal when the call is complete by returning
|
||||
// the same Call object. If done is nil, Go will allocate a new channel.
|
||||
// If non-nil, done must be buffered or Go will deliberately crash.
|
||||
func (client *Client) Go(ctx context.Context, serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
|
||||
span := trace.NewSpanFromContext(ctx)
|
||||
span.StartClient(serviceMethod)
|
||||
defer span.Finish()
|
||||
|
||||
call := new(Call)
|
||||
call.ServiceMethod = serviceMethod
|
||||
call.Args = args
|
||||
call.Reply = reply
|
||||
if done == nil {
|
||||
done = make(chan *Call, 10) // buffered.
|
||||
} else {
|
||||
// If caller passes done != nil, it must arrange that
|
||||
// done has enough buffer for the number of simultaneous
|
||||
// RPCs that will be using that channel. If the channel
|
||||
// is totally unbuffered, it's best not to run at all.
|
||||
if cap(done) == 0 {
|
||||
log.Panic("rpc: done channel is unbuffered")
|
||||
}
|
||||
}
|
||||
call.Done = done
|
||||
client.send(call)
|
||||
return call
|
||||
}
|
||||
|
||||
// StreamGo invokes the streaming function asynchronously. It returns the Call structure representing
|
||||
// the invocation.
|
||||
func (client *Client) StreamGo(serviceMethod string, args interface{}, replyStream interface{}) *Call {
|
||||
// first check the replyStream object is a stream of pointers to a data structure
|
||||
typ := reflect.TypeOf(replyStream)
|
||||
// FIXME: check the direction of the channel, maybe?
|
||||
if typ.Kind() != reflect.Chan || typ.Elem().Kind() != reflect.Ptr {
|
||||
log.Panic("rpc: replyStream is not a channel of pointers")
|
||||
return nil
|
||||
}
|
||||
|
||||
call := new(Call)
|
||||
call.ServiceMethod = serviceMethod
|
||||
call.Args = args
|
||||
call.Reply = replyStream
|
||||
call.Stream = true
|
||||
call.Subseq = 0
|
||||
client.send(call)
|
||||
return call
|
||||
}
|
||||
|
||||
// Call invokes the named function, waits for it to complete, and returns its error status.
|
||||
func (client *Client) Call(ctx context.Context, serviceMethod string, args interface{}, reply interface{}) error {
|
||||
call := <-client.Go(ctx, serviceMethod, args, reply, make(chan *Call, 1)).Done
|
||||
return call.Error
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package rpcplus
|
||||
|
||||
/*
|
||||
Some HTML presented at http://machine:port/debug/rpc
|
||||
Lists services, their methods, and some statistics, still rudimentary.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
const debugText = `<html>
|
||||
<body>
|
||||
<title>Services</title>
|
||||
{{range .}}
|
||||
<hr>
|
||||
Service {{.Name}}
|
||||
<hr>
|
||||
<table>
|
||||
<th align=center>Method</th><th align=center>Calls</th>
|
||||
{{range .Method}}
|
||||
<tr>
|
||||
<td align=left font=fixed>{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error</td>
|
||||
<td align=center>{{.Type.NumCalls}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
{{end}}
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
var debug = template.Must(template.New("RPC debug").Parse(debugText))
|
||||
|
||||
type debugMethod struct {
|
||||
Type *methodType
|
||||
Name string
|
||||
}
|
||||
|
||||
type methodArray []debugMethod
|
||||
|
||||
type debugService struct {
|
||||
Service *service
|
||||
Name string
|
||||
Method methodArray
|
||||
}
|
||||
|
||||
type serviceArray []debugService
|
||||
|
||||
func (s serviceArray) Len() int { return len(s) }
|
||||
func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
|
||||
func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
func (m methodArray) Len() int { return len(m) }
|
||||
func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
|
||||
func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
||||
|
||||
type debugHTTP struct {
|
||||
*Server
|
||||
}
|
||||
|
||||
// Runs at /debug/rpc
|
||||
func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// Build a sorted version of the data.
|
||||
var services = make(serviceArray, len(server.serviceMap))
|
||||
i := 0
|
||||
server.mu.Lock()
|
||||
for sname, service := range server.serviceMap {
|
||||
services[i] = debugService{service, sname, make(methodArray, len(service.method))}
|
||||
j := 0
|
||||
for mname, method := range service.method {
|
||||
services[i].Method[j] = debugMethod{method, mname}
|
||||
j++
|
||||
}
|
||||
sort.Sort(services[i].Method)
|
||||
i++
|
||||
}
|
||||
server.mu.Unlock()
|
||||
sort.Sort(services)
|
||||
err := debug.Execute(w, services)
|
||||
if err != nil {
|
||||
fmt.Fprintln(w, "rpc: error executing template:", err.Error())
|
||||
}
|
||||
}
|
|
@ -1,270 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package jsonrpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/youtube/vitess/go/rpcplus"
|
||||
)
|
||||
|
||||
type Args struct {
|
||||
A, B int
|
||||
}
|
||||
|
||||
type Reply struct {
|
||||
C int
|
||||
}
|
||||
|
||||
type Arith int
|
||||
|
||||
func (t *Arith) Add(args *Args, reply *Reply) error {
|
||||
reply.C = args.A + args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Mul(args *Args, reply *Reply) error {
|
||||
reply.C = args.A * args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Div(args *Args, reply *Reply) error {
|
||||
if args.B == 0 {
|
||||
return errors.New("divide by zero")
|
||||
}
|
||||
reply.C = args.A / args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Error(args *Args, reply *Reply) error {
|
||||
panic("ERROR")
|
||||
}
|
||||
|
||||
func (t *Arith) Thrive(args *Args, sendReply func(reply interface{}) error) error {
|
||||
for i := 0; i < args.A; i++ {
|
||||
r := &Reply{C: i}
|
||||
err := sendReply(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
rpcplus.Register(new(Arith))
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
type addResp struct {
|
||||
ID interface{} `json:"id"`
|
||||
Result Reply `json:"result"`
|
||||
Error interface{} `json:"error"`
|
||||
}
|
||||
|
||||
cli, srv := net.Pipe()
|
||||
defer cli.Close()
|
||||
go ServeConn(srv)
|
||||
dec := json.NewDecoder(cli)
|
||||
|
||||
// Send hand-coded requests to server, parse responses.
|
||||
for i := 0; i < 10; i++ {
|
||||
fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
|
||||
var resp addResp
|
||||
err := dec.Decode(&resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode: %s", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("resp.Error: %s", resp.Error)
|
||||
}
|
||||
if resp.ID.(string) != string(i) {
|
||||
t.Fatalf("resp: bad id %q want %q", resp.ID.(string), string(i))
|
||||
}
|
||||
if resp.Result.C != 2*i+1 {
|
||||
t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(cli, "{}\n")
|
||||
var resp addResp
|
||||
if err := dec.Decode(&resp); err != nil {
|
||||
t.Fatalf("Decode after empty: %s", err)
|
||||
}
|
||||
if resp.Error == nil {
|
||||
t.Fatalf("Expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// Assume server is okay (TestServer is above).
|
||||
// Test client against server.
|
||||
cli, srv := net.Pipe()
|
||||
go ServeConn(srv)
|
||||
|
||||
client := NewClient(cli)
|
||||
defer client.Close()
|
||||
|
||||
// Synchronous calls
|
||||
args := &Args{7, 8}
|
||||
reply := new(Reply)
|
||||
err := client.Call(ctx, "Arith.Add", args, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Add: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A+args.B {
|
||||
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
|
||||
}
|
||||
|
||||
args = &Args{7, 8}
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.Mul", args, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Mul: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A*args.B {
|
||||
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
|
||||
}
|
||||
|
||||
// Out of order.
|
||||
args = &Args{7, 8}
|
||||
mulReply := new(Reply)
|
||||
mulCall := client.Go(ctx, "Arith.Mul", args, mulReply, nil)
|
||||
addReply := new(Reply)
|
||||
addCall := client.Go(ctx, "Arith.Add", args, addReply, nil)
|
||||
|
||||
addCall = <-addCall.Done
|
||||
if addCall.Error != nil {
|
||||
t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
|
||||
}
|
||||
if addReply.C != args.A+args.B {
|
||||
t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
|
||||
}
|
||||
|
||||
mulCall = <-mulCall.Done
|
||||
if mulCall.Error != nil {
|
||||
t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
|
||||
}
|
||||
if mulReply.C != args.A*args.B {
|
||||
t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
|
||||
}
|
||||
|
||||
// Error test
|
||||
args = &Args{7, 0}
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.Div", args, reply)
|
||||
// expect an error: zero divide
|
||||
if err == nil {
|
||||
t.Error("Div: expected error")
|
||||
} else if err.Error() != "divide by zero" {
|
||||
t.Error("Div: expected divide by zero error; got", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMalformedInput(t *testing.T) {
|
||||
cli, srv := net.Pipe()
|
||||
go cli.Write([]byte(`{id:1}`)) // invalid json
|
||||
ServeConn(srv) // must return, not loop
|
||||
}
|
||||
|
||||
func TestUnexpectedError(t *testing.T) {
|
||||
cli, srv := myPipe()
|
||||
go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error
|
||||
ServeConn(srv) // must return, not loop
|
||||
}
|
||||
|
||||
func TestStreamingCall(t *testing.T) {
|
||||
// Assume server is okay (TestServer is above).
|
||||
// Test client against server.
|
||||
cli, srv := net.Pipe()
|
||||
go ServeConn(srv)
|
||||
|
||||
client := NewClient(cli)
|
||||
defer client.Close()
|
||||
|
||||
args := &Args{7, 0}
|
||||
rowChan := make(chan *Reply, 10)
|
||||
c := client.StreamGo("Arith.Thrive", args, rowChan)
|
||||
|
||||
// fetch all the rows
|
||||
count := 0
|
||||
for row := range rowChan {
|
||||
if row.C != count {
|
||||
t.Fatal("unexpected value:", row.C)
|
||||
}
|
||||
count++
|
||||
|
||||
// log.Println("Values: ", row)
|
||||
}
|
||||
|
||||
if c.Error != nil {
|
||||
t.Fatal("unexpected error:", c.Error.Error())
|
||||
}
|
||||
|
||||
if count != 7 {
|
||||
t.Fatal("Didn't receive the right number of packets back:", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Copied from package net.
|
||||
func myPipe() (*pipe, *pipe) {
|
||||
r1, w1 := io.Pipe()
|
||||
r2, w2 := io.Pipe()
|
||||
|
||||
return &pipe{r1, w2}, &pipe{r2, w1}
|
||||
}
|
||||
|
||||
type pipe struct {
|
||||
*io.PipeReader
|
||||
*io.PipeWriter
|
||||
}
|
||||
|
||||
type pipeAddr int
|
||||
|
||||
func (pipeAddr) Network() string {
|
||||
return "pipe"
|
||||
}
|
||||
|
||||
func (pipeAddr) String() string {
|
||||
return "pipe"
|
||||
}
|
||||
|
||||
func (p *pipe) Close() error {
|
||||
err := p.PipeReader.Close()
|
||||
err1 := p.PipeWriter.Close()
|
||||
if err == nil {
|
||||
err = err1
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *pipe) LocalAddr() net.Addr {
|
||||
return pipeAddr(0)
|
||||
}
|
||||
|
||||
func (p *pipe) RemoteAddr() net.Addr {
|
||||
return pipeAddr(0)
|
||||
}
|
||||
|
||||
func (p *pipe) SetTimeout(nsec int64) error {
|
||||
return errors.New("net.Pipe does not support timeouts")
|
||||
}
|
||||
|
||||
func (p *pipe) SetReadTimeout(nsec int64) error {
|
||||
return errors.New("net.Pipe does not support timeouts")
|
||||
}
|
||||
|
||||
func (p *pipe) SetWriteTimeout(nsec int64) error {
|
||||
return errors.New("net.Pipe does not support timeouts")
|
||||
}
|
|
@ -1,190 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec
|
||||
// for the rpcplus package.
|
||||
package jsonrpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
)
|
||||
|
||||
type clientCodec struct {
|
||||
dec *json.Decoder // for reading JSON values
|
||||
enc *json.Encoder // for writing JSON values
|
||||
c io.Closer
|
||||
|
||||
// temporary work space
|
||||
req clientRequest
|
||||
resp clientResponse
|
||||
|
||||
// JSON-RPC responses include the request id but not the request method.
|
||||
// Package rpc expects both.
|
||||
// We save the request method in pending when sending a request
|
||||
// and then look it up by request ID when filling out the rpc Response.
|
||||
mutex sync.Mutex // protects pending
|
||||
pending map[uint64]string // map request id to method name
|
||||
}
|
||||
|
||||
// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn.
|
||||
func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
|
||||
return &clientCodec{
|
||||
dec: json.NewDecoder(conn),
|
||||
enc: json.NewEncoder(conn),
|
||||
c: conn,
|
||||
pending: make(map[uint64]string),
|
||||
}
|
||||
}
|
||||
|
||||
type clientRequest struct {
|
||||
Method string `json:"method"`
|
||||
Params [1]interface{} `json:"params"`
|
||||
ID uint64 `json:"id"`
|
||||
}
|
||||
|
||||
func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error {
|
||||
c.mutex.Lock()
|
||||
c.pending[r.Seq] = r.ServiceMethod
|
||||
c.mutex.Unlock()
|
||||
c.req.Method = r.ServiceMethod
|
||||
c.req.Params[0] = param
|
||||
c.req.ID = r.Seq
|
||||
return c.enc.Encode(&c.req)
|
||||
}
|
||||
|
||||
type clientResponse struct {
|
||||
ID uint64 `json:"id"`
|
||||
Result *json.RawMessage `json:"result"`
|
||||
Error interface{} `json:"error"`
|
||||
}
|
||||
|
||||
func (r *clientResponse) reset() {
|
||||
r.ID = 0
|
||||
r.Result = nil
|
||||
r.Error = nil
|
||||
}
|
||||
|
||||
func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error {
|
||||
c.resp.reset()
|
||||
if err := c.dec.Decode(&c.resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
r.ServiceMethod = c.pending[c.resp.ID]
|
||||
delete(c.pending, c.resp.ID)
|
||||
c.mutex.Unlock()
|
||||
|
||||
r.Error = ""
|
||||
r.Seq = c.resp.ID
|
||||
if c.resp.Error != nil {
|
||||
x, ok := c.resp.Error.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid error %v", c.resp.Error)
|
||||
}
|
||||
if x == "" {
|
||||
x = "unspecified error"
|
||||
}
|
||||
r.Error = x
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientCodec) ReadResponseBody(x interface{}) error {
|
||||
if x == nil {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(*c.resp.Result, x)
|
||||
}
|
||||
|
||||
func (c *clientCodec) Close() error {
|
||||
return c.c.Close()
|
||||
}
|
||||
|
||||
// NewClient returns a new rpc.Client to handle requests to the
|
||||
// set of services at the other end of the connection.
|
||||
func NewClient(conn io.ReadWriteCloser) *rpc.Client {
|
||||
return rpc.NewClientWithCodec(NewClientCodec(conn))
|
||||
}
|
||||
|
||||
// Dial connects to a JSON-RPC server at the specified network address.
|
||||
func Dial(network, address string) (*rpc.Client, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClient(conn), err
|
||||
}
|
||||
|
||||
// HTTPClient holds the required parameters and functions for communicating with
|
||||
// the HTTP RPC server
|
||||
type HTTPClient struct {
|
||||
Addr string
|
||||
seq uint64
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// NewHTTPClient creates a helper json rpc client for regular http based
|
||||
// endpoints
|
||||
func NewHTTPClient(addr string) *HTTPClient {
|
||||
return &HTTPClient{
|
||||
Addr: addr,
|
||||
seq: 0,
|
||||
m: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Call calls the http rpc endpoint with given parameters, uses POST request and
|
||||
// can be called by multiple go routines
|
||||
func (h *HTTPClient) Call(serviceMethod string, args interface{}, reply interface{}) error {
|
||||
var params [1]interface{}
|
||||
params[0] = args
|
||||
|
||||
h.m.Lock()
|
||||
seq := h.seq
|
||||
h.seq++
|
||||
h.m.Unlock()
|
||||
|
||||
cr := &clientRequest{
|
||||
Method: serviceMethod,
|
||||
Params: params,
|
||||
ID: seq,
|
||||
}
|
||||
|
||||
byteData, err := json.Marshal(cr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", h.Addr, bytes.NewReader(byteData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v := &clientResponse{}
|
||||
err = json.NewDecoder(res.Body).Decode(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v.Error != nil {
|
||||
return errors.New(v.Error.(string))
|
||||
}
|
||||
|
||||
return json.Unmarshal(*v.Result, reply)
|
||||
}
|
|
@ -1,147 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package jsonrpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type serverCodec struct {
|
||||
dec *json.Decoder // for reading JSON values
|
||||
enc *json.Encoder // for writing JSON values
|
||||
c io.Closer
|
||||
|
||||
// temporary work space
|
||||
req serverRequest
|
||||
resp serverResponse
|
||||
|
||||
// JSON-RPC clients can use arbitrary json values as request IDs.
|
||||
// Package rpc expects uint64 request IDs.
|
||||
// We assign uint64 sequence numbers to incoming requests
|
||||
// but save the original request ID in the pending map.
|
||||
// When rpc responds, we use the sequence number in
|
||||
// the response to find the original request ID.
|
||||
mutex sync.Mutex // protects seq, pending
|
||||
seq uint64
|
||||
pending map[uint64]*json.RawMessage
|
||||
}
|
||||
|
||||
// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
|
||||
func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
|
||||
return &serverCodec{
|
||||
dec: json.NewDecoder(conn),
|
||||
enc: json.NewEncoder(conn),
|
||||
c: conn,
|
||||
pending: make(map[uint64]*json.RawMessage),
|
||||
}
|
||||
}
|
||||
|
||||
type serverRequest struct {
|
||||
Method string `json:"method"`
|
||||
Params *json.RawMessage `json:"params"`
|
||||
ID *json.RawMessage `json:"id"`
|
||||
}
|
||||
|
||||
func (r *serverRequest) reset() {
|
||||
r.Method = ""
|
||||
if r.Params != nil {
|
||||
*r.Params = (*r.Params)[0:0]
|
||||
}
|
||||
if r.ID != nil {
|
||||
*r.ID = (*r.ID)[0:0]
|
||||
}
|
||||
}
|
||||
|
||||
type serverResponse struct {
|
||||
ID *json.RawMessage `json:"id"`
|
||||
Result interface{} `json:"result"`
|
||||
Error interface{} `json:"error"`
|
||||
}
|
||||
|
||||
func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error {
|
||||
c.req.reset()
|
||||
if err := c.dec.Decode(&c.req); err != nil {
|
||||
return err
|
||||
}
|
||||
r.ServiceMethod = c.req.Method
|
||||
|
||||
// JSON request id can be any JSON value;
|
||||
// RPC package expects uint64. Translate to
|
||||
// internal uint64 and save JSON on the side.
|
||||
c.mutex.Lock()
|
||||
c.seq++
|
||||
c.pending[c.seq] = c.req.ID
|
||||
c.req.ID = nil
|
||||
r.Seq = c.seq
|
||||
c.mutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverCodec) ReadRequestBody(x interface{}) error {
|
||||
if x == nil {
|
||||
return nil
|
||||
}
|
||||
// JSON params is array value.
|
||||
// RPC params is struct.
|
||||
// Unmarshal into array containing struct for now.
|
||||
// Should think about making RPC more general.
|
||||
var params [1]interface{}
|
||||
params[0] = x
|
||||
return json.Unmarshal(*c.req.Params, ¶ms)
|
||||
}
|
||||
|
||||
var null = json.RawMessage([]byte("null"))
|
||||
|
||||
func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}, last bool) error {
|
||||
var resp serverResponse
|
||||
c.mutex.Lock()
|
||||
b, ok := c.pending[r.Seq]
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return errors.New("invalid sequence number in response")
|
||||
}
|
||||
if last {
|
||||
delete(c.pending, r.Seq)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
|
||||
if b == nil {
|
||||
// Invalid request so no id. Use JSON null.
|
||||
b = &null
|
||||
}
|
||||
resp.ID = b
|
||||
resp.Result = x
|
||||
if r.Error == "" {
|
||||
resp.Error = nil
|
||||
} else {
|
||||
resp.Error = r.Error
|
||||
}
|
||||
return c.enc.Encode(resp)
|
||||
}
|
||||
|
||||
// Close closes the server connection
|
||||
func (c *serverCodec) Close() error {
|
||||
return c.c.Close()
|
||||
}
|
||||
|
||||
// ServeConn runs the JSON-RPC server on a single connection.
|
||||
// ServeConn blocks, serving the connection until the client hangs up.
|
||||
// The caller typically invokes ServeConn in a go statement.
|
||||
func ServeConn(conn io.ReadWriteCloser) {
|
||||
ServeConnWithContext(context.TODO(), conn)
|
||||
}
|
||||
|
||||
// ServeConnWithContext is like ServeConn but it allows to pass a
|
||||
// connection context to the RPC methods.
|
||||
func ServeConnWithContext(ctx context.Context, conn io.ReadWriteCloser) {
|
||||
rpc.ServeCodecWithContext(ctx, NewServerCodec(conn))
|
||||
}
|
|
@ -1,104 +0,0 @@
|
|||
// Package pbrpc implements a ClientCodec and ServerCodec
|
||||
// for the rpcplus package using Protocol Buffers.
|
||||
package pbrpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
)
|
||||
|
||||
// NewClientCodec returns a new rpc.ClientCodec using Protobuf on conn.
|
||||
func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
|
||||
return &pbClientCodec{rwc: conn}
|
||||
}
|
||||
|
||||
// NewClient returns a new rpc.Client to handle requests to the
|
||||
// set of services at the other end of the connection.
|
||||
func NewClient(conn io.ReadWriteCloser) *rpc.Client {
|
||||
return rpc.NewClientWithCodec(NewClientCodec(conn))
|
||||
}
|
||||
|
||||
// Dial connects to a Protobuf-RPC server at the specified network address.
|
||||
func Dial(network, address string) (*rpc.Client, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClient(conn), err
|
||||
}
|
||||
|
||||
type pbClientCodec struct {
|
||||
mu sync.Mutex
|
||||
rwc io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// WriteRequest - implement rpc.ClientCodec interface.
|
||||
func (c *pbClientCodec) WriteRequest(r *rpc.Request, body interface{}) (err error) {
|
||||
// Use a mutex to guarantee the header/body are written in the correct order.
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// This is protobuf, of course we copy it.
|
||||
pbr := &Request{ServiceMethod: &r.ServiceMethod, Seq: &r.Seq}
|
||||
data, err := proto.Marshal(pbr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = WriteNetString(c.rwc, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Of course this is a protobuf! Trust me or detonate the program.
|
||||
data, err = proto.Marshal(body.(proto.Message))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = WriteNetString(c.rwc, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if flusher, ok := c.rwc.(flusher); ok {
|
||||
err = flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ReadResponseHeader - implement rpc.ClientCodec interface.
|
||||
func (c *pbClientCodec) ReadResponseHeader(r *rpc.Response) error {
|
||||
data, err := ReadNetString(c.rwc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rtmp := new(Response)
|
||||
err = proto.Unmarshal(data, rtmp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.ServiceMethod = *rtmp.ServiceMethod
|
||||
r.Seq = *rtmp.Seq
|
||||
r.Error = *rtmp.Error
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadResponseBody - implement rpc.ClientCodec interface.
|
||||
func (c *pbClientCodec) ReadResponseBody(body interface{}) error {
|
||||
data, err := ReadNetString(c.rwc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if body != nil {
|
||||
return proto.Unmarshal(data, body.(proto.Message))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close - implement rpc.ClientCodec interface.
|
||||
func (c *pbClientCodec) Close() error {
|
||||
return c.rwc.Close()
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
// Code generated by protoc-gen-go.
|
||||
// source: envelope.proto
|
||||
// DO NOT EDIT!
|
||||
|
||||
/*
|
||||
Package pbrpc is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
envelope.proto
|
||||
|
||||
It has these top-level messages:
|
||||
Request
|
||||
Response
|
||||
*/
|
||||
package pbrpc
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import json "encoding/json"
|
||||
import math "math"
|
||||
|
||||
// Reference proto, json, and math imports to suppress error if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = &json.SyntaxError{}
|
||||
var _ = math.Inf
|
||||
|
||||
type Request struct {
|
||||
ServiceMethod *string `protobuf:"bytes,1,opt,name=service_method" json:"service_method,omitempty"`
|
||||
Seq *uint64 `protobuf:"fixed64,2,opt,name=seq" json:"seq,omitempty"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Request) Reset() { *m = Request{} }
|
||||
func (m *Request) String() string { return proto.CompactTextString(m) }
|
||||
func (*Request) ProtoMessage() {}
|
||||
|
||||
func (m *Request) GetServiceMethod() string {
|
||||
if m != nil && m.ServiceMethod != nil {
|
||||
return *m.ServiceMethod
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Request) GetSeq() uint64 {
|
||||
if m != nil && m.Seq != nil {
|
||||
return *m.Seq
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
ServiceMethod *string `protobuf:"bytes,1,opt,name=service_method" json:"service_method,omitempty"`
|
||||
Seq *uint64 `protobuf:"fixed64,2,opt,name=seq" json:"seq,omitempty"`
|
||||
Error *string `protobuf:"bytes,3,opt,name=error" json:"error,omitempty"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Response) Reset() { *m = Response{} }
|
||||
func (m *Response) String() string { return proto.CompactTextString(m) }
|
||||
func (*Response) ProtoMessage() {}
|
||||
|
||||
func (m *Response) GetServiceMethod() string {
|
||||
if m != nil && m.ServiceMethod != nil {
|
||||
return *m.ServiceMethod
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Response) GetSeq() uint64 {
|
||||
if m != nil && m.Seq != nil {
|
||||
return *m.Seq
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Response) GetError() string {
|
||||
if m != nil && m.Error != nil {
|
||||
return *m.Error
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
package pbrpc;
|
||||
|
||||
message Request {
|
||||
optional string service_method = 1;
|
||||
optional fixed64 seq = 2;
|
||||
}
|
||||
|
||||
message Response {
|
||||
optional string service_method = 1;
|
||||
optional fixed64 seq = 2;
|
||||
optional string error = 3;
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
package pbrpc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcwrap"
|
||||
)
|
||||
|
||||
const codecName = "protobuf"
|
||||
|
||||
// DialHTTP with Protobuf codec.
|
||||
func DialHTTP(network, address string, connectTimeout time.Duration) (*rpc.Client, error) {
|
||||
return rpcwrap.DialHTTP(network, address, codecName, NewClientCodec, connectTimeout)
|
||||
}
|
||||
|
||||
// ServeRPC with Protobuf codec.
|
||||
func ServeRPC() {
|
||||
rpcwrap.ServeRPC(codecName, NewServerCodec)
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
package pbrpc
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
// WriteNetString writes data to a big-endian netstring on a Writer.
|
||||
// Size is always a 32-bit unsigned int.
|
||||
func WriteNetString(w io.Writer, data []byte) (written int, err error) {
|
||||
size := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(size, uint32(len(data)))
|
||||
if written, err = w.Write(size); err != nil {
|
||||
return
|
||||
}
|
||||
return w.Write(data)
|
||||
}
|
||||
|
||||
// ReadNetString reads data from a big-endian netstring.
|
||||
func ReadNetString(r io.Reader) (data []byte, err error) {
|
||||
sizeBuf := make([]byte, 4)
|
||||
_, err = r.Read(sizeBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size := binary.BigEndian.Uint32(sizeBuf)
|
||||
if size == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
data = make([]byte, size)
|
||||
_, err = r.Read(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
package pbrpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
)
|
||||
|
||||
type pbServerCodec struct {
|
||||
mu sync.Mutex
|
||||
rwc io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// NewServerCodec returns a new ServerCodec.
|
||||
func NewServerCodec(rwc io.ReadWriteCloser) rpc.ServerCodec {
|
||||
return &pbServerCodec{rwc: rwc}
|
||||
}
|
||||
|
||||
// ReadRequestHeader reads a Request.
|
||||
func (c *pbServerCodec) ReadRequestHeader(r *rpc.Request) error {
|
||||
data, err := ReadNetString(c.rwc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rtmp := new(Request)
|
||||
err = proto.Unmarshal(data, rtmp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.ServiceMethod = *rtmp.ServiceMethod
|
||||
r.Seq = *rtmp.Seq
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadRequestBody reads a body structure from the codec.
|
||||
func (c *pbServerCodec) ReadRequestBody(body interface{}) error {
|
||||
data, err := ReadNetString(c.rwc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if body != nil {
|
||||
return proto.Unmarshal(data, body.(proto.Message))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type flusher interface {
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// WriteResponse writes a response on the codec.
|
||||
func (c *pbServerCodec) WriteResponse(r *rpc.Response, body interface{}, last bool) (err error) {
|
||||
// Use a mutex to guarantee the header/body are written in the correct order.
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
rtmp := &Response{ServiceMethod: &r.ServiceMethod, Seq: &r.Seq, Error: &r.Error}
|
||||
data, err := proto.Marshal(rtmp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = WriteNetString(c.rwc, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if pb, ok := body.(proto.Message); ok {
|
||||
data, err = proto.Marshal(pb)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
data = nil
|
||||
}
|
||||
_, err = WriteNetString(c.rwc, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if flusher, ok := c.rwc.(flusher); ok {
|
||||
err = flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Close the underlying connection.
|
||||
func (c *pbServerCodec) Close() error {
|
||||
return c.rwc.Close()
|
||||
}
|
|
@ -1,829 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package rpcplus provides access to the exported methods of an object across a
|
||||
network or other I/O connection. A server registers an object, making it visible
|
||||
as a service with the name of the type of the object. After registration,
|
||||
exported methods of the object will be accessible remotely. A server may
|
||||
register multiple objects (services) of different types, but it is an error to
|
||||
register multiple objects of the same type.
|
||||
|
||||
Only methods that satisfy the following criteria will be made available for
|
||||
remote access (other methods will be ignored):
|
||||
|
||||
- the method is exported.
|
||||
- the method has two arguments, both exported (or builtin) types.
|
||||
- the method's second argument is either a pointer, or a function pointer.
|
||||
- the method has return type error.
|
||||
|
||||
In effect, the method must look schematically like one of these two:
|
||||
|
||||
func (t *T) MethodName(argType T1, replyType *T2) error
|
||||
func (t *T) MethodName(argType T1, sendReply func(interface{}) error) error
|
||||
|
||||
where T, T1 and T2 can be marshaled by encoding/gob.
|
||||
These requirements apply even if a different codec is used.
|
||||
(In the future, these requirements may soften for custom codecs.)
|
||||
|
||||
The method's first argument represents the arguments provided by the caller; the
|
||||
second argument represents the result parameters to be returned to the caller,
|
||||
or the function to call to send results.
|
||||
The method's return value, if non-nil, is passed back as a string that the
|
||||
client sees as if created by errors.New. If an error is returned, the reply
|
||||
parameter will not be sent back to the client.
|
||||
|
||||
The server may handle requests on a single connection by calling ServeConn. More
|
||||
typically it will create a network listener and call Accept or, for an HTTP
|
||||
listener, HandleHTTP and http.Serve.
|
||||
|
||||
A client wishing to use the service establishes a connection and then invokes
|
||||
NewClient on the connection. The convenience function Dial (DialHTTP) performs
|
||||
both steps for a raw network connection (an HTTP connection). The resulting
|
||||
Client object has two methods, Call and Go, that specify the service and method
|
||||
to call, a pointer containing the arguments, and a pointer to receive the result
|
||||
parameters. It also has a StreamGo method, that specifies a reply channel
|
||||
to receive the results in the case of streaming RPCs.
|
||||
|
||||
The Call method waits for the remote call to complete while the Go method
|
||||
launches the call asynchronously and signals completion using the Call
|
||||
structure's Done channel. The StreamGo method is always asynchronous.
|
||||
|
||||
Unless an explicit codec is set up, package encoding/gob is used to
|
||||
transport the data.
|
||||
|
||||
Here is a simple example. A server wishes to export an object of type Arith:
|
||||
|
||||
package server
|
||||
|
||||
type Args struct {
|
||||
A, B int
|
||||
}
|
||||
|
||||
type Quotient struct {
|
||||
Quo, Rem int
|
||||
}
|
||||
|
||||
type Arith int
|
||||
|
||||
func (t *Arith) Multiply(args *Args, reply *int) error {
|
||||
*reply = args.A * args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Divide(args *Args, quo *Quotient) error {
|
||||
if args.B == 0 {
|
||||
return errors.New("divide by zero")
|
||||
}
|
||||
quo.Quo = args.A / args.B
|
||||
quo.Rem = args.A % args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
The server calls (for HTTP service):
|
||||
|
||||
arith := new(Arith)
|
||||
rpc.Register(arith)
|
||||
rpc.HandleHTTP()
|
||||
l, e := net.Listen("tcp", ":1234")
|
||||
if e != nil {
|
||||
log.Fatal("listen error:", e)
|
||||
}
|
||||
go http.Serve(l, nil)
|
||||
|
||||
At this point, clients can see a service "Arith" with methods "Arith.Multiply" and
|
||||
"Arith.Divide". To invoke one, a client first dials the server:
|
||||
|
||||
client, err := rpc.DialHTTP("tcp", serverAddress + ":1234")
|
||||
if err != nil {
|
||||
log.Fatal("dialing:", err)
|
||||
}
|
||||
|
||||
Then it can make a remote call:
|
||||
|
||||
// Synchronous call
|
||||
args := &server.Args{7,8}
|
||||
var reply int
|
||||
err = client.Call("Arith.Multiply", args, &reply)
|
||||
if err != nil {
|
||||
log.Fatal("arith error:", err)
|
||||
}
|
||||
fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply)
|
||||
|
||||
or:
|
||||
|
||||
// Asynchronous call
|
||||
quotient := new(Quotient)
|
||||
divCall := client.Go("Arith.Divide", args, "ient, nil)
|
||||
replyCall := <-divCall.Done // will be equal to divCall
|
||||
// check errors, print, etc.
|
||||
|
||||
A server implementation will often provide a simple, type-safe wrapper for the
|
||||
client.
|
||||
*/
|
||||
package rpcplus
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultRPCPath used as handler location for HandleHTTP
|
||||
DefaultRPCPath = "/_goRPC_"
|
||||
|
||||
// DefaultDebugPath used as debug handler location for HandleHTTP
|
||||
DefaultDebugPath = "/debug/rpc"
|
||||
)
|
||||
|
||||
// Precompute the reflect type for error. Can't use error directly
|
||||
// because Typeof takes an empty interface value. This is annoying.
|
||||
var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
|
||||
|
||||
type methodType struct {
|
||||
sync.Mutex // protects counters
|
||||
method reflect.Method
|
||||
ArgType reflect.Type
|
||||
ReplyType reflect.Type
|
||||
ContextType reflect.Type
|
||||
stream bool
|
||||
numCalls uint
|
||||
}
|
||||
|
||||
func (m *methodType) TakesContext() bool {
|
||||
return m.ContextType != nil
|
||||
}
|
||||
|
||||
type service struct {
|
||||
name string // name of service
|
||||
rcvr reflect.Value // receiver of methods for the service
|
||||
typ reflect.Type // type of the receiver
|
||||
method map[string]*methodType // registered methods
|
||||
}
|
||||
|
||||
// Request is a header written before every RPC call. It is used internally
|
||||
// but documented here as an aid to debugging, such as when analyzing
|
||||
// network traffic.
|
||||
type Request struct {
|
||||
ServiceMethod string // format: "Service.Method"
|
||||
Seq uint64 // sequence number chosen by client
|
||||
next *Request // for free list in Server
|
||||
}
|
||||
|
||||
// Response is a header written before every RPC return. It is used internally
|
||||
// but documented here as an aid to debugging, such as when analyzing
|
||||
// network traffic.
|
||||
type Response struct {
|
||||
ServiceMethod string // echoes that of the Request
|
||||
Seq uint64 // echoes that of the request
|
||||
Error string // error, if any.
|
||||
next *Response // for free list in Server
|
||||
}
|
||||
|
||||
const lastStreamResponseError = "EOS"
|
||||
|
||||
// Server represents an RPC Server.
|
||||
type Server struct {
|
||||
mu sync.Mutex // protects the serviceMap
|
||||
serviceMap map[string]*service
|
||||
reqLock sync.Mutex // protects freeReq
|
||||
freeReq *Request
|
||||
respLock sync.Mutex // protects freeResp
|
||||
freeResp *Response
|
||||
}
|
||||
|
||||
// NewServer returns a new Server.
|
||||
func NewServer() *Server {
|
||||
return &Server{serviceMap: make(map[string]*service)}
|
||||
}
|
||||
|
||||
// DefaultServer is the default instance of *Server.
|
||||
var DefaultServer = NewServer()
|
||||
|
||||
// Is this an exported - upper case - name?
|
||||
func isExported(name string) bool {
|
||||
rune, _ := utf8.DecodeRuneInString(name)
|
||||
return unicode.IsUpper(rune)
|
||||
}
|
||||
|
||||
// Is this type exported or a builtin?
|
||||
func isExportedOrBuiltinType(t reflect.Type) bool {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
// PkgPath will be non-empty even for an exported type,
|
||||
// so we need to check the type name as well.
|
||||
return isExported(t.Name()) || t.PkgPath() == ""
|
||||
}
|
||||
|
||||
// Register publishes in the server the set of methods of the
|
||||
// receiver value that satisfy the following conditions:
|
||||
// - exported method
|
||||
// - two arguments, both pointers to exported structs
|
||||
// - one return value, of type error
|
||||
// It returns an error if the receiver is not an exported type or has no
|
||||
// suitable methods.
|
||||
// The client accesses each method using a string of the form "Type.Method",
|
||||
// where Type is the receiver's concrete type.
|
||||
func (server *Server) Register(rcvr interface{}) error {
|
||||
return server.register(rcvr, "", false)
|
||||
}
|
||||
|
||||
// RegisterName is like Register but uses the provided name for the type
|
||||
// instead of the receiver's concrete type.
|
||||
func (server *Server) RegisterName(name string, rcvr interface{}) error {
|
||||
return server.register(rcvr, name, true)
|
||||
}
|
||||
|
||||
// prepareMethod returns a methodType for the provided method or nil
|
||||
// in case if the method was unsuitable.
|
||||
func prepareMethod(method reflect.Method) *methodType {
|
||||
mtype := method.Type
|
||||
mname := method.Name
|
||||
var replyType, argType, contextType reflect.Type
|
||||
|
||||
stream := false
|
||||
// Method must be exported.
|
||||
if method.PkgPath != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch mtype.NumIn() {
|
||||
case 3:
|
||||
// normal method
|
||||
argType = mtype.In(1)
|
||||
replyType = mtype.In(2)
|
||||
contextType = nil
|
||||
case 4:
|
||||
// method that takes a context
|
||||
argType = mtype.In(2)
|
||||
replyType = mtype.In(3)
|
||||
contextType = mtype.In(1)
|
||||
default:
|
||||
log.Println("method", mname, "of", mtype, "has wrong number of ins:", mtype.NumIn())
|
||||
return nil
|
||||
}
|
||||
|
||||
// First arg need not be a pointer.
|
||||
if !isExportedOrBuiltinType(argType) {
|
||||
log.Println(mname, "argument type not exported:", argType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// the second argument will tell us if it's a streaming call
|
||||
// or a regular call
|
||||
if replyType.Kind() == reflect.Func {
|
||||
// this is a streaming call
|
||||
stream = true
|
||||
if replyType.NumIn() != 1 {
|
||||
log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn())
|
||||
return nil
|
||||
}
|
||||
if replyType.In(0).Kind() != reflect.Interface {
|
||||
log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0))
|
||||
return nil
|
||||
}
|
||||
if replyType.NumOut() != 1 {
|
||||
log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut())
|
||||
return nil
|
||||
}
|
||||
if returnType := replyType.Out(0); returnType != typeOfError {
|
||||
log.Println("method", mname, "sendReply returns", returnType.String(), "not error")
|
||||
return nil
|
||||
}
|
||||
|
||||
} else if replyType.Kind() != reflect.Ptr {
|
||||
log.Println("method", mname, "reply type not a pointer:", replyType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reply type must be exported.
|
||||
if !isExportedOrBuiltinType(replyType) {
|
||||
log.Println("method", mname, "reply type not exported:", replyType)
|
||||
return nil
|
||||
}
|
||||
// Method needs one out.
|
||||
if mtype.NumOut() != 1 {
|
||||
log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
|
||||
return nil
|
||||
}
|
||||
// The return type of the method must be error.
|
||||
if returnType := mtype.Out(0); returnType != typeOfError {
|
||||
log.Println("method", mname, "returns", returnType.String(), "not error")
|
||||
return nil
|
||||
}
|
||||
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
|
||||
}
|
||||
|
||||
func (server *Server) register(rcvr interface{}, name string, useName bool) error {
|
||||
server.mu.Lock()
|
||||
defer server.mu.Unlock()
|
||||
if server.serviceMap == nil {
|
||||
server.serviceMap = make(map[string]*service)
|
||||
}
|
||||
s := new(service)
|
||||
s.typ = reflect.TypeOf(rcvr)
|
||||
s.rcvr = reflect.ValueOf(rcvr)
|
||||
sname := reflect.Indirect(s.rcvr).Type().Name()
|
||||
if useName {
|
||||
sname = name
|
||||
}
|
||||
if sname == "" {
|
||||
log.Fatal("rpc: no service name for type", s.typ.String())
|
||||
}
|
||||
if !isExported(sname) && !useName {
|
||||
s := "rpc Register: type " + sname + " is not exported"
|
||||
log.Print(s)
|
||||
return errors.New(s)
|
||||
}
|
||||
if _, present := server.serviceMap[sname]; present {
|
||||
return errors.New("rpc: service already defined: " + sname)
|
||||
}
|
||||
s.name = sname
|
||||
s.method = make(map[string]*methodType)
|
||||
|
||||
// Install the methods
|
||||
for m := 0; m < s.typ.NumMethod(); m++ {
|
||||
method := s.typ.Method(m)
|
||||
if mt := prepareMethod(method); mt != nil {
|
||||
s.method[method.Name] = mt
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.method) == 0 {
|
||||
s := "rpc Register: type " + sname + " has no exported methods of suitable type"
|
||||
log.Print(s)
|
||||
return errors.New(s)
|
||||
}
|
||||
server.serviceMap[s.name] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
// A value sent as a placeholder for the server's response value when the server
|
||||
// receives an invalid request. It is never decoded by the client since the Response
|
||||
// contains an error when it is used.
|
||||
var invalidRequest = struct{}{}
|
||||
|
||||
func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string, last bool) (err error) {
|
||||
resp := server.getResponse()
|
||||
// Encode the response header
|
||||
resp.ServiceMethod = req.ServiceMethod
|
||||
if errmsg != "" {
|
||||
resp.Error = errmsg
|
||||
reply = invalidRequest
|
||||
}
|
||||
resp.Seq = req.Seq
|
||||
sending.Lock()
|
||||
err = codec.WriteResponse(resp, reply, last)
|
||||
if err != nil {
|
||||
log.Println("rpc: writing response:", err)
|
||||
}
|
||||
sending.Unlock()
|
||||
server.freeResponse(resp)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *methodType) NumCalls() (n uint) {
|
||||
m.Lock()
|
||||
n = m.numCalls
|
||||
m.Unlock()
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *service) call(ctx context.Context, server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
|
||||
mtype.Lock()
|
||||
mtype.numCalls++
|
||||
mtype.Unlock()
|
||||
function := mtype.method.Func
|
||||
var returnValues []reflect.Value
|
||||
|
||||
if !mtype.stream {
|
||||
|
||||
// Invoke the method, providing a new value for the reply.
|
||||
if mtype.TakesContext() {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, replyv})
|
||||
} else {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, argv, replyv})
|
||||
}
|
||||
|
||||
// The return value for the method is an error.
|
||||
errInter := returnValues[0].Interface()
|
||||
errmsg := ""
|
||||
if errInter != nil {
|
||||
errmsg = errInter.(error).Error()
|
||||
}
|
||||
server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true)
|
||||
server.freeRequest(req)
|
||||
return
|
||||
}
|
||||
|
||||
// declare a local error to see if we errored out already
|
||||
// keep track of the type, to make sure we return
|
||||
// the same one consistently
|
||||
var lastError error
|
||||
var firstType reflect.Type
|
||||
|
||||
sendReply := func(oneReply interface{}) error {
|
||||
|
||||
// we already triggered an error, we're done
|
||||
if lastError != nil {
|
||||
return lastError
|
||||
}
|
||||
|
||||
// check the oneReply has the right type using reflection
|
||||
typ := reflect.TypeOf(oneReply)
|
||||
if firstType == nil {
|
||||
firstType = typ
|
||||
} else {
|
||||
if firstType != typ {
|
||||
log.Println("passing wrong type to sendReply",
|
||||
firstType, "!=", typ)
|
||||
lastError = errors.New("rpc: passing wrong type to sendReply")
|
||||
return lastError
|
||||
}
|
||||
}
|
||||
|
||||
lastError = server.sendResponse(sending, req, oneReply, codec, "", false)
|
||||
if lastError != nil {
|
||||
return lastError
|
||||
}
|
||||
|
||||
// we manage to send, we're good
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invoke the method, providing a new value for the reply.
|
||||
if mtype.TakesContext() {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), argv, reflect.ValueOf(sendReply)})
|
||||
} else {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)})
|
||||
}
|
||||
errInter := returnValues[0].Interface()
|
||||
errmsg := ""
|
||||
if errInter != nil {
|
||||
// the function returned an error, we use that
|
||||
errmsg = errInter.(error).Error()
|
||||
} else if lastError != nil {
|
||||
// we had an error inside sendReply, we use that
|
||||
errmsg = lastError.Error()
|
||||
} else {
|
||||
// no error, we send the special EOS error
|
||||
errmsg = lastStreamResponseError
|
||||
}
|
||||
|
||||
// this is the last packet, we don't do anything with
|
||||
// the error here (well sendStreamResponse will log it
|
||||
// already)
|
||||
server.sendResponse(sending, req, nil, codec, errmsg, true)
|
||||
server.freeRequest(req)
|
||||
}
|
||||
|
||||
type gobServerCodec struct {
|
||||
rwc io.ReadWriteCloser
|
||||
dec *gob.Decoder
|
||||
enc *gob.Encoder
|
||||
encBuf *bufio.Writer
|
||||
}
|
||||
|
||||
func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
|
||||
return c.dec.Decode(r)
|
||||
}
|
||||
|
||||
func (c *gobServerCodec) ReadRequestBody(body interface{}) error {
|
||||
return c.dec.Decode(body)
|
||||
}
|
||||
|
||||
func (c *gobServerCodec) WriteResponse(r *Response, body interface{}, last bool) (err error) {
|
||||
if err = c.enc.Encode(r); err != nil {
|
||||
return
|
||||
}
|
||||
if err = c.enc.Encode(body); err != nil {
|
||||
return
|
||||
}
|
||||
return c.encBuf.Flush()
|
||||
}
|
||||
|
||||
func (c *gobServerCodec) Close() error {
|
||||
return c.rwc.Close()
|
||||
}
|
||||
|
||||
// ServeConn runs the server on a single connection.
|
||||
// ServeConn blocks, serving the connection until the client hangs up.
|
||||
// The caller typically invokes ServeConn in a go statement.
|
||||
// ServeConn uses the gob wire format (see package gob) on the
|
||||
// connection. To use an alternate codec, use ServeCodec.
|
||||
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
|
||||
server.ServeConnWithContext(context.TODO(), conn)
|
||||
}
|
||||
|
||||
// ServeConnWithContext is like ServeConn but makes it possible to
|
||||
// pass a connection context to the RPC methods.
|
||||
func (server *Server) ServeConnWithContext(ctx context.Context, conn io.ReadWriteCloser) {
|
||||
buf := bufio.NewWriter(conn)
|
||||
srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf}
|
||||
server.ServeCodecWithContext(ctx, srv)
|
||||
}
|
||||
|
||||
// ServeCodec is like ServeConn but uses the specified codec to
|
||||
// decode requests and encode responses.
|
||||
func (server *Server) ServeCodec(codec ServerCodec) {
|
||||
server.ServeCodecWithContext(context.TODO(), codec)
|
||||
}
|
||||
|
||||
// ServeCodecWithContext is like ServeCodec but it makes it possible
|
||||
// to pass a connection context to the RPC methods.
|
||||
func (server *Server) ServeCodecWithContext(ctx context.Context, codec ServerCodec) {
|
||||
sending := new(sync.Mutex)
|
||||
for {
|
||||
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Println("rpc:", err)
|
||||
}
|
||||
if !keepReading {
|
||||
break
|
||||
}
|
||||
// send a response if we actually managed to read a header.
|
||||
if req != nil {
|
||||
server.sendResponse(sending, req, invalidRequest, codec, err.Error(), true)
|
||||
server.freeRequest(req)
|
||||
}
|
||||
continue
|
||||
}
|
||||
go service.call(ctx, server, sending, mtype, req, argv, replyv, codec)
|
||||
}
|
||||
codec.Close()
|
||||
}
|
||||
|
||||
func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
|
||||
if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
|
||||
return contextv
|
||||
}
|
||||
return reflect.Zero(m.ContextType)
|
||||
}
|
||||
|
||||
// ServeRequest is like ServeCodec but synchronously serves a single request.
|
||||
// It does not close the codec upon completion.
|
||||
func (server *Server) ServeRequest(codec ServerCodec) error {
|
||||
return server.ServeRequestWithContext(context.TODO(), codec)
|
||||
}
|
||||
|
||||
// ServeRequestWithContext is like ServeRequest but makes it possible
|
||||
// to pass a connection context to the RPC methods.
|
||||
func (server *Server) ServeRequestWithContext(ctx context.Context, codec ServerCodec) error {
|
||||
sending := new(sync.Mutex)
|
||||
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
|
||||
if err != nil {
|
||||
if !keepReading {
|
||||
return err
|
||||
}
|
||||
// send a response if we actually managed to read a header.
|
||||
if req != nil {
|
||||
server.sendResponse(sending, req, invalidRequest, codec, err.Error(), true)
|
||||
server.freeRequest(req)
|
||||
}
|
||||
return err
|
||||
}
|
||||
service.call(ctx, server, sending, mtype, req, argv, replyv, codec)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) getRequest() *Request {
|
||||
server.reqLock.Lock()
|
||||
req := server.freeReq
|
||||
if req == nil {
|
||||
req = new(Request)
|
||||
} else {
|
||||
server.freeReq = req.next
|
||||
*req = Request{}
|
||||
}
|
||||
server.reqLock.Unlock()
|
||||
return req
|
||||
}
|
||||
|
||||
func (server *Server) freeRequest(req *Request) {
|
||||
server.reqLock.Lock()
|
||||
req.next = server.freeReq
|
||||
server.freeReq = req
|
||||
server.reqLock.Unlock()
|
||||
}
|
||||
|
||||
func (server *Server) getResponse() *Response {
|
||||
server.respLock.Lock()
|
||||
resp := server.freeResp
|
||||
if resp == nil {
|
||||
resp = new(Response)
|
||||
} else {
|
||||
server.freeResp = resp.next
|
||||
*resp = Response{}
|
||||
}
|
||||
server.respLock.Unlock()
|
||||
return resp
|
||||
}
|
||||
|
||||
func (server *Server) freeResponse(resp *Response) {
|
||||
server.respLock.Lock()
|
||||
resp.next = server.freeResp
|
||||
server.freeResp = resp
|
||||
server.respLock.Unlock()
|
||||
}
|
||||
|
||||
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
|
||||
service, mtype, req, keepReading, err = server.readRequestHeader(codec)
|
||||
if err != nil {
|
||||
if !keepReading {
|
||||
return
|
||||
}
|
||||
// discard body
|
||||
codec.ReadRequestBody(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode the argument value.
|
||||
argIsValue := false // if true, need to indirect before calling.
|
||||
if mtype.ArgType.Kind() == reflect.Ptr {
|
||||
argv = reflect.New(mtype.ArgType.Elem())
|
||||
} else {
|
||||
argv = reflect.New(mtype.ArgType)
|
||||
argIsValue = true
|
||||
}
|
||||
// argv guaranteed to be a pointer now.
|
||||
if err = codec.ReadRequestBody(argv.Interface()); err != nil {
|
||||
return
|
||||
}
|
||||
if argIsValue {
|
||||
argv = argv.Elem()
|
||||
}
|
||||
|
||||
if !mtype.stream {
|
||||
replyv = reflect.New(mtype.ReplyType.Elem())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err error) {
|
||||
// Grab the request header.
|
||||
req = server.getRequest()
|
||||
err = codec.ReadRequestHeader(req)
|
||||
if err != nil {
|
||||
req = nil
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
return
|
||||
}
|
||||
err = errors.New("rpc: server cannot decode request: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// We read the header successfully. If we see an error now,
|
||||
// we can still recover and move on to the next request.
|
||||
keepReading = true
|
||||
|
||||
serviceMethod := strings.Split(req.ServiceMethod, ".")
|
||||
if len(serviceMethod) != 2 {
|
||||
err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
|
||||
return
|
||||
}
|
||||
// Look up the request.
|
||||
server.mu.Lock()
|
||||
service = server.serviceMap[serviceMethod[0]]
|
||||
server.mu.Unlock()
|
||||
if service == nil {
|
||||
err = errors.New("rpc: can't find service " + req.ServiceMethod)
|
||||
return
|
||||
}
|
||||
mtype = service.method[serviceMethod[1]]
|
||||
if mtype == nil {
|
||||
err = errors.New("rpc: can't find method " + req.ServiceMethod)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Accept accepts connections on the listener and serves requests
|
||||
// for each incoming connection. Accept blocks; the caller typically
|
||||
// invokes it in a go statement.
|
||||
func (server *Server) Accept(lis net.Listener) {
|
||||
for {
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
log.Fatal("rpc.Serve: accept:", err.Error()) // TODO(r): exit?
|
||||
}
|
||||
go server.ServeConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Register publishes the receiver's methods in the DefaultServer.
|
||||
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
|
||||
|
||||
// RegisterName is like Register but uses the provided name for the type
|
||||
// instead of the receiver's concrete type.
|
||||
func RegisterName(name string, rcvr interface{}) error {
|
||||
return DefaultServer.RegisterName(name, rcvr)
|
||||
}
|
||||
|
||||
// A ServerCodec implements reading of RPC requests and writing of
|
||||
// RPC responses for the server side of an RPC session.
|
||||
// The server calls ReadRequestHeader and ReadRequestBody in pairs
|
||||
// to read requests from the connection, and it calls WriteResponse to
|
||||
// write a response back. The server calls Close when finished with the
|
||||
// connection. ReadRequestBody may be called with a nil
|
||||
// argument to force the body of the request to be read and discarded.
|
||||
type ServerCodec interface {
|
||||
ReadRequestHeader(*Request) error
|
||||
ReadRequestBody(interface{}) error
|
||||
WriteResponse(*Response, interface{}, bool) error
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
// ServeConn runs the DefaultServer on a single connection.
|
||||
// ServeConn blocks, serving the connection until the client hangs up.
|
||||
// The caller typically invokes ServeConn in a go statement.
|
||||
// ServeConn uses the gob wire format (see package gob) on the
|
||||
// connection. To use an alternate codec, use ServeCodec.
|
||||
func ServeConn(conn io.ReadWriteCloser) {
|
||||
ServeConnWithContext(context.TODO(), conn)
|
||||
}
|
||||
|
||||
// ServeConnWithContext is like ServeConn but it allows to pass a
|
||||
// connection context to the RPC methods.
|
||||
func ServeConnWithContext(ctx context.Context, conn io.ReadWriteCloser) {
|
||||
DefaultServer.ServeConnWithContext(ctx, conn)
|
||||
}
|
||||
|
||||
// ServeCodec is like ServeConn but uses the specified codec to
|
||||
// decode requests and encode responses.
|
||||
func ServeCodec(codec ServerCodec) {
|
||||
ServeCodecWithContext(context.TODO(), codec)
|
||||
}
|
||||
|
||||
// ServeCodecWithContext is like ServeCodec but it allows to pass a
|
||||
// connection context to the RPC methods.
|
||||
func ServeCodecWithContext(ctx context.Context, codec ServerCodec) {
|
||||
DefaultServer.ServeCodecWithContext(ctx, codec)
|
||||
}
|
||||
|
||||
// ServeRequest is like ServeCodec but synchronously serves a single request.
|
||||
// It does not close the codec upon completion.
|
||||
func ServeRequest(codec ServerCodec) error {
|
||||
return ServeRequestWithContext(context.TODO(), codec)
|
||||
|
||||
}
|
||||
|
||||
// ServeRequestWithContext is like ServeRequest but it allows to pass
|
||||
// a connection context to the RPC methods.
|
||||
func ServeRequestWithContext(ctx context.Context, codec ServerCodec) error {
|
||||
return DefaultServer.ServeRequestWithContext(ctx, codec)
|
||||
}
|
||||
|
||||
// Accept accepts connections on the listener and serves requests
|
||||
// to DefaultServer for each incoming connection.
|
||||
// Accept blocks; the caller typically invokes it in a go statement.
|
||||
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
|
||||
|
||||
// Can connect to RPC service using HTTP CONNECT to rpcPath.
|
||||
var connected = "200 Connected to Go RPC"
|
||||
|
||||
// ServeHTTP implements an http.Handler that answers RPC requests.
|
||||
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != "CONNECT" {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
io.WriteString(w, "405 must CONNECT\n")
|
||||
return
|
||||
}
|
||||
conn, _, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
|
||||
server.ServeConn(conn)
|
||||
}
|
||||
|
||||
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
|
||||
// and a debugging handler on debugPath.
|
||||
// It is still necessary to invoke http.Serve(), typically in a go statement.
|
||||
func (server *Server) HandleHTTP(rpcPath, debugPath string) {
|
||||
http.Handle(rpcPath, server)
|
||||
http.Handle(debugPath, debugHTTP{server})
|
||||
}
|
||||
|
||||
// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
|
||||
// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
|
||||
// It is still necessary to invoke http.Serve(), typically in a go statement.
|
||||
func HandleHTTP() {
|
||||
DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
|
||||
}
|
|
@ -1,615 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package rpcplus
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var (
|
||||
newServer *Server
|
||||
serverAddr, newServerAddr string
|
||||
httpServerAddr string
|
||||
once, newOnce, httpOnce sync.Once
|
||||
)
|
||||
|
||||
const (
|
||||
newHTTPPath = "/foo"
|
||||
)
|
||||
|
||||
type Args struct {
|
||||
A, B int
|
||||
}
|
||||
|
||||
type Reply struct {
|
||||
C int
|
||||
}
|
||||
|
||||
type Arith int
|
||||
|
||||
// Some of Arith's methods have value args, some have pointer args. That's deliberate.
|
||||
|
||||
func (t *Arith) Add(args Args, reply *Reply) error {
|
||||
reply.C = args.A + args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Mul(args *Args, reply *Reply) error {
|
||||
reply.C = args.A * args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Div(args Args, reply *Reply) error {
|
||||
if args.B == 0 {
|
||||
return errors.New("divide by zero")
|
||||
}
|
||||
reply.C = args.A / args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) String(args *Args, reply *string) error {
|
||||
*reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Scan(args string, reply *Reply) (err error) {
|
||||
_, err = fmt.Sscan(args, &reply.C)
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Arith) Error(args *Args, reply *Reply) error {
|
||||
panic("ERROR")
|
||||
}
|
||||
|
||||
func (t *Arith) TakesContext(context interface{}, args string, reply *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func listenTCP() (net.Listener, string) {
|
||||
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
|
||||
if e != nil {
|
||||
log.Fatalf("net.Listen tcp :0: %v", e)
|
||||
}
|
||||
return l, l.Addr().String()
|
||||
}
|
||||
|
||||
func startServer() {
|
||||
Register(new(Arith))
|
||||
|
||||
var l net.Listener
|
||||
l, serverAddr = listenTCP()
|
||||
log.Println("Test RPC server listening on", serverAddr)
|
||||
go Accept(l)
|
||||
|
||||
HandleHTTP()
|
||||
httpOnce.Do(startHTTPServer)
|
||||
}
|
||||
|
||||
func startNewServer() {
|
||||
newServer = NewServer()
|
||||
newServer.Register(new(Arith))
|
||||
|
||||
var l net.Listener
|
||||
l, newServerAddr = listenTCP()
|
||||
log.Println("NewServer test RPC server listening on", newServerAddr)
|
||||
go Accept(l)
|
||||
|
||||
newServer.HandleHTTP(newHTTPPath, "/bar")
|
||||
httpOnce.Do(startHTTPServer)
|
||||
}
|
||||
|
||||
func startHTTPServer() {
|
||||
server := httptest.NewServer(nil)
|
||||
httpServerAddr = server.Listener.Addr().String()
|
||||
log.Println("Test HTTP RPC server listening on", httpServerAddr)
|
||||
}
|
||||
|
||||
func TestRPC(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
once.Do(startServer)
|
||||
testRPC(ctx, t, serverAddr)
|
||||
newOnce.Do(startNewServer)
|
||||
testRPC(ctx, t, newServerAddr)
|
||||
}
|
||||
|
||||
func testRPC(ctx context.Context, t *testing.T, addr string) {
|
||||
client, err := Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatal("dialing", err)
|
||||
}
|
||||
|
||||
// Synchronous calls
|
||||
args := &Args{7, 8}
|
||||
reply := new(Reply)
|
||||
err = client.Call(ctx, "Arith.Add", args, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Add: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A+args.B {
|
||||
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
|
||||
}
|
||||
|
||||
// Nonexistent method
|
||||
args = &Args{7, 0}
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.BadOperation", args, reply)
|
||||
// expect an error
|
||||
if err == nil {
|
||||
t.Error("BadOperation: expected error")
|
||||
} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
|
||||
t.Errorf("BadOperation: expected can't find method error; got %q", err)
|
||||
}
|
||||
|
||||
// Unknown service
|
||||
args = &Args{7, 8}
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.Unknown", args, reply)
|
||||
if err == nil {
|
||||
t.Error("expected error calling unknown service")
|
||||
} else if strings.Index(err.Error(), "method") < 0 {
|
||||
t.Error("expected error about method; got", err)
|
||||
}
|
||||
|
||||
// Out of order.
|
||||
args = &Args{7, 8}
|
||||
mulReply := new(Reply)
|
||||
mulCall := client.Go(ctx, "Arith.Mul", args, mulReply, nil)
|
||||
addReply := new(Reply)
|
||||
addCall := client.Go(ctx, "Arith.Add", args, addReply, nil)
|
||||
|
||||
addCall = <-addCall.Done
|
||||
if addCall.Error != nil {
|
||||
t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
|
||||
}
|
||||
if addReply.C != args.A+args.B {
|
||||
t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
|
||||
}
|
||||
|
||||
mulCall = <-mulCall.Done
|
||||
if mulCall.Error != nil {
|
||||
t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
|
||||
}
|
||||
if mulReply.C != args.A*args.B {
|
||||
t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
|
||||
}
|
||||
|
||||
// Error test
|
||||
args = &Args{7, 0}
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.Div", args, reply)
|
||||
// expect an error: zero divide
|
||||
if err == nil {
|
||||
t.Error("Div: expected error")
|
||||
} else if err.Error() != "divide by zero" {
|
||||
t.Error("Div: expected divide by zero error; got", err)
|
||||
}
|
||||
|
||||
// Bad type.
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.Add", reply, reply) // args, reply would be the correct thing to use
|
||||
if err == nil {
|
||||
t.Error("expected error calling Arith.Add with wrong arg type")
|
||||
} else if strings.Index(err.Error(), "type") < 0 {
|
||||
t.Error("expected error about type; got", err)
|
||||
}
|
||||
|
||||
// Non-struct argument
|
||||
const Val = 12345
|
||||
str := fmt.Sprint(Val)
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.Scan", &str, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Scan: expected no error but got string %q", err.Error())
|
||||
} else if reply.C != Val {
|
||||
t.Errorf("Scan: expected %d got %d", Val, reply.C)
|
||||
}
|
||||
|
||||
// Non-struct reply
|
||||
args = &Args{27, 35}
|
||||
str = ""
|
||||
err = client.Call(ctx, "Arith.String", args, &str)
|
||||
if err != nil {
|
||||
t.Errorf("String: expected no error but got string %q", err.Error())
|
||||
}
|
||||
expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
|
||||
if str != expect {
|
||||
t.Errorf("String: expected %s got %s", expect, str)
|
||||
}
|
||||
|
||||
args = &Args{7, 8}
|
||||
reply = new(Reply)
|
||||
err = client.Call(ctx, "Arith.Mul", args, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Mul: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A*args.B {
|
||||
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
|
||||
}
|
||||
|
||||
// Takes context
|
||||
emptyString := ""
|
||||
err = client.Call(ctx, "Arith.TakesContext", "", &emptyString)
|
||||
if err != nil {
|
||||
t.Errorf("TakesContext: expected no error but got string %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
once.Do(startServer)
|
||||
testHTTPRPC(ctx, t, "")
|
||||
newOnce.Do(startNewServer)
|
||||
testHTTPRPC(ctx, t, newHTTPPath)
|
||||
}
|
||||
|
||||
func testHTTPRPC(ctx context.Context, t *testing.T, path string) {
|
||||
var client *Client
|
||||
var err error
|
||||
if path == "" {
|
||||
client, err = DialHTTP("tcp", httpServerAddr)
|
||||
} else {
|
||||
client, err = DialHTTPPath("tcp", httpServerAddr, path)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("dialing", err)
|
||||
}
|
||||
|
||||
// Synchronous calls
|
||||
args := &Args{7, 8}
|
||||
reply := new(Reply)
|
||||
err = client.Call(ctx, "Arith.Add", args, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Add: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A+args.B {
|
||||
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
|
||||
}
|
||||
}
|
||||
|
||||
// CodecEmulator provides a client-like api and a ServerCodec interface.
|
||||
// Can be used to test ServeRequest.
|
||||
type CodecEmulator struct {
|
||||
server *Server
|
||||
serviceMethod string
|
||||
args *Args
|
||||
reply *Reply
|
||||
err error
|
||||
}
|
||||
|
||||
func (codec *CodecEmulator) Call(ctx context.Context, serviceMethod string, args *Args, reply *Reply) error {
|
||||
codec.serviceMethod = serviceMethod
|
||||
codec.args = args
|
||||
codec.reply = reply
|
||||
codec.err = nil
|
||||
var serverError error
|
||||
if codec.server == nil {
|
||||
serverError = ServeRequest(codec)
|
||||
} else {
|
||||
serverError = codec.server.ServeRequest(codec)
|
||||
}
|
||||
if codec.err == nil && serverError != nil {
|
||||
codec.err = serverError
|
||||
}
|
||||
return codec.err
|
||||
}
|
||||
|
||||
func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
|
||||
req.ServiceMethod = codec.serviceMethod
|
||||
req.Seq = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
|
||||
if codec.args == nil {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
*(argv.(*Args)) = *codec.args
|
||||
return nil
|
||||
}
|
||||
|
||||
func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}, last bool) error {
|
||||
if resp.Error != "" {
|
||||
codec.err = errors.New(resp.Error)
|
||||
} else {
|
||||
*codec.reply = *(reply.(*Reply))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (codec *CodecEmulator) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestServeRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
once.Do(startServer)
|
||||
testServeRequest(ctx, t, nil)
|
||||
newOnce.Do(startNewServer)
|
||||
testServeRequest(ctx, t, newServer)
|
||||
}
|
||||
|
||||
func testServeRequest(ctx context.Context, t *testing.T, server *Server) {
|
||||
client := CodecEmulator{server: server}
|
||||
|
||||
args := &Args{7, 8}
|
||||
reply := new(Reply)
|
||||
err := client.Call(ctx, "Arith.Add", args, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Add: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A+args.B {
|
||||
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
|
||||
}
|
||||
|
||||
err = client.Call(ctx, "Arith.Add", nil, reply)
|
||||
if err == nil {
|
||||
t.Errorf("expected error calling Arith.Add with nil arg")
|
||||
}
|
||||
}
|
||||
|
||||
type ReplyNotPointer int
|
||||
type ArgNotPublic int
|
||||
type ReplyNotPublic int
|
||||
type local struct{}
|
||||
|
||||
func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check that registration handles lots of bad methods and a type with no suitable methods.
|
||||
func TestRegistrationError(t *testing.T) {
|
||||
err := Register(new(ReplyNotPointer))
|
||||
if err == nil {
|
||||
t.Errorf("expected error registering ReplyNotPointer")
|
||||
}
|
||||
err = Register(new(ArgNotPublic))
|
||||
if err == nil {
|
||||
t.Errorf("expected error registering ArgNotPublic")
|
||||
}
|
||||
err = Register(new(ReplyNotPublic))
|
||||
if err == nil {
|
||||
t.Errorf("expected error registering ReplyNotPublic")
|
||||
}
|
||||
}
|
||||
|
||||
type WriteFailCodec int
|
||||
|
||||
func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
|
||||
// the panic caused by this error used to not unlock a lock.
|
||||
return errors.New("fail")
|
||||
}
|
||||
|
||||
func (WriteFailCodec) ReadResponseHeader(*Response) error {
|
||||
select {}
|
||||
}
|
||||
|
||||
func (WriteFailCodec) ReadResponseBody(interface{}) error {
|
||||
select {}
|
||||
}
|
||||
|
||||
func (WriteFailCodec) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSendDeadlock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := NewClientWithCodec(WriteFailCodec(0))
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
testSendDeadlock(ctx, client)
|
||||
testSendDeadlock(ctx, client)
|
||||
done <- true
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("deadlock")
|
||||
}
|
||||
}
|
||||
|
||||
func testSendDeadlock(ctx context.Context, client *Client) {
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
args := &Args{7, 8}
|
||||
reply := new(Reply)
|
||||
client.Call(ctx, "Arith.Add", args, reply)
|
||||
}
|
||||
|
||||
func dialDirect() (*Client, error) {
|
||||
return Dial("tcp", serverAddr)
|
||||
}
|
||||
|
||||
func dialHTTP() (*Client, error) {
|
||||
return DialHTTP("tcp", httpServerAddr)
|
||||
}
|
||||
|
||||
func countMallocs(ctx context.Context, dial func() (*Client, error), t *testing.T) uint64 {
|
||||
once.Do(startServer)
|
||||
client, err := dial()
|
||||
if err != nil {
|
||||
t.Fatal("error dialing", err)
|
||||
}
|
||||
args := &Args{7, 8}
|
||||
reply := new(Reply)
|
||||
memstats := new(runtime.MemStats)
|
||||
runtime.ReadMemStats(memstats)
|
||||
mallocs := 0 - memstats.Mallocs
|
||||
const count = 100
|
||||
for i := 0; i < count; i++ {
|
||||
err := client.Call(ctx, "Arith.Add", args, reply)
|
||||
if err != nil {
|
||||
t.Errorf("Add: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A+args.B {
|
||||
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
|
||||
}
|
||||
}
|
||||
runtime.ReadMemStats(memstats)
|
||||
mallocs += memstats.Mallocs
|
||||
return mallocs / count
|
||||
}
|
||||
|
||||
func TestCountMallocs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(ctx, dialDirect, t))
|
||||
}
|
||||
|
||||
func TestCountMallocsOverHTTP(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(ctx, dialHTTP, t))
|
||||
}
|
||||
|
||||
type writeCrasher struct {
|
||||
done chan bool
|
||||
}
|
||||
|
||||
func (writeCrasher) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *writeCrasher) Read(p []byte) (int, error) {
|
||||
<-w.done
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (writeCrasher) Write(p []byte) (int, error) {
|
||||
return 0, errors.New("fake write failure")
|
||||
}
|
||||
|
||||
func TestClientWriteError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
w := &writeCrasher{done: make(chan bool)}
|
||||
client := NewClient(w)
|
||||
res := false
|
||||
err := client.Call(ctx, "foo", 1, &res)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if err.Error() != "fake write failure" {
|
||||
t.Error("unexpected value of error:", err)
|
||||
}
|
||||
w.done <- true
|
||||
}
|
||||
|
||||
func benchmarkEndToEnd(ctx context.Context, dial func() (*Client, error), b *testing.B) {
|
||||
b.StopTimer()
|
||||
once.Do(startServer)
|
||||
client, err := dial()
|
||||
if err != nil {
|
||||
b.Fatal("error dialing:", err)
|
||||
}
|
||||
|
||||
// Synchronous calls
|
||||
args := &Args{7, 8}
|
||||
procs := runtime.GOMAXPROCS(-1)
|
||||
N := int32(b.N)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(procs)
|
||||
b.StartTimer()
|
||||
|
||||
for p := 0; p < procs; p++ {
|
||||
go func() {
|
||||
reply := new(Reply)
|
||||
for atomic.AddInt32(&N, -1) >= 0 {
|
||||
err := client.Call(ctx, "Arith.Add", args, reply)
|
||||
if err != nil {
|
||||
b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
|
||||
}
|
||||
if reply.C != args.A+args.B {
|
||||
b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func benchmarkEndToEndAsync(ctx context.Context, dial func() (*Client, error), b *testing.B) {
|
||||
const MaxConcurrentCalls = 100
|
||||
b.StopTimer()
|
||||
once.Do(startServer)
|
||||
client, err := dial()
|
||||
if err != nil {
|
||||
b.Fatal("error dialing:", err)
|
||||
}
|
||||
|
||||
// Asynchronous calls
|
||||
args := &Args{7, 8}
|
||||
procs := 4 * runtime.GOMAXPROCS(-1)
|
||||
send := int32(b.N)
|
||||
recv := int32(b.N)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(procs)
|
||||
gate := make(chan bool, MaxConcurrentCalls)
|
||||
res := make(chan *Call, MaxConcurrentCalls)
|
||||
b.StartTimer()
|
||||
|
||||
for p := 0; p < procs; p++ {
|
||||
go func() {
|
||||
for atomic.AddInt32(&send, -1) >= 0 {
|
||||
gate <- true
|
||||
reply := new(Reply)
|
||||
client.Go(ctx, "Arith.Add", args, reply, res)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for call := range res {
|
||||
A := call.Args.(*Args).A
|
||||
B := call.Args.(*Args).B
|
||||
C := call.Reply.(*Reply).C
|
||||
if A+B != C {
|
||||
b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
|
||||
}
|
||||
<-gate
|
||||
if atomic.AddInt32(&recv, -1) == 0 {
|
||||
close(res)
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkEndToEnd(b *testing.B) {
|
||||
benchmarkEndToEnd(context.Background(), dialDirect, b)
|
||||
}
|
||||
|
||||
func BenchmarkEndToEndHTTP(b *testing.B) {
|
||||
benchmarkEndToEnd(context.Background(), dialHTTP, b)
|
||||
}
|
||||
|
||||
func BenchmarkEndToEndAsync(b *testing.B) {
|
||||
benchmarkEndToEndAsync(context.Background(), dialDirect, b)
|
||||
}
|
||||
|
||||
func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
|
||||
benchmarkEndToEndAsync(context.Background(), dialHTTP, b)
|
||||
}
|
|
@ -1,227 +0,0 @@
|
|||
package rpcplus
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
httpPath = "/srpc"
|
||||
)
|
||||
|
||||
type StreamingArgs struct {
|
||||
A int
|
||||
Count int
|
||||
// next two values have to be between 0 and Count-2 to trigger anything
|
||||
ErrorAt int // will trigger an error at the given spot,
|
||||
BadTypeAt int // will send the wrong type in sendReply
|
||||
}
|
||||
|
||||
type StreamingReply struct {
|
||||
C int
|
||||
Index int
|
||||
}
|
||||
|
||||
var errTriggeredInTheMiddle = errors.New("triggered error in middle")
|
||||
|
||||
type StreamingArith int
|
||||
|
||||
func (t *StreamingArith) Thrive(args StreamingArgs, sendReply func(reply interface{}) error) error {
|
||||
|
||||
for i := 0; i < args.Count; i++ {
|
||||
if i == args.ErrorAt {
|
||||
return errTriggeredInTheMiddle
|
||||
}
|
||||
if i == args.BadTypeAt {
|
||||
// send args instead of response
|
||||
sr := new(StreamingArgs)
|
||||
err := sendReply(sr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// log.Println(" Sending sample", i)
|
||||
sr := &StreamingReply{C: args.A, Index: i}
|
||||
err := sendReply(sr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// make a server, a cient, and connect them
|
||||
func makeLink(t *testing.T) (client *Client) {
|
||||
// start a server
|
||||
server := NewServer()
|
||||
err := server.Register(new(StreamingArith))
|
||||
if err != nil {
|
||||
t.Fatal("Register failed", err)
|
||||
}
|
||||
|
||||
// listen and handle queries
|
||||
var l net.Listener
|
||||
l, serverAddr = listenTCP()
|
||||
log.Println("Test RPC server listening on", serverAddr)
|
||||
go server.Accept(l)
|
||||
|
||||
// dial the client
|
||||
client, err = Dial("tcp", serverAddr)
|
||||
if err != nil {
|
||||
t.Fatal("dialing", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// this is a specific function so we can call it to check the link
|
||||
// is still active and well
|
||||
func callOnceAndCheck(t *testing.T, client *Client) {
|
||||
|
||||
args := &StreamingArgs{3, 5, -1, -1}
|
||||
rowChan := make(chan *StreamingReply, 10)
|
||||
c := client.StreamGo("StreamingArith.Thrive", args, rowChan)
|
||||
|
||||
count := 0
|
||||
for row := range rowChan {
|
||||
if row.Index != count {
|
||||
t.Fatal("unexpected value:", row.Index)
|
||||
}
|
||||
count++
|
||||
|
||||
// log.Println("Values: ", row.C, row.Index)
|
||||
}
|
||||
|
||||
if c.Error != nil {
|
||||
t.Fatal("unexpected error:", c.Error.Error())
|
||||
}
|
||||
|
||||
if count != 5 {
|
||||
t.Fatal("Didn't receive the right number of packets back:", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingRpc(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if testing.Short() {
|
||||
t.Skip("skipping wait-based test in short mode.")
|
||||
}
|
||||
|
||||
client := makeLink(t)
|
||||
|
||||
// Nonexistent method
|
||||
args := &StreamingArgs{7, 10, -1, -1}
|
||||
reply := new(StreamingReply)
|
||||
err := client.Call(ctx, "StreamingArith.BadOperation", args, reply)
|
||||
// expect an error
|
||||
if err == nil {
|
||||
t.Error("BadOperation: expected error")
|
||||
} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
|
||||
t.Errorf("BadOperation: expected can't find method error; got %q", err)
|
||||
}
|
||||
|
||||
// call that works
|
||||
callOnceAndCheck(t, client)
|
||||
|
||||
// call that may block forever (but won't!)
|
||||
args = &StreamingArgs{3, 100, -1, -1} // 100 is greater than the next 10
|
||||
rowChan := make(chan *StreamingReply, 10)
|
||||
client.StreamGo("StreamingArith.Thrive", args, rowChan)
|
||||
// read one guy, sleep a bit to make sure everything went
|
||||
// through, then close
|
||||
_, ok := <-rowChan
|
||||
if !ok {
|
||||
t.Fatal("unexpected closed channel")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// log.Println("Closing")
|
||||
client.Close()
|
||||
for range rowChan {
|
||||
}
|
||||
// log.Println("Closed")
|
||||
|
||||
// the sleep here is intended to show the log at the end of the input()
|
||||
// go routine, to make sure it existed. Not sure how to test it
|
||||
// programmatically (short of having a counter in the library
|
||||
// on how many input() threads we have?)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
func TestInterruptedCallByServer(t *testing.T) {
|
||||
|
||||
// make our client
|
||||
client := makeLink(t)
|
||||
|
||||
args := &StreamingArgs{3, 100, 30, -1} // 30 elements back, then error
|
||||
rowChan := make(chan *StreamingReply, 10)
|
||||
c := client.StreamGo("StreamingArith.Thrive", args, rowChan)
|
||||
|
||||
// check we get the error at the 30th call exactly
|
||||
count := 0
|
||||
for row := range rowChan {
|
||||
if row.Index != count {
|
||||
t.Fatal("unexpected value:", row.Index)
|
||||
}
|
||||
count++
|
||||
}
|
||||
if count != 30 {
|
||||
t.Fatal("received error before the right time:", count)
|
||||
}
|
||||
if c.Error.Error() != errTriggeredInTheMiddle.Error() {
|
||||
t.Fatal("received wrong error message:", c.Error)
|
||||
}
|
||||
|
||||
// make sure the wire is still in good shape
|
||||
callOnceAndCheck(t, client)
|
||||
|
||||
// then check a call that doesn't send anything, but errors out first
|
||||
args = &StreamingArgs{3, 100, 0, -1}
|
||||
rowChan = make(chan *StreamingReply, 10)
|
||||
c = client.StreamGo("StreamingArith.Thrive", args, rowChan)
|
||||
_, ok := <-rowChan
|
||||
if ok {
|
||||
t.Fatal("expected closed channel")
|
||||
}
|
||||
if c.Error.Error() != errTriggeredInTheMiddle.Error() {
|
||||
t.Fatal("received wrong error message:", c.Error)
|
||||
}
|
||||
|
||||
// make sure the wire is still in good shape
|
||||
callOnceAndCheck(t, client)
|
||||
}
|
||||
|
||||
func TestBadTypeByServer(t *testing.T) {
|
||||
|
||||
// make our client
|
||||
client := makeLink(t)
|
||||
|
||||
args := &StreamingArgs{3, 100, -1, 30} // 30 elements back, then bad
|
||||
rowChan := make(chan *StreamingReply, 10)
|
||||
c := client.StreamGo("StreamingArith.Thrive", args, rowChan)
|
||||
|
||||
// check we get the error at the 30th call exactly
|
||||
count := 0
|
||||
for row := range rowChan {
|
||||
if row.Index != count {
|
||||
t.Fatal("unexpected value:", row.Index)
|
||||
}
|
||||
count++
|
||||
}
|
||||
if count != 30 {
|
||||
t.Fatal("received error before the right time:", count)
|
||||
}
|
||||
if c.Error.Error() != "rpc: passing wrong type to sendReply" {
|
||||
t.Fatal("received wrong error message:", c.Error)
|
||||
}
|
||||
|
||||
// make sure the wire is still in good shape
|
||||
callOnceAndCheck(t, client)
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package bsonrpc provides codecs for bsonrpc communication
|
||||
package bsonrpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcwrap"
|
||||
)
|
||||
|
||||
const (
|
||||
codecName = "bson"
|
||||
)
|
||||
|
||||
// ClientCodec holds required parameters for providing a client codec for
|
||||
// bsonrpc
|
||||
type ClientCodec struct {
|
||||
rwc io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// NewClientCodec creates a new client codec for bsonrpc communication
|
||||
func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
|
||||
return &ClientCodec{conn}
|
||||
}
|
||||
|
||||
// DefaultBufferSize holds the default value for buffer size
|
||||
const DefaultBufferSize = 4096
|
||||
|
||||
// WriteRequest sends the request to the server
|
||||
func (cc *ClientCodec) WriteRequest(r *rpc.Request, body interface{}) error {
|
||||
buf := bytes2.NewChunkedWriter(DefaultBufferSize)
|
||||
if err := bson.MarshalToBuffer(buf, &RequestBson{r}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bson.MarshalToBuffer(buf, body); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := buf.WriteTo(cc.rwc)
|
||||
return err
|
||||
}
|
||||
|
||||
// ReadResponseHeader reads the header of server response
|
||||
func (cc *ClientCodec) ReadResponseHeader(r *rpc.Response) error {
|
||||
return bson.UnmarshalFromStream(cc.rwc, &ResponseBson{r})
|
||||
}
|
||||
|
||||
// ReadResponseBody reads the body of server response
|
||||
func (cc *ClientCodec) ReadResponseBody(body interface{}) error {
|
||||
return bson.UnmarshalFromStream(cc.rwc, body)
|
||||
}
|
||||
|
||||
// Close closes the codec
|
||||
func (cc *ClientCodec) Close() error {
|
||||
return cc.rwc.Close()
|
||||
}
|
||||
|
||||
// ServerCodec holds required parameters for providing a server codec for
|
||||
// bsonrpc
|
||||
type ServerCodec struct {
|
||||
rwc io.ReadWriteCloser
|
||||
cw *bytes2.ChunkedWriter
|
||||
}
|
||||
|
||||
// NewServerCodec creates a new server codec for bsonrpc communication
|
||||
func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
|
||||
return &ServerCodec{conn, bytes2.NewChunkedWriter(DefaultBufferSize)}
|
||||
}
|
||||
|
||||
// ReadRequestHeader reads the header of the request
|
||||
func (sc *ServerCodec) ReadRequestHeader(r *rpc.Request) error {
|
||||
return bson.UnmarshalFromStream(sc.rwc, &RequestBson{r})
|
||||
}
|
||||
|
||||
// ReadRequestBody reads the body of the request
|
||||
func (sc *ServerCodec) ReadRequestBody(body interface{}) error {
|
||||
return bson.UnmarshalFromStream(sc.rwc, body)
|
||||
}
|
||||
|
||||
// WriteResponse send the response of the request to the client
|
||||
func (sc *ServerCodec) WriteResponse(r *rpc.Response, body interface{}, last bool) error {
|
||||
if err := bson.MarshalToBuffer(sc.cw, &ResponseBson{r}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bson.MarshalToBuffer(sc.cw, body); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := sc.cw.WriteTo(sc.rwc)
|
||||
sc.cw.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the codec
|
||||
func (sc *ServerCodec) Close() error {
|
||||
return sc.rwc.Close()
|
||||
}
|
||||
|
||||
// DialHTTP dials a HTTP endpoint with bsonrpc codec
|
||||
func DialHTTP(network, address string, connectTimeout time.Duration) (*rpc.Client, error) {
|
||||
return rpcwrap.DialHTTP(network, address, codecName, NewClientCodec, connectTimeout)
|
||||
}
|
||||
|
||||
// ServeRPC serves bsonrpc codec with the default rpc server
|
||||
func ServeRPC() {
|
||||
rpcwrap.ServeRPC(codecName, NewServerCodec)
|
||||
}
|
||||
|
||||
// ServeCustomRPC serves bsonrpc codec with a custom rpc server
|
||||
func ServeCustomRPC(handler *http.ServeMux, server *rpc.Server) {
|
||||
rpcwrap.ServeCustomRPC(handler, server, codecName, NewServerCodec)
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bsonrpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
)
|
||||
|
||||
// RequestBson provides bson rpc request parameters
|
||||
type RequestBson struct {
|
||||
*rpc.Request
|
||||
}
|
||||
|
||||
// MarshalBson marshals request to the given writer with optional prefix
|
||||
func (req *RequestBson) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
bson.EncodeString(buf, "ServiceMethod", req.ServiceMethod)
|
||||
bson.EncodeUint64(buf, "Seq", req.Seq)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson unmarshals request to the given byte buffer as verifying the
|
||||
// kind
|
||||
func (req *RequestBson) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
bson.VerifyObject(kind)
|
||||
bson.Next(buf, 4)
|
||||
|
||||
kind = bson.NextByte(buf)
|
||||
for kind != bson.EOO {
|
||||
key := bson.ReadCString(buf)
|
||||
switch key {
|
||||
case "ServiceMethod":
|
||||
req.ServiceMethod = bson.DecodeString(buf, kind)
|
||||
case "Seq":
|
||||
req.Seq = bson.DecodeUint64(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
kind = bson.NextByte(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseBson provides bson rpc request parameters
|
||||
type ResponseBson struct {
|
||||
*rpc.Response
|
||||
}
|
||||
|
||||
// MarshalBson marshals response to the given writer with optional prefix
|
||||
func (resp *ResponseBson) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
bson.EncodeString(buf, "ServiceMethod", resp.ServiceMethod)
|
||||
bson.EncodeUint64(buf, "Seq", resp.Seq)
|
||||
bson.EncodeString(buf, "Error", resp.Error)
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson unmarshals response to the given byte buffer as verifying the
|
||||
// kind
|
||||
func (resp *ResponseBson) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
bson.VerifyObject(kind)
|
||||
bson.Next(buf, 4)
|
||||
|
||||
kind = bson.NextByte(buf)
|
||||
for kind != bson.EOO {
|
||||
key := bson.ReadCString(buf)
|
||||
switch key {
|
||||
case "ServiceMethod":
|
||||
resp.ServiceMethod = bson.DecodeString(buf, kind)
|
||||
case "Seq":
|
||||
resp.Seq = bson.DecodeUint64(buf, kind)
|
||||
case "Error":
|
||||
resp.Error = bson.DecodeString(buf, kind)
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
kind = bson.NextByte(buf)
|
||||
}
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package bsonrpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
)
|
||||
|
||||
type reflectRequestBson struct {
|
||||
ServiceMethod string
|
||||
Seq uint64
|
||||
}
|
||||
|
||||
type extraRequestBson struct {
|
||||
Extra int
|
||||
ServiceMethod string
|
||||
Seq uint64
|
||||
}
|
||||
|
||||
func TestRequestBson(t *testing.T) {
|
||||
reflected, err := bson.Marshal(&reflectRequestBson{
|
||||
ServiceMethod: "aa",
|
||||
Seq: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
want := string(reflected)
|
||||
|
||||
custom := RequestBson{
|
||||
&rpc.Request{
|
||||
ServiceMethod: "aa",
|
||||
Seq: 1,
|
||||
},
|
||||
}
|
||||
encoded, err := bson.Marshal(&custom)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
got := string(encoded)
|
||||
if want != got {
|
||||
t.Errorf("want\n%#v, got\n%#v", want, got)
|
||||
}
|
||||
|
||||
unmarshalled := RequestBson{Request: new(rpc.Request)}
|
||||
err = bson.Unmarshal(encoded, &unmarshalled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if custom.ServiceMethod != unmarshalled.ServiceMethod {
|
||||
t.Errorf("want %v, got %#v", custom.ServiceMethod, unmarshalled.ServiceMethod)
|
||||
}
|
||||
if custom.Seq != unmarshalled.Seq {
|
||||
t.Errorf("want %v, got %#v", custom.Seq, unmarshalled.Seq)
|
||||
}
|
||||
|
||||
extra, err := bson.Marshal(&extraRequestBson{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = bson.Unmarshal(extra, &unmarshalled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
type reflectResponseBson struct {
|
||||
ServiceMethod string
|
||||
Seq uint64
|
||||
Error string
|
||||
}
|
||||
|
||||
type extraResponseBson struct {
|
||||
Extra int
|
||||
ServiceMethod string
|
||||
Seq uint64
|
||||
Error string
|
||||
}
|
||||
|
||||
func TestResponseBson(t *testing.T) {
|
||||
reflected, err := bson.Marshal(&reflectResponseBson{
|
||||
ServiceMethod: "aa",
|
||||
Seq: 1,
|
||||
Error: "err",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
want := string(reflected)
|
||||
|
||||
custom := ResponseBson{
|
||||
&rpc.Response{
|
||||
ServiceMethod: "aa",
|
||||
Seq: 1,
|
||||
Error: "err",
|
||||
},
|
||||
}
|
||||
encoded, err := bson.Marshal(&custom)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
got := string(encoded)
|
||||
if want != got {
|
||||
t.Errorf("want\n%#v, got\n%#v", want, got)
|
||||
}
|
||||
|
||||
unmarshalled := ResponseBson{Response: new(rpc.Response)}
|
||||
err = bson.Unmarshal(encoded, &unmarshalled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if custom.ServiceMethod != unmarshalled.ServiceMethod {
|
||||
t.Errorf("want %v, got %#v", custom.ServiceMethod, unmarshalled.ServiceMethod)
|
||||
}
|
||||
if custom.Seq != unmarshalled.Seq {
|
||||
t.Errorf("want %v, got %#v", custom.Seq, unmarshalled.Seq)
|
||||
}
|
||||
if custom.Error != unmarshalled.Error {
|
||||
t.Errorf("want %v, got %#v", custom.Error, unmarshalled.Error)
|
||||
}
|
||||
|
||||
extra, err := bson.Marshal(&extraResponseBson{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = bson.Unmarshal(extra, &unmarshalled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package jsonrpc provides wrappers for json rpc communication
|
||||
package jsonrpc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
oldjson "github.com/youtube/vitess/go/rpcplus/jsonrpc"
|
||||
"github.com/youtube/vitess/go/rpcwrap"
|
||||
)
|
||||
|
||||
// DialHTTP dials a json rpc HTTP endpoint
|
||||
func DialHTTP(network, address string, connectTimeout time.Duration) (*rpc.Client, error) {
|
||||
return rpcwrap.DialHTTP(network, address, "json", oldjson.NewClientCodec, connectTimeout)
|
||||
}
|
||||
|
||||
// ServeRPC serves a json rpc endpoint using default server
|
||||
func ServeRPC() {
|
||||
rpcwrap.ServeRPC("json", oldjson.NewServerCodec)
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
// Package proto provides protocol functions
|
||||
package proto
|
||||
|
||||
import "golang.org/x/net/context"
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
remoteAddrKey contextKey = 0
|
||||
)
|
||||
|
||||
// RemoteAddr accesses the remote address of the rpcwrap call connection in this context.
|
||||
func RemoteAddr(ctx context.Context) (addr string, ok bool) {
|
||||
val := ctx.Value(remoteAddrKey)
|
||||
if val == nil {
|
||||
return "", false
|
||||
}
|
||||
addr, ok = val.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return addr, true
|
||||
}
|
||||
|
||||
// NewContext creates a default context satisfying context.Context
|
||||
func NewContext(remoteAddr string) context.Context {
|
||||
return context.WithValue(context.Background(), remoteAddrKey, remoteAddr)
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
package rpcwrap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcplus/jsonrpc"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
A, B int
|
||||
}
|
||||
|
||||
type Arith int
|
||||
|
||||
func (t *Arith) Success(ctx context.Context, args *Request, reply *int) error {
|
||||
*reply = args.A * args.B
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Arith) Fail(ctx context.Context, args *Request, reply *int) error {
|
||||
return errors.New("fail")
|
||||
}
|
||||
|
||||
func (t *Arith) Context(ctx context.Context, args *Request, reply *int) error {
|
||||
if data := ctx.Value("context"); data == nil {
|
||||
return errors.New("context is not set")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func startListeningWithContext(ctx context.Context) net.Listener {
|
||||
server := rpcplus.NewServer()
|
||||
server.Register(new(Arith))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
contextCreator := func(req *http.Request) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
ServeHTTPRPC(
|
||||
mux, // httpmuxer
|
||||
server, // rpcserver
|
||||
"json", // codec name
|
||||
jsonrpc.NewServerCodec, // jsoncodec
|
||||
contextCreator, // contextCreator
|
||||
)
|
||||
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go http.Serve(l, mux)
|
||||
return l
|
||||
}
|
||||
|
||||
func startListening() net.Listener {
|
||||
return startListeningWithContext(context.Background())
|
||||
}
|
||||
|
||||
func createAddr(l net.Listener) string {
|
||||
return "http://" + l.Addr().String() + GetRpcPath("json")
|
||||
}
|
||||
|
||||
func TestSuccess(t *testing.T) {
|
||||
l := startListening()
|
||||
defer l.Close()
|
||||
|
||||
params := &Request{
|
||||
A: 7,
|
||||
B: 8,
|
||||
}
|
||||
|
||||
var r int
|
||||
|
||||
err := jsonrpc.NewHTTPClient(createAddr(l)).Call("Arith.Success", params, &r)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if r != 56 {
|
||||
t.Fatalf("Expected: 56, but got: %d", r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFail(t *testing.T) {
|
||||
l := startListening()
|
||||
defer l.Close()
|
||||
|
||||
params := &Request{
|
||||
A: 7,
|
||||
B: 8,
|
||||
}
|
||||
|
||||
var r int
|
||||
|
||||
err := jsonrpc.NewHTTPClient(createAddr(l)).Call("Arith.Fail", params, &r)
|
||||
if err == nil {
|
||||
t.Fatal("Expected a non-nil err")
|
||||
}
|
||||
|
||||
if err.Error() != "fail" {
|
||||
t.Fatalf("Expected \"fail\" as err message, but got %s", err.Error())
|
||||
}
|
||||
|
||||
if r != 0 {
|
||||
t.Fatalf("Expected: 0, but got: %d", r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "context", "value")
|
||||
l := startListeningWithContext(ctx)
|
||||
defer l.Close()
|
||||
|
||||
params := &Request{
|
||||
A: 7,
|
||||
B: 8,
|
||||
}
|
||||
|
||||
var r int
|
||||
|
||||
err := jsonrpc.NewHTTPClient(createAddr(l)).Call("Arith.Context", params, &r)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
}
|
|
@ -1,209 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package rpcwrap provides wrappers for rpcplus package
|
||||
package rpcwrap
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcwrap/proto"
|
||||
"github.com/youtube/vitess/go/stats"
|
||||
)
|
||||
|
||||
const (
|
||||
connected = "200 Connected to Go RPC"
|
||||
)
|
||||
|
||||
var (
|
||||
connCount = stats.NewInt("connection-count")
|
||||
connAccepted = stats.NewInt("connection-accepted")
|
||||
)
|
||||
|
||||
// ClientCodecFactory holds pattern for other client codec factories
|
||||
type ClientCodecFactory func(conn io.ReadWriteCloser) rpc.ClientCodec
|
||||
|
||||
// BufferedConnection holds connection data for codecs
|
||||
type BufferedConnection struct {
|
||||
isClosed bool
|
||||
*bufio.Reader
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
// NewBufferedConnection creates a new Buffered Connection
|
||||
func NewBufferedConnection(conn io.ReadWriteCloser) *BufferedConnection {
|
||||
connCount.Add(1)
|
||||
connAccepted.Add(1)
|
||||
return &BufferedConnection{false, bufio.NewReader(conn), conn}
|
||||
}
|
||||
|
||||
// Close closes the buffered connection
|
||||
// FIXME(sougou/szopa): Find a better way to track connection count.
|
||||
func (bc *BufferedConnection) Close() error {
|
||||
if !bc.isClosed {
|
||||
bc.isClosed = true
|
||||
connCount.Add(-1)
|
||||
}
|
||||
return bc.WriteCloser.Close()
|
||||
}
|
||||
|
||||
// DialHTTP connects to a go HTTP RPC server using the specified codec.
|
||||
// use 0 as connectTimeout for no timeout
|
||||
func DialHTTP(network, address, codecName string, cFactory ClientCodecFactory, connectTimeout time.Duration) (*rpc.Client, error) {
|
||||
return dialHTTP(network, address, codecName, cFactory, connectTimeout)
|
||||
}
|
||||
|
||||
func dialHTTP(network, address, codecName string, cFactory ClientCodecFactory, connectTimeout time.Duration) (*rpc.Client, error) {
|
||||
var err error
|
||||
var conn net.Conn
|
||||
if connectTimeout != 0 {
|
||||
conn, err = net.DialTimeout(network, address, connectTimeout)
|
||||
} else {
|
||||
conn, err = net.Dial(network, address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = io.WriteString(conn, "CONNECT "+GetRpcPath(codecName)+" HTTP/1.0\n\n")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Require successful HTTP response
|
||||
// before switching to RPC protocol.
|
||||
buffered := NewBufferedConnection(conn)
|
||||
resp, err := http.ReadResponse(buffered.Reader, &http.Request{Method: "CONNECT"})
|
||||
if err == nil && resp.Status == connected {
|
||||
return rpc.NewClientWithCodec(cFactory(buffered)), nil
|
||||
}
|
||||
if err == nil {
|
||||
err = errors.New("unexpected HTTP response: " + resp.Status)
|
||||
}
|
||||
conn.Close()
|
||||
return nil, &net.OpError{Op: "dial-http", Net: network + " " + address, Addr: nil, Err: err}
|
||||
}
|
||||
|
||||
// ServerCodecFactory holds pattern for other server codec factories
|
||||
type ServerCodecFactory func(conn io.ReadWriteCloser) rpc.ServerCodec
|
||||
|
||||
// ServeRPC handles rpc requests using the hijack scheme of rpc
|
||||
func ServeRPC(codecName string, cFactory ServerCodecFactory) {
|
||||
http.Handle(GetRpcPath(codecName), &rpcHandler{cFactory, rpc.DefaultServer})
|
||||
}
|
||||
|
||||
// ServeCustomRPC serves the given rpc requests with the provided ServeMux
|
||||
func ServeCustomRPC(handler *http.ServeMux, server *rpc.Server, codecName string, cFactory ServerCodecFactory) {
|
||||
handler.Handle(GetRpcPath(codecName), &rpcHandler{cFactory, server})
|
||||
}
|
||||
|
||||
// rpcHandler handles rpc queries for a 'CONNECT' method.
|
||||
type rpcHandler struct {
|
||||
cFactory ServerCodecFactory
|
||||
server *rpc.Server
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler's ServeHTTP
|
||||
func (h *rpcHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != "CONNECT" {
|
||||
c.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
c.WriteHeader(http.StatusMethodNotAllowed)
|
||||
io.WriteString(c, "405 must CONNECT\n")
|
||||
return
|
||||
}
|
||||
conn, _, err := c.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
log.Errorf("rpc hijacking %s: %v", req.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
|
||||
codec := h.cFactory(NewBufferedConnection(conn))
|
||||
ctx := proto.NewContext(req.RemoteAddr)
|
||||
h.server.ServeCodecWithContext(ctx, codec)
|
||||
}
|
||||
|
||||
// GetRpcPath returns the toplevel path used for serving RPCs over HTTP
|
||||
func GetRpcPath(codecName string) string {
|
||||
return "/_" + codecName + "_rpc_"
|
||||
}
|
||||
|
||||
// ServeHTTPRPC serves the given http rpc requests with the provided ServeMux
|
||||
func ServeHTTPRPC(
|
||||
handler *http.ServeMux,
|
||||
server *rpc.Server,
|
||||
codecName string,
|
||||
cFactory ServerCodecFactory,
|
||||
contextCreator func(*http.Request) context.Context) {
|
||||
|
||||
handler.Handle(
|
||||
GetRpcPath(codecName),
|
||||
&httpRPCHandler{
|
||||
cFactory: cFactory,
|
||||
server: server,
|
||||
contextCreator: contextCreator,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// httpRPCHandler handles rpc queries for a all types of HTTP requests, does not
|
||||
// maintain a persistent connection.
|
||||
type httpRPCHandler struct {
|
||||
cFactory ServerCodecFactory
|
||||
server *rpc.Server
|
||||
// contextCreator creates an application specific context, while creating
|
||||
// the context it should not read the request body nor write anything to
|
||||
// headers
|
||||
contextCreator func(*http.Request) context.Context
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler's ServeHTTP
|
||||
func (h *httpRPCHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
|
||||
codec := h.cFactory(&httpReadWriteCloser{rw: c, req: req})
|
||||
|
||||
var ctx context.Context
|
||||
|
||||
if h.contextCreator != nil {
|
||||
ctx = h.contextCreator(req)
|
||||
} else {
|
||||
ctx = proto.NewContext(req.RemoteAddr)
|
||||
}
|
||||
|
||||
h.server.ServeRequestWithContext(
|
||||
ctx,
|
||||
codec,
|
||||
)
|
||||
|
||||
codec.Close()
|
||||
}
|
||||
|
||||
// httpReadWriteCloser wraps http.ResponseWriter and http.Request, with the help
|
||||
// of those, implements ReadWriteCloser interface
|
||||
type httpReadWriteCloser struct {
|
||||
rw http.ResponseWriter
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
// Read implements Reader interface
|
||||
func (i *httpReadWriteCloser) Read(p []byte) (n int, err error) {
|
||||
return i.req.Body.Read(p)
|
||||
}
|
||||
|
||||
// Write implements Writer interface
|
||||
func (i *httpReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
return i.rw.Write(p)
|
||||
}
|
||||
|
||||
// Close implements Closer interface
|
||||
func (i *httpReadWriteCloser) Close() error {
|
||||
return i.req.Body.Close()
|
||||
}
|
|
@ -1,137 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqltypes
|
||||
|
||||
// File has been manually edited. Do not regenerate.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
querypb "github.com/youtube/vitess/go/vt/proto/query"
|
||||
)
|
||||
|
||||
// BSONField is a temporary struct for backward compatibility.
|
||||
type BSONField struct {
|
||||
Name string
|
||||
Type int64
|
||||
Flags int64
|
||||
}
|
||||
|
||||
// MarshalBson bson-encodes Result.
|
||||
func (result *Result) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
// []*query.Field
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "Fields")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
var f BSONField
|
||||
for _i, _v1 := range result.Fields {
|
||||
// *query.Field
|
||||
// This part was manually changed.
|
||||
f.Name = _v1.Name
|
||||
f.Type, f.Flags = TypeToMySQL(_v1.Type)
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, bson.Itoa(_i))
|
||||
bson.MarshalToBuffer(buf, &f)
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
bson.EncodeUint64(buf, "RowsAffected", result.RowsAffected)
|
||||
bson.EncodeUint64(buf, "InsertId", result.InsertID)
|
||||
// [][]Value
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "Rows")
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v2 := range result.Rows {
|
||||
// []Value
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, bson.Itoa(_i))
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
for _i, _v3 := range _v2 {
|
||||
_v3.MarshalBson(buf, bson.Itoa(_i))
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into Result.
|
||||
func (result *Result) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for Result", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
switch bson.ReadCString(buf) {
|
||||
case "Fields":
|
||||
// []*query.Field
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for result.Fields", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
result.Fields = make([]*querypb.Field, 0, 8)
|
||||
var f BSONField
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v1 *querypb.Field
|
||||
// *query.Field
|
||||
_v1 = new(querypb.Field)
|
||||
bson.UnmarshalFromBuffer(buf, &f)
|
||||
_v1.Name = f.Name
|
||||
_v1.Type = MySQLToType(f.Type, f.Flags)
|
||||
result.Fields = append(result.Fields, _v1)
|
||||
}
|
||||
}
|
||||
case "RowsAffected":
|
||||
result.RowsAffected = bson.DecodeUint64(buf, kind)
|
||||
case "InsertId":
|
||||
result.InsertID = bson.DecodeUint64(buf, kind)
|
||||
case "Rows":
|
||||
// [][]Value
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for result.Rows", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
result.Rows = make([][]Value, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v2 []Value
|
||||
// []Value
|
||||
if kind != bson.Null {
|
||||
if kind != bson.Array {
|
||||
panic(bson.NewBsonError("unexpected kind %v for _v2", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
_v2 = make([]Value, 0, 8)
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.SkipIndex(buf)
|
||||
var _v3 Value
|
||||
_v3.UnmarshalBson(buf, kind)
|
||||
_v2 = append(_v2, _v3)
|
||||
}
|
||||
}
|
||||
result.Rows = append(result.Rows, _v2)
|
||||
}
|
||||
}
|
||||
default:
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
querypb "github.com/youtube/vitess/go/vt/proto/query"
|
||||
)
|
||||
|
||||
type TestCase struct {
|
||||
qr Result
|
||||
encoded string
|
||||
}
|
||||
|
||||
var testcases = []TestCase{
|
||||
// Empty
|
||||
{
|
||||
qr: Result{},
|
||||
encoded: "E\x00\x00\x00\x04Fields\x00\x05\x00\x00\x00\x00\x12RowsAffected\x00\x00\x00\x00\x00\x00\x00\x00\x00\x12InsertId\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04Rows\x00\x05\x00\x00\x00\x00\x00",
|
||||
},
|
||||
// Only fields set
|
||||
{
|
||||
qr: Result{
|
||||
Fields: []*querypb.Field{
|
||||
{Name: "foo", Type: Int8},
|
||||
},
|
||||
},
|
||||
encoded: "x\x00\x00\x00\x04Fields\x008\x00\x00\x00\x030\x000\x00\x00\x00\x05Name\x00\x03\x00\x00\x00\x00foo\x12Type\x00\x01\x00\x00\x00\x00\x00\x00\x00\x12Flags\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x12RowsAffected\x00\x00\x00\x00\x00\x00\x00\x00\x00\x12InsertId\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04Rows\x00\x05\x00\x00\x00\x00\x00",
|
||||
},
|
||||
// two rows
|
||||
{
|
||||
qr: Result{
|
||||
Fields: []*querypb.Field{
|
||||
{Name: "foo", Type: VarChar},
|
||||
{Name: "bar", Type: Int64},
|
||||
{Name: "baz", Type: Float64},
|
||||
},
|
||||
Rows: [][]Value{
|
||||
{testVal(VarBinary, "abcd"), testVal(VarBinary, "1234"), testVal(VarBinary, "1.234")},
|
||||
{testVal(VarBinary, "efgh"), testVal(VarBinary, "5678"), testVal(VarBinary, "5.678")},
|
||||
},
|
||||
},
|
||||
encoded: "",
|
||||
},
|
||||
}
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
for caseno, tcase := range testcases {
|
||||
actual, err := bson.Marshal(&tcase.qr)
|
||||
if err != nil {
|
||||
t.Errorf("Error on %d: %v", caseno, err)
|
||||
}
|
||||
if tcase.encoded != "" && string(actual) != tcase.encoded {
|
||||
t.Errorf("Expecting vs actual for %d:\n%#v\n%#v", caseno, tcase.encoded, string(actual))
|
||||
}
|
||||
var newqr Result
|
||||
err = bson.Unmarshal(actual, &newqr)
|
||||
if err != nil {
|
||||
t.Errorf("Error on %d: %v", caseno, err)
|
||||
}
|
||||
if len(newqr.Fields) == 0 {
|
||||
newqr.Fields = nil
|
||||
}
|
||||
if len(newqr.Rows) == 0 {
|
||||
newqr.Rows = nil
|
||||
}
|
||||
if !reflect.DeepEqual(newqr, tcase.qr) {
|
||||
t.Errorf("Case: %d,\n%#v, want\n%#v", caseno, newqr, tcase.qr)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,7 +6,6 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -14,8 +13,6 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
"github.com/youtube/vitess/go/hack"
|
||||
querypb "github.com/youtube/vitess/go/vt/proto/query"
|
||||
)
|
||||
|
@ -308,32 +305,6 @@ func (v Value) IsBinary() bool {
|
|||
return IsBinary(v.typ)
|
||||
}
|
||||
|
||||
// MarshalBson marshals Value into bson.
|
||||
func (v Value) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
if key == "" {
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
defer lenWriter.Close()
|
||||
key = bson.MAGICTAG
|
||||
}
|
||||
if v.IsNull() {
|
||||
bson.EncodePrefix(buf, bson.Null, key)
|
||||
} else {
|
||||
bson.EncodeBinary(buf, key, v.val)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalBson unmarshals from bson.
|
||||
func (v *Value) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
if kind == bson.EOO {
|
||||
bson.Next(buf, 4)
|
||||
kind = bson.NextByte(buf)
|
||||
bson.ReadCString(buf)
|
||||
}
|
||||
if kind != bson.Null {
|
||||
*v = MakeString(bson.DecodeBinary(buf, kind))
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON should only be used for testing.
|
||||
// It's not a complete implementation.
|
||||
func (v Value) MarshalJSON() ([]byte, error) {
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"encoding/hex"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -335,19 +334,41 @@ func (blp *BinlogPlayer) ApplyBinlogEvents(ctx context.Context) error {
|
|||
}()
|
||||
}
|
||||
|
||||
var responseChan chan *binlogdatapb.BinlogTransaction
|
||||
var errFunc ErrFunc
|
||||
var stream BinlogTransactionStream
|
||||
if len(blp.tables) > 0 {
|
||||
responseChan, errFunc, err = blplClient.StreamTables(ctx, replication.EncodePosition(blp.position), blp.tables, blp.defaultCharset)
|
||||
stream, err = blplClient.StreamTables(ctx, replication.EncodePosition(blp.position), blp.tables, blp.defaultCharset)
|
||||
} else {
|
||||
responseChan, errFunc, err = blplClient.StreamKeyRange(ctx, replication.EncodePosition(blp.position), blp.keyRange, blp.defaultCharset)
|
||||
stream, err = blplClient.StreamKeyRange(ctx, replication.EncodePosition(blp.position), blp.keyRange, blp.defaultCharset)
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("Error sending streaming query to binlog server: %v", err)
|
||||
return fmt.Errorf("error sending streaming query to binlog server: %v", err)
|
||||
}
|
||||
|
||||
for response := range responseChan {
|
||||
for {
|
||||
// get the response
|
||||
response, err := stream.Recv()
|
||||
if err != nil {
|
||||
switch err {
|
||||
case context.Canceled:
|
||||
return nil
|
||||
default:
|
||||
// if the context is canceled, we
|
||||
// return nil (some RPC
|
||||
// implementations will remap the
|
||||
// context error to their own errors)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
}
|
||||
return fmt.Errorf("Error received from Stream %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// process the transaction
|
||||
for {
|
||||
ok, err = blp.processTransaction(response)
|
||||
if err != nil {
|
||||
|
@ -366,24 +387,6 @@ func (blp *BinlogPlayer) ApplyBinlogEvents(ctx context.Context) error {
|
|||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
switch err := errFunc(); err {
|
||||
case nil:
|
||||
return io.EOF
|
||||
case context.Canceled:
|
||||
return nil
|
||||
default:
|
||||
// if the context is canceled, we return nil (some RPC
|
||||
// implementations will remap the context error to their own
|
||||
// errors)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
}
|
||||
return fmt.Errorf("Error received from ServeBinlog %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateBlpCheckpoint returns the statements required to create
|
||||
|
|
|
@ -22,8 +22,21 @@ This file contains the API and registration mechanism for binlog player client.
|
|||
|
||||
var binlogPlayerProtocol = flag.String("binlog_player_protocol", "grpc", "the protocol to download binlogs from a vttablet")
|
||||
|
||||
// ErrFunc is a return value for streaming events
|
||||
type ErrFunc func() error
|
||||
// StreamEventStream is the interface of the object returned by
|
||||
// ServeUpdateStream
|
||||
type StreamEventStream interface {
|
||||
// Recv returns the next StreamEvent, or an error if the RPC was
|
||||
// interrupted.
|
||||
Recv() (*binlogdatapb.StreamEvent, error)
|
||||
}
|
||||
|
||||
// BinlogTransactionStream is the interface of the object returned by
|
||||
// StreamTables and StreamKeyRange
|
||||
type BinlogTransactionStream interface {
|
||||
// Recv returns the next BinlogTransaction, or an error if the RPC was
|
||||
// interrupted.
|
||||
Recv() (*binlogdatapb.BinlogTransaction, error)
|
||||
}
|
||||
|
||||
// Client is the interface all clients must satisfy
|
||||
type Client interface {
|
||||
|
@ -35,15 +48,15 @@ type Client interface {
|
|||
|
||||
// Ask the server to stream binlog updates.
|
||||
// Should return context.Canceled if the context is canceled.
|
||||
ServeUpdateStream(ctx context.Context, position string) (chan *binlogdatapb.StreamEvent, ErrFunc, error)
|
||||
ServeUpdateStream(ctx context.Context, position string) (StreamEventStream, error)
|
||||
|
||||
// Ask the server to stream updates related to the provided tables.
|
||||
// Should return context.Canceled if the context is canceled.
|
||||
StreamTables(ctx context.Context, position string, tables []string, charset *binlogdatapb.Charset) (chan *binlogdatapb.BinlogTransaction, ErrFunc, error)
|
||||
StreamTables(ctx context.Context, position string, tables []string, charset *binlogdatapb.Charset) (BinlogTransactionStream, error)
|
||||
|
||||
// Ask the server to stream updates related to the provided keyrange.
|
||||
// Should return context.Canceled if the context is canceled.
|
||||
StreamKeyRange(ctx context.Context, position string, keyRange *topodatapb.KeyRange, charset *binlogdatapb.Charset) (chan *binlogdatapb.BinlogTransaction, ErrFunc, error)
|
||||
StreamKeyRange(ctx context.Context, position string, keyRange *topodatapb.KeyRange, charset *binlogdatapb.Charset) (BinlogTransactionStream, error)
|
||||
}
|
||||
|
||||
// ClientFactory is the factory method to create a Client
|
||||
|
|
|
@ -90,37 +90,34 @@ func (fake *FakeBinlogStreamer) ServeUpdateStream(position string, sendReply fun
|
|||
|
||||
func testServeUpdateStream(t *testing.T, bpc binlogplayer.Client) {
|
||||
ctx := context.Background()
|
||||
c, errFunc, err := bpc.ServeUpdateStream(ctx, testUpdateStreamRequest)
|
||||
stream, err := bpc.ServeUpdateStream(ctx, testUpdateStreamRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
if se, ok := <-c; !ok {
|
||||
t.Fatalf("got no response")
|
||||
if se, err := stream.Recv(); err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
} else {
|
||||
if !reflect.DeepEqual(*se, *testStreamEvent) {
|
||||
t.Errorf("got wrong result, got \n%#v expected \n%#v", *se, *testStreamEvent)
|
||||
}
|
||||
}
|
||||
if se, ok := <-c; ok {
|
||||
if se, err := stream.Recv(); err == nil {
|
||||
t.Fatalf("got a response when error expected: %v", se)
|
||||
}
|
||||
if err := errFunc(); err != nil {
|
||||
t.Errorf("got unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testServeUpdateStreamPanics(t *testing.T, bpc binlogplayer.Client) {
|
||||
ctx := context.Background()
|
||||
c, errFunc, err := bpc.ServeUpdateStream(ctx, testUpdateStreamRequest)
|
||||
stream, err := bpc.ServeUpdateStream(ctx, testUpdateStreamRequest)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
if se, ok := <-c; ok {
|
||||
if se, err := stream.Recv(); err == nil {
|
||||
t.Fatalf("got a response when error expected: %v", se)
|
||||
}
|
||||
err = errFunc()
|
||||
if err == nil || !strings.Contains(err.Error(), "test-triggered panic") {
|
||||
t.Errorf("wrong error from panic: %v", err)
|
||||
} else {
|
||||
if !strings.Contains(err.Error(), "test-triggered panic") {
|
||||
t.Errorf("wrong error from panic: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,37 +173,34 @@ func (fake *FakeBinlogStreamer) StreamKeyRange(position string, keyRange *topoda
|
|||
|
||||
func testStreamKeyRange(t *testing.T, bpc binlogplayer.Client) {
|
||||
ctx := context.Background()
|
||||
c, errFunc, err := bpc.StreamKeyRange(ctx, testKeyRangeRequest.Position, testKeyRangeRequest.KeyRange, testKeyRangeRequest.Charset)
|
||||
stream, err := bpc.StreamKeyRange(ctx, testKeyRangeRequest.Position, testKeyRangeRequest.KeyRange, testKeyRangeRequest.Charset)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
if se, ok := <-c; !ok {
|
||||
t.Fatalf("got no response")
|
||||
if se, err := stream.Recv(); err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
} else {
|
||||
if !reflect.DeepEqual(*se, *testBinlogTransaction) {
|
||||
t.Errorf("got wrong result, got %v expected %v", *se, *testBinlogTransaction)
|
||||
}
|
||||
}
|
||||
if se, ok := <-c; ok {
|
||||
if se, err := stream.Recv(); err == nil {
|
||||
t.Fatalf("got a response when error expected: %v", se)
|
||||
}
|
||||
if err := errFunc(); err != nil {
|
||||
t.Errorf("got unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testStreamKeyRangePanics(t *testing.T, bpc binlogplayer.Client) {
|
||||
ctx := context.Background()
|
||||
c, errFunc, err := bpc.StreamKeyRange(ctx, testKeyRangeRequest.Position, testKeyRangeRequest.KeyRange, testKeyRangeRequest.Charset)
|
||||
stream, err := bpc.StreamKeyRange(ctx, testKeyRangeRequest.Position, testKeyRangeRequest.KeyRange, testKeyRangeRequest.Charset)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
if se, ok := <-c; ok {
|
||||
if se, err := stream.Recv(); err == nil {
|
||||
t.Fatalf("got a response when error expected: %v", se)
|
||||
}
|
||||
err = errFunc()
|
||||
if err == nil || !strings.Contains(err.Error(), "test-triggered panic") {
|
||||
t.Errorf("wrong error from panic: %v", err)
|
||||
} else {
|
||||
if !strings.Contains(err.Error(), "test-triggered panic") {
|
||||
t.Errorf("wrong error from panic: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,37 +237,34 @@ func (fake *FakeBinlogStreamer) StreamTables(position string, tables []string, c
|
|||
|
||||
func testStreamTables(t *testing.T, bpc binlogplayer.Client) {
|
||||
ctx := context.Background()
|
||||
c, errFunc, err := bpc.StreamTables(ctx, testTablesRequest.Position, testTablesRequest.Tables, testTablesRequest.Charset)
|
||||
stream, err := bpc.StreamTables(ctx, testTablesRequest.Position, testTablesRequest.Tables, testTablesRequest.Charset)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
if se, ok := <-c; !ok {
|
||||
t.Fatalf("got no response")
|
||||
if se, err := stream.Recv(); err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
} else {
|
||||
if !reflect.DeepEqual(*se, *testBinlogTransaction) {
|
||||
t.Errorf("got wrong result, got %v expected %v", *se, *testBinlogTransaction)
|
||||
}
|
||||
}
|
||||
if se, ok := <-c; ok {
|
||||
if se, err := stream.Recv(); err == nil {
|
||||
t.Fatalf("got a response when error expected: %v", se)
|
||||
}
|
||||
if err := errFunc(); err != nil {
|
||||
t.Errorf("got unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testStreamTablesPanics(t *testing.T, bpc binlogplayer.Client) {
|
||||
ctx := context.Background()
|
||||
c, errFunc, err := bpc.StreamTables(ctx, testTablesRequest.Position, testTablesRequest.Tables, testTablesRequest.Charset)
|
||||
stream, err := bpc.StreamTables(ctx, testTablesRequest.Position, testTablesRequest.Tables, testTablesRequest.Charset)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
if se, ok := <-c; ok {
|
||||
if se, err := stream.Recv(); err == nil {
|
||||
t.Fatalf("got a response when error expected: %v", se)
|
||||
}
|
||||
err = errFunc()
|
||||
if err == nil || !strings.Contains(err.Error(), "test-triggered panic") {
|
||||
t.Errorf("wrong error from panic: %v", err)
|
||||
} else {
|
||||
if !strings.Contains(err.Error(), "test-triggered panic") {
|
||||
t.Errorf("wrong error from panic: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package grpcbinlogplayer
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
@ -41,95 +40,77 @@ func (client *client) Close() {
|
|||
client.cc.Close()
|
||||
}
|
||||
|
||||
func (client *client) ServeUpdateStream(ctx context.Context, position string) (chan *binlogdatapb.StreamEvent, binlogplayer.ErrFunc, error) {
|
||||
response := make(chan *binlogdatapb.StreamEvent, 10)
|
||||
type serveUpdateStreamAdapter struct {
|
||||
stream binlogservicepb.UpdateStream_StreamUpdateClient
|
||||
}
|
||||
|
||||
func (s *serveUpdateStreamAdapter) Recv() (*binlogdatapb.StreamEvent, error) {
|
||||
r, err := s.stream.Recv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.StreamEvent, nil
|
||||
}
|
||||
|
||||
func (client *client) ServeUpdateStream(ctx context.Context, position string) (binlogplayer.StreamEventStream, error) {
|
||||
query := &binlogdatapb.StreamUpdateRequest{
|
||||
Position: position,
|
||||
}
|
||||
|
||||
stream, err := client.c.StreamUpdate(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
var finalErr error
|
||||
go func() {
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
finalErr = err
|
||||
}
|
||||
close(response)
|
||||
return
|
||||
}
|
||||
response <- r.StreamEvent
|
||||
}
|
||||
}()
|
||||
return response, func() error {
|
||||
return finalErr
|
||||
}, nil
|
||||
return &serveUpdateStreamAdapter{stream}, nil
|
||||
}
|
||||
|
||||
func (client *client) StreamKeyRange(ctx context.Context, position string, keyRange *topodatapb.KeyRange, charset *binlogdatapb.Charset) (chan *binlogdatapb.BinlogTransaction, binlogplayer.ErrFunc, error) {
|
||||
response := make(chan *binlogdatapb.BinlogTransaction, 10)
|
||||
type serveStreamKeyRangeAdapter struct {
|
||||
stream binlogservicepb.UpdateStream_StreamKeyRangeClient
|
||||
}
|
||||
|
||||
func (s *serveStreamKeyRangeAdapter) Recv() (*binlogdatapb.BinlogTransaction, error) {
|
||||
r, err := s.stream.Recv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.BinlogTransaction, nil
|
||||
}
|
||||
|
||||
func (client *client) StreamKeyRange(ctx context.Context, position string, keyRange *topodatapb.KeyRange, charset *binlogdatapb.Charset) (binlogplayer.BinlogTransactionStream, error) {
|
||||
query := &binlogdatapb.StreamKeyRangeRequest{
|
||||
Position: position,
|
||||
KeyRange: keyRange,
|
||||
Charset: charset,
|
||||
}
|
||||
|
||||
stream, err := client.c.StreamKeyRange(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
var finalErr error
|
||||
go func() {
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
finalErr = err
|
||||
}
|
||||
close(response)
|
||||
return
|
||||
}
|
||||
response <- r.BinlogTransaction
|
||||
}
|
||||
}()
|
||||
return response, func() error {
|
||||
return finalErr
|
||||
}, nil
|
||||
return &serveStreamKeyRangeAdapter{stream}, nil
|
||||
}
|
||||
|
||||
func (client *client) StreamTables(ctx context.Context, position string, tables []string, charset *binlogdatapb.Charset) (chan *binlogdatapb.BinlogTransaction, binlogplayer.ErrFunc, error) {
|
||||
response := make(chan *binlogdatapb.BinlogTransaction, 10)
|
||||
type serveStreamTablesAdapter struct {
|
||||
stream binlogservicepb.UpdateStream_StreamTablesClient
|
||||
}
|
||||
|
||||
func (s *serveStreamTablesAdapter) Recv() (*binlogdatapb.BinlogTransaction, error) {
|
||||
r, err := s.stream.Recv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.BinlogTransaction, nil
|
||||
}
|
||||
|
||||
func (client *client) StreamTables(ctx context.Context, position string, tables []string, charset *binlogdatapb.Charset) (binlogplayer.BinlogTransactionStream, error) {
|
||||
query := &binlogdatapb.StreamTablesRequest{
|
||||
Position: position,
|
||||
Tables: tables,
|
||||
Charset: charset,
|
||||
}
|
||||
|
||||
stream, err := client.c.StreamTables(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
var finalErr error
|
||||
go func() {
|
||||
for {
|
||||
r, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
finalErr = err
|
||||
}
|
||||
close(response)
|
||||
return
|
||||
}
|
||||
response <- r.BinlogTransaction
|
||||
}
|
||||
}()
|
||||
return response, func() error {
|
||||
return finalErr
|
||||
}, nil
|
||||
return &serveStreamTablesAdapter{stream}, nil
|
||||
}
|
||||
|
||||
// Registration as a factory
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
package gorpccallerid
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/vt/callerid"
|
||||
|
||||
querypb "github.com/youtube/vitess/go/vt/proto/query"
|
||||
vtrpcpb "github.com/youtube/vitess/go/vt/proto/vtrpc"
|
||||
)
|
||||
|
||||
// CallerID is the BSON implementation of the proto3 vtrpc.CallerID
|
||||
type CallerID struct {
|
||||
Principal string
|
||||
Component string
|
||||
Subcomponent string
|
||||
}
|
||||
|
||||
// VTGateCallerID is the BSON implementation of the proto3 query.VTGateCallerID
|
||||
type VTGateCallerID struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
// GoRPCImmediateCallerID creates new ImmediateCallerID(querypb.VTGateCallerID)
|
||||
// from GoRPC's VTGateCallerID
|
||||
func GoRPCImmediateCallerID(v *VTGateCallerID) *querypb.VTGateCallerID {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return callerid.NewImmediateCallerID(v.Username)
|
||||
}
|
||||
|
||||
// GoRPCEffectiveCallerID creates new EffectiveCallerID(vtrpcpb.CallerID)
|
||||
// from GoRPC's CallerID
|
||||
func GoRPCEffectiveCallerID(c *CallerID) *vtrpcpb.CallerID {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return callerid.NewEffectiveCallerID(c.Principal, c.Component, c.Subcomponent)
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
package gorpccallerid
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/callerid/testsuite"
|
||||
)
|
||||
|
||||
func TestGoRPCCallerID(t *testing.T) {
|
||||
im := VTGateCallerID{
|
||||
Username: testsuite.FakeUsername,
|
||||
}
|
||||
ef := CallerID{
|
||||
Principal: testsuite.FakePrincipal,
|
||||
Component: testsuite.FakeComponent,
|
||||
Subcomponent: testsuite.FakeSubcomponent,
|
||||
}
|
||||
// Test nil cases
|
||||
if n := GoRPCImmediateCallerID(nil); n != nil {
|
||||
t.Errorf("Expect nil from GoRPCImmediateCallerID(nil), but got %v", n)
|
||||
}
|
||||
if n := GoRPCEffectiveCallerID(nil); n != nil {
|
||||
t.Errorf("Expect nil from GoRPCEffectiveCallerID(nil), but got %v", n)
|
||||
}
|
||||
testsuite.RunTests(t, GoRPCImmediateCallerID(&im), GoRPCEffectiveCallerID(&ef))
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
package callinfo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
||||
"github.com/youtube/vitess/go/rpcwrap/proto"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// RPCWrapCallInfo takes a context generated by rpcwrap, and
|
||||
// returns one that has CallInfo filled in.
|
||||
func RPCWrapCallInfo(ctx context.Context) context.Context {
|
||||
remoteAddr, _ := proto.RemoteAddr(ctx)
|
||||
return NewContext(ctx, &rpcWrapCallInfoImpl{
|
||||
remoteAddr: remoteAddr,
|
||||
})
|
||||
}
|
||||
|
||||
type rpcWrapCallInfoImpl struct {
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
func (rwci *rpcWrapCallInfoImpl) RemoteAddr() string {
|
||||
return rwci.remoteAddr
|
||||
}
|
||||
|
||||
func (rwci *rpcWrapCallInfoImpl) Username() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (rwci *rpcWrapCallInfoImpl) Text() string {
|
||||
return fmt.Sprintf("%s", rwci.remoteAddr)
|
||||
}
|
||||
|
||||
func (rwci *rpcWrapCallInfoImpl) HTML() template.HTML {
|
||||
return template.HTML("<b>RemoteAddr:</b> " + rwci.remoteAddr + "</br>\n")
|
||||
}
|
|
@ -5,13 +5,8 @@
|
|||
package replication
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
||||
// GTID represents a Global Transaction ID, also known as Transaction Group ID.
|
||||
|
@ -101,98 +96,3 @@ func MustDecodeGTID(s string) GTID {
|
|||
}
|
||||
return gtid
|
||||
}
|
||||
|
||||
// GTIDField is a concrete struct that contains a GTID interface value. This can
|
||||
// be used as a field inside marshalable structs, which cannot contain interface
|
||||
// values because there would be no way to know which concrete type to
|
||||
// instantiate upon unmarshaling.
|
||||
//
|
||||
// Note: GTIDField should not implement GTID, because it would tend to create
|
||||
// subtle bugs. For example, the compiler would allow something like this:
|
||||
//
|
||||
// GTIDField{googleGTID{1234}} == googleGTID{1234}
|
||||
//
|
||||
// But it would evaluate to false (because the underlying types don't match),
|
||||
// which is probably not what was expected.
|
||||
type GTIDField struct {
|
||||
Value GTID
|
||||
}
|
||||
|
||||
// String returns a string representation of the underlying GTID. If the
|
||||
// GTID value is nil, it returns "<nil>" in the style of Sprintf("%v", nil).
|
||||
func (gf GTIDField) String() string {
|
||||
if gf.Value == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return gf.Value.String()
|
||||
}
|
||||
|
||||
// MarshalBson bson-encodes GTIDField.
|
||||
func (gf GTIDField) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
||||
bson.EncodeOptionalPrefix(buf, bson.Object, key)
|
||||
|
||||
lenWriter := bson.NewLenWriter(buf)
|
||||
|
||||
if gf.Value != nil {
|
||||
// The name of the bson field is the MySQL flavor.
|
||||
bson.EncodeString(buf, gf.Value.Flavor(), gf.Value.String())
|
||||
}
|
||||
|
||||
lenWriter.Close()
|
||||
}
|
||||
|
||||
// UnmarshalBson bson-decodes into GTIDField.
|
||||
func (gf *GTIDField) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for GTIDField", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
// We expect exactly zero or one fields in this bson object.
|
||||
kind = bson.NextByte(buf)
|
||||
if kind == bson.EOO {
|
||||
// The GTID was nil, nothing to do.
|
||||
return
|
||||
}
|
||||
|
||||
// The field name is the MySQL flavor.
|
||||
flavor := bson.ReadCString(buf)
|
||||
value := bson.DecodeString(buf, kind)
|
||||
|
||||
// Check for and consume the end byte.
|
||||
if kind = bson.NextByte(buf); kind != bson.EOO {
|
||||
panic(bson.NewBsonError("too many fields for GTIDField"))
|
||||
}
|
||||
|
||||
// Parse the value.
|
||||
gtid, err := ParseGTID(flavor, value)
|
||||
if err != nil {
|
||||
panic(bson.NewBsonError("invalid value %v for GTIDField: %v", value, err))
|
||||
}
|
||||
gf.Value = gtid
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (gf GTIDField) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(EncodeGTID(gf.Value))
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (gf *GTIDField) UnmarshalJSON(buf []byte) error {
|
||||
var s string
|
||||
err := json.Unmarshal(buf, &s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
gf.Value, err = DecodeGTID(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,9 +9,10 @@ package replication
|
|||
// transactions that came before it, but in others a more complex structure is
|
||||
// required.
|
||||
//
|
||||
// GTIDSet is wrapped by ReplicationPosition, which is a concrete struct that
|
||||
// enables JSON and BSON marshaling. Most code outside of this package should
|
||||
// use ReplicationPosition rather than GTIDSet.
|
||||
// GTIDSet is wrapped by replication.Position, which is a concrete struct.
|
||||
// When sending a GTIDSet over RPCs, encode/decode it as a string.
|
||||
// Most code outside of this package should use replication.Position rather
|
||||
// than GTIDSet.
|
||||
type GTIDSet interface {
|
||||
// String returns the canonical printed form of the set as expected by a
|
||||
// particular flavor of MySQL.
|
||||
|
|
|
@ -5,11 +5,8 @@
|
|||
package replication
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
func TestParseGTID(t *testing.T) {
|
||||
|
@ -179,244 +176,6 @@ func TestDecodeGTIDWithSeparator(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGTIDFieldString(t *testing.T) {
|
||||
input := GTIDField{fakeGTID{flavor: "gahgah", value: "googoo"}}
|
||||
want := "googoo"
|
||||
if got := input.String(); got != want {
|
||||
t.Errorf("%#v.String() = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGTIDFieldStringNil(t *testing.T) {
|
||||
input := GTIDField{nil}
|
||||
want := "<nil>"
|
||||
if got := input.String(); got != want {
|
||||
t.Errorf("%#v.String() = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGTIDFieldEqual(t *testing.T) {
|
||||
input1 := GTIDField{fakeGTID{flavor: "poo", value: "bah"}}
|
||||
input2 := GTIDField{fakeGTID{flavor: "poo", value: "bah"}}
|
||||
want := true
|
||||
|
||||
if got := (input1 == input2); got != want {
|
||||
t.Errorf("(%#v == %#v) = %v, want %v", input1, input2, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGTIDFieldNotEqual(t *testing.T) {
|
||||
input1 := GTIDField{fakeGTID{flavor: "poo", value: "bah"}}
|
||||
input2 := GTIDField{fakeGTID{flavor: "foo", value: "bah"}}
|
||||
want := false
|
||||
|
||||
if got := (input1 == input2); got != want {
|
||||
t.Errorf("(%#v == %#v) = %v, want %v", input1, input2, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBsonMarshalUnmarshalGTIDField(t *testing.T) {
|
||||
gtidParsers["golf"] = func(s string) (GTID, error) {
|
||||
return fakeGTID{flavor: "golf", value: s}, nil
|
||||
}
|
||||
input := fakeGTID{flavor: "golf", value: "par"}
|
||||
want := fakeGTID{flavor: "golf", value: "par"}
|
||||
|
||||
buf, err := bson.Marshal(GTIDField{input})
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var gotField GTIDField
|
||||
if err = bson.Unmarshal(buf, &gotField); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := gotField.Value; got != want {
|
||||
t.Errorf("marshal->unmarshal mismatch, got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBsonMarshalUnmarshalGTIDFieldPointer(t *testing.T) {
|
||||
gtidParsers["golf"] = func(s string) (GTID, error) {
|
||||
return fakeGTID{flavor: "golf", value: s}, nil
|
||||
}
|
||||
input := fakeGTID{flavor: "golf", value: "par"}
|
||||
want := fakeGTID{flavor: "golf", value: "par"}
|
||||
|
||||
buf, err := bson.Marshal(>IDField{input})
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var gotField GTIDField
|
||||
if err = bson.Unmarshal(buf, &gotField); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := gotField.Value; got != want {
|
||||
t.Errorf("marshal->unmarshal mismatch, got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBsonMarshalUnmarshalGTIDFieldInStruct(t *testing.T) {
|
||||
gtidParsers["golf"] = func(s string) (GTID, error) {
|
||||
return fakeGTID{flavor: "golf", value: s}, nil
|
||||
}
|
||||
input := fakeGTID{flavor: "golf", value: "par"}
|
||||
want := fakeGTID{flavor: "golf", value: "par"}
|
||||
|
||||
type mystruct struct {
|
||||
GTIDField
|
||||
}
|
||||
|
||||
buf, err := bson.Marshal(&mystruct{GTIDField{input}})
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var gotStruct mystruct
|
||||
if err = bson.Unmarshal(buf, &gotStruct); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := gotStruct.GTIDField.Value; got != want {
|
||||
t.Errorf("marshal->unmarshal mismatch, got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBsonMarshalUnmarshalNilGTID(t *testing.T) {
|
||||
gtidParsers["golf"] = func(s string) (GTID, error) {
|
||||
return fakeGTID{flavor: "golf", value: s}, nil
|
||||
}
|
||||
input := GTID(nil)
|
||||
want := GTID(nil)
|
||||
|
||||
buf, err := bson.Marshal(GTIDField{input})
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var gotField GTIDField
|
||||
if err = bson.Unmarshal(buf, &gotField); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := gotField.Value; got != want {
|
||||
t.Errorf("marshal->unmarshal mismatch, got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonMarshalGTIDField(t *testing.T) {
|
||||
input := GTIDField{fakeGTID{flavor: "golf", value: "par"}}
|
||||
want := `"golf/par"`
|
||||
|
||||
buf, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := string(buf); got != want {
|
||||
t.Errorf("json.Marshal(%#v) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonMarshalGTIDFieldPointer(t *testing.T) {
|
||||
input := GTIDField{fakeGTID{flavor: "golf", value: "par"}}
|
||||
want := `"golf/par"`
|
||||
|
||||
buf, err := json.Marshal(&input)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := string(buf); got != want {
|
||||
t.Errorf("json.Marshal(%#v) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonUnmarshalGTIDField(t *testing.T) {
|
||||
gtidParsers["golf"] = func(s string) (GTID, error) {
|
||||
return fakeGTID{flavor: "golf", value: s}, nil
|
||||
}
|
||||
input := `"golf/par"`
|
||||
want := GTIDField{fakeGTID{flavor: "golf", value: "par"}}
|
||||
|
||||
var got GTIDField
|
||||
err := json.Unmarshal([]byte(input), &got)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("json.Unmarshal(%#v) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonMarshalGTIDFieldInStruct(t *testing.T) {
|
||||
input := GTIDField{fakeGTID{flavor: "golf", value: "par"}}
|
||||
want := `{"GTIDField":"golf/par"}`
|
||||
|
||||
type mystruct struct {
|
||||
GTIDField GTIDField
|
||||
}
|
||||
|
||||
buf, err := json.Marshal(&mystruct{input})
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := string(buf); got != want {
|
||||
t.Errorf("json.Marshal(%#v) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonUnmarshalGTIDFieldInStruct(t *testing.T) {
|
||||
gtidParsers["golf"] = func(s string) (GTID, error) {
|
||||
return fakeGTID{flavor: "golf", value: s}, nil
|
||||
}
|
||||
input := `{"GTIDField":"golf/par"}`
|
||||
want := GTIDField{fakeGTID{flavor: "golf", value: "par"}}
|
||||
|
||||
var gotStruct struct {
|
||||
GTIDField GTIDField
|
||||
}
|
||||
err := json.Unmarshal([]byte(input), &gotStruct)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if got := gotStruct.GTIDField; got != want {
|
||||
t.Errorf("json.Unmarshal(%#v) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonMarshalNilGTID(t *testing.T) {
|
||||
input := GTIDField{nil}
|
||||
want := `""`
|
||||
|
||||
buf, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got := string(buf); got != want {
|
||||
t.Errorf("json.Marshal(%#v) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonUnmarshalNilGTID(t *testing.T) {
|
||||
input := `""`
|
||||
want := GTIDField{nil}
|
||||
|
||||
var got GTIDField
|
||||
err := json.Unmarshal([]byte(input), &got)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("json.Unmarshal(%#v) = %#v, want %#v", input, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeGTID struct {
|
||||
flavor, value string
|
||||
}
|
||||
|
|
|
@ -1,55 +0,0 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package rpc contains RPC-related structs shared between many components.
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
// Unused is a placeholder type for args and reply values that aren't used.
|
||||
// For example, a server might declare a method without args or reply:
|
||||
//
|
||||
// func (s *Server) SomeMethod(ctx *Context, args *rpc.Unused, reply *rpc.Unused)
|
||||
//
|
||||
// The client would then call it like this:
|
||||
//
|
||||
// client.rpcCall("SomeMethod", &rpc.Unused{}, &rpc.Unused{}, waitTime)
|
||||
//
|
||||
// Using Unused ensures that, when the server doesn't care about a value, it
|
||||
// will silently ignore any value that is passed. With previous placeholder
|
||||
// values, such as strings, certain values might result in panics in the BSON
|
||||
// library.
|
||||
//
|
||||
// If a method declared with Unused as its args or reply is changed to accept
|
||||
// a real struct, old clients will encode an empty struct, which will be
|
||||
// silently ignored by the new servers. New clients talking to old servers will
|
||||
// send the real struct, which will be silently ignored. This allows args or
|
||||
// replies to be added to existing methods that previously did not use them.
|
||||
type Unused struct{}
|
||||
|
||||
// UnmarshalBson skips over an encoded Unused in a backward-compatible way.
|
||||
// Previous versions of Unused would BSON-encode a naked string value, which
|
||||
// bson.EncodeSimple() would encode as a document with a field named
|
||||
// bson.MAGICTAG. Here we skip over any document or sub-document without looking
|
||||
// for the special MAGICTAG field name.
|
||||
func (u *Unused) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
||||
switch kind {
|
||||
case bson.EOO, bson.Object:
|
||||
// valid
|
||||
case bson.Null:
|
||||
return
|
||||
default:
|
||||
panic(bson.NewBsonError("unexpected kind %v for Unused", kind))
|
||||
}
|
||||
bson.Next(buf, 4)
|
||||
|
||||
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
|
||||
bson.ReadCString(buf)
|
||||
bson.Skip(buf, kind)
|
||||
}
|
||||
}
|
|
@ -1,113 +0,0 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package rpc contains RPC-related structs shared between many components.
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
)
|
||||
|
||||
func TestUnmarshalStringIntoUnused(t *testing.T) {
|
||||
// Previous versions would marshal a naked empty string for unused args.
|
||||
str := ""
|
||||
buf, err := bson.Marshal(&str)
|
||||
if err != nil {
|
||||
t.Fatalf("bson.Marshal: %v", err)
|
||||
}
|
||||
|
||||
// Check that a new-style server expecting Unused{} can handle a naked empty
|
||||
// string sent by an old-style client.
|
||||
var unused Unused
|
||||
if err := bson.Unmarshal(buf, &unused); err != nil {
|
||||
t.Errorf("bson.Unmarshal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEmptyStructIntoUnused(t *testing.T) {
|
||||
buf, err := bson.Marshal(&struct{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("bson.Marshal: %v", err)
|
||||
}
|
||||
|
||||
// Check that a new-style server expecting Unused{} can handle an empty struct.
|
||||
var unused Unused
|
||||
if err := bson.Unmarshal(buf, &unused); err != nil {
|
||||
t.Errorf("bson.Unmarshal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalStructIntoUnused(t *testing.T) {
|
||||
buf, err := bson.Marshal(&struct{ A, B string }{"A", "B"})
|
||||
if err != nil {
|
||||
t.Fatalf("bson.Marshal: %v", err)
|
||||
}
|
||||
|
||||
// Check that a new-style server expecting Unused{} can handle an actual
|
||||
// struct being sent.
|
||||
var unused Unused
|
||||
if err := bson.Unmarshal(buf, &unused); err != nil {
|
||||
t.Errorf("bson.Unmarshal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalUnusedIntoString(t *testing.T) {
|
||||
buf, err := bson.Marshal(&Unused{})
|
||||
if err != nil {
|
||||
t.Fatalf("bson.Marshal: %v", err)
|
||||
}
|
||||
|
||||
// Check that it's safe for new clients to send Unused{} to an old server
|
||||
// expecting the naked empty string convention.
|
||||
var str string
|
||||
if err := bson.Unmarshal(buf, &str); err != nil {
|
||||
t.Errorf("bson.Unmarshal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalStructIntoString(t *testing.T) {
|
||||
buf, err := bson.Marshal(&struct{ A, B string }{"A", "B"})
|
||||
if err != nil {
|
||||
t.Fatalf("bson.Marshal: %v", err)
|
||||
}
|
||||
|
||||
// This fails. That's why you can't upgrade a method from unused (either old
|
||||
// or new style) to a real struct with public fields, if there are still
|
||||
// servers around that expect old-style string.
|
||||
var str string
|
||||
if err := bson.Unmarshal(buf, &str); err == nil {
|
||||
t.Errorf("expected error from bson.Unmarshal, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalStringIntoStruct(t *testing.T) {
|
||||
str := ""
|
||||
buf, err := bson.Marshal(&str)
|
||||
if err != nil {
|
||||
t.Fatalf("bson.Marshal: %v", err)
|
||||
}
|
||||
|
||||
// This fails. That's why you can't have old-style (empty string) clients
|
||||
// talking to servers that already expect a real struct (not Unused).
|
||||
var out struct{ A, B string }
|
||||
if err := bson.Unmarshal(buf, &out); err == nil {
|
||||
t.Errorf("expected error from bson.Unmarshal, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEmptyStructIntoStruct(t *testing.T) {
|
||||
buf, err := bson.Marshal(&struct{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("bson.Marshal: %v", err)
|
||||
}
|
||||
|
||||
// It should always be possible to add fields to something that's already a
|
||||
// struct. The struct name is irrelevant since it's never encoded.
|
||||
var out struct{ A, B string }
|
||||
if err := bson.Unmarshal(buf, &out); err != nil {
|
||||
t.Errorf("bson.Unmarshal: %v", err)
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package servenv
|
||||
|
||||
import (
|
||||
log "github.com/golang/glog"
|
||||
rpc "github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcwrap/bsonrpc"
|
||||
)
|
||||
|
||||
// Register registers a bsonrpc service according to serviceMap
|
||||
func Register(name string, rcvr interface{}) {
|
||||
if serviceMap["bsonrpc-vt-"+name] {
|
||||
log.Infof("Registering %v for bsonrpc over vt port, disable it with -bsonrpc-vt-%v service_map parameter", name, name)
|
||||
rpc.Register(rcvr)
|
||||
} else {
|
||||
log.Infof("Not registering %v for bsonrpc over vt port, enable it with bsonrpc-vt-%v service_map parameter", name, name)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeRPC will deal with bson rpc serving
|
||||
func ServeRPC() {
|
||||
// rpc.HandleHTTP registers the default GOB handler at /_goRPC_
|
||||
// and the debug RPC service at /debug/rpc (it displays a list
|
||||
// of registered services and their methods).
|
||||
if serviceMap["gob-vt"] {
|
||||
log.Infof("Registering GOB handler and /debug/rpc URL for vt port")
|
||||
rpc.HandleHTTP()
|
||||
}
|
||||
|
||||
// and register the regular bsonrpc too.
|
||||
bsonrpc.ServeRPC()
|
||||
}
|
|
@ -20,7 +20,6 @@ var (
|
|||
func Run(port int) {
|
||||
populateListeningURL()
|
||||
onRunHooks.Fire()
|
||||
ServeRPC()
|
||||
serveGRPC()
|
||||
|
||||
l, err := proc.Listen(fmt.Sprintf("%v", port))
|
||||
|
|
|
@ -16,8 +16,7 @@ var (
|
|||
serviceMapFlag flagutil.StringListValue
|
||||
|
||||
// serviceMap is the used version of the service map.
|
||||
// init() functions will add default values to it (using
|
||||
// InitServiceMap and InitServiceMapForBsonRpcService).
|
||||
// init() functions can add default values to it (using InitServiceMap).
|
||||
// service_map command line parameter will alter the map.
|
||||
// Can only be used after servenv.Init has been called.
|
||||
serviceMap = make(map[string]bool)
|
||||
|
|
|
@ -4,8 +4,6 @@ import (
|
|||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
|
||||
topodatapb "github.com/youtube/vitess/go/vt/proto/topodata"
|
||||
)
|
||||
|
||||
|
@ -56,44 +54,3 @@ func TestExtraFieldsJson(t *testing.T) {
|
|||
t.Errorf("Cannot re-decode struct without field: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingFieldsBson(t *testing.T) {
|
||||
swra := &slaveWasRestartedTestArgs{
|
||||
Parent: &topodatapb.TabletAlias{
|
||||
Uid: 1,
|
||||
Cell: "aa",
|
||||
},
|
||||
ExpectedMasterAddr: "a1",
|
||||
ExpectedMasterIPAddr: "i1",
|
||||
}
|
||||
data, err := bson.Marshal(swra)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
output := &SlaveWasRestartedArgs{}
|
||||
err = bson.Unmarshal(data, output)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraFieldsBson(t *testing.T) {
|
||||
swra := &SlaveWasRestartedArgs{
|
||||
Parent: &topodatapb.TabletAlias{
|
||||
Uid: 1,
|
||||
Cell: "aa",
|
||||
},
|
||||
}
|
||||
data, err := bson.Marshal(swra)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
output := &slaveWasRestartedTestArgs{}
|
||||
|
||||
err = bson.Unmarshal(data, output)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -82,60 +82,40 @@ func (fbc *fakeBinlogClient) Close() {
|
|||
}
|
||||
|
||||
// ServeUpdateStream is part of the binlogplayer.Client interface
|
||||
func (fbc *fakeBinlogClient) ServeUpdateStream(ctx context.Context, position string) (chan *binlogdatapb.StreamEvent, binlogplayer.ErrFunc, error) {
|
||||
return nil, nil, fmt.Errorf("Should never be called")
|
||||
func (fbc *fakeBinlogClient) ServeUpdateStream(ctx context.Context, position string) (binlogplayer.StreamEventStream, error) {
|
||||
return nil, fmt.Errorf("Should never be called")
|
||||
}
|
||||
|
||||
type testStreamEventAdapter struct {
|
||||
c chan *binlogdatapb.BinlogTransaction
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (t *testStreamEventAdapter) Recv() (*binlogdatapb.BinlogTransaction, error) {
|
||||
select {
|
||||
case bt := <-t.c:
|
||||
return bt, nil
|
||||
case <-t.ctx.Done():
|
||||
return nil, t.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// StreamTables is part of the binlogplayer.Client interface
|
||||
func (fbc *fakeBinlogClient) StreamTables(ctx context.Context, position string, tables []string, charset *binlogdatapb.Charset) (chan *binlogdatapb.BinlogTransaction, binlogplayer.ErrFunc, error) {
|
||||
func (fbc *fakeBinlogClient) StreamTables(ctx context.Context, position string, tables []string, charset *binlogdatapb.Charset) (binlogplayer.BinlogTransactionStream, error) {
|
||||
actualTables := strings.Join(tables, ",")
|
||||
if actualTables != fbc.expectedTables {
|
||||
return nil, nil, fmt.Errorf("Got wrong tables %v, expected %v", actualTables, fbc.expectedTables)
|
||||
return nil, fmt.Errorf("Got wrong tables %v, expected %v", actualTables, fbc.expectedTables)
|
||||
}
|
||||
|
||||
c := make(chan *binlogdatapb.BinlogTransaction)
|
||||
var finalErr error
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case bt := <-fbc.tablesChannel:
|
||||
c <- bt
|
||||
case <-ctx.Done():
|
||||
finalErr = ctx.Err()
|
||||
close(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return c, func() error {
|
||||
return finalErr
|
||||
}, nil
|
||||
return &testStreamEventAdapter{c: fbc.tablesChannel, ctx: ctx}, nil
|
||||
}
|
||||
|
||||
// StreamKeyRange is part of the binlogplayer.Client interface
|
||||
func (fbc *fakeBinlogClient) StreamKeyRange(ctx context.Context, position string, keyRange *topodatapb.KeyRange, charset *binlogdatapb.Charset) (chan *binlogdatapb.BinlogTransaction, binlogplayer.ErrFunc, error) {
|
||||
func (fbc *fakeBinlogClient) StreamKeyRange(ctx context.Context, position string, keyRange *topodatapb.KeyRange, charset *binlogdatapb.Charset) (binlogplayer.BinlogTransactionStream, error) {
|
||||
actualKeyRange := key.KeyRangeString(keyRange)
|
||||
if actualKeyRange != fbc.expectedKeyRange {
|
||||
return nil, nil, fmt.Errorf("Got wrong keyrange %v, expected %v", actualKeyRange, fbc.expectedKeyRange)
|
||||
return nil, fmt.Errorf("Got wrong keyrange %v, expected %v", actualKeyRange, fbc.expectedKeyRange)
|
||||
}
|
||||
|
||||
c := make(chan *binlogdatapb.BinlogTransaction)
|
||||
var finalErr error
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case bt := <-fbc.keyRangeChannel:
|
||||
c <- bt
|
||||
case <-ctx.Done():
|
||||
finalErr = ctx.Err()
|
||||
close(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return c, func() error {
|
||||
return finalErr
|
||||
}, nil
|
||||
return &testStreamEventAdapter{c: fbc.keyRangeChannel, ctx: ctx}, nil
|
||||
}
|
||||
|
||||
// fakeTabletConn implement TabletConn interface. We only care about the
|
||||
|
|
|
@ -243,16 +243,6 @@ func (agent *ActionAgent) runHealthCheck(targetTabletType topodatapb.TabletType)
|
|||
if isServing {
|
||||
// We are not healthy or should not be running the query service.
|
||||
//
|
||||
// We do NOT enter lameduck in this case, because we should only hit this
|
||||
// in the following scenarios:
|
||||
//
|
||||
// * Healthcheck fails: We're probably serving errors anyway, so no point.
|
||||
// * Replication lag exceeds unhealthy threshold: This is very rare, so it
|
||||
// isn't worth optimizing the potential 1s of errors away. It will also
|
||||
// go away when vtgate is the only one looking at lag.
|
||||
// * We're in a special state where queryservice should be disabled
|
||||
// despite being non-SPARE: This is not a live serving instance anyway.
|
||||
//
|
||||
// We don't care if the QueryService state actually changed because we'll
|
||||
// broadcast the latest health status after this immediately anway.
|
||||
_ /* state changed */, err := agent.disallowQueries(tablet.Type,
|
||||
|
|
|
@ -6,12 +6,16 @@ package tabletmanager
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/vt/binlog/binlogplayer"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl"
|
||||
"github.com/youtube/vitess/go/vt/tabletmanager/actionnode"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver"
|
||||
|
@ -109,7 +113,7 @@ func (fhc *fakeHealthCheck) HTMLName() template.HTML {
|
|||
return template.HTML("fakeHealthCheck")
|
||||
}
|
||||
|
||||
func createTestAgent(ctx context.Context, t *testing.T) *ActionAgent {
|
||||
func createTestAgent(ctx context.Context, t *testing.T) (*ActionAgent, chan<- *binlogplayer.VtClientMock) {
|
||||
ts := zktestserver.New(t, []string{"cell1"})
|
||||
|
||||
if err := ts.CreateKeyspace(ctx, "test_keyspace", &topodatapb.Keyspace{}); err != nil {
|
||||
|
@ -138,19 +142,33 @@ func createTestAgent(ctx context.Context, t *testing.T) *ActionAgent {
|
|||
|
||||
mysqlDaemon := &mysqlctl.FakeMysqlDaemon{MysqlPort: 3306}
|
||||
agent := NewTestActionAgent(ctx, ts, tabletAlias, port, 0, mysqlDaemon)
|
||||
agent.BinlogPlayerMap = NewBinlogPlayerMap(ts, nil, nil)
|
||||
|
||||
vtClientMocksChannel := make(chan *binlogplayer.VtClientMock, 1)
|
||||
agent.BinlogPlayerMap = NewBinlogPlayerMap(ts, mysqlDaemon, func() binlogplayer.VtClient {
|
||||
return <-vtClientMocksChannel
|
||||
})
|
||||
|
||||
agent.HealthReporter = &fakeHealthCheck{}
|
||||
|
||||
return agent
|
||||
return agent, vtClientMocksChannel
|
||||
}
|
||||
|
||||
// TestHealthCheckControlsQueryService verifies that a tablet going healthy
|
||||
// starts the query service, and going unhealthy stops it.
|
||||
func TestHealthCheckControlsQueryService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agent := createTestAgent(ctx, t)
|
||||
agent, _ := createTestAgent(ctx, t)
|
||||
targetTabletType := topodatapb.TabletType_REPLICA
|
||||
|
||||
// Consume the first health broadcast triggered by ActionAgent.Start():
|
||||
// (SPARE, SERVING) goes to (SPARE, NOT_SERVING).
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_SPARE); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// first health check, should change us to replica, and update the
|
||||
// mysql port to 3306
|
||||
before := time.Now()
|
||||
|
@ -175,10 +193,15 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
waitForBroadcastData(t, agent.QueryServiceControl, 12)
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_REPLICA {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
}
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 12); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, true, topodatapb.TabletType_REPLICA); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// now make the tablet unhealthy
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 13 * time.Second
|
||||
|
@ -201,9 +224,41 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
waitForBroadcastData(t, agent.QueryServiceControl, 13)
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_SPARE {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
want := topodatapb.TabletType_SPARE
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != want {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, want)
|
||||
}
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 13); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// QueryService disabled since we are unhealthy now.
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_REPLICA); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Consume second health broadcast (runHealthCheck() called refreshTablet()
|
||||
// which broadcasts since we go from REPLICA to SPARE and into lameduck.)
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 13); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// NOTE: No state change here because the type during lameduck is still
|
||||
// REPLICA and the QueryService is already set to NOT_SERVING.
|
||||
//
|
||||
// Consume third health broadcast (runHealthCheck() called refreshTablet()
|
||||
// which broadcasts that the QueryService state changed from REPLICA to SPARE
|
||||
// (NOT_SERVING was already set before when we went into lameduck).)
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 13); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// After the lameduck grace period, the type changed from REPLICA to SPARE.
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_SPARE); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := expectBroadcastDataEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChangesEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,10 +266,19 @@ func TestHealthCheckControlsQueryService(t *testing.T) {
|
|||
// query service, it should not go healthy
|
||||
func TestQueryServiceNotStarting(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agent := createTestAgent(ctx, t)
|
||||
agent, _ := createTestAgent(ctx, t)
|
||||
targetTabletType := topodatapb.TabletType_REPLICA
|
||||
agent.QueryServiceControl.(*tabletservermock.Controller).SetServingTypeError = fmt.Errorf("test cannot start query service")
|
||||
|
||||
// Consume the first health broadcast triggered by ActionAgent.Start():
|
||||
// (SPARE, SERVING) goes to (SPARE, NOT_SERVING).
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_SPARE); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
before := time.Now()
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err := agent.TopoServer.GetTablet(ctx, tabletAlias)
|
||||
|
@ -240,15 +304,31 @@ func TestQueryServiceNotStarting(t *testing.T) {
|
|||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_SPARE {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
}
|
||||
|
||||
if err := expectBroadcastDataEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChangesEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryServiceStopped verifies that if a healthy tablet's query
|
||||
// service is shut down, the tablet goes unhealthy
|
||||
func TestQueryServiceStopped(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agent := createTestAgent(ctx, t)
|
||||
agent, _ := createTestAgent(ctx, t)
|
||||
targetTabletType := topodatapb.TabletType_REPLICA
|
||||
|
||||
// Consume the first health broadcast triggered by ActionAgent.Start():
|
||||
// (SPARE, SERVING) goes to (SPARE, NOT_SERVING).
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_SPARE); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// first health check, should change us to replica
|
||||
before := time.Now()
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 14 * time.Second
|
||||
|
@ -269,9 +349,16 @@ func TestQueryServiceStopped(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
waitForBroadcastData(t, agent.QueryServiceControl, 14)
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_REPLICA {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
want := topodatapb.TabletType_REPLICA
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != want {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, want)
|
||||
}
|
||||
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 14); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, true, want); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// shut down query service and prevent it from starting again
|
||||
|
@ -300,12 +387,33 @@ func TestQueryServiceStopped(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
bd := waitForBroadcastData(t, agent.QueryServiceControl, 15)
|
||||
if bd.RealtimeStats.HealthError != "test cannot start query service" {
|
||||
t.Errorf("unexpected HealthError: %v", *bd)
|
||||
want = topodatapb.TabletType_REPLICA
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != want {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, want)
|
||||
}
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_REPLICA {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
if bd, err := expectBroadcastData(agent.QueryServiceControl, 15); err == nil {
|
||||
if bd.RealtimeStats.HealthError != "test cannot start query service" {
|
||||
t.Errorf("unexpected HealthError: %v", *bd)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, want); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Consume second health broadcast (runHealthCheck() called refreshTablet()
|
||||
// which broadcasts since we go from REPLICA to SPARE and into lameduck.)
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 15); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// NOTE: No more broadcasts or state changes since SetServingTypeError is set
|
||||
// on the mocked controller and this disables its SetServingType().
|
||||
|
||||
if err := expectBroadcastDataEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChangesEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -313,9 +421,18 @@ func TestQueryServiceStopped(t *testing.T) {
|
|||
// query service in a tablet.
|
||||
func TestTabletControl(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agent := createTestAgent(ctx, t)
|
||||
agent, _ := createTestAgent(ctx, t)
|
||||
targetTabletType := topodatapb.TabletType_REPLICA
|
||||
|
||||
// Consume the first health broadcast triggered by ActionAgent.Start():
|
||||
// (SPARE, SERVING) goes to (SPARE, NOT_SERVING).
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_SPARE); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// first health check, should change us to replica
|
||||
before := time.Now()
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 16 * time.Second
|
||||
|
@ -336,9 +453,14 @@ func TestTabletControl(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
waitForBroadcastData(t, agent.QueryServiceControl, 16)
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_REPLICA {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != targetTabletType {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, targetTabletType)
|
||||
}
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 16); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, true, targetTabletType); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// now update the shard
|
||||
|
@ -372,6 +494,15 @@ func TestTabletControl(t *testing.T) {
|
|||
t.Errorf("UpdateStream should be running")
|
||||
}
|
||||
|
||||
// Consume the health broadcast which was triggered due to the QueryService
|
||||
// state change from SERVING to NOT_SERVING.
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 16); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, targetTabletType); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check running a health check will not start it again
|
||||
before = time.Now()
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 17 * time.Second
|
||||
|
@ -392,10 +523,13 @@ func TestTabletControl(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
waitForBroadcastData(t, agent.QueryServiceControl, 17)
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_REPLICA {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != targetTabletType {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, targetTabletType)
|
||||
}
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 17); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// NOTE: No state change here since nothing has changed.
|
||||
|
||||
// go unhealthy, check we go to spare and QS is not running
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportError = fmt.Errorf("tablet is unhealthy")
|
||||
|
@ -418,9 +552,29 @@ func TestTabletControl(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
waitForBroadcastData(t, agent.QueryServiceControl, 18)
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_SPARE {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 18); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// NOTE: No state change here since QueryService is already NOT_SERVING.
|
||||
want := topodatapb.TabletType_SPARE
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != want {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, want)
|
||||
}
|
||||
// Consume second health broadcast (runHealthCheck() called refreshTablet()
|
||||
// which broadcasts since we go from REPLICA to SPARE into lameduck.)
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 18); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Consume third health broadcast (runHealthCheck() called refreshTablet()
|
||||
// which broadcasts since the QueryService state changes from REPLICA to SPARE.
|
||||
// TODO(mberlin): With this, the cached TabletControl in the agent is also
|
||||
// cleared since it was only meant for REPLICA and now we are a SPARE.
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 18); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_SPARE); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// go back healthy, check QS is still not running
|
||||
|
@ -444,9 +598,208 @@ func TestTabletControl(t *testing.T) {
|
|||
if agent._healthyTime.Sub(before) < 0 {
|
||||
t.Errorf("runHealthCheck did not update agent._healthyTime")
|
||||
}
|
||||
waitForBroadcastData(t, agent.QueryServiceControl, 19)
|
||||
if agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType != topodatapb.TabletType_REPLICA {
|
||||
t.Errorf("invalid tabletserver target: %v", agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType)
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 19); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != targetTabletType {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, targetTabletType)
|
||||
}
|
||||
// NOTE: At this point in time, the QueryService is actually visible as
|
||||
// SERVING since the previous change from REPLICA to SPARE cleared the
|
||||
// cached TabletControl and now the healthcheck assumes that the REPLICA type
|
||||
// is allowed to serve. This problem will be fixed when the healthcheck calls
|
||||
// refreshTablet() due to the seen state change from SPARE to REPLICA. Then,
|
||||
// the topology is read again and TabletControl becomes effective again.
|
||||
// TODO(mberlin): Fix this bug.
|
||||
if err := expectStateChange(agent.QueryServiceControl, true, targetTabletType); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// QueryService changed back from SERVING to NOT_SERVING since refreshTablet()
|
||||
// re-read the topology and saw that REPLICA is still not allowed to serve.
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 19); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, targetTabletType); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := expectBroadcastDataEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChangesEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueryServiceChangeImmediateHealthcheckResponse verifies that a change
|
||||
// of the QueryService state or the tablet type will result into a broadcast
|
||||
// of a StreamHealthResponse message.
|
||||
func TestStateChangeImmediateHealthBroadcast(t *testing.T) {
|
||||
// BinlogPlayer will fail in the second retry because we don't fully mock
|
||||
// it. Retry faster to make it fail faster.
|
||||
flag.Set("binlog_player_retry_delay", "100ms")
|
||||
|
||||
ctx := context.Background()
|
||||
agent, vtClientMocksChannel := createTestAgent(ctx, t)
|
||||
targetTabletType := topodatapb.TabletType_MASTER
|
||||
|
||||
// Consume the first health broadcast triggered by ActionAgent.Start():
|
||||
// (SPARE, SERVING) goes to (SPARE, NOT_SERVING).
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, topodatapb.TabletType_SPARE); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Run health check to get changed from SPARE to MASTER.
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 20 * time.Second
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err := agent.TopoServer.GetTablet(ctx, tabletAlias)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTablet failed: %v", err)
|
||||
}
|
||||
if ti.Type != targetTabletType {
|
||||
t.Errorf("First health check failed to go to replica: %v", ti.Type)
|
||||
}
|
||||
if !agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should be running")
|
||||
}
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != targetTabletType {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, targetTabletType)
|
||||
}
|
||||
if _, err := expectBroadcastData(agent.QueryServiceControl, 20); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, true, targetTabletType); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Simulate a vertical split resharding where we set SourceShards in the topo
|
||||
// and enable filtered replication.
|
||||
si, err := agent.TopoServer.GetShard(ctx, "test_keyspace", "0")
|
||||
if err != nil {
|
||||
t.Fatalf("GetShard failed: %v", err)
|
||||
}
|
||||
si.SourceShards = []*topodatapb.Shard_SourceShard{
|
||||
{
|
||||
Uid: 1,
|
||||
Keyspace: "source_keyspace",
|
||||
Shard: "0",
|
||||
Tables: []string{
|
||||
"table1",
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := agent.TopoServer.UpdateShard(ctx, si); err != nil {
|
||||
t.Fatalf("UpdateShard failed: %v", err)
|
||||
}
|
||||
// Mock out the BinlogPlayer client. Tell the BinlogPlayer not to start.
|
||||
vtClientMock := binlogplayer.NewVtClientMock()
|
||||
vtClientMock.Result = &sqltypes.Result{
|
||||
Fields: nil,
|
||||
RowsAffected: 1,
|
||||
InsertID: 0,
|
||||
Rows: [][]sqltypes.Value{
|
||||
{
|
||||
sqltypes.MakeString([]byte("MariaDB/0-1-1234")),
|
||||
sqltypes.MakeString([]byte("DontStart")),
|
||||
},
|
||||
},
|
||||
}
|
||||
vtClientMocksChannel <- vtClientMock
|
||||
|
||||
// Refresh the tablet state, as vtworker would do.
|
||||
// Since we change the QueryService state, we'll also trigger a health broadcast.
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 21 * time.Second
|
||||
agent.RPCWrapLockAction(ctx, actionnode.TabletActionRefreshState, "", "", true, func() error {
|
||||
agent.RefreshState(ctx)
|
||||
return nil
|
||||
})
|
||||
// (Destination) MASTER with enabled filtered replication mustn't serve anymore.
|
||||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
// Consume health broadcast sent out due to QueryService state change from
|
||||
// (MASTER, SERVING) to (MASTER, NOT_SERVING).
|
||||
// Since we didn't run healthcheck again yet, the broadcast data contains the
|
||||
// cached replication lag of 20 instead of 21.
|
||||
if bd, err := expectBroadcastData(agent.QueryServiceControl, 20); err == nil {
|
||||
if bd.RealtimeStats.BinlogPlayersCount != 1 {
|
||||
t.Fatalf("filtered replication must be enabled: %v", bd)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, false, targetTabletType); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Running a healthcheck won't put the QueryService back to SERVING.
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 22 * time.Second
|
||||
agent.runHealthCheck(targetTabletType)
|
||||
ti, err = agent.TopoServer.GetTablet(ctx, tabletAlias)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTablet failed: %v", err)
|
||||
}
|
||||
if ti.Type != targetTabletType {
|
||||
t.Errorf("Health check failed to go to replica: %v", ti.Type)
|
||||
}
|
||||
if agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
if got := agent.QueryServiceControl.(*tabletservermock.Controller).CurrentTarget.TabletType; got != targetTabletType {
|
||||
t.Errorf("invalid tabletserver target: got = %v, want = %v", got, targetTabletType)
|
||||
}
|
||||
if bd, err := expectBroadcastData(agent.QueryServiceControl, 22); err == nil {
|
||||
if bd.RealtimeStats.BinlogPlayersCount != 1 {
|
||||
t.Fatalf("filtered replication must be still running: %v", bd)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// NOTE: No state change here since nothing has changed.
|
||||
|
||||
// Simulate migration to destination master i.e. remove SourceShards.
|
||||
si, err = agent.TopoServer.GetShard(ctx, "test_keyspace", "0")
|
||||
if err != nil {
|
||||
t.Fatalf("GetShard failed: %v", err)
|
||||
}
|
||||
si.SourceShards = nil
|
||||
if err = agent.TopoServer.UpdateShard(ctx, si); err != nil {
|
||||
t.Fatalf("UpdateShard failed: %v", err)
|
||||
}
|
||||
// Refresh the tablet state, as vtctl MigrateServedFrom would do.
|
||||
// This should also trigger a health broadcast since the QueryService state
|
||||
// changes from NOT_SERVING to SERVING.
|
||||
agent.HealthReporter.(*fakeHealthCheck).reportReplicationDelay = 23 * time.Second
|
||||
agent.RPCWrapLockAction(ctx, actionnode.TabletActionRefreshState, "", "", true, func() error {
|
||||
agent.RefreshState(ctx)
|
||||
return nil
|
||||
})
|
||||
// QueryService changed from NOT_SERVING to SERVING.
|
||||
if !agent.QueryServiceControl.IsServing() {
|
||||
t.Errorf("Query service should not be running")
|
||||
}
|
||||
// Since we didn't run healthcheck again yet, the broadcast data contains the
|
||||
// cached replication lag of 22 instead of 23.
|
||||
if bd, err := expectBroadcastData(agent.QueryServiceControl, 22); err == nil {
|
||||
if bd.RealtimeStats.BinlogPlayersCount != 0 {
|
||||
t.Fatalf("filtered replication must be disabled now: %v", bd)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChange(agent.QueryServiceControl, true, targetTabletType); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := expectBroadcastDataEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := expectStateChangesEmpty(agent.QueryServiceControl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -454,7 +807,7 @@ func TestTabletControl(t *testing.T) {
|
|||
// return an error
|
||||
func TestOldHealthCheck(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agent := createTestAgent(ctx, t)
|
||||
agent, _ := createTestAgent(ctx, t)
|
||||
*healthCheckInterval = 20 * time.Second
|
||||
agent._healthy = nil
|
||||
|
||||
|
@ -477,19 +830,53 @@ func TestOldHealthCheck(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// waitForBroadcastData is used by tests to get the first BroadcastData that's
|
||||
// recent enough, without relying on the precise number of times TabletManager
|
||||
// calls BroadcastHealth() in the meantime.
|
||||
func waitForBroadcastData(t *testing.T, qsc tabletserver.Controller, secondsBehindMaster uint32) *tabletservermock.BroadcastData {
|
||||
timer := time.NewTimer(10 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case bd := <-qsc.(*tabletservermock.Controller).BroadcastData:
|
||||
if bd.RealtimeStats.SecondsBehindMaster == secondsBehindMaster {
|
||||
return bd
|
||||
}
|
||||
case <-timer.C:
|
||||
t.Fatalf("Timed out waiting for SecondsBehindMaster = %v", secondsBehindMaster)
|
||||
}
|
||||
// expectBroadcastData checks that runHealthCheck() broadcasted the expected
|
||||
// stats (going the value for secondsBehindMaster).
|
||||
// Note that it may be necessary to call this function twice when
|
||||
// runHealthCheck() also calls freshTablet() which might trigger another
|
||||
// broadcast e.g. because we went from REPLICA to SPARE and into lameduck.
|
||||
func expectBroadcastData(qsc tabletserver.Controller, secondsBehindMaster uint32) (*tabletservermock.BroadcastData, error) {
|
||||
bd := <-qsc.(*tabletservermock.Controller).BroadcastData
|
||||
if got := bd.RealtimeStats.SecondsBehindMaster; got != secondsBehindMaster {
|
||||
return nil, fmt.Errorf("unexpected BroadcastData. got: %v want: %v got bd: %+v", got, secondsBehindMaster, bd)
|
||||
}
|
||||
return bd, nil
|
||||
}
|
||||
|
||||
// expectBroadcastDataEmpty closes the health broadcast channel and verifies
|
||||
// that all broadcasted messages were consumed by expectBroadcastData().
|
||||
func expectBroadcastDataEmpty(qsc tabletserver.Controller) error {
|
||||
c := qsc.(*tabletservermock.Controller).BroadcastData
|
||||
close(c)
|
||||
bd, ok := <-c
|
||||
if ok {
|
||||
return fmt.Errorf("BroadcastData channel should have been consumed, but was not: %v", bd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// expectStateChange verifies that the test changed the QueryService state
|
||||
// to the expected state (serving or not, specific tablet type).
|
||||
func expectStateChange(qsc tabletserver.Controller, serving bool, tabletType topodatapb.TabletType) error {
|
||||
want := &tabletservermock.StateChange{
|
||||
Serving: serving,
|
||||
TabletType: tabletType,
|
||||
}
|
||||
got := <-qsc.(*tabletservermock.Controller).StateChanges
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
return fmt.Errorf("unexpected state change. got: %v want: %v got", got, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// expectStateChangesEmpty closes the StateChange channel and verifies
|
||||
// that all sent state changes were consumed by expectStateChange().
|
||||
func expectStateChangesEmpty(qsc tabletserver.Controller) error {
|
||||
c := qsc.(*tabletservermock.Controller).StateChanges
|
||||
close(c)
|
||||
sc, ok := <-c
|
||||
if ok {
|
||||
return fmt.Errorf("StateChanges channel should have been consumed, but was not: %v", sc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -247,6 +247,7 @@ func (agent *ActionAgent) changeCallback(ctx context.Context, oldTablet, newTabl
|
|||
log.Errorf("Can't start query service for MASTER+REPLICA mode: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if stateChanged, err := agent.allowQueries(newTablet.Type); err == nil {
|
||||
// If the state changed, broadcast to vtgate.
|
||||
// (e.g. this happens when the tablet was already master, but it just
|
||||
|
@ -269,6 +270,7 @@ func (agent *ActionAgent) changeCallback(ctx context.Context, oldTablet, newTabl
|
|||
agent.broadcastHealth()
|
||||
time.Sleep(*gracePeriod)
|
||||
}
|
||||
|
||||
if stateChanged, err := agent.disallowQueries(newTablet.Type, disallowQueryReason); err == nil {
|
||||
// If the state changed, broadcast to vtgate.
|
||||
// (e.g. this happens when the tablet was already master, but it just
|
||||
|
|
|
@ -289,7 +289,7 @@ func (conn *gRPCQueryClient) StreamHealth(ctx context.Context) (tabletconn.Strea
|
|||
return conn.c.StreamHealth(ctx, &querypb.StreamHealthRequest{})
|
||||
}
|
||||
|
||||
// Close closes underlying bsonrpc.
|
||||
// Close closes underlying gRPC channel.
|
||||
func (conn *gRPCQueryClient) Close() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
)
|
||||
|
||||
// This test makes sure the go rpc service works
|
||||
func TestGoRPCTabletConn(t *testing.T) {
|
||||
func TestGRPCTabletConn(t *testing.T) {
|
||||
// fake service
|
||||
service := tabletconntest.CreateFakeServer(t)
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ func (stats *LogStats) RewrittenSQL() string {
|
|||
}
|
||||
|
||||
// SizeOfResponse returns the approximate size of the response in
|
||||
// bytes (this does not take in account BSON encoding). It will return
|
||||
// bytes (this does not take in account protocol encoding). It will return
|
||||
// 0 for streaming requests.
|
||||
func (stats *LogStats) SizeOfResponse() int {
|
||||
if stats.Rows == nil {
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче