RPC methods take a context argument.

* * *
The context object is passed around.
* * *
the context is passed as interface{}
* * *
Bring back old function signatures.
* * *
Context-taking RPC method can declare their arguments using the real context type.
* * *
Use inteface{} for context in rpcplus.
This commit is contained in:
Ric Szopa 2012-10-24 12:49:59 -07:00
Родитель c9c355d859
Коммит f4024cdfa5
12 изменённых файлов: 321 добавлений и 178 удалений

Просмотреть файл

@ -135,5 +135,11 @@ func (c *serverCodec) Close() error {
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
func ServeConn(conn io.ReadWriteCloser) {
rpc.ServeCodec(NewServerCodec(conn))
ServeConnWithContext(conn, nil)
}
// ServeConnWithContext is like ServeConn but it allows to pass a
// connection context to the RPC methods.
func ServeConnWithContext(conn io.ReadWriteCloser, context interface{}) {
rpc.ServeCodecWithContext(NewServerCodec(conn), context)
}

Просмотреть файл

@ -150,12 +150,17 @@ const (
var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
type methodType struct {
sync.Mutex // protects counters
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
stream bool
numCalls uint
sync.Mutex // protects counters
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
ContextType reflect.Type
stream bool
numCalls uint
}
func (m methodType) TakesContext() bool {
return m.ContextType != nil
}
type service struct {
@ -239,6 +244,86 @@ func (server *Server) RegisterName(name string, rcvr interface{}) error {
return server.register(rcvr, name, true)
}
// prepareMethod returns a methodType for the provided method or nil
// in case if the method was unsuitable.
func prepareMethod(method reflect.Method) *methodType {
mtype := method.Type
mname := method.Name
var replyType, argType, contextType reflect.Type
stream := false
// Method must be exported.
if method.PkgPath != "" {
return nil
}
switch mtype.NumIn() {
case 3:
// normal method
argType = mtype.In(1)
replyType = mtype.In(2)
contextType = nil
case 4:
// method that takes a context
argType = mtype.In(2)
replyType = mtype.In(3)
contextType = mtype.In(1)
default:
log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
return nil
}
// First arg need not be a pointer.
if !isExportedOrBuiltinType(argType) {
log.Println(mname, "argument type not exported:", argType)
return nil
}
// the second argument will tell us if it's a streaming call
// or a regular call
if replyType.Kind() == reflect.Func {
// this is a streaming call
stream = true
if replyType.NumIn() != 1 {
log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn())
return nil
}
if replyType.In(0).Kind() != reflect.Interface {
log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0))
return nil
}
if replyType.NumOut() != 1 {
log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut())
return nil
}
if returnType := replyType.Out(0); returnType != typeOfError {
log.Println("method", mname, "sendReply returns", returnType.String(), "not error")
return nil
}
} else if replyType.Kind() != reflect.Ptr {
log.Println("method", mname, "reply type not a pointer:", replyType)
return nil
}
// Reply type must be exported.
if !isExportedOrBuiltinType(replyType) {
log.Println("method", mname, "reply type not exported:", replyType)
return nil
}
// Method needs one out.
if mtype.NumOut() != 1 {
log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
return nil
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
log.Println("method", mname, "returns", returnType.String(), "not error")
return nil
}
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
}
func (server *Server) register(rcvr interface{}, name string, useName bool) error {
server.mu.Lock()
defer server.mu.Unlock()
@ -269,72 +354,9 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
// Install the methods
for m := 0; m < s.typ.NumMethod(); m++ {
method := s.typ.Method(m)
mtype := method.Type
mname := method.Name
var replyType reflect.Type
stream := false
// Method must be exported.
if method.PkgPath != "" {
continue
if mt := prepareMethod(method); mt != nil {
s.method[method.Name] = mt
}
// Method needs three ins: receiver, *args, *reply.
// Or: receiver, *args, sendReply func.
if mtype.NumIn() != 3 {
log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
continue
}
// First arg need not be a pointer.
argType := mtype.In(1)
if !isExportedOrBuiltinType(argType) {
log.Println(mname, "argument type not exported:", argType)
continue
}
// the second argument will tell us if it's a streaming call
// or a regular call
replyType = mtype.In(2)
if replyType.Kind() == reflect.Func {
// this is a streaming call
stream = true
if replyType.NumIn() != 1 {
log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn())
continue
}
if replyType.In(0).Kind() != reflect.Interface {
log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0))
continue
}
if replyType.NumOut() != 1 {
log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut())
continue
}
if returnType := replyType.Out(0); returnType != typeOfError {
log.Println("method", mname, "sendReply returns", returnType.String(), "not error")
continue
}
} else if replyType.Kind() != reflect.Ptr {
log.Println("method", mname, "reply type not a pointer:", replyType)
continue
}
// Reply type must be exported.
if !isExportedOrBuiltinType(replyType) {
log.Println("method", mname, "reply type not exported:", replyType)
continue
}
// Method needs one out.
if mtype.NumOut() != 1 {
log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
continue
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
log.Println("method", mname, "returns", returnType.String(), "not error")
continue
}
s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType, stream: stream}
}
if len(s.method) == 0 {
@ -377,14 +399,22 @@ func (m *methodType) NumCalls() (n uint) {
return n
}
func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec, context interface{}) {
mtype.Lock()
mtype.numCalls++
mtype.Unlock()
function := mtype.method.Func
var returnValues []reflect.Value
if !mtype.stream {
// Invoke the method, providing a new value for the reply.
returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
if mtype.TakesContext() {
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(context), argv, replyv})
} else {
returnValues = function.Call([]reflect.Value{s.rcvr, argv, replyv})
}
// The return value for the method is an error.
errInter := returnValues[0].Interface()
errmsg := ""
@ -393,63 +423,68 @@ func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, r
}
server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true)
server.freeRequest(req)
} else {
// declare a local error to see if we errored out already
// keep track of the type, to make sure we return
// the same one consistently
var lastError error
var firstType reflect.Type
sendReply := func(oneReply interface{}) error {
// we already triggered an error, we're done
if lastError != nil {
return lastError
}
// check the oneReply has the right type using reflection
typ := reflect.TypeOf(oneReply)
if firstType == nil {
firstType = typ
} else {
if firstType != typ {
log.Println("passing wrong type to sendReply",
firstType, "!=", typ)
lastError = errors.New("rpc: passing wrong type to sendReply")
return lastError
}
}
lastError = server.sendResponse(sending, req, oneReply, codec, "", false)
if lastError != nil {
return lastError
}
// we manage to send, we're good
return nil
}
// Invoke the method, providing a new value for the reply.
returnValues := function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)})
errInter := returnValues[0].Interface()
errmsg := ""
if errInter != nil {
// the function returned an error, we use that
errmsg = errInter.(error).Error()
} else if lastError != nil {
// we had an error inside sendReply, we use that
errmsg = lastError.Error()
} else {
// no error, we send the special EOS error
errmsg = lastStreamResponseError
}
// this is the last packet, we don't do anything with
// the error here (well sendStreamResponse will log it
// already)
server.sendResponse(sending, req, nil, codec, errmsg, true)
server.freeRequest(req)
return
}
// declare a local error to see if we errored out already
// keep track of the type, to make sure we return
// the same one consistently
var lastError error
var firstType reflect.Type
sendReply := func(oneReply interface{}) error {
// we already triggered an error, we're done
if lastError != nil {
return lastError
}
// check the oneReply has the right type using reflection
typ := reflect.TypeOf(oneReply)
if firstType == nil {
firstType = typ
} else {
if firstType != typ {
log.Println("passing wrong type to sendReply",
firstType, "!=", typ)
lastError = errors.New("rpc: passing wrong type to sendReply")
return lastError
}
}
lastError = server.sendResponse(sending, req, oneReply, codec, "", false)
if lastError != nil {
return lastError
}
// we manage to send, we're good
return nil
}
// Invoke the method, providing a new value for the reply.
if mtype.TakesContext() {
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(context), argv, reflect.ValueOf(sendReply)})
} else {
returnValues = function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)})
}
errInter := returnValues[0].Interface()
errmsg := ""
if errInter != nil {
// the function returned an error, we use that
errmsg = errInter.(error).Error()
} else if lastError != nil {
// we had an error inside sendReply, we use that
errmsg = lastError.Error()
} else {
// no error, we send the special EOS error
errmsg = lastStreamResponseError
}
// this is the last packet, we don't do anything with
// the error here (well sendStreamResponse will log it
// already)
server.sendResponse(sending, req, nil, codec, errmsg, true)
server.freeRequest(req)
}
type gobServerCodec struct {
@ -487,14 +522,26 @@ func (c *gobServerCodec) Close() error {
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
server.ServeConnWithContext(conn, nil)
}
// ServeConnWithContext is like ServeConn but makes it possible to
// pass a connection context to the RPC methods.
func (server *Server) ServeConnWithContext(conn io.ReadWriteCloser, context interface{}) {
buf := bufio.NewWriter(conn)
srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf}
server.ServeCodec(srv)
server.ServeCodecWithContext(srv, context)
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) {
server.ServeCodecWithContext(codec, nil)
}
// ServeCodecWithContext is like ServeCodec but it makes it possible
// to pass a connection context to the RPC methods.
func (server *Server) ServeCodecWithContext(codec ServerCodec, context interface{}) {
sending := new(sync.Mutex)
for {
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
@ -512,14 +559,27 @@ func (server *Server) ServeCodec(codec ServerCodec) {
}
continue
}
go service.call(server, sending, mtype, req, argv, replyv, codec)
go service.call(server, sending, mtype, req, argv, replyv, codec, context)
}
codec.Close()
}
func (mtype methodType) prepareContext(context interface{}) reflect.Value {
if contextv := reflect.ValueOf(context); contextv.IsValid() {
return contextv
}
return reflect.Zero(mtype.ContextType)
}
// ServeRequest is like ServeCodec but synchronously serves a single request.
// It does not close the codec upon completion.
func (server *Server) ServeRequest(codec ServerCodec) error {
return server.ServeRequestWithContext(codec, nil)
}
// ServeRequestWithContext is like ServeRequest but makes it possible
// to pass a connection context to the RPC methods.
func (server *Server) ServeRequestWithContext(codec ServerCodec, context interface{}) error {
sending := new(sync.Mutex)
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
@ -533,7 +593,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error {
}
return err
}
service.call(server, sending, mtype, req, argv, replyv, codec)
service.call(server, sending, mtype, req, argv, replyv, codec, context)
return nil
}
@ -690,19 +750,38 @@ type ServerCodec interface {
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
func ServeConn(conn io.ReadWriteCloser) {
DefaultServer.ServeConn(conn)
ServeConnWithContext(conn, nil)
}
// ServeConnWithContext is like ServeConn but it allows to pass a
// connection context to the RPC methods.
func ServeConnWithContext(conn io.ReadWriteCloser, context interface{}) {
DefaultServer.ServeConnWithContext(conn, context)
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func ServeCodec(codec ServerCodec) {
DefaultServer.ServeCodec(codec)
ServeCodecWithContext(codec, nil)
}
// ServeCodecWithContext is like ServeCodec but it allows to pass a
// connection context to the RPC methods.
func ServeCodecWithContext(codec ServerCodec, context interface{}) {
DefaultServer.ServeCodecWithContext(codec, context)
}
// ServeRequest is like ServeCodec but synchronously serves a single request.
// It does not close the codec upon completion.
func ServeRequest(codec ServerCodec) error {
return DefaultServer.ServeRequest(codec)
return ServeRequestWithContext(codec, nil)
}
// ServeRequestWithContext is like ServeRequest but it allows to pass
// a connection context to the RPC methods.
func ServeRequestWithContext(codec ServerCodec, context interface{}) error {
return DefaultServer.ServeRequestWithContext(codec, context)
}
// Accept accepts connections on the listener and serves requests

Просмотреть файл

@ -74,6 +74,10 @@ func (t *Arith) Error(args *Args, reply *Reply) error {
panic("ERROR")
}
func (t *Arith) TakesContext(context interface{}, args string, reply *string) error {
return nil
}
func listenTCP() (net.Listener, string) {
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
if e != nil {
@ -233,6 +237,13 @@ func testRPC(t *testing.T, addr string) {
if reply.C != args.A*args.B {
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
}
// Takes context
emptyString := ""
err = client.Call("Arith.TakesContext", "", &emptyString)
if err != nil {
t.Errorf("TakesContext: expected no error but got string %q", err.Error())
}
}
func TestHTTP(t *testing.T) {

Просмотреть файл

@ -9,6 +9,7 @@ import (
"code.google.com/p/vitess/go/relog"
rpc "code.google.com/p/vitess/go/rpcplus"
"code.google.com/p/vitess/go/rpcwrap/proto"
)
// UnusedArgument is a type used to indicate an argument that is
@ -64,10 +65,10 @@ func LoadCredentials(filename string) error {
// Authenticate returns true if it the client manages to authenticate
// the codec in at most maxRequest number of requests.
func Authenticate(c rpc.ServerCodec) (bool, error) {
func Authenticate(c rpc.ServerCodec, context *proto.Context) (bool, error) {
auth := newAuthenticatedCodec(c)
for i := 0; i < CRAMMD5MaxRequests; i++ {
err := AuthenticationServer.ServeRequest(auth)
err := AuthenticationServer.ServeRequestWithContext(auth, context)
if err != nil {
return false, err
}
@ -90,7 +91,7 @@ func (a *AuthenticatorCRAMMD5) GetNewChallenge(_ UnusedArgument, reply *GetNewCh
}
// Authenticate checks if the client proof is correct.
func (a *AuthenticatorCRAMMD5) Authenticate(req *AuthenticateRequest, reply *AuthenticateReply) error {
func (a *AuthenticatorCRAMMD5) Authenticate(context proto.Context, req *AuthenticateRequest, reply *AuthenticateReply) error {
username := strings.SplitN(req.Proof, " ", 2)[0]
secrets, ok := a.Credentials[username]
if !ok {
@ -103,6 +104,7 @@ func (a *AuthenticatorCRAMMD5) Authenticate(req *AuthenticateRequest, reply *Aut
}
for _, secret := range secrets {
if expected := CRAMMD5GetExpected(username, secret, req.state.challenge); expected == req.Proof {
context.Username = username
return nil
}
}

Просмотреть файл

@ -0,0 +1,6 @@
package proto
type Context struct {
RemoteAddr string
Username string
}

Просмотреть файл

@ -15,6 +15,7 @@ import (
"code.google.com/p/vitess/go/relog"
rpc "code.google.com/p/vitess/go/rpcplus"
"code.google.com/p/vitess/go/rpcwrap/auth"
"code.google.com/p/vitess/go/rpcwrap/proto"
)
const (
@ -121,11 +122,11 @@ func RegisterAuthenticated(rcvr interface{}) error {
// ServeCodec calls ServeCodec for the appropriate server
// (authenticated or default).
func (h *rpcHandler) ServeCodec(c rpc.ServerCodec) {
func (h *rpcHandler) ServeCodecWithContext(c rpc.ServerCodec, context *proto.Context) {
if h.useAuth {
AuthenticatedServer.ServeCodec(c)
AuthenticatedServer.ServeCodecWithContext(c, context)
} else {
rpc.ServeCodec(c)
rpc.ServeCodecWithContext(c, context)
}
}
@ -142,9 +143,9 @@ func (h *rpcHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
codec := h.cFactory(NewBufferedConnection(conn))
context := &proto.Context{RemoteAddr: req.RemoteAddr}
if h.useAuth {
if authenticated, err := auth.Authenticate(codec); !authenticated {
if authenticated, err := auth.Authenticate(codec, context); !authenticated {
if err != nil {
relog.Error("authentication erred at %s: %v", req.RemoteAddr, err)
}
@ -152,8 +153,7 @@ func (h *rpcHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
return
}
}
h.ServeCodec(codec)
h.ServeCodecWithContext(codec, context)
}
func GetRpcPath(codecName string, auth bool) string {
@ -171,7 +171,7 @@ type httpHandler struct {
func (hh *httpHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
conn := &httpConnectionBroker{c, req.Body}
codec := hh.cFactory(conn)
if err := rpc.ServeRequest(codec); err != nil {
if err := rpc.ServeRequestWithContext(codec, &proto.Context{RemoteAddr: req.RemoteAddr}); err != nil {
relog.Error("rpcwrap: %v", err)
}
}

Просмотреть файл

@ -7,6 +7,7 @@ package proto
import (
mproto "code.google.com/p/vitess/go/mysql/proto"
"code.google.com/p/vitess/go/rpcwrap"
rpcproto "code.google.com/p/vitess/go/rpcwrap/proto"
)
// defines the RPC services
@ -17,21 +18,21 @@ type SqlQuery interface {
// FIXME(sugu) Note the client will support both returning an
// int64 or a structure. Using the structure will be rolled
// out after the client is rolled out.
Begin(session *Session, transactionId *int64) error
Commit(session *Session, noOutput *string) error
Rollback(session *Session, noOutput *string) error
Begin(context *rpcproto.Context, session *Session, transactionId *int64) error
Commit(context *rpcproto.Context, session *Session, noOutput *string) error
Rollback(context *rpcproto.Context, session *Session, noOutput *string) error
CreateReserved(session *Session, connectionInfo *ConnectionInfo) error
CloseReserved(session *Session, noOutput *string) error
Execute(query *Query, reply *mproto.QueryResult) error
StreamExecute(query *Query, sendReply func(reply interface{}) error) error
ExecuteBatch(queryList *QueryList, reply *QueryResultList) error
Execute(context *rpcproto.Context, query *Query, reply *mproto.QueryResult) error
StreamExecute(context *rpcproto.Context, query *Query, sendReply func(reply interface{}) error) error
ExecuteBatch(context *rpcproto.Context, queryList *QueryList, reply *QueryResultList) error
Invalidate(cacheInvalidate *CacheInvalidate, noOutput *string) error
InvalidateForDDL(ddl *DDLInvalidate, noOutput *string) error
Invalidate(context *rpcproto.Context, cacheInvalidate *CacheInvalidate, noOutput *string) error
InvalidateForDDL(context *rpcproto.Context, ddl *DDLInvalidate, noOutput *string) error
Ping(query *string, reply *string) error
Ping(context *rpcproto.Context, query *string, reply *string) error
}
// helper method to register the server (does interface checking)

Просмотреть файл

@ -17,6 +17,7 @@ import (
"code.google.com/p/vitess/go/mysql"
mproto "code.google.com/p/vitess/go/mysql/proto"
"code.google.com/p/vitess/go/relog"
rpcproto "code.google.com/p/vitess/go/rpcwrap/proto"
"code.google.com/p/vitess/go/stats"
"code.google.com/p/vitess/go/vt/sqlparser"
"code.google.com/p/vitess/go/vt/tabletserver/proto"
@ -275,9 +276,10 @@ func (sq *SqlQuery) GetSessionId(sessionParams *proto.SessionParams, sessionInfo
return nil
}
func (sq *SqlQuery) Begin(session *proto.Session, transactionId *int64) (err error) {
func (sq *SqlQuery) Begin(context *rpcproto.Context, session *proto.Session, transactionId *int64) (err error) {
logStats := newSqlQueryStats("Begin")
logStats.OriginalSql = "begin"
logStats.context = context
logStats.StartTime = time.Now()
defer func() {
logStats.EndTime = time.Now()
@ -301,9 +303,10 @@ func (sq *SqlQuery) Begin(session *proto.Session, transactionId *int64) (err err
return nil
}
func (sq *SqlQuery) Commit(session *proto.Session, noOutput *string) (err error) {
func (sq *SqlQuery) Commit(context *rpcproto.Context, session *proto.Session, noOutput *string) (err error) {
logStats := newSqlQueryStats("Commit")
logStats.OriginalSql = "commit"
logStats.context = context
logStats.StartTime = time.Now()
defer func() {
logStats.EndTime = time.Now()
@ -336,9 +339,10 @@ func (sq *SqlQuery) invalidateRows(logStats *sqlQueryStats, dirtyTables map[stri
}
}
func (sq *SqlQuery) Rollback(session *proto.Session, noOutput *string) (err error) {
func (sq *SqlQuery) Rollback(context *rpcproto.Context, session *proto.Session, noOutput *string) (err error) {
logStats := newSqlQueryStats("Rollback")
logStats.StartTime = time.Now()
logStats.context = context
defer func() {
logStats.EndTime = time.Now()
sqlQueryLogger.Send(logStats)
@ -388,9 +392,9 @@ func handleExecError(query *proto.Query, err *error) {
}
}
func (sq *SqlQuery) Execute(query *proto.Query, reply *mproto.QueryResult) (err error) {
func (sq *SqlQuery) Execute(context *rpcproto.Context, query *proto.Query, reply *mproto.QueryResult) (err error) {
logStats := newSqlQueryStats("Execute")
logStats.context = context
logStats.StartTime = time.Now()
defer func() {
logStats.EndTime = time.Now()
@ -499,8 +503,9 @@ func (sq *SqlQuery) Execute(query *proto.Query, reply *mproto.QueryResult) (err
// the first QueryResult will have Fields set (and Rows nil)
// the subsequent QueryResult will have Rows set (and Fields nil)
func (sq *SqlQuery) StreamExecute(query *proto.Query, sendReply func(reply interface{}) error) (err error) {
func (sq *SqlQuery) StreamExecute(context *rpcproto.Context, query *proto.Query, sendReply func(reply interface{}) error) (err error) {
logStats := newSqlQueryStats("StreamExecute")
logStats.context = context
defer func() {
logStats.EndTime = time.Now()
sqlQueryLogger.Send(logStats)
@ -552,7 +557,8 @@ func (sq *SqlQuery) StreamExecute(query *proto.Query, sendReply func(reply inter
return nil
}
func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryResultList) (err error) {
func (sq *SqlQuery) ExecuteBatch(context *rpcproto.Context, queryList *proto.QueryList, reply *proto.QueryResultList) (err error) {
defer handleError(&err)
ql := queryList.List
if len(ql) == 0 {
@ -576,7 +582,7 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
if session.TransactionId != 0 {
panic(NewTabletError(FAIL, "Nested transactions disallowed"))
}
if err = sq.Begin(&session, &session.TransactionId); err != nil {
if err = sq.Begin(context, &session, &session.TransactionId); err != nil {
return err
}
begin_called = true
@ -585,7 +591,7 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
if !begin_called {
panic(NewTabletError(FAIL, "Cannot commit without begin"))
}
if err = sq.Commit(&session, &noOutput); err != nil {
if err = sq.Commit(context, &session, &noOutput); err != nil {
return err
}
session.TransactionId = 0
@ -596,9 +602,9 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
query.ConnectionId = session.ConnectionId
query.SessionId = session.SessionId
var localReply mproto.QueryResult
if err = sq.Execute(&query, &localReply); err != nil {
if err = sq.Execute(context, &query, &localReply); err != nil {
if begin_called {
sq.Rollback(&session, &noOutput)
sq.Rollback(context, &session, &noOutput)
}
return err
}
@ -606,14 +612,16 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
}
}
if begin_called {
sq.Rollback(&session, &noOutput)
sq.Rollback(context, &session, &noOutput)
panic(NewTabletError(FAIL, "begin called with no commit"))
}
return nil
}
func (sq *SqlQuery) Invalidate(cacheInvalidate *proto.CacheInvalidate, noOutput *string) (err error) {
func (sq *SqlQuery) Invalidate(context *rpcproto.Context, cacheInvalidate *proto.CacheInvalidate, noOutput *string) (err error) {
logStats := newSqlQueryStats("Invalidate")
logStats.StartTime = time.Now()
logStats.context = context
defer func() {
logStats.EndTime = time.Now()
sqlQueryLogger.Send(logStats)
@ -649,8 +657,10 @@ func (sq *SqlQuery) Invalidate(cacheInvalidate *proto.CacheInvalidate, noOutput
return nil
}
func (sq *SqlQuery) InvalidateForDDL(ddl *proto.DDLInvalidate, noOutput *string) (err error) {
func (sq *SqlQuery) InvalidateForDDL(context *rpcproto.Context, ddl *proto.DDLInvalidate, noOutput *string) (err error) {
logStats := newSqlQueryStats("InvalidateForDDL")
logStats.context = context
logStats.StartTime = time.Now()
defer func() {
logStats.EndTime = time.Now()
sqlQueryLogger.Send(logStats)
@ -678,8 +688,10 @@ func (sq *SqlQuery) InvalidateForDDL(ddl *proto.DDLInvalidate, noOutput *string)
return nil
}
func (sq *SqlQuery) Ping(query *string, reply *string) error {
func (sq *SqlQuery) Ping(context *rpcproto.Context, query *string, reply *string) error {
logStats := newSqlQueryStats("Ping")
logStats.StartTime = time.Now()
logStats.context = context
defer func() {
logStats.EndTime = time.Now()
sqlQueryLogger.Send(logStats)

Просмотреть файл

@ -8,6 +8,7 @@ import (
"time"
"code.google.com/p/vitess/go/relog"
"code.google.com/p/vitess/go/rpcwrap/proto"
"code.google.com/p/vitess/go/streamlog"
)
@ -37,6 +38,7 @@ type sqlQueryStats struct {
CacheInvalidations int64
QuerySources byte
Rows [][]interface{}
context *proto.Context
}
func newSqlQueryStats(methodName string) *sqlQueryStats {
@ -94,9 +96,9 @@ func (stats *sqlQueryStats) FmtBindVariables() string {
for k, v := range stats.BindVariables {
switch val := v.(type) {
case string:
scrubbed[k] = "string " + string(len(val))
scrubbed[k] = fmt.Sprintf("string %v", len(val))
case []byte:
scrubbed[k] = "bytes " + string(len(val))
scrubbed[k] = fmt.Sprintf("bytes %v", len(val))
default:
scrubbed[k] = v
}
@ -133,11 +135,21 @@ func (stats *sqlQueryStats) FmtQuerySources() string {
return strings.Join(sources[:n], ",")
}
func (log sqlQueryStats) RemoteAddr() string {
return log.context.RemoteAddr
}
func (log sqlQueryStats) Username() string {
return log.context.Username
}
//String returns a tab separated list of logged fields.
func (log sqlQueryStats) String() string {
return fmt.Sprintf(
"%v\t%v\t%v\t%v\t%v\t%q\t%v\t%v\t%q\t%v\t%v\t%v\t%v\t%v\t%v\t%v\t",
"%v\t%v\t%v\t%v\t%v\t%v\t%v\t%q\t%v\t%v\t%q\t%v\t%v\t%v\t%v\t%v\t%v\t%v\t",
log.Method,
log.RemoteAddr(),
log.Username(),
log.StartTime,
log.EndTime,
log.TotalTime(),

Просмотреть файл

@ -59,8 +59,11 @@ class TestAuthentication(unittest.TestCase):
@classmethod
def tearDownVTOCC(klass):
klass.process.kill()
klass.process.wait()
try:
klass.process.kill()
klass.process.wait()
except AttributeError:
pass
def call(self, *args, **kwargs):
return self.conn.client.call(*args, **kwargs)
@ -112,7 +115,7 @@ if __name__=="__main__":
try:
TestAuthentication.setUpVTOCC(options.dbconfig, options.auth_credentials)
unittest.main()
unittest.main(argv=["auth_test.py"])
finally:
print "Waiting for vtocc to terminate...",
TestAuthentication.tearDownVTOCC()

Просмотреть файл

@ -14,6 +14,8 @@ class Log(object):
def __init__(self, line):
self.line = line
(self.method,
self.remote_address,
self.username,
self.start_time,
self.end_time,
self.total_time,
@ -99,6 +101,10 @@ class Log(object):
if rewritten != self.rewritten_sql:
self.fail("Bad rewritten SQL", rewritten, self.rewritten_sql)
def check_remote_address(self, case):
if not self.remote_address.startswith(case.remote_address):
return self.fail("Bad RemoteAddr", case.remote_address, self.remote_address)
def check_number_of_queries(self, case):
if case.rewritten is not None and int(self.number_of_queries) != len(case.rewritten):
return self.fail("wrong number of queries", len(case.rewritten), int(self.number_of_queries))
@ -106,7 +112,8 @@ class Log(object):
class Case(object):
def __init__(self, sql, bindings=None, result=None, rewritten=None, doc='',
cache_table="vtocc_cached", query_plan=None, cache_hits=None,
cache_misses=None, cache_absent=None, cache_invalidations=None):
cache_misses=None, cache_absent=None, cache_invalidations=None,
remote_address="[::1]"):
# For all cache_* parameters, a number n means "check this value
# is exactly n," while None means "I am not interested in this
# value, leave it alone."
@ -123,6 +130,7 @@ class Case(object):
self.cache_misses = cache_misses
self.cache_absent = cache_absent
self.cache_invalidations = cache_invalidations
self.remote_address = remote_address
def normalizelog(self, data):
return [line.split("INFO: ")[-1]

Просмотреть файл

@ -95,6 +95,9 @@ cases = [
sql='select /* order */ * from vtocc_a order by id desc',
result=[(1L, 2L, 'bcde', 'fghi'), (1L, 1L, 'abcd', 'efgh')],
rewritten='select /* order */ * from vtocc_a order by id desc limit 10001'),
Case(doc='string in bindings are not shown in logs',
sql='select /* limit */ %(somestring)s, eid, id from vtocc_a limit %(a)s',
bindings={"somestring": "Ala ma kota.", "a": 1}),
MultiCase(
'simple insert',