diff --git a/go/cmd/vtgateclienttest/main.go b/go/cmd/vtgateclienttest/main.go index 4a58a43aa2..f2bdd97dfc 100644 --- a/go/cmd/vtgateclienttest/main.go +++ b/go/cmd/vtgateclienttest/main.go @@ -27,10 +27,12 @@ func main() { servenv.Init() // The implementation chain. - s := services.CreateServices() - for _, f := range vtgate.RegisterVTGates { - f(s) - } + servenv.OnRun(func() { + s := services.CreateServices() + for _, f := range vtgate.RegisterVTGates { + f(s) + } + }) servenv.RunDefault() } diff --git a/go/vt/servenv/grpc_server.go b/go/vt/servenv/grpc_server.go index 8268a4ffd3..43c940176b 100644 --- a/go/vt/servenv/grpc_server.go +++ b/go/vt/servenv/grpc_server.go @@ -47,17 +47,31 @@ var ( GRPCServer *grpc.Server ) +// isGRPCEnabled returns true if gRPC server is set +func isGRPCEnabled() bool { + if GRPCPort != nil && *GRPCPort != 0 { + return true + } + + if SocketFile != nil && *SocketFile != "" { + return true + } + + return false +} + // createGRPCServer create the gRPC server we will be using. // It has to be called after flags are parsed, but before // services register themselves. func createGRPCServer() { // skip if not registered - if GRPCPort == nil || *GRPCPort == 0 { + if !isGRPCEnabled() { + log.Infof("Skipping gRPC server creation") return } var opts []grpc.ServerOption - if *GRPCCert != "" && *GRPCKey != "" { + if GRPCPort != nil && *GRPCCert != "" && *GRPCKey != "" { config := &tls.Config{} // load the server cert and key @@ -120,7 +134,7 @@ func RegisterGRPCFlags() { func GRPCCheckServiceMap(name string) bool { // Silently fail individual services if gRPC is not enabled in // the first place (either on a grpc port or on the socket file) - if (GRPCPort == nil || *GRPCPort == 0) && (SocketFile == nil || *SocketFile == "") { + if !isGRPCEnabled() { return false } diff --git a/go/vt/servenv/run.go b/go/vt/servenv/run.go index 26a96ff765..3b5f34bfa8 100644 --- a/go/vt/servenv/run.go +++ b/go/vt/servenv/run.go @@ -22,6 +22,7 @@ func Run(port int) { createGRPCServer() onRunHooks.Fire() serveGRPC() + serveSocketFile() l, err := proc.Listen(fmt.Sprintf("%v", port)) if err != nil { diff --git a/go/vt/servenv/unix_socket.go b/go/vt/servenv/unix_socket.go index f1030ee1d5..03ee83e416 100644 --- a/go/vt/servenv/unix_socket.go +++ b/go/vt/servenv/unix_socket.go @@ -13,16 +13,18 @@ import ( ) var ( - // The flags used when calling RegisterDefaultSocketFileFlags. + // SocketFile has the flag used when calling + // RegisterDefaultSocketFileFlags. SocketFile *string ) // serveSocketFile listen to the named socket and serves RPCs on it. -func serveSocketFile(name string) { - if name == "" { +func serveSocketFile() { + if SocketFile == nil || *SocketFile == "" { log.Infof("Not listening on socket file") return } + name := *SocketFile // try to delete if file exists if _, err := os.Stat(name); err == nil { @@ -41,12 +43,7 @@ func serveSocketFile(name string) { } // RegisterDefaultSocketFileFlags registers the default flags for listening -// to a socket. It also registers an OnRun callback to enable the listening -// socket. -// This needs to be called before flags are parsed. +// to a socket. This needs to be called before flags are parsed. func RegisterDefaultSocketFileFlags() { SocketFile = flag.String("socket_file", "", "Local unix socket file to listen on") - OnRun(func() { - serveSocketFile(*SocketFile) - }) } diff --git a/test/tablet.py b/test/tablet.py index 7e2ba5e1ad..81da188e5c 100644 --- a/test/tablet.py +++ b/test/tablet.py @@ -597,11 +597,15 @@ class Tablet(object): mysql_sock = os.path.join(self.tablet_dir, 'mysql.sock') mysqlctl_sock = os.path.join(self.tablet_dir, 'mysqlctl.sock') while True: - if os.path.exists(mysql_sock) and os.path.exists(mysqlctl_sock): + wait_for = [] + if not os.path.exists(mysql_sock): + wait_for.append(mysql_sock) + if not os.path.exists(mysqlctl_sock): + wait_for.append(mysqlctl_sock) + if not wait_for: return - timeout = utils.wait_step( - 'waiting for mysql and mysqlctl socket files: %s %s' % - (mysql_sock, mysqlctl_sock), timeout) + timeout = utils.wait_step('waiting for socket files: %s' % str(wait_for), + timeout, sleep_time=2.0) def _add_dbconfigs(self, args, repl_extra_flags=None): if repl_extra_flags is None: