ssh: allow server auth callbacks to send additional banners

Add a new BannerError error type that auth callbacks can return to send
banner to the client. While the BannerCallback can send the initial
banner message, auth callbacks might want to communicate more
information to the client to help them diagnose failures.

Updates golang/go#64962

Change-Id: I97a26480ff4064b95a0a26042b0a5e19737cfb62
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/558695
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
Andrew Lytvynov 2024-01-25 18:32:22 -07:00 коммит произвёл Gopher Robot
Родитель 67b13616a5
Коммит 44c9b0ff9e
2 изменённых файлов: 104 добавлений и 0 удалений

Просмотреть файл

@ -462,6 +462,24 @@ func (p *PartialSuccessError) Error() string {
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")
// BannerError is an error that can be returned by authentication handlers in
// ServerConfig to send a banner message to the client.
type BannerError struct {
Err error
Message string
}
func (b *BannerError) Unwrap() error {
return b.Err
}
func (b *BannerError) Error() string {
if b.Err == nil {
return b.Message
}
return b.Err.Error()
}
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID()
var cache pubKeyCache
@ -734,6 +752,18 @@ userAuthLoop:
config.AuthLogCallback(s, userAuthReq.Method, authErr)
}
var bannerErr *BannerError
if errors.As(authErr, &bannerErr) {
if bannerErr.Message != "" {
bannerMsg := &userAuthBannerMsg{
Message: bannerErr.Message,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
if authErr == nil {
break userAuthLoop
}

Просмотреть файл

@ -6,8 +6,10 @@ package ssh
import (
"errors"
"fmt"
"io"
"net"
"slices"
"strings"
"sync/atomic"
"testing"
@ -225,6 +227,78 @@ func TestNewServerConnValidationErrors(t *testing.T) {
}
}
func TestBannerError(t *testing.T) {
serverConfig := &ServerConfig{
BannerCallback: func(ConnMetadata) string {
return "banner from BannerCallback"
},
NoClientAuth: true,
NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
err := &BannerError{
Err: errors.New("error from NoClientAuthCallback"),
Message: "banner from NoClientAuthCallback",
}
return nil, fmt.Errorf("wrapped: %w", err)
},
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
return &Permissions{}, nil
},
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
return nil, &BannerError{
Err: errors.New("error from PublicKeyCallback"),
Message: "banner from PublicKeyCallback",
}
},
KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
return nil, &BannerError{
Err: nil, // make sure that a nil inner error is allowed
Message: "banner from KeyboardInteractiveCallback",
}
},
}
serverConfig.AddHostKey(testSigners["rsa"])
var banners []string
clientConfig := &ClientConfig{
User: "test",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
return []string{"letmein"}, nil
}),
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(msg string) error {
banners = append(banners, msg)
return nil
},
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
c, _, _, err := NewClientConn(c2, "", clientConfig)
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
defer c.Close()
wantBanners := []string{
"banner from BannerCallback",
"banner from NoClientAuthCallback",
"banner from PublicKeyCallback",
"banner from KeyboardInteractiveCallback",
}
if !slices.Equal(banners, wantBanners) {
t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
}
}
type markerConn struct {
closed uint32
used uint32