go.crypto/ssh: add hook for host key checking.
R=dave, agl CC=gobot, golang-dev https://golang.org/cl/9922043
This commit is contained in:
Родитель
b88b016522
Коммит
afdc305bc8
|
@ -26,6 +26,9 @@ type ClientConn struct {
|
|||
chanList // channels associated with this connection
|
||||
forwardList // forwarded tcpip connections from the remote side
|
||||
globalRequest
|
||||
|
||||
// Address as passed to the Dial function.
|
||||
dialAddress string
|
||||
}
|
||||
|
||||
type globalRequest struct {
|
||||
|
@ -35,11 +38,17 @@ type globalRequest struct {
|
|||
|
||||
// Client returns a new SSH client connection using c as the underlying transport.
|
||||
func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
|
||||
return clientWithAddress(c, "", config)
|
||||
}
|
||||
|
||||
func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
|
||||
conn := &ClientConn{
|
||||
transport: newTransport(c, config.rand()),
|
||||
config: config,
|
||||
globalRequest: globalRequest{response: make(chan interface{}, 1)},
|
||||
dialAddress: addr,
|
||||
}
|
||||
|
||||
if err := conn.handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("handshake failed: %v", err)
|
||||
|
@ -168,6 +177,12 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
if checker := c.config.HostKeyChecker; checker != nil {
|
||||
if err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, kexDHReply.HostKey); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
kInt, err := group.diffieHellman(kexDHReply.Y, x)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -445,7 +460,7 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Client(conn, config)
|
||||
return clientWithAddress(conn, addr, config)
|
||||
}
|
||||
|
||||
// A ClientConfig structure is used to configure a ClientConn. After one has
|
||||
|
@ -463,6 +478,11 @@ type ClientConfig struct {
|
|||
// of a particular RFC 4252 method will be used during authentication.
|
||||
Auth []ClientAuth
|
||||
|
||||
// HostKeyChecker, if not nil, is called during the cryptographic
|
||||
// handshake to validate the server's host key. A nil HostKeyChecker
|
||||
// implies that all host keys are accepted.
|
||||
HostKeyChecker HostKeyChecker
|
||||
|
||||
// Cryptographic-related configuration.
|
||||
Crypto CryptoConfig
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// authenticate authenticates with the remote server. See RFC 4252.
|
||||
|
@ -63,6 +64,17 @@ func keys(m map[string]bool) (s []string) {
|
|||
return
|
||||
}
|
||||
|
||||
// HostKeyChecker represents a database of known server host keys.
|
||||
type HostKeyChecker interface {
|
||||
// Check is called during the handshake to check server's
|
||||
// public key for unexpected changes. The hostKey argument is
|
||||
// in SSH wire format. It can be parsed using
|
||||
// ssh.ParsePublicKey. The address before DNS resolution is
|
||||
// passed in the addr argument, so the key can also be checked
|
||||
// against the hostname.
|
||||
Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error
|
||||
}
|
||||
|
||||
// A ClientAuth represents an instance of an RFC 4252 authentication method.
|
||||
type ClientAuth interface {
|
||||
// auth authenticates user over transport t.
|
||||
|
|
|
@ -33,6 +33,25 @@ func TestRunCommandSuccess(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHostKeyCheck(t *testing.T) {
|
||||
server := newServer(t)
|
||||
defer server.Shutdown()
|
||||
|
||||
conf := clientConfig()
|
||||
k := conf.HostKeyChecker.(*storedHostKey)
|
||||
|
||||
// change the key.
|
||||
k.keys["ssh-rsa"][25]++
|
||||
|
||||
conn, err := server.TryDial(conf)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
t.Fatalf("dial should have failed.")
|
||||
} else if !strings.Contains(err.Error(), "host key mismatch") {
|
||||
t.Fatalf("'host key mismatch' not found in %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCommandFailed(t *testing.T) {
|
||||
server := newServer(t)
|
||||
defer server.Shutdown()
|
||||
|
|
|
@ -55,14 +55,25 @@ HostbasedAuthentication no
|
|||
`
|
||||
|
||||
var (
|
||||
configTmpl template.Template
|
||||
rsakey *rsa.PrivateKey
|
||||
configTmpl template.Template
|
||||
rsakey *rsa.PrivateKey
|
||||
serializedHostKey []byte
|
||||
)
|
||||
|
||||
func init() {
|
||||
template.Must(configTmpl.Parse(sshd_config))
|
||||
block, _ := pem.Decode([]byte(testClientPrivateKey))
|
||||
rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
|
||||
block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"]))
|
||||
if block == nil {
|
||||
panic("pem.Decode ssh_host_rsa_key")
|
||||
}
|
||||
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
panic("ParsePKCS1PrivateKey: " + err.Error())
|
||||
}
|
||||
serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey)
|
||||
}
|
||||
|
||||
type server struct {
|
||||
|
@ -89,7 +100,29 @@ func username() string {
|
|||
return username
|
||||
}
|
||||
|
||||
type storedHostKey struct {
|
||||
// keys map from an algorithm string to binary key data.
|
||||
keys map[string][]byte
|
||||
}
|
||||
|
||||
func (k *storedHostKey) Add(algo string, public []byte) {
|
||||
if k.keys == nil {
|
||||
k.keys = map[string][]byte{}
|
||||
}
|
||||
k.keys[algo] = append([]byte(nil), public...)
|
||||
}
|
||||
|
||||
func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
|
||||
if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
|
||||
return errors.New("host key mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func clientConfig() *ssh.ClientConfig {
|
||||
keyChecker := storedHostKey{}
|
||||
keyChecker.Add("ssh-rsa", serializedHostKey)
|
||||
|
||||
kc := new(keychain)
|
||||
kc.keys = append(kc.keys, rsakey)
|
||||
config := &ssh.ClientConfig{
|
||||
|
@ -97,11 +130,12 @@ func clientConfig() *ssh.ClientConfig {
|
|||
Auth: []ssh.ClientAuth{
|
||||
ssh.ClientAuthKeyring(kc),
|
||||
},
|
||||
HostKeyChecker: &keyChecker,
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
|
||||
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
|
||||
sshd, err := exec.LookPath("sshd")
|
||||
if err != nil {
|
||||
s.t.Skipf("skipping test: %v", err)
|
||||
|
@ -123,7 +157,12 @@ func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
|
|||
s.Shutdown()
|
||||
s.t.Fatalf("s.cmd.Start: %v", err)
|
||||
}
|
||||
conn, err := ssh.Client(&client{wc: w2, r: r1}, config)
|
||||
|
||||
return ssh.Client(&client{wc: w2, r: r1}, config)
|
||||
}
|
||||
|
||||
func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
|
||||
conn, err := s.TryDial(config)
|
||||
if err != nil {
|
||||
s.t.Fail()
|
||||
s.Shutdown()
|
||||
|
|
Загрузка…
Ссылка в новой задаче