diff --git a/data/test/mysql_ldap_auth_config.json b/data/test/mysql_ldap_auth_config.json index c6958d2df5..45f3e38590 100644 --- a/data/test/mysql_ldap_auth_config.json +++ b/data/test/mysql_ldap_auth_config.json @@ -3,9 +3,5 @@ "ldapCert": "path/to/ldap-client-cert.pem", "ldapKey": "path/to/ldap-client-key.pem", "ldapCA": "path/to/ldap-client-ca.pem", - "queryUser": "vitessLdapROUser", - "queryPassword": "vitessLdapROUserPassword", - "queryStr": "uid=%s,ou=users,ou=people,dc=example,dc=com", - "getGroups": true, - "groupQueryStr": "ou=groups,ou=people,dc=example,dc=com" + "userDnPattern": "uid=%s,ou=users,ou=people,dc=example,dc=com" } diff --git a/go/mysqlconn/ldapauthserver/auth_server_ldap.go b/go/mysqlconn/ldapauthserver/auth_server_ldap.go index ead32b443c..a3eb9f6199 100644 --- a/go/mysqlconn/ldapauthserver/auth_server_ldap.go +++ b/go/mysqlconn/ldapauthserver/auth_server_ldap.go @@ -1,16 +1,14 @@ package ldapauthserver import ( - "bytes" "encoding/json" - "errors" "flag" "fmt" "io/ioutil" - "strings" log "github.com/golang/glog" "github.com/youtube/vitess/go/mysqlconn" + "github.com/youtube/vitess/go/netutil" "github.com/youtube/vitess/go/vt/servenv/grpcutils" "gopkg.in/ldap.v2" ) @@ -21,17 +19,32 @@ var ( ) // AuthServerLdap implements AuthServer with an LDAP backend -// * include port in ldapServer, "ldap.example.com:386" type AuthServerLdap struct { + Config *AuthServerLdapConfig + Client LdapClient +} + +// AuthServerLdapConfig holds the config for AuthServerLdap +// * include port in ldapServer, "ldap.example.com:386" +type AuthServerLdapConfig struct { ldapServer string ldapCert string ldapKey string ldapCA string - queryUser string - queryPassword string - queryStr string - getGroups bool - groupQueryStr string + userDnPattern string +} + +// LdapClient abstracts the call to Dial so we can mock it +type LdapClient interface { + Dial(network, server string) (ldap.Client, error) +} + +// LdapClientImpl is the real implementation of LdapClient +type LdapClientImpl struct{} + +// Dial calls the ldap.v2 library's Dial +func (lci *LdapClientImpl) Dial(network, server string) (ldap.Client, error) { + return ldap.Dial(network, server) } func init() { @@ -43,7 +56,7 @@ func init() { log.Infof("Both mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are non-empty, can only use one.") return } - ldapAuthServer := newAuthServerLdap() + ldapAuthServer := &AuthServerLdap{Config: &AuthServerLdapConfig{}, Client: &LdapClientImpl{}} data := []byte(*ldapAuthConfigString) if *ldapAuthConfigFile != "" { @@ -53,16 +66,12 @@ func init() { log.Fatalf("Failed to read mysql_ldap_auth_config_file: %v", err) } } - if err := json.Unmarshal(data, &ldapAuthServer); err != nil { + if err := json.Unmarshal(data, ldapAuthServer.Config); err != nil { log.Fatalf("Error parsing AuthServerLdap config: %v", err) } mysqlconn.RegisterAuthServerImpl("ldap", ldapAuthServer) } -func newAuthServerLdap() *AuthServerLdap { - return &AuthServerLdap{} -} - // UseClearText is always true for AuthServerLdap func (asl *AuthServerLdap) UseClearText() bool { return true @@ -78,23 +87,22 @@ func (asl *AuthServerLdap) ValidateHash(salt []byte, user string, authResponse [ panic("unimplemented") } -// ValidateClearText connects to the LDAP server over TLS, -// searches for the user, attempts to bind as that user with the supplied password, -// and, if so configured, queries for the user's groups, returning them in a -// comma-separated string. -// In reality, it runs whatever queries are supplied. See data/test/mysql_ldap_auth_config.json for an example. -// It is recommended that queryUser have read-only privileges on ldapServer +// ValidateClearText connects to the LDAP server over TLS +// and attempts to bind as that user with the supplied password. +// It returns the supplied username. func (asl *AuthServerLdap) ValidateClearText(username, password string) (string, error) { - conn, err := ldap.Dial("tcp", asl.ldapServer) + conn, err := asl.Client.Dial("tcp", asl.Config.ldapServer) defer conn.Close() if err != nil { return "", err } - // Reconnect with TLS - idx := strings.LastIndex(asl.ldapServer, ":") // allow users to (incorrectly) specify ipv6 without [] - serverName := asl.ldapServer[:idx] - tlsConfig, err := grpcutils.TLSClientConfig(asl.ldapCert, asl.ldapKey, asl.ldapCA, serverName) + // Reconnect with TLS ... why don't we simply DialTLS directly? + serverName, _, err := netutil.SplitHostPort(asl.Config.ldapServer) + if err != nil { + return "", err + } + tlsConfig, err := grpcutils.TLSClientConfig(asl.Config.ldapCert, asl.Config.ldapKey, asl.Config.ldapCA, serverName) if err != nil { return "", err } @@ -104,65 +112,9 @@ func (asl *AuthServerLdap) ValidateClearText(username, password string) (string, } // queryUser can be read-only - err = conn.Bind(asl.queryUser, asl.queryPassword) + err = conn.Bind(fmt.Sprintf(asl.Config.userDnPattern, username), password) if err != nil { return "", err } - - // Search for the given username - req := ldap.NewSearchRequest( - fmt.Sprintf(asl.queryStr, username), - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, - "(objectClass=organizationalPerson)", - []string{"dn"}, - nil, - ) - - res, err := conn.Search(req) - if err != nil { - return "", err - } - - if len(res.Entries) != 1 { - return "", errors.New("User does not exist or too many entries returned") - } - - userdn := res.Entries[0].DN - - // Bind as the user to verify their password - err = conn.Bind(userdn, password) - if err != nil { - return "", err - } - if !asl.getGroups { - return "", nil - } - - // Rebind as the query user for group query - err = conn.Bind(asl.queryUser, asl.queryPassword) - if err != nil { - return "", err - } - - req = ldap.NewSearchRequest( - asl.groupQueryStr, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, - fmt.Sprintf("(memberUid=%s)", username), - []string{"cn"}, - nil, - ) - res, err = conn.Search(req) - if err != nil { - return "", err - } - var buffer bytes.Buffer - sep := "" - for _, entry := range res.Entries { - for _, attr := range entry.Attributes { - buffer.WriteString(sep) - buffer.WriteString(attr.Values[0]) - sep = "," - } - } - return buffer.String(), nil + return username, nil } diff --git a/go/mysqlconn/ldapauthserver/auth_server_ldap_test.go b/go/mysqlconn/ldapauthserver/auth_server_ldap_test.go new file mode 100644 index 0000000000..67d7b0747e --- /dev/null +++ b/go/mysqlconn/ldapauthserver/auth_server_ldap_test.go @@ -0,0 +1,60 @@ +package ldapauthserver + +import ( + "crypto/tls" + "fmt" + "testing" + "time" + + "gopkg.in/ldap.v2" +) + +type MockLdapClient struct{} + +func (mlc *MockLdapClient) Dial(network, server string) (ldap.Client, error) { + return &MockLdapConn{}, nil +} + +type MockLdapConn struct{} + +func (mlc *MockLdapConn) StartTLS(config *tls.Config) error { return nil } +func (mlc *MockLdapConn) Bind(username, password string) error { + if username != "testuser" || password != "testpass" { + return fmt.Errorf("invalid credentials: %s, %s", username, password) + } + return nil +} +func (mlc *MockLdapConn) Close() {} + +func (mlc *MockLdapConn) Start() { panic("unimpl") } +func (mlc *MockLdapConn) SetTimeout(time.Duration) { panic("unimpl") } +func (mlc *MockLdapConn) SimpleBind(simpleBindRequest *ldap.SimpleBindRequest) (*ldap.SimpleBindResult, error) { + panic("unimpl") +} +func (mlc *MockLdapConn) Add(addRequest *ldap.AddRequest) error { panic("unimpl") } +func (mlc *MockLdapConn) Del(delRequest *ldap.DelRequest) error { panic("unimpl") } +func (mlc *MockLdapConn) Modify(modifyRequest *ldap.ModifyRequest) error { panic("unimpl") } +func (mlc *MockLdapConn) Compare(dn, attribute, value string) (bool, error) { panic("unimpl") } +func (mlc *MockLdapConn) PasswordModify(passwordModifyRequest *ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) { + panic("unimpl") +} +func (mlc *MockLdapConn) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + panic("unimpl") +} +func (mlc *MockLdapConn) SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) { + panic("unimpl") +} + +func TestValidateClearText(t *testing.T) { + mockLdapConfig := &AuthServerLdapConfig{ldapServer: "ldap.test.com:386", userDnPattern: "%s"} + asl := &AuthServerLdap{Config: mockLdapConfig, Client: &MockLdapClient{}} + _, err := asl.ValidateClearText("testuser", "testpass") + if err != nil { + t.Fatalf("AuthServerLdap failed to validate valid credentials. Got: %v", err) + } + + _, err = asl.ValidateClearText("invaliduser", "invalidpass") + if err == nil { + t.Fatalf("AuthServerLdap validated invalid credentials.") + } +}