go.crypto/ssh: add hook for host key checking.

R=dave, agl
CC=gobot, golang-dev
https://golang.org/cl/9922043
This commit is contained in:
Han-Wen Nienhuys 2013-06-21 12:46:35 -04:00 коммит произвёл Adam Langley
Родитель b88b016522
Коммит afdc305bc8
4 изменённых файлов: 95 добавлений и 5 удалений

Просмотреть файл

@ -26,6 +26,9 @@ type ClientConn struct {
chanList // channels associated with this connection
forwardList // forwarded tcpip connections from the remote side
globalRequest
// Address as passed to the Dial function.
dialAddress string
}
type globalRequest struct {
@ -35,11 +38,17 @@ type globalRequest struct {
// Client returns a new SSH client connection using c as the underlying transport.
func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
return clientWithAddress(c, "", config)
}
func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
conn := &ClientConn{
transport: newTransport(c, config.rand()),
config: config,
globalRequest: globalRequest{response: make(chan interface{}, 1)},
dialAddress: addr,
}
if err := conn.handshake(); err != nil {
conn.Close()
return nil, fmt.Errorf("handshake failed: %v", err)
@ -168,6 +177,12 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return nil, nil, err
}
if checker := c.config.HostKeyChecker; checker != nil {
if err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, kexDHReply.HostKey); err != nil {
return nil, nil, err
}
}
kInt, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil {
return nil, nil, err
@ -445,7 +460,7 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
if err != nil {
return nil, err
}
return Client(conn, config)
return clientWithAddress(conn, addr, config)
}
// A ClientConfig structure is used to configure a ClientConn. After one has
@ -463,6 +478,11 @@ type ClientConfig struct {
// of a particular RFC 4252 method will be used during authentication.
Auth []ClientAuth
// HostKeyChecker, if not nil, is called during the cryptographic
// handshake to validate the server's host key. A nil HostKeyChecker
// implies that all host keys are accepted.
HostKeyChecker HostKeyChecker
// Cryptographic-related configuration.
Crypto CryptoConfig
}

Просмотреть файл

@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"net"
)
// authenticate authenticates with the remote server. See RFC 4252.
@ -63,6 +64,17 @@ func keys(m map[string]bool) (s []string) {
return
}
// HostKeyChecker represents a database of known server host keys.
type HostKeyChecker interface {
// Check is called during the handshake to check server's
// public key for unexpected changes. The hostKey argument is
// in SSH wire format. It can be parsed using
// ssh.ParsePublicKey. The address before DNS resolution is
// passed in the addr argument, so the key can also be checked
// against the hostname.
Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error
}
// A ClientAuth represents an instance of an RFC 4252 authentication method.
type ClientAuth interface {
// auth authenticates user over transport t.

Просмотреть файл

@ -33,6 +33,25 @@ func TestRunCommandSuccess(t *testing.T) {
}
}
func TestHostKeyCheck(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
k := conf.HostKeyChecker.(*storedHostKey)
// change the key.
k.keys["ssh-rsa"][25]++
conn, err := server.TryDial(conf)
if err == nil {
conn.Close()
t.Fatalf("dial should have failed.")
} else if !strings.Contains(err.Error(), "host key mismatch") {
t.Fatalf("'host key mismatch' not found in %v", err)
}
}
func TestRunCommandFailed(t *testing.T) {
server := newServer(t)
defer server.Shutdown()

Просмотреть файл

@ -55,14 +55,25 @@ HostbasedAuthentication no
`
var (
configTmpl template.Template
rsakey *rsa.PrivateKey
configTmpl template.Template
rsakey *rsa.PrivateKey
serializedHostKey []byte
)
func init() {
template.Must(configTmpl.Parse(sshd_config))
block, _ := pem.Decode([]byte(testClientPrivateKey))
rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"]))
if block == nil {
panic("pem.Decode ssh_host_rsa_key")
}
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
panic("ParsePKCS1PrivateKey: " + err.Error())
}
serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey)
}
type server struct {
@ -89,7 +100,29 @@ func username() string {
return username
}
type storedHostKey struct {
// keys map from an algorithm string to binary key data.
keys map[string][]byte
}
func (k *storedHostKey) Add(algo string, public []byte) {
if k.keys == nil {
k.keys = map[string][]byte{}
}
k.keys[algo] = append([]byte(nil), public...)
}
func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
return errors.New("host key mismatch")
}
return nil
}
func clientConfig() *ssh.ClientConfig {
keyChecker := storedHostKey{}
keyChecker.Add("ssh-rsa", serializedHostKey)
kc := new(keychain)
kc.keys = append(kc.keys, rsakey)
config := &ssh.ClientConfig{
@ -97,11 +130,12 @@ func clientConfig() *ssh.ClientConfig {
Auth: []ssh.ClientAuth{
ssh.ClientAuthKeyring(kc),
},
HostKeyChecker: &keyChecker,
}
return config
}
func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
@ -123,7 +157,12 @@ func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
s.Shutdown()
s.t.Fatalf("s.cmd.Start: %v", err)
}
conn, err := ssh.Client(&client{wc: w2, r: r1}, config)
return ssh.Client(&client{wc: w2, r: r1}, config)
}
func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
conn, err := s.TryDial(config)
if err != nil {
s.t.Fail()
s.Shutdown()