зеркало из https://github.com/github/vitess-gh.git
RPC methods take a context argument.
* * * The context object is passed around. * * * the context is passed as interface{} * * * Bring back old function signatures. * * * Context-taking RPC method can declare their arguments using the real context type. * * * Use inteface{} for context in rpcplus.
This commit is contained in:
Родитель
c9c355d859
Коммит
f4024cdfa5
|
@ -135,5 +135,11 @@ func (c *serverCodec) Close() error {
|
|||
// ServeConn blocks, serving the connection until the client hangs up.
|
||||
// The caller typically invokes ServeConn in a go statement.
|
||||
func ServeConn(conn io.ReadWriteCloser) {
|
||||
rpc.ServeCodec(NewServerCodec(conn))
|
||||
ServeConnWithContext(conn, nil)
|
||||
}
|
||||
|
||||
// ServeConnWithContext is like ServeConn but it allows to pass a
|
||||
// connection context to the RPC methods.
|
||||
func ServeConnWithContext(conn io.ReadWriteCloser, context interface{}) {
|
||||
rpc.ServeCodecWithContext(NewServerCodec(conn), context)
|
||||
}
|
||||
|
|
|
@ -150,12 +150,17 @@ const (
|
|||
var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
|
||||
|
||||
type methodType struct {
|
||||
sync.Mutex // protects counters
|
||||
method reflect.Method
|
||||
ArgType reflect.Type
|
||||
ReplyType reflect.Type
|
||||
stream bool
|
||||
numCalls uint
|
||||
sync.Mutex // protects counters
|
||||
method reflect.Method
|
||||
ArgType reflect.Type
|
||||
ReplyType reflect.Type
|
||||
ContextType reflect.Type
|
||||
stream bool
|
||||
numCalls uint
|
||||
}
|
||||
|
||||
func (m methodType) TakesContext() bool {
|
||||
return m.ContextType != nil
|
||||
}
|
||||
|
||||
type service struct {
|
||||
|
@ -239,6 +244,86 @@ func (server *Server) RegisterName(name string, rcvr interface{}) error {
|
|||
return server.register(rcvr, name, true)
|
||||
}
|
||||
|
||||
// prepareMethod returns a methodType for the provided method or nil
|
||||
// in case if the method was unsuitable.
|
||||
func prepareMethod(method reflect.Method) *methodType {
|
||||
mtype := method.Type
|
||||
mname := method.Name
|
||||
var replyType, argType, contextType reflect.Type
|
||||
|
||||
stream := false
|
||||
// Method must be exported.
|
||||
if method.PkgPath != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch mtype.NumIn() {
|
||||
case 3:
|
||||
// normal method
|
||||
argType = mtype.In(1)
|
||||
replyType = mtype.In(2)
|
||||
contextType = nil
|
||||
case 4:
|
||||
// method that takes a context
|
||||
argType = mtype.In(2)
|
||||
replyType = mtype.In(3)
|
||||
contextType = mtype.In(1)
|
||||
default:
|
||||
log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
|
||||
return nil
|
||||
}
|
||||
|
||||
// First arg need not be a pointer.
|
||||
if !isExportedOrBuiltinType(argType) {
|
||||
log.Println(mname, "argument type not exported:", argType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// the second argument will tell us if it's a streaming call
|
||||
// or a regular call
|
||||
if replyType.Kind() == reflect.Func {
|
||||
// this is a streaming call
|
||||
stream = true
|
||||
if replyType.NumIn() != 1 {
|
||||
log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn())
|
||||
return nil
|
||||
}
|
||||
if replyType.In(0).Kind() != reflect.Interface {
|
||||
log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0))
|
||||
return nil
|
||||
}
|
||||
if replyType.NumOut() != 1 {
|
||||
log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut())
|
||||
return nil
|
||||
}
|
||||
if returnType := replyType.Out(0); returnType != typeOfError {
|
||||
log.Println("method", mname, "sendReply returns", returnType.String(), "not error")
|
||||
return nil
|
||||
}
|
||||
|
||||
} else if replyType.Kind() != reflect.Ptr {
|
||||
log.Println("method", mname, "reply type not a pointer:", replyType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reply type must be exported.
|
||||
if !isExportedOrBuiltinType(replyType) {
|
||||
log.Println("method", mname, "reply type not exported:", replyType)
|
||||
return nil
|
||||
}
|
||||
// Method needs one out.
|
||||
if mtype.NumOut() != 1 {
|
||||
log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
|
||||
return nil
|
||||
}
|
||||
// The return type of the method must be error.
|
||||
if returnType := mtype.Out(0); returnType != typeOfError {
|
||||
log.Println("method", mname, "returns", returnType.String(), "not error")
|
||||
return nil
|
||||
}
|
||||
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
|
||||
}
|
||||
|
||||
func (server *Server) register(rcvr interface{}, name string, useName bool) error {
|
||||
server.mu.Lock()
|
||||
defer server.mu.Unlock()
|
||||
|
@ -269,72 +354,9 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
|
|||
// Install the methods
|
||||
for m := 0; m < s.typ.NumMethod(); m++ {
|
||||
method := s.typ.Method(m)
|
||||
mtype := method.Type
|
||||
mname := method.Name
|
||||
var replyType reflect.Type
|
||||
stream := false
|
||||
// Method must be exported.
|
||||
if method.PkgPath != "" {
|
||||
continue
|
||||
if mt := prepareMethod(method); mt != nil {
|
||||
s.method[method.Name] = mt
|
||||
}
|
||||
// Method needs three ins: receiver, *args, *reply.
|
||||
// Or: receiver, *args, sendReply func.
|
||||
if mtype.NumIn() != 3 {
|
||||
log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
|
||||
continue
|
||||
}
|
||||
|
||||
// First arg need not be a pointer.
|
||||
argType := mtype.In(1)
|
||||
if !isExportedOrBuiltinType(argType) {
|
||||
log.Println(mname, "argument type not exported:", argType)
|
||||
continue
|
||||
}
|
||||
|
||||
// the second argument will tell us if it's a streaming call
|
||||
// or a regular call
|
||||
replyType = mtype.In(2)
|
||||
if replyType.Kind() == reflect.Func {
|
||||
// this is a streaming call
|
||||
stream = true
|
||||
if replyType.NumIn() != 1 {
|
||||
log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn())
|
||||
continue
|
||||
}
|
||||
if replyType.In(0).Kind() != reflect.Interface {
|
||||
log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0))
|
||||
continue
|
||||
}
|
||||
if replyType.NumOut() != 1 {
|
||||
log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut())
|
||||
continue
|
||||
}
|
||||
if returnType := replyType.Out(0); returnType != typeOfError {
|
||||
log.Println("method", mname, "sendReply returns", returnType.String(), "not error")
|
||||
continue
|
||||
}
|
||||
|
||||
} else if replyType.Kind() != reflect.Ptr {
|
||||
log.Println("method", mname, "reply type not a pointer:", replyType)
|
||||
continue
|
||||
}
|
||||
|
||||
// Reply type must be exported.
|
||||
if !isExportedOrBuiltinType(replyType) {
|
||||
log.Println("method", mname, "reply type not exported:", replyType)
|
||||
continue
|
||||
}
|
||||
// Method needs one out.
|
||||
if mtype.NumOut() != 1 {
|
||||
log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
|
||||
continue
|
||||
}
|
||||
// The return type of the method must be error.
|
||||
if returnType := mtype.Out(0); returnType != typeOfError {
|
||||
log.Println("method", mname, "returns", returnType.String(), "not error")
|
||||
continue
|
||||
}
|
||||
s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType, stream: stream}
|
||||
}
|
||||
|
||||
if len(s.method) == 0 {
|
||||
|
@ -377,14 +399,22 @@ func (m *methodType) NumCalls() (n uint) {
|
|||
return n
|
||||
}
|
||||
|
||||
func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
|
||||
func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec, context interface{}) {
|
||||
mtype.Lock()
|
||||
mtype.numCalls++
|
||||
mtype.Unlock()
|
||||
function := mtype.method.Func
|
||||
var returnValues []reflect.Value
|
||||
|
||||
if !mtype.stream {
|
||||
|
||||
// Invoke the method, providing a new value for the reply.
|
||||
returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
|
||||
if mtype.TakesContext() {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(context), argv, replyv})
|
||||
} else {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, argv, replyv})
|
||||
}
|
||||
|
||||
// The return value for the method is an error.
|
||||
errInter := returnValues[0].Interface()
|
||||
errmsg := ""
|
||||
|
@ -393,63 +423,68 @@ func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, r
|
|||
}
|
||||
server.sendResponse(sending, req, replyv.Interface(), codec, errmsg, true)
|
||||
server.freeRequest(req)
|
||||
} else {
|
||||
// declare a local error to see if we errored out already
|
||||
// keep track of the type, to make sure we return
|
||||
// the same one consistently
|
||||
var lastError error
|
||||
var firstType reflect.Type
|
||||
|
||||
sendReply := func(oneReply interface{}) error {
|
||||
|
||||
// we already triggered an error, we're done
|
||||
if lastError != nil {
|
||||
return lastError
|
||||
}
|
||||
|
||||
// check the oneReply has the right type using reflection
|
||||
typ := reflect.TypeOf(oneReply)
|
||||
if firstType == nil {
|
||||
firstType = typ
|
||||
} else {
|
||||
if firstType != typ {
|
||||
log.Println("passing wrong type to sendReply",
|
||||
firstType, "!=", typ)
|
||||
lastError = errors.New("rpc: passing wrong type to sendReply")
|
||||
return lastError
|
||||
}
|
||||
}
|
||||
|
||||
lastError = server.sendResponse(sending, req, oneReply, codec, "", false)
|
||||
if lastError != nil {
|
||||
return lastError
|
||||
}
|
||||
|
||||
// we manage to send, we're good
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invoke the method, providing a new value for the reply.
|
||||
returnValues := function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)})
|
||||
errInter := returnValues[0].Interface()
|
||||
errmsg := ""
|
||||
if errInter != nil {
|
||||
// the function returned an error, we use that
|
||||
errmsg = errInter.(error).Error()
|
||||
} else if lastError != nil {
|
||||
// we had an error inside sendReply, we use that
|
||||
errmsg = lastError.Error()
|
||||
} else {
|
||||
// no error, we send the special EOS error
|
||||
errmsg = lastStreamResponseError
|
||||
}
|
||||
|
||||
// this is the last packet, we don't do anything with
|
||||
// the error here (well sendStreamResponse will log it
|
||||
// already)
|
||||
server.sendResponse(sending, req, nil, codec, errmsg, true)
|
||||
server.freeRequest(req)
|
||||
return
|
||||
}
|
||||
|
||||
// declare a local error to see if we errored out already
|
||||
// keep track of the type, to make sure we return
|
||||
// the same one consistently
|
||||
var lastError error
|
||||
var firstType reflect.Type
|
||||
|
||||
sendReply := func(oneReply interface{}) error {
|
||||
|
||||
// we already triggered an error, we're done
|
||||
if lastError != nil {
|
||||
return lastError
|
||||
}
|
||||
|
||||
// check the oneReply has the right type using reflection
|
||||
typ := reflect.TypeOf(oneReply)
|
||||
if firstType == nil {
|
||||
firstType = typ
|
||||
} else {
|
||||
if firstType != typ {
|
||||
log.Println("passing wrong type to sendReply",
|
||||
firstType, "!=", typ)
|
||||
lastError = errors.New("rpc: passing wrong type to sendReply")
|
||||
return lastError
|
||||
}
|
||||
}
|
||||
|
||||
lastError = server.sendResponse(sending, req, oneReply, codec, "", false)
|
||||
if lastError != nil {
|
||||
return lastError
|
||||
}
|
||||
|
||||
// we manage to send, we're good
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invoke the method, providing a new value for the reply.
|
||||
if mtype.TakesContext() {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(context), argv, reflect.ValueOf(sendReply)})
|
||||
} else {
|
||||
returnValues = function.Call([]reflect.Value{s.rcvr, argv, reflect.ValueOf(sendReply)})
|
||||
}
|
||||
errInter := returnValues[0].Interface()
|
||||
errmsg := ""
|
||||
if errInter != nil {
|
||||
// the function returned an error, we use that
|
||||
errmsg = errInter.(error).Error()
|
||||
} else if lastError != nil {
|
||||
// we had an error inside sendReply, we use that
|
||||
errmsg = lastError.Error()
|
||||
} else {
|
||||
// no error, we send the special EOS error
|
||||
errmsg = lastStreamResponseError
|
||||
}
|
||||
|
||||
// this is the last packet, we don't do anything with
|
||||
// the error here (well sendStreamResponse will log it
|
||||
// already)
|
||||
server.sendResponse(sending, req, nil, codec, errmsg, true)
|
||||
server.freeRequest(req)
|
||||
}
|
||||
|
||||
type gobServerCodec struct {
|
||||
|
@ -487,14 +522,26 @@ func (c *gobServerCodec) Close() error {
|
|||
// ServeConn uses the gob wire format (see package gob) on the
|
||||
// connection. To use an alternate codec, use ServeCodec.
|
||||
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
|
||||
server.ServeConnWithContext(conn, nil)
|
||||
}
|
||||
|
||||
// ServeConnWithContext is like ServeConn but makes it possible to
|
||||
// pass a connection context to the RPC methods.
|
||||
func (server *Server) ServeConnWithContext(conn io.ReadWriteCloser, context interface{}) {
|
||||
buf := bufio.NewWriter(conn)
|
||||
srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf}
|
||||
server.ServeCodec(srv)
|
||||
server.ServeCodecWithContext(srv, context)
|
||||
}
|
||||
|
||||
// ServeCodec is like ServeConn but uses the specified codec to
|
||||
// decode requests and encode responses.
|
||||
func (server *Server) ServeCodec(codec ServerCodec) {
|
||||
server.ServeCodecWithContext(codec, nil)
|
||||
}
|
||||
|
||||
// ServeCodecWithContext is like ServeCodec but it makes it possible
|
||||
// to pass a connection context to the RPC methods.
|
||||
func (server *Server) ServeCodecWithContext(codec ServerCodec, context interface{}) {
|
||||
sending := new(sync.Mutex)
|
||||
for {
|
||||
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
|
||||
|
@ -512,14 +559,27 @@ func (server *Server) ServeCodec(codec ServerCodec) {
|
|||
}
|
||||
continue
|
||||
}
|
||||
go service.call(server, sending, mtype, req, argv, replyv, codec)
|
||||
go service.call(server, sending, mtype, req, argv, replyv, codec, context)
|
||||
}
|
||||
codec.Close()
|
||||
}
|
||||
|
||||
func (mtype methodType) prepareContext(context interface{}) reflect.Value {
|
||||
if contextv := reflect.ValueOf(context); contextv.IsValid() {
|
||||
return contextv
|
||||
}
|
||||
return reflect.Zero(mtype.ContextType)
|
||||
}
|
||||
|
||||
// ServeRequest is like ServeCodec but synchronously serves a single request.
|
||||
// It does not close the codec upon completion.
|
||||
func (server *Server) ServeRequest(codec ServerCodec) error {
|
||||
return server.ServeRequestWithContext(codec, nil)
|
||||
}
|
||||
|
||||
// ServeRequestWithContext is like ServeRequest but makes it possible
|
||||
// to pass a connection context to the RPC methods.
|
||||
func (server *Server) ServeRequestWithContext(codec ServerCodec, context interface{}) error {
|
||||
sending := new(sync.Mutex)
|
||||
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
|
||||
if err != nil {
|
||||
|
@ -533,7 +593,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
service.call(server, sending, mtype, req, argv, replyv, codec)
|
||||
service.call(server, sending, mtype, req, argv, replyv, codec, context)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -690,19 +750,38 @@ type ServerCodec interface {
|
|||
// ServeConn uses the gob wire format (see package gob) on the
|
||||
// connection. To use an alternate codec, use ServeCodec.
|
||||
func ServeConn(conn io.ReadWriteCloser) {
|
||||
DefaultServer.ServeConn(conn)
|
||||
ServeConnWithContext(conn, nil)
|
||||
}
|
||||
|
||||
// ServeConnWithContext is like ServeConn but it allows to pass a
|
||||
// connection context to the RPC methods.
|
||||
func ServeConnWithContext(conn io.ReadWriteCloser, context interface{}) {
|
||||
DefaultServer.ServeConnWithContext(conn, context)
|
||||
}
|
||||
|
||||
// ServeCodec is like ServeConn but uses the specified codec to
|
||||
// decode requests and encode responses.
|
||||
func ServeCodec(codec ServerCodec) {
|
||||
DefaultServer.ServeCodec(codec)
|
||||
ServeCodecWithContext(codec, nil)
|
||||
}
|
||||
|
||||
// ServeCodecWithContext is like ServeCodec but it allows to pass a
|
||||
// connection context to the RPC methods.
|
||||
func ServeCodecWithContext(codec ServerCodec, context interface{}) {
|
||||
DefaultServer.ServeCodecWithContext(codec, context)
|
||||
}
|
||||
|
||||
// ServeRequest is like ServeCodec but synchronously serves a single request.
|
||||
// It does not close the codec upon completion.
|
||||
func ServeRequest(codec ServerCodec) error {
|
||||
return DefaultServer.ServeRequest(codec)
|
||||
return ServeRequestWithContext(codec, nil)
|
||||
|
||||
}
|
||||
|
||||
// ServeRequestWithContext is like ServeRequest but it allows to pass
|
||||
// a connection context to the RPC methods.
|
||||
func ServeRequestWithContext(codec ServerCodec, context interface{}) error {
|
||||
return DefaultServer.ServeRequestWithContext(codec, context)
|
||||
}
|
||||
|
||||
// Accept accepts connections on the listener and serves requests
|
||||
|
|
|
@ -74,6 +74,10 @@ func (t *Arith) Error(args *Args, reply *Reply) error {
|
|||
panic("ERROR")
|
||||
}
|
||||
|
||||
func (t *Arith) TakesContext(context interface{}, args string, reply *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func listenTCP() (net.Listener, string) {
|
||||
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
|
||||
if e != nil {
|
||||
|
@ -233,6 +237,13 @@ func testRPC(t *testing.T, addr string) {
|
|||
if reply.C != args.A*args.B {
|
||||
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
|
||||
}
|
||||
|
||||
// Takes context
|
||||
emptyString := ""
|
||||
err = client.Call("Arith.TakesContext", "", &emptyString)
|
||||
if err != nil {
|
||||
t.Errorf("TakesContext: expected no error but got string %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"code.google.com/p/vitess/go/relog"
|
||||
rpc "code.google.com/p/vitess/go/rpcplus"
|
||||
"code.google.com/p/vitess/go/rpcwrap/proto"
|
||||
)
|
||||
|
||||
// UnusedArgument is a type used to indicate an argument that is
|
||||
|
@ -64,10 +65,10 @@ func LoadCredentials(filename string) error {
|
|||
|
||||
// Authenticate returns true if it the client manages to authenticate
|
||||
// the codec in at most maxRequest number of requests.
|
||||
func Authenticate(c rpc.ServerCodec) (bool, error) {
|
||||
func Authenticate(c rpc.ServerCodec, context *proto.Context) (bool, error) {
|
||||
auth := newAuthenticatedCodec(c)
|
||||
for i := 0; i < CRAMMD5MaxRequests; i++ {
|
||||
err := AuthenticationServer.ServeRequest(auth)
|
||||
err := AuthenticationServer.ServeRequestWithContext(auth, context)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -90,7 +91,7 @@ func (a *AuthenticatorCRAMMD5) GetNewChallenge(_ UnusedArgument, reply *GetNewCh
|
|||
}
|
||||
|
||||
// Authenticate checks if the client proof is correct.
|
||||
func (a *AuthenticatorCRAMMD5) Authenticate(req *AuthenticateRequest, reply *AuthenticateReply) error {
|
||||
func (a *AuthenticatorCRAMMD5) Authenticate(context proto.Context, req *AuthenticateRequest, reply *AuthenticateReply) error {
|
||||
username := strings.SplitN(req.Proof, " ", 2)[0]
|
||||
secrets, ok := a.Credentials[username]
|
||||
if !ok {
|
||||
|
@ -103,6 +104,7 @@ func (a *AuthenticatorCRAMMD5) Authenticate(req *AuthenticateRequest, reply *Aut
|
|||
}
|
||||
for _, secret := range secrets {
|
||||
if expected := CRAMMD5GetExpected(username, secret, req.state.challenge); expected == req.Proof {
|
||||
context.Username = username
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
package proto
|
||||
|
||||
type Context struct {
|
||||
RemoteAddr string
|
||||
Username string
|
||||
}
|
|
@ -15,6 +15,7 @@ import (
|
|||
"code.google.com/p/vitess/go/relog"
|
||||
rpc "code.google.com/p/vitess/go/rpcplus"
|
||||
"code.google.com/p/vitess/go/rpcwrap/auth"
|
||||
"code.google.com/p/vitess/go/rpcwrap/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -121,11 +122,11 @@ func RegisterAuthenticated(rcvr interface{}) error {
|
|||
|
||||
// ServeCodec calls ServeCodec for the appropriate server
|
||||
// (authenticated or default).
|
||||
func (h *rpcHandler) ServeCodec(c rpc.ServerCodec) {
|
||||
func (h *rpcHandler) ServeCodecWithContext(c rpc.ServerCodec, context *proto.Context) {
|
||||
if h.useAuth {
|
||||
AuthenticatedServer.ServeCodec(c)
|
||||
AuthenticatedServer.ServeCodecWithContext(c, context)
|
||||
} else {
|
||||
rpc.ServeCodec(c)
|
||||
rpc.ServeCodecWithContext(c, context)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,9 +143,9 @@ func (h *rpcHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
|
||||
codec := h.cFactory(NewBufferedConnection(conn))
|
||||
|
||||
context := &proto.Context{RemoteAddr: req.RemoteAddr}
|
||||
if h.useAuth {
|
||||
if authenticated, err := auth.Authenticate(codec); !authenticated {
|
||||
if authenticated, err := auth.Authenticate(codec, context); !authenticated {
|
||||
if err != nil {
|
||||
relog.Error("authentication erred at %s: %v", req.RemoteAddr, err)
|
||||
}
|
||||
|
@ -152,8 +153,7 @@ func (h *rpcHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.ServeCodec(codec)
|
||||
h.ServeCodecWithContext(codec, context)
|
||||
}
|
||||
|
||||
func GetRpcPath(codecName string, auth bool) string {
|
||||
|
@ -171,7 +171,7 @@ type httpHandler struct {
|
|||
func (hh *httpHandler) ServeHTTP(c http.ResponseWriter, req *http.Request) {
|
||||
conn := &httpConnectionBroker{c, req.Body}
|
||||
codec := hh.cFactory(conn)
|
||||
if err := rpc.ServeRequest(codec); err != nil {
|
||||
if err := rpc.ServeRequestWithContext(codec, &proto.Context{RemoteAddr: req.RemoteAddr}); err != nil {
|
||||
relog.Error("rpcwrap: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ package proto
|
|||
import (
|
||||
mproto "code.google.com/p/vitess/go/mysql/proto"
|
||||
"code.google.com/p/vitess/go/rpcwrap"
|
||||
rpcproto "code.google.com/p/vitess/go/rpcwrap/proto"
|
||||
)
|
||||
|
||||
// defines the RPC services
|
||||
|
@ -17,21 +18,21 @@ type SqlQuery interface {
|
|||
// FIXME(sugu) Note the client will support both returning an
|
||||
// int64 or a structure. Using the structure will be rolled
|
||||
// out after the client is rolled out.
|
||||
Begin(session *Session, transactionId *int64) error
|
||||
Commit(session *Session, noOutput *string) error
|
||||
Rollback(session *Session, noOutput *string) error
|
||||
Begin(context *rpcproto.Context, session *Session, transactionId *int64) error
|
||||
Commit(context *rpcproto.Context, session *Session, noOutput *string) error
|
||||
Rollback(context *rpcproto.Context, session *Session, noOutput *string) error
|
||||
|
||||
CreateReserved(session *Session, connectionInfo *ConnectionInfo) error
|
||||
CloseReserved(session *Session, noOutput *string) error
|
||||
|
||||
Execute(query *Query, reply *mproto.QueryResult) error
|
||||
StreamExecute(query *Query, sendReply func(reply interface{}) error) error
|
||||
ExecuteBatch(queryList *QueryList, reply *QueryResultList) error
|
||||
Execute(context *rpcproto.Context, query *Query, reply *mproto.QueryResult) error
|
||||
StreamExecute(context *rpcproto.Context, query *Query, sendReply func(reply interface{}) error) error
|
||||
ExecuteBatch(context *rpcproto.Context, queryList *QueryList, reply *QueryResultList) error
|
||||
|
||||
Invalidate(cacheInvalidate *CacheInvalidate, noOutput *string) error
|
||||
InvalidateForDDL(ddl *DDLInvalidate, noOutput *string) error
|
||||
Invalidate(context *rpcproto.Context, cacheInvalidate *CacheInvalidate, noOutput *string) error
|
||||
InvalidateForDDL(context *rpcproto.Context, ddl *DDLInvalidate, noOutput *string) error
|
||||
|
||||
Ping(query *string, reply *string) error
|
||||
Ping(context *rpcproto.Context, query *string, reply *string) error
|
||||
}
|
||||
|
||||
// helper method to register the server (does interface checking)
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"code.google.com/p/vitess/go/mysql"
|
||||
mproto "code.google.com/p/vitess/go/mysql/proto"
|
||||
"code.google.com/p/vitess/go/relog"
|
||||
rpcproto "code.google.com/p/vitess/go/rpcwrap/proto"
|
||||
"code.google.com/p/vitess/go/stats"
|
||||
"code.google.com/p/vitess/go/vt/sqlparser"
|
||||
"code.google.com/p/vitess/go/vt/tabletserver/proto"
|
||||
|
@ -275,9 +276,10 @@ func (sq *SqlQuery) GetSessionId(sessionParams *proto.SessionParams, sessionInfo
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) Begin(session *proto.Session, transactionId *int64) (err error) {
|
||||
func (sq *SqlQuery) Begin(context *rpcproto.Context, session *proto.Session, transactionId *int64) (err error) {
|
||||
logStats := newSqlQueryStats("Begin")
|
||||
logStats.OriginalSql = "begin"
|
||||
logStats.context = context
|
||||
logStats.StartTime = time.Now()
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
|
@ -301,9 +303,10 @@ func (sq *SqlQuery) Begin(session *proto.Session, transactionId *int64) (err err
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) Commit(session *proto.Session, noOutput *string) (err error) {
|
||||
func (sq *SqlQuery) Commit(context *rpcproto.Context, session *proto.Session, noOutput *string) (err error) {
|
||||
logStats := newSqlQueryStats("Commit")
|
||||
logStats.OriginalSql = "commit"
|
||||
logStats.context = context
|
||||
logStats.StartTime = time.Now()
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
|
@ -336,9 +339,10 @@ func (sq *SqlQuery) invalidateRows(logStats *sqlQueryStats, dirtyTables map[stri
|
|||
}
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) Rollback(session *proto.Session, noOutput *string) (err error) {
|
||||
func (sq *SqlQuery) Rollback(context *rpcproto.Context, session *proto.Session, noOutput *string) (err error) {
|
||||
logStats := newSqlQueryStats("Rollback")
|
||||
logStats.StartTime = time.Now()
|
||||
logStats.context = context
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
sqlQueryLogger.Send(logStats)
|
||||
|
@ -388,9 +392,9 @@ func handleExecError(query *proto.Query, err *error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) Execute(query *proto.Query, reply *mproto.QueryResult) (err error) {
|
||||
func (sq *SqlQuery) Execute(context *rpcproto.Context, query *proto.Query, reply *mproto.QueryResult) (err error) {
|
||||
logStats := newSqlQueryStats("Execute")
|
||||
|
||||
logStats.context = context
|
||||
logStats.StartTime = time.Now()
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
|
@ -499,8 +503,9 @@ func (sq *SqlQuery) Execute(query *proto.Query, reply *mproto.QueryResult) (err
|
|||
|
||||
// the first QueryResult will have Fields set (and Rows nil)
|
||||
// the subsequent QueryResult will have Rows set (and Fields nil)
|
||||
func (sq *SqlQuery) StreamExecute(query *proto.Query, sendReply func(reply interface{}) error) (err error) {
|
||||
func (sq *SqlQuery) StreamExecute(context *rpcproto.Context, query *proto.Query, sendReply func(reply interface{}) error) (err error) {
|
||||
logStats := newSqlQueryStats("StreamExecute")
|
||||
logStats.context = context
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
sqlQueryLogger.Send(logStats)
|
||||
|
@ -552,7 +557,8 @@ func (sq *SqlQuery) StreamExecute(query *proto.Query, sendReply func(reply inter
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryResultList) (err error) {
|
||||
func (sq *SqlQuery) ExecuteBatch(context *rpcproto.Context, queryList *proto.QueryList, reply *proto.QueryResultList) (err error) {
|
||||
|
||||
defer handleError(&err)
|
||||
ql := queryList.List
|
||||
if len(ql) == 0 {
|
||||
|
@ -576,7 +582,7 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
|
|||
if session.TransactionId != 0 {
|
||||
panic(NewTabletError(FAIL, "Nested transactions disallowed"))
|
||||
}
|
||||
if err = sq.Begin(&session, &session.TransactionId); err != nil {
|
||||
if err = sq.Begin(context, &session, &session.TransactionId); err != nil {
|
||||
return err
|
||||
}
|
||||
begin_called = true
|
||||
|
@ -585,7 +591,7 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
|
|||
if !begin_called {
|
||||
panic(NewTabletError(FAIL, "Cannot commit without begin"))
|
||||
}
|
||||
if err = sq.Commit(&session, &noOutput); err != nil {
|
||||
if err = sq.Commit(context, &session, &noOutput); err != nil {
|
||||
return err
|
||||
}
|
||||
session.TransactionId = 0
|
||||
|
@ -596,9 +602,9 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
|
|||
query.ConnectionId = session.ConnectionId
|
||||
query.SessionId = session.SessionId
|
||||
var localReply mproto.QueryResult
|
||||
if err = sq.Execute(&query, &localReply); err != nil {
|
||||
if err = sq.Execute(context, &query, &localReply); err != nil {
|
||||
if begin_called {
|
||||
sq.Rollback(&session, &noOutput)
|
||||
sq.Rollback(context, &session, &noOutput)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -606,14 +612,16 @@ func (sq *SqlQuery) ExecuteBatch(queryList *proto.QueryList, reply *proto.QueryR
|
|||
}
|
||||
}
|
||||
if begin_called {
|
||||
sq.Rollback(&session, &noOutput)
|
||||
sq.Rollback(context, &session, &noOutput)
|
||||
panic(NewTabletError(FAIL, "begin called with no commit"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) Invalidate(cacheInvalidate *proto.CacheInvalidate, noOutput *string) (err error) {
|
||||
func (sq *SqlQuery) Invalidate(context *rpcproto.Context, cacheInvalidate *proto.CacheInvalidate, noOutput *string) (err error) {
|
||||
logStats := newSqlQueryStats("Invalidate")
|
||||
logStats.StartTime = time.Now()
|
||||
logStats.context = context
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
sqlQueryLogger.Send(logStats)
|
||||
|
@ -649,8 +657,10 @@ func (sq *SqlQuery) Invalidate(cacheInvalidate *proto.CacheInvalidate, noOutput
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) InvalidateForDDL(ddl *proto.DDLInvalidate, noOutput *string) (err error) {
|
||||
func (sq *SqlQuery) InvalidateForDDL(context *rpcproto.Context, ddl *proto.DDLInvalidate, noOutput *string) (err error) {
|
||||
logStats := newSqlQueryStats("InvalidateForDDL")
|
||||
logStats.context = context
|
||||
logStats.StartTime = time.Now()
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
sqlQueryLogger.Send(logStats)
|
||||
|
@ -678,8 +688,10 @@ func (sq *SqlQuery) InvalidateForDDL(ddl *proto.DDLInvalidate, noOutput *string)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sq *SqlQuery) Ping(query *string, reply *string) error {
|
||||
func (sq *SqlQuery) Ping(context *rpcproto.Context, query *string, reply *string) error {
|
||||
logStats := newSqlQueryStats("Ping")
|
||||
logStats.StartTime = time.Now()
|
||||
logStats.context = context
|
||||
defer func() {
|
||||
logStats.EndTime = time.Now()
|
||||
sqlQueryLogger.Send(logStats)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"code.google.com/p/vitess/go/relog"
|
||||
"code.google.com/p/vitess/go/rpcwrap/proto"
|
||||
"code.google.com/p/vitess/go/streamlog"
|
||||
)
|
||||
|
||||
|
@ -37,6 +38,7 @@ type sqlQueryStats struct {
|
|||
CacheInvalidations int64
|
||||
QuerySources byte
|
||||
Rows [][]interface{}
|
||||
context *proto.Context
|
||||
}
|
||||
|
||||
func newSqlQueryStats(methodName string) *sqlQueryStats {
|
||||
|
@ -94,9 +96,9 @@ func (stats *sqlQueryStats) FmtBindVariables() string {
|
|||
for k, v := range stats.BindVariables {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
scrubbed[k] = "string " + string(len(val))
|
||||
scrubbed[k] = fmt.Sprintf("string %v", len(val))
|
||||
case []byte:
|
||||
scrubbed[k] = "bytes " + string(len(val))
|
||||
scrubbed[k] = fmt.Sprintf("bytes %v", len(val))
|
||||
default:
|
||||
scrubbed[k] = v
|
||||
}
|
||||
|
@ -133,11 +135,21 @@ func (stats *sqlQueryStats) FmtQuerySources() string {
|
|||
return strings.Join(sources[:n], ",")
|
||||
}
|
||||
|
||||
func (log sqlQueryStats) RemoteAddr() string {
|
||||
return log.context.RemoteAddr
|
||||
}
|
||||
|
||||
func (log sqlQueryStats) Username() string {
|
||||
return log.context.Username
|
||||
}
|
||||
|
||||
//String returns a tab separated list of logged fields.
|
||||
func (log sqlQueryStats) String() string {
|
||||
return fmt.Sprintf(
|
||||
"%v\t%v\t%v\t%v\t%v\t%q\t%v\t%v\t%q\t%v\t%v\t%v\t%v\t%v\t%v\t%v\t",
|
||||
"%v\t%v\t%v\t%v\t%v\t%v\t%v\t%q\t%v\t%v\t%q\t%v\t%v\t%v\t%v\t%v\t%v\t%v\t",
|
||||
log.Method,
|
||||
log.RemoteAddr(),
|
||||
log.Username(),
|
||||
log.StartTime,
|
||||
log.EndTime,
|
||||
log.TotalTime(),
|
||||
|
|
|
@ -59,8 +59,11 @@ class TestAuthentication(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def tearDownVTOCC(klass):
|
||||
klass.process.kill()
|
||||
klass.process.wait()
|
||||
try:
|
||||
klass.process.kill()
|
||||
klass.process.wait()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
return self.conn.client.call(*args, **kwargs)
|
||||
|
@ -112,7 +115,7 @@ if __name__=="__main__":
|
|||
|
||||
try:
|
||||
TestAuthentication.setUpVTOCC(options.dbconfig, options.auth_credentials)
|
||||
unittest.main()
|
||||
unittest.main(argv=["auth_test.py"])
|
||||
finally:
|
||||
print "Waiting for vtocc to terminate...",
|
||||
TestAuthentication.tearDownVTOCC()
|
||||
|
|
|
@ -14,6 +14,8 @@ class Log(object):
|
|||
def __init__(self, line):
|
||||
self.line = line
|
||||
(self.method,
|
||||
self.remote_address,
|
||||
self.username,
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
self.total_time,
|
||||
|
@ -99,6 +101,10 @@ class Log(object):
|
|||
if rewritten != self.rewritten_sql:
|
||||
self.fail("Bad rewritten SQL", rewritten, self.rewritten_sql)
|
||||
|
||||
def check_remote_address(self, case):
|
||||
if not self.remote_address.startswith(case.remote_address):
|
||||
return self.fail("Bad RemoteAddr", case.remote_address, self.remote_address)
|
||||
|
||||
def check_number_of_queries(self, case):
|
||||
if case.rewritten is not None and int(self.number_of_queries) != len(case.rewritten):
|
||||
return self.fail("wrong number of queries", len(case.rewritten), int(self.number_of_queries))
|
||||
|
@ -106,7 +112,8 @@ class Log(object):
|
|||
class Case(object):
|
||||
def __init__(self, sql, bindings=None, result=None, rewritten=None, doc='',
|
||||
cache_table="vtocc_cached", query_plan=None, cache_hits=None,
|
||||
cache_misses=None, cache_absent=None, cache_invalidations=None):
|
||||
cache_misses=None, cache_absent=None, cache_invalidations=None,
|
||||
remote_address="[::1]"):
|
||||
# For all cache_* parameters, a number n means "check this value
|
||||
# is exactly n," while None means "I am not interested in this
|
||||
# value, leave it alone."
|
||||
|
@ -123,6 +130,7 @@ class Case(object):
|
|||
self.cache_misses = cache_misses
|
||||
self.cache_absent = cache_absent
|
||||
self.cache_invalidations = cache_invalidations
|
||||
self.remote_address = remote_address
|
||||
|
||||
def normalizelog(self, data):
|
||||
return [line.split("INFO: ")[-1]
|
||||
|
|
|
@ -95,6 +95,9 @@ cases = [
|
|||
sql='select /* order */ * from vtocc_a order by id desc',
|
||||
result=[(1L, 2L, 'bcde', 'fghi'), (1L, 1L, 'abcd', 'efgh')],
|
||||
rewritten='select /* order */ * from vtocc_a order by id desc limit 10001'),
|
||||
Case(doc='string in bindings are not shown in logs',
|
||||
sql='select /* limit */ %(somestring)s, eid, id from vtocc_a limit %(a)s',
|
||||
bindings={"somestring": "Ala ma kota.", "a": 1}),
|
||||
|
||||
MultiCase(
|
||||
'simple insert',
|
||||
|
|
Загрузка…
Ссылка в новой задаче