зеркало из https://github.com/github/vitess-gh.git
test use of ldap client via mock
This commit is contained in:
Родитель
8fdd420c40
Коммит
1418695b7c
|
@ -3,9 +3,5 @@
|
|||
"ldapCert": "path/to/ldap-client-cert.pem",
|
||||
"ldapKey": "path/to/ldap-client-key.pem",
|
||||
"ldapCA": "path/to/ldap-client-ca.pem",
|
||||
"queryUser": "vitessLdapROUser",
|
||||
"queryPassword": "vitessLdapROUserPassword",
|
||||
"queryStr": "uid=%s,ou=users,ou=people,dc=example,dc=com",
|
||||
"getGroups": true,
|
||||
"groupQueryStr": "ou=groups,ou=people,dc=example,dc=com"
|
||||
"userDnPattern": "uid=%s,ou=users,ou=people,dc=example,dc=com"
|
||||
}
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
package ldapauthserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/mysqlconn"
|
||||
"github.com/youtube/vitess/go/netutil"
|
||||
"github.com/youtube/vitess/go/vt/servenv/grpcutils"
|
||||
"gopkg.in/ldap.v2"
|
||||
)
|
||||
|
@ -21,17 +19,32 @@ var (
|
|||
)
|
||||
|
||||
// AuthServerLdap implements AuthServer with an LDAP backend
|
||||
// * include port in ldapServer, "ldap.example.com:386"
|
||||
type AuthServerLdap struct {
|
||||
Config *AuthServerLdapConfig
|
||||
Client LdapClient
|
||||
}
|
||||
|
||||
// AuthServerLdapConfig holds the config for AuthServerLdap
|
||||
// * include port in ldapServer, "ldap.example.com:386"
|
||||
type AuthServerLdapConfig struct {
|
||||
ldapServer string
|
||||
ldapCert string
|
||||
ldapKey string
|
||||
ldapCA string
|
||||
queryUser string
|
||||
queryPassword string
|
||||
queryStr string
|
||||
getGroups bool
|
||||
groupQueryStr string
|
||||
userDnPattern string
|
||||
}
|
||||
|
||||
// LdapClient abstracts the call to Dial so we can mock it
|
||||
type LdapClient interface {
|
||||
Dial(network, server string) (ldap.Client, error)
|
||||
}
|
||||
|
||||
// LdapClientImpl is the real implementation of LdapClient
|
||||
type LdapClientImpl struct{}
|
||||
|
||||
// Dial calls the ldap.v2 library's Dial
|
||||
func (lci *LdapClientImpl) Dial(network, server string) (ldap.Client, error) {
|
||||
return ldap.Dial(network, server)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -43,7 +56,7 @@ func init() {
|
|||
log.Infof("Both mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are non-empty, can only use one.")
|
||||
return
|
||||
}
|
||||
ldapAuthServer := newAuthServerLdap()
|
||||
ldapAuthServer := &AuthServerLdap{Config: &AuthServerLdapConfig{}, Client: &LdapClientImpl{}}
|
||||
|
||||
data := []byte(*ldapAuthConfigString)
|
||||
if *ldapAuthConfigFile != "" {
|
||||
|
@ -53,16 +66,12 @@ func init() {
|
|||
log.Fatalf("Failed to read mysql_ldap_auth_config_file: %v", err)
|
||||
}
|
||||
}
|
||||
if err := json.Unmarshal(data, &ldapAuthServer); err != nil {
|
||||
if err := json.Unmarshal(data, ldapAuthServer.Config); err != nil {
|
||||
log.Fatalf("Error parsing AuthServerLdap config: %v", err)
|
||||
}
|
||||
mysqlconn.RegisterAuthServerImpl("ldap", ldapAuthServer)
|
||||
}
|
||||
|
||||
func newAuthServerLdap() *AuthServerLdap {
|
||||
return &AuthServerLdap{}
|
||||
}
|
||||
|
||||
// UseClearText is always true for AuthServerLdap
|
||||
func (asl *AuthServerLdap) UseClearText() bool {
|
||||
return true
|
||||
|
@ -78,23 +87,22 @@ func (asl *AuthServerLdap) ValidateHash(salt []byte, user string, authResponse [
|
|||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// ValidateClearText connects to the LDAP server over TLS,
|
||||
// searches for the user, attempts to bind as that user with the supplied password,
|
||||
// and, if so configured, queries for the user's groups, returning them in a
|
||||
// comma-separated string.
|
||||
// In reality, it runs whatever queries are supplied. See data/test/mysql_ldap_auth_config.json for an example.
|
||||
// It is recommended that queryUser have read-only privileges on ldapServer
|
||||
// ValidateClearText connects to the LDAP server over TLS
|
||||
// and attempts to bind as that user with the supplied password.
|
||||
// It returns the supplied username.
|
||||
func (asl *AuthServerLdap) ValidateClearText(username, password string) (string, error) {
|
||||
conn, err := ldap.Dial("tcp", asl.ldapServer)
|
||||
conn, err := asl.Client.Dial("tcp", asl.Config.ldapServer)
|
||||
defer conn.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Reconnect with TLS
|
||||
idx := strings.LastIndex(asl.ldapServer, ":") // allow users to (incorrectly) specify ipv6 without []
|
||||
serverName := asl.ldapServer[:idx]
|
||||
tlsConfig, err := grpcutils.TLSClientConfig(asl.ldapCert, asl.ldapKey, asl.ldapCA, serverName)
|
||||
// Reconnect with TLS ... why don't we simply DialTLS directly?
|
||||
serverName, _, err := netutil.SplitHostPort(asl.Config.ldapServer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tlsConfig, err := grpcutils.TLSClientConfig(asl.Config.ldapCert, asl.Config.ldapKey, asl.Config.ldapCA, serverName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -104,65 +112,9 @@ func (asl *AuthServerLdap) ValidateClearText(username, password string) (string,
|
|||
}
|
||||
|
||||
// queryUser can be read-only
|
||||
err = conn.Bind(asl.queryUser, asl.queryPassword)
|
||||
err = conn.Bind(fmt.Sprintf(asl.Config.userDnPattern, username), password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Search for the given username
|
||||
req := ldap.NewSearchRequest(
|
||||
fmt.Sprintf(asl.queryStr, username),
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
"(objectClass=organizationalPerson)",
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
)
|
||||
|
||||
res, err := conn.Search(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(res.Entries) != 1 {
|
||||
return "", errors.New("User does not exist or too many entries returned")
|
||||
}
|
||||
|
||||
userdn := res.Entries[0].DN
|
||||
|
||||
// Bind as the user to verify their password
|
||||
err = conn.Bind(userdn, password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !asl.getGroups {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Rebind as the query user for group query
|
||||
err = conn.Bind(asl.queryUser, asl.queryPassword)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req = ldap.NewSearchRequest(
|
||||
asl.groupQueryStr,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(memberUid=%s)", username),
|
||||
[]string{"cn"},
|
||||
nil,
|
||||
)
|
||||
res, err = conn.Search(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
sep := ""
|
||||
for _, entry := range res.Entries {
|
||||
for _, attr := range entry.Attributes {
|
||||
buffer.WriteString(sep)
|
||||
buffer.WriteString(attr.Values[0])
|
||||
sep = ","
|
||||
}
|
||||
}
|
||||
return buffer.String(), nil
|
||||
return username, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
package ldapauthserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/ldap.v2"
|
||||
)
|
||||
|
||||
type MockLdapClient struct{}
|
||||
|
||||
func (mlc *MockLdapClient) Dial(network, server string) (ldap.Client, error) {
|
||||
return &MockLdapConn{}, nil
|
||||
}
|
||||
|
||||
type MockLdapConn struct{}
|
||||
|
||||
func (mlc *MockLdapConn) StartTLS(config *tls.Config) error { return nil }
|
||||
func (mlc *MockLdapConn) Bind(username, password string) error {
|
||||
if username != "testuser" || password != "testpass" {
|
||||
return fmt.Errorf("invalid credentials: %s, %s", username, password)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (mlc *MockLdapConn) Close() {}
|
||||
|
||||
func (mlc *MockLdapConn) Start() { panic("unimpl") }
|
||||
func (mlc *MockLdapConn) SetTimeout(time.Duration) { panic("unimpl") }
|
||||
func (mlc *MockLdapConn) SimpleBind(simpleBindRequest *ldap.SimpleBindRequest) (*ldap.SimpleBindResult, error) {
|
||||
panic("unimpl")
|
||||
}
|
||||
func (mlc *MockLdapConn) Add(addRequest *ldap.AddRequest) error { panic("unimpl") }
|
||||
func (mlc *MockLdapConn) Del(delRequest *ldap.DelRequest) error { panic("unimpl") }
|
||||
func (mlc *MockLdapConn) Modify(modifyRequest *ldap.ModifyRequest) error { panic("unimpl") }
|
||||
func (mlc *MockLdapConn) Compare(dn, attribute, value string) (bool, error) { panic("unimpl") }
|
||||
func (mlc *MockLdapConn) PasswordModify(passwordModifyRequest *ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) {
|
||||
panic("unimpl")
|
||||
}
|
||||
func (mlc *MockLdapConn) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
|
||||
panic("unimpl")
|
||||
}
|
||||
func (mlc *MockLdapConn) SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) {
|
||||
panic("unimpl")
|
||||
}
|
||||
|
||||
func TestValidateClearText(t *testing.T) {
|
||||
mockLdapConfig := &AuthServerLdapConfig{ldapServer: "ldap.test.com:386", userDnPattern: "%s"}
|
||||
asl := &AuthServerLdap{Config: mockLdapConfig, Client: &MockLdapClient{}}
|
||||
_, err := asl.ValidateClearText("testuser", "testpass")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthServerLdap failed to validate valid credentials. Got: %v", err)
|
||||
}
|
||||
|
||||
_, err = asl.ValidateClearText("invaliduser", "invalidpass")
|
||||
if err == nil {
|
||||
t.Fatalf("AuthServerLdap validated invalid credentials.")
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче