diff --git a/ssh/keys.go b/ssh/keys.go index 34d95822..2261dc38 100644 --- a/ssh/keys.go +++ b/ssh/keys.go @@ -903,8 +903,8 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { // Implemented based on the documentation at // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) { - magic := append([]byte("openssh-key-v1"), 0) - if !bytes.Equal(magic, key[0:len(magic)]) { + const magic = "openssh-key-v1\x00" + if len(key) < len(magic) || string(key[:len(magic)]) != magic { return nil, errors.New("ssh: invalid openssh private key format") } remaining := key[len(magic):] diff --git a/ssh/keys_test.go b/ssh/keys_test.go index 9a90abc0..f28725f1 100644 --- a/ssh/keys_test.go +++ b/ssh/keys_test.go @@ -13,7 +13,9 @@ import ( "crypto/rsa" "crypto/x509" "encoding/base64" + "encoding/pem" "fmt" + "io" "reflect" "strings" "testing" @@ -498,3 +500,32 @@ func TestFingerprintSHA256(t *testing.T) { t.Errorf("got fingerprint %q want %q", fingerprint, want) } } + +func TestInvalidKeys(t *testing.T) { + keyTypes := []string{ + "RSA PRIVATE KEY", + "PRIVATE KEY", + "EC PRIVATE KEY", + "DSA PRIVATE KEY", + "OPENSSH PRIVATE KEY", + } + + for _, keyType := range keyTypes { + for _, dataLen := range []int{0, 1, 2, 5, 10, 20} { + data := make([]byte, dataLen) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + pem.Encode(&buf, &pem.Block{ + Type: keyType, + Bytes: data, + }) + + // This test is just to ensure that the function + // doesn't panic so the return value is ignored. + ParseRawPrivateKey(buf.Bytes()) + } + } +}