ssh: add server side multi-step authentication
Add support for sending back partial success to the client while handling authentication in the server. This is implemented by a special error that can be returned by any of the authentication methods, which contains the authentication methods to offer next. This patch is based on CL 399075 with some minor changes and the addition of test cases. Fixes golang/go#17889 Fixes golang/go#61447 Fixes golang/go#64974 Co-authored-by: Peter Verraedt <peter.verraedt@kuleuven.be> Change-Id: I05c8f913bb407d22c2e41c4cbe965e36ab4739b0 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/516355 Reviewed-by: Andrew Lytvynov <awly@tailscale.com> Reviewed-by: Than McIntosh <thanm@google.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Filippo Valsorda <filippo@golang.org> Auto-Submit: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Родитель
8d0d405eed
Коммит
6f79b5a20c
120
ssh/server.go
120
ssh/server.go
|
@ -426,6 +426,35 @@ func (l ServerAuthError) Error() string {
|
|||
return "[" + strings.Join(errs, ", ") + "]"
|
||||
}
|
||||
|
||||
// ServerAuthCallbacks defines server-side authentication callbacks.
|
||||
type ServerAuthCallbacks struct {
|
||||
// PasswordCallback behaves like [ServerConfig.PasswordCallback].
|
||||
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
|
||||
|
||||
// PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback].
|
||||
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
|
||||
|
||||
// KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback].
|
||||
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
|
||||
|
||||
// GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig].
|
||||
GSSAPIWithMICConfig *GSSAPIWithMICConfig
|
||||
}
|
||||
|
||||
// PartialSuccessError can be returned by any of the [ServerConfig]
|
||||
// authentication callbacks to indicate to the client that authentication has
|
||||
// partially succeeded, but further steps are required.
|
||||
type PartialSuccessError struct {
|
||||
// Next defines the authentication callbacks to apply to further steps. The
|
||||
// available methods communicated to the client are based on the non-nil
|
||||
// ServerAuthCallbacks fields.
|
||||
Next ServerAuthCallbacks
|
||||
}
|
||||
|
||||
func (p *PartialSuccessError) Error() string {
|
||||
return "ssh: authenticated with partial success"
|
||||
}
|
||||
|
||||
// ErrNoAuth is the error value returned if no
|
||||
// authentication method has been passed yet. This happens as a normal
|
||||
// part of the authentication loop, since the client first tries
|
||||
|
@ -441,6 +470,15 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
|
|||
authFailures := 0
|
||||
var authErrs []error
|
||||
var displayedBanner bool
|
||||
partialSuccessReturned := false
|
||||
// Set the initial authentication callbacks from the config. They can be
|
||||
// changed if a PartialSuccessError is returned.
|
||||
authConfig := ServerAuthCallbacks{
|
||||
PasswordCallback: config.PasswordCallback,
|
||||
PublicKeyCallback: config.PublicKeyCallback,
|
||||
KeyboardInteractiveCallback: config.KeyboardInteractiveCallback,
|
||||
GSSAPIWithMICConfig: config.GSSAPIWithMICConfig,
|
||||
}
|
||||
|
||||
userAuthLoop:
|
||||
for {
|
||||
|
@ -471,6 +509,11 @@ userAuthLoop:
|
|||
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
|
||||
}
|
||||
|
||||
if s.user != userAuthReq.User && partialSuccessReturned {
|
||||
return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q",
|
||||
s.user, userAuthReq.User)
|
||||
}
|
||||
|
||||
s.user = userAuthReq.User
|
||||
|
||||
if !displayedBanner && config.BannerCallback != nil {
|
||||
|
@ -491,20 +534,17 @@ userAuthLoop:
|
|||
|
||||
switch userAuthReq.Method {
|
||||
case "none":
|
||||
if config.NoClientAuth {
|
||||
// We don't allow none authentication after a partial success
|
||||
// response.
|
||||
if config.NoClientAuth && !partialSuccessReturned {
|
||||
if config.NoClientAuthCallback != nil {
|
||||
perms, authErr = config.NoClientAuthCallback(s)
|
||||
} else {
|
||||
authErr = nil
|
||||
}
|
||||
}
|
||||
|
||||
// allow initial attempt of 'none' without penalty
|
||||
if authFailures == 0 {
|
||||
authFailures--
|
||||
}
|
||||
case "password":
|
||||
if config.PasswordCallback == nil {
|
||||
if authConfig.PasswordCallback == nil {
|
||||
authErr = errors.New("ssh: password auth not configured")
|
||||
break
|
||||
}
|
||||
|
@ -518,17 +558,17 @@ userAuthLoop:
|
|||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
|
||||
perms, authErr = config.PasswordCallback(s, password)
|
||||
perms, authErr = authConfig.PasswordCallback(s, password)
|
||||
case "keyboard-interactive":
|
||||
if config.KeyboardInteractiveCallback == nil {
|
||||
if authConfig.KeyboardInteractiveCallback == nil {
|
||||
authErr = errors.New("ssh: keyboard-interactive auth not configured")
|
||||
break
|
||||
}
|
||||
|
||||
prompter := &sshClientKeyboardInteractive{s}
|
||||
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
|
||||
perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge)
|
||||
case "publickey":
|
||||
if config.PublicKeyCallback == nil {
|
||||
if authConfig.PublicKeyCallback == nil {
|
||||
authErr = errors.New("ssh: publickey auth not configured")
|
||||
break
|
||||
}
|
||||
|
@ -562,11 +602,18 @@ userAuthLoop:
|
|||
if !ok {
|
||||
candidate.user = s.user
|
||||
candidate.pubKeyData = pubKeyData
|
||||
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
|
||||
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
|
||||
candidate.result = checkSourceAddress(
|
||||
candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
|
||||
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
|
||||
|
||||
if (candidate.result == nil || isPartialSuccessError) &&
|
||||
candidate.perms != nil &&
|
||||
candidate.perms.CriticalOptions != nil &&
|
||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
|
||||
if err := checkSourceAddress(
|
||||
s.RemoteAddr(),
|
||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption])
|
||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil {
|
||||
candidate.result = err
|
||||
}
|
||||
}
|
||||
cache.add(candidate)
|
||||
}
|
||||
|
@ -578,8 +625,8 @@ userAuthLoop:
|
|||
if len(payload) > 0 {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
|
||||
if candidate.result == nil {
|
||||
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
|
||||
if candidate.result == nil || isPartialSuccessError {
|
||||
okMsg := userAuthPubKeyOkMsg{
|
||||
Algo: algo,
|
||||
PubKey: pubKeyData,
|
||||
|
@ -629,11 +676,11 @@ userAuthLoop:
|
|||
perms = candidate.perms
|
||||
}
|
||||
case "gssapi-with-mic":
|
||||
if config.GSSAPIWithMICConfig == nil {
|
||||
if authConfig.GSSAPIWithMICConfig == nil {
|
||||
authErr = errors.New("ssh: gssapi-with-mic auth not configured")
|
||||
break
|
||||
}
|
||||
gssapiConfig := config.GSSAPIWithMICConfig
|
||||
gssapiConfig := authConfig.GSSAPIWithMICConfig
|
||||
userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
|
||||
if err != nil {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
|
@ -689,7 +736,28 @@ userAuthLoop:
|
|||
break userAuthLoop
|
||||
}
|
||||
|
||||
var failureMsg userAuthFailureMsg
|
||||
|
||||
if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
|
||||
// After a partial success error we don't allow changing the user
|
||||
// name and execute the NoClientAuthCallback.
|
||||
partialSuccessReturned = true
|
||||
|
||||
// In case a partial success is returned, the server may send
|
||||
// a new set of authentication methods.
|
||||
authConfig = partialSuccess.Next
|
||||
|
||||
// Reset pubkey cache, as the new PublicKeyCallback might
|
||||
// accept a different set of public keys.
|
||||
cache = pubKeyCache{}
|
||||
|
||||
// Send back a partial success message to the user.
|
||||
failureMsg.PartialSuccess = true
|
||||
} else {
|
||||
// Allow initial attempt of 'none' without penalty.
|
||||
if authFailures > 0 || userAuthReq.Method != "none" {
|
||||
authFailures++
|
||||
}
|
||||
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
|
||||
// If we have hit the max attempts, don't bother sending the
|
||||
// final SSH_MSG_USERAUTH_FAILURE message, since there are
|
||||
|
@ -714,24 +782,24 @@ userAuthLoop:
|
|||
// to match that behavior.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var failureMsg userAuthFailureMsg
|
||||
if config.PasswordCallback != nil {
|
||||
if authConfig.PasswordCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "password")
|
||||
}
|
||||
if config.PublicKeyCallback != nil {
|
||||
if authConfig.PublicKeyCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "publickey")
|
||||
}
|
||||
if config.KeyboardInteractiveCallback != nil {
|
||||
if authConfig.KeyboardInteractiveCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
|
||||
}
|
||||
if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
|
||||
config.GSSAPIWithMICConfig.AllowLogin != nil {
|
||||
if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil &&
|
||||
authConfig.GSSAPIWithMICConfig.AllowLogin != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
|
||||
}
|
||||
|
||||
if len(failureMsg.Methods) == 0 {
|
||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
|
||||
return nil, errors.New("ssh: no authentication methods available")
|
||||
}
|
||||
|
||||
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
|
||||
|
|
|
@ -0,0 +1,412 @@
|
|||
// Copyright 2024 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
|
||||
c1, c2, err := netPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("netPipe: %v", err)
|
||||
}
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
var serverAuthErrors []error
|
||||
|
||||
serverConfig.AddHostKey(testSigners["rsa"])
|
||||
serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
|
||||
serverAuthErrors = append(serverAuthErrors, err)
|
||||
}
|
||||
go newServer(c1, serverConfig)
|
||||
c, _, _, err := NewClientConn(c2, "", clientConfig)
|
||||
if err == nil {
|
||||
c.Close()
|
||||
}
|
||||
return serverAuthErrors, err
|
||||
}
|
||||
|
||||
func TestMultiStepAuth(t *testing.T) {
|
||||
// This user can login with password, public key or public key + password.
|
||||
username := "testuser"
|
||||
// This user can login with public key + password only.
|
||||
usernameSecondFactor := "testuser_second_factor"
|
||||
errPwdAuthFailed := errors.New("password auth failed")
|
||||
errWrongSequence := errors.New("wrong sequence")
|
||||
|
||||
serverConfig := &ServerConfig{
|
||||
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
|
||||
if conn.User() == usernameSecondFactor {
|
||||
return nil, errWrongSequence
|
||||
}
|
||||
if conn.User() == username && string(password) == clientPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errPwdAuthFailed
|
||||
},
|
||||
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
|
||||
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
|
||||
if conn.User() == usernameSecondFactor {
|
||||
return nil, &PartialSuccessError{
|
||||
Next: ServerAuthCallbacks{
|
||||
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
|
||||
if string(password) == clientPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errPwdAuthFailed
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
|
||||
},
|
||||
}
|
||||
|
||||
clientConfig := &ClientConfig{
|
||||
User: usernameSecondFactor,
|
||||
Auth: []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
Password(clientPassword),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("client login error: %s", err)
|
||||
}
|
||||
|
||||
// The error sequence is:
|
||||
// - no auth passed yet
|
||||
// - partial success
|
||||
// - nil
|
||||
if len(serverAuthErrors) != 3 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||||
t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
|
||||
}
|
||||
// Now test a wrong sequence.
|
||||
clientConfig.Auth = []AuthMethod{
|
||||
Password(clientPassword),
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err == nil {
|
||||
t.Fatal("client login with wrong sequence must fail")
|
||||
}
|
||||
// The error sequence is:
|
||||
// - no auth passed yet
|
||||
// - wrong sequence
|
||||
// - partial success
|
||||
if len(serverAuthErrors) != 3 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if serverAuthErrors[1] != errWrongSequence {
|
||||
t.Fatal("server not returned wrong sequence")
|
||||
}
|
||||
if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
|
||||
t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
|
||||
}
|
||||
// Now test using a correct sequence but a wrong password before the right
|
||||
// one.
|
||||
n := 0
|
||||
passwords := []string{"WRONG", "WRONG", clientPassword}
|
||||
clientConfig.Auth = []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
RetryableAuthMethod(PasswordCallback(func() (string, error) {
|
||||
p := passwords[n]
|
||||
n++
|
||||
return p, nil
|
||||
}), 3),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("client login error: %s", err)
|
||||
}
|
||||
// The error sequence is:
|
||||
// - no auth passed yet
|
||||
// - partial success
|
||||
// - wrong password
|
||||
// - wrong password
|
||||
// - nil
|
||||
if len(serverAuthErrors) != 5 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
if serverAuthErrors[2] != errPwdAuthFailed {
|
||||
t.Fatal("server not returned password authentication failed")
|
||||
}
|
||||
if serverAuthErrors[3] != errPwdAuthFailed {
|
||||
t.Fatal("server not returned password authentication failed")
|
||||
}
|
||||
// Only password authentication should fail.
|
||||
clientConfig.Auth = []AuthMethod{
|
||||
Password(clientPassword),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err == nil {
|
||||
t.Fatal("client login with password only must fail")
|
||||
}
|
||||
// The error sequence is:
|
||||
// - no auth passed yet
|
||||
// - wrong sequence
|
||||
if len(serverAuthErrors) != 2 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if serverAuthErrors[1] != errWrongSequence {
|
||||
t.Fatal("server not returned wrong sequence")
|
||||
}
|
||||
|
||||
// Only public key authentication should fail.
|
||||
clientConfig.Auth = []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err == nil {
|
||||
t.Fatal("client login with public key only must fail")
|
||||
}
|
||||
// The error sequence is:
|
||||
// - no auth passed yet
|
||||
// - partial success
|
||||
if len(serverAuthErrors) != 2 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
|
||||
// Public key and wrong password.
|
||||
clientConfig.Auth = []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
Password("WRONG"),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err == nil {
|
||||
t.Fatal("client login with wrong password after public key must fail")
|
||||
}
|
||||
// The error sequence is:
|
||||
// - no auth passed yet
|
||||
// - partial success
|
||||
// - password auth failed
|
||||
if len(serverAuthErrors) != 3 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
if serverAuthErrors[2] != errPwdAuthFailed {
|
||||
t.Fatal("server not returned password authentication failed")
|
||||
}
|
||||
|
||||
// Public key, public key again and then correct password. Public key
|
||||
// authentication is attempted only once because the partial success error
|
||||
// returns only "password" as the allowed authentication method.
|
||||
clientConfig.Auth = []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
Password(clientPassword),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("client login error: %s", err)
|
||||
}
|
||||
// The error sequence is:
|
||||
// - no auth passed yet
|
||||
// - partial success
|
||||
// - nil
|
||||
if len(serverAuthErrors) != 3 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
|
||||
// The unrestricted username can do anything
|
||||
clientConfig = &ClientConfig{
|
||||
User: username,
|
||||
Auth: []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
Password(clientPassword),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
_, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("unrestricted client login error: %s", err)
|
||||
}
|
||||
|
||||
clientConfig = &ClientConfig{
|
||||
User: username,
|
||||
Auth: []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
_, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("unrestricted client login error: %s", err)
|
||||
}
|
||||
|
||||
clientConfig = &ClientConfig{
|
||||
User: username,
|
||||
Auth: []AuthMethod{
|
||||
Password(clientPassword),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
_, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("unrestricted client login error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicAuthCallbacks(t *testing.T) {
|
||||
user1 := "user1"
|
||||
user2 := "user2"
|
||||
errInvalidCredentials := errors.New("invalid credentials")
|
||||
|
||||
serverConfig := &ServerConfig{
|
||||
NoClientAuth: true,
|
||||
NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
|
||||
switch conn.User() {
|
||||
case user1:
|
||||
return nil, &PartialSuccessError{
|
||||
Next: ServerAuthCallbacks{
|
||||
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
|
||||
if conn.User() == user1 && string(password) == clientPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errInvalidCredentials
|
||||
},
|
||||
},
|
||||
}
|
||||
case user2:
|
||||
return nil, &PartialSuccessError{
|
||||
Next: ServerAuthCallbacks{
|
||||
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
|
||||
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
|
||||
if conn.User() == user2 {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
return nil, errInvalidCredentials
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil, errInvalidCredentials
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
clientConfig := &ClientConfig{
|
||||
User: user1,
|
||||
Auth: []AuthMethod{
|
||||
Password(clientPassword),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("client login error: %s", err)
|
||||
}
|
||||
// The error sequence is:
|
||||
// - partial success
|
||||
// - nil
|
||||
if len(serverAuthErrors) != 2 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
|
||||
clientConfig = &ClientConfig{
|
||||
User: user2,
|
||||
Auth: []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("client login error: %s", err)
|
||||
}
|
||||
// The error sequence is:
|
||||
// - partial success
|
||||
// - nil
|
||||
if len(serverAuthErrors) != 2 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
|
||||
// user1 cannot login with public key
|
||||
clientConfig = &ClientConfig{
|
||||
User: user1,
|
||||
Auth: []AuthMethod{
|
||||
PublicKeys(testSigners["rsa"]),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err == nil {
|
||||
t.Fatal("user1 login with public key must fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no supported methods remain") {
|
||||
t.Errorf("got %v, expected 'no supported methods remain'", err)
|
||||
}
|
||||
if len(serverAuthErrors) != 1 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
// user2 cannot login with password
|
||||
clientConfig = &ClientConfig{
|
||||
User: user2,
|
||||
Auth: []AuthMethod{
|
||||
Password(clientPassword),
|
||||
},
|
||||
HostKeyCallback: InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
|
||||
if err == nil {
|
||||
t.Fatal("user2 login with password must fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no supported methods remain") {
|
||||
t.Errorf("got %v, expected 'no supported methods remain'", err)
|
||||
}
|
||||
if len(serverAuthErrors) != 1 {
|
||||
t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
|
||||
}
|
||||
if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
|
||||
t.Fatal("server not returned partial success")
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче