diff --git a/go/mysql/conn_params.go b/go/mysql/conn_params.go index d395634661..488cca1019 100644 --- a/go/mysql/conn_params.go +++ b/go/mysql/conn_params.go @@ -30,11 +30,12 @@ type ConnParams struct { // The following SSL flags are only used when flags |= 2048 // is set (CapabilityClientSSL). - SslCa string `json:"ssl_ca"` - SslCaPath string `json:"ssl_ca_path"` - SslCert string `json:"ssl_cert"` - SslKey string `json:"ssl_key"` - ServerName string `json:"server_name"` + SslCa string `json:"ssl_ca"` + SslCaPath string `json:"ssl_ca_path"` + SslCert string `json:"ssl_cert"` + SslKey string `json:"ssl_key"` + ServerName string `json:"server_name"` + ConnectTimeoutMs uint64 `json:"connect_timeout_ms"` // The following is only set when the deprecated "dbname" flags are // supplied and will be removed. diff --git a/go/vt/dbconfigs/dbconfigs.go b/go/vt/dbconfigs/dbconfigs.go index c9df1ab26c..1c0ee4173b 100644 --- a/go/vt/dbconfigs/dbconfigs.go +++ b/go/vt/dbconfigs/dbconfigs.go @@ -103,7 +103,7 @@ func registerBaseFlags() { flag.StringVar(&baseConfig.SslCert, "db_ssl_cert", "", "connection ssl certificate") flag.StringVar(&baseConfig.SslKey, "db_ssl_key", "", "connection ssl key") flag.StringVar(&baseConfig.ServerName, "db_server_name", "", "server name of the DB we are connecting to.") - + flag.Uint64Var(&baseConfig.ConnectTimeoutMs, "db_connect_timeout_ms", 0, "connection timeout to mysqld in milliseconds (0 for no timeout)") } // The flags will change the global singleton @@ -287,6 +287,7 @@ func Init(defaultSocketFile string) (*DBConfigs, error) { uc.param.SslKey = baseConfig.SslKey uc.param.ServerName = baseConfig.ServerName } + uc.param.ConnectTimeoutMs = baseConfig.ConnectTimeoutMs } // See if the CredentialsServer is working. We do not use the diff --git a/go/vt/dbconfigs/dbconfigs_test.go b/go/vt/dbconfigs/dbconfigs_test.go index d20b50539f..1917b1b375 100644 --- a/go/vt/dbconfigs/dbconfigs_test.go +++ b/go/vt/dbconfigs/dbconfigs_test.go @@ -261,6 +261,60 @@ func TestInit(t *testing.T) { } } +func TestInitTimeout(t *testing.T) { + f := saveDBConfigs() + defer f() + + baseConfig = mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "b", + Pass: "c", + DbName: "d", + UnixSocket: "e", + Charset: "f", + Flags: 2, + Flavor: "flavor", + ConnectTimeoutMs: 250, + } + dbConfigs = DBConfigs{ + userConfigs: map[string]*userConfig{ + App: { + param: mysql.ConnParams{ + Uname: "app", + Pass: "apppass", + }, + }, + }, + } + + dbc, err := Init("default") + if err != nil { + t.Fatal(err) + } + want := &DBConfigs{ + userConfigs: map[string]*userConfig{ + App: { + param: mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "app", + Pass: "apppass", + UnixSocket: "e", + Charset: "f", + Flags: 2, + Flavor: "flavor", + ConnectTimeoutMs: 250, + }, + }, + }, + } + + if !reflect.DeepEqual(dbc.userConfigs[App].param, want.userConfigs[App].param) { + t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[App].param, want.userConfigs[App].param) + } +} + func TestAccessors(t *testing.T) { dbc := &DBConfigs{ userConfigs: map[string]*userConfig{ diff --git a/go/vt/dbconnpool/connection.go b/go/vt/dbconnpool/connection.go index 294a731923..e1a2890a41 100644 --- a/go/vt/dbconnpool/connection.go +++ b/go/vt/dbconnpool/connection.go @@ -124,6 +124,11 @@ func NewDBConnection(info *mysql.ConnParams, mysqlStats *stats.Timings) (*DBConn return nil, err } ctx := context.Background() + if info.ConnectTimeoutMs != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(info.ConnectTimeoutMs)*time.Millisecond) + defer cancel() + } c, err := mysql.Connect(ctx, params) if err != nil { mysqlStats.Record("ConnectError", start)