From edd02eacbef20ab549402e70caa7c2b359971c8c Mon Sep 17 00:00:00 2001 From: Jim Minter Date: Sun, 9 Feb 2020 16:39:16 -0600 Subject: [PATCH] encryption fixups: * pass cipher into database.NewDatabase, rather than bool * unexport as much as possible * remove backwards-compatibility and "read without key" options for now, adds too much complexity --- Gopkg.toml | 2 +- cmd/aro/monitor.go | 8 +- cmd/aro/rp.go | 8 +- ...are-a-shared-rp-development-environment.md | 6 +- hack/db/db.go | 8 +- pkg/api/openshiftcluster.go | 4 +- pkg/backend/openshiftcluster/create.go | 5 +- pkg/database/database.go | 14 +- pkg/database/extensions.go | 94 ++++----- pkg/database/extensions_test.go | 195 ++++-------------- pkg/env/env.go | 4 +- pkg/env/prod.go | 19 +- pkg/install/install.go | 8 +- pkg/util/encryption/encrypt.go | 112 ---------- pkg/util/encryption/encrypt_test.go | 118 ----------- pkg/util/encryption/xchacha20poly1305.go | 74 +++++++ pkg/util/encryption/xchacha20poly1305_test.go | 153 ++++++++++++++ 17 files changed, 369 insertions(+), 463 deletions(-) delete mode 100644 pkg/util/encryption/encrypt.go delete mode 100644 pkg/util/encryption/encrypt_test.go create mode 100644 pkg/util/encryption/xchacha20poly1305.go create mode 100644 pkg/util/encryption/xchacha20poly1305_test.go diff --git a/Gopkg.toml b/Gopkg.toml index 6beb53b02..45ffe9571 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -62,8 +62,8 @@ required = [ version = "1.1.7" [[constraint]] - branch = "master" name = "golang.org/x/crypto" + branch = "master" [[override]] name = "k8s.io/api" diff --git a/cmd/aro/monitor.go b/cmd/aro/monitor.go index f56595ea0..c3790aa17 100644 --- a/cmd/aro/monitor.go +++ b/cmd/aro/monitor.go @@ -13,6 +13,7 @@ import ( "github.com/Azure/ARO-RP/pkg/env" "github.com/Azure/ARO-RP/pkg/metrics/statsd" pkgmonitor "github.com/Azure/ARO-RP/pkg/monitor" + "github.com/Azure/ARO-RP/pkg/util/encryption" ) func monitor(ctx context.Context, log *logrus.Entry) error { @@ -30,7 +31,12 @@ func monitor(ctx context.Context, log *logrus.Entry) error { } defer m.Close() - db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true) + cipher, err := encryption.NewXChaCha20Poly1305(ctx, env) + if err != nil { + return err + } + + db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, cipher, uuid) if err != nil { return err } diff --git a/cmd/aro/rp.go b/cmd/aro/rp.go index a78698b7a..80c67f9f5 100644 --- a/cmd/aro/rp.go +++ b/cmd/aro/rp.go @@ -19,6 +19,7 @@ import ( "github.com/Azure/ARO-RP/pkg/env" "github.com/Azure/ARO-RP/pkg/frontend" "github.com/Azure/ARO-RP/pkg/metrics/statsd" + "github.com/Azure/ARO-RP/pkg/util/encryption" ) func rp(ctx context.Context, log *logrus.Entry) error { @@ -36,7 +37,12 @@ func rp(ctx context.Context, log *logrus.Entry) error { } defer m.Close() - db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true) + cipher, err := encryption.NewXChaCha20Poly1305(ctx, env) + if err != nil { + return err + } + + db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, cipher, uuid) if err != nil { return err } diff --git a/docs/prepare-a-shared-rp-development-environment.md b/docs/prepare-a-shared-rp-development-environment.md index aca308928..ebca69313 100644 --- a/docs/prepare-a-shared-rp-development-environment.md +++ b/docs/prepare-a-shared-rp-development-environment.md @@ -310,9 +310,9 @@ locations. >/dev/null az keyvault secret set \ --vault-name "$KEYVAULT_PREFIX-service" \ - --name "encryption-key" \ - --value $(openssl rand -base64 32) \ - >/dev/null + --name encryption-key \ + --value "$(openssl rand -base64 32)" \ + >/dev/null ``` 1. Create nameserver records in the parent DNS zone: diff --git a/hack/db/db.go b/hack/db/db.go index 617a5fe04..14a56f778 100644 --- a/hack/db/db.go +++ b/hack/db/db.go @@ -15,6 +15,7 @@ import ( "github.com/Azure/ARO-RP/pkg/database" "github.com/Azure/ARO-RP/pkg/env" "github.com/Azure/ARO-RP/pkg/metrics/noop" + "github.com/Azure/ARO-RP/pkg/util/encryption" utillog "github.com/Azure/ARO-RP/pkg/util/log" ) @@ -28,7 +29,12 @@ func run(ctx context.Context, log *logrus.Entry) error { return err } - db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, "", true) + cipher, err := encryption.NewXChaCha20Poly1305(ctx, env) + if err != nil { + return err + } + + db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, cipher, "") if err != nil { return err } diff --git a/pkg/api/openshiftcluster.go b/pkg/api/openshiftcluster.go index baa28c56e..d4f068b33 100644 --- a/pkg/api/openshiftcluster.go +++ b/pkg/api/openshiftcluster.go @@ -21,10 +21,10 @@ type OpenShiftCluster struct { Properties Properties `json:"properties,omitempty"` } -// SecureBytes represents encrypted []byte +// SecureBytes represents an encrypted []byte type SecureBytes []byte -// SecureString represents encrypted string +// SecureString represents an encrypted string type SecureString string // Properties represents an OpenShift cluster's properties diff --git a/pkg/backend/openshiftcluster/create.go b/pkg/backend/openshiftcluster/create.go index 49387aac0..90ec3719b 100644 --- a/pkg/backend/openshiftcluster/create.go +++ b/pkg/backend/openshiftcluster/create.go @@ -40,16 +40,13 @@ func (m *Manager) Create(ctx context.Context) error { resourceGroup := m.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID[strings.LastIndexByte(m.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID, '/')+1:] m.doc, err = m.db.PatchWithLease(ctx, m.doc.Key, func(doc *api.OpenShiftClusterDocument) error { - var err error if doc.OpenShiftCluster.Properties.SSHKey == nil { sshKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return err } + doc.OpenShiftCluster.Properties.SSHKey = x509.MarshalPKCS1PrivateKey(sshKey) - if err != nil { - return err - } } if doc.OpenShiftCluster.Properties.StorageSuffix == "" { diff --git a/pkg/database/database.go b/pkg/database/database.go index 3046d4173..f942d600c 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -33,15 +33,7 @@ type Database struct { } // NewDatabase returns a new Database -func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, uuid string, decryptDatabase bool) (db *Database, err error) { - var cipher encryption.Cipher - if decryptDatabase { - cipher, err = encryption.NewCipher(ctx, env) - if err != nil { - return nil, err - } - } - +func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, cipher encryption.Cipher, uuid string) (db *Database, err error) { databaseAccount, masterKey := env.CosmosDB() h := newJSONHandle(cipher) @@ -99,7 +91,7 @@ func newJSONHandle(cipher encryption.Cipher) *codec.JsonHandle { }, } - h.SetInterfaceExt(reflect.TypeOf(api.SecureBytes{}), 1, SecureBytesExt{Cipher: cipher}) - h.SetInterfaceExt(reflect.TypeOf((*api.SecureString)(nil)), 1, SecureStringExt{Cipher: cipher}) + h.SetInterfaceExt(reflect.TypeOf(api.SecureBytes{}), 1, secureBytesExt{cipher: cipher}) + h.SetInterfaceExt(reflect.TypeOf((*api.SecureString)(nil)), 1, secureStringExt{cipher: cipher}) return h } diff --git a/pkg/database/extensions.go b/pkg/database/extensions.go index 30ea3b16d..d467fb2b3 100644 --- a/pkg/database/extensions.go +++ b/pkg/database/extensions.go @@ -4,6 +4,8 @@ package database // Licensed under the Apache License 2.0. import ( + "encoding/base64" + "github.com/ugorji/go/codec" "github.com/Azure/ARO-RP/pkg/api" @@ -11,62 +13,60 @@ import ( encrypt "github.com/Azure/ARO-RP/pkg/util/encryption" ) -var _ codec.InterfaceExt = (*SecureBytesExt)(nil) +var _ codec.InterfaceExt = (*secureBytesExt)(nil) -type SecureBytesExt struct { - Cipher encryption.Cipher +type secureBytesExt struct { + cipher encryption.Cipher } -func (s SecureBytesExt) ConvertExt(v interface{}) interface{} { - data := v.(api.SecureBytes) - if s.Cipher != nil { - encrypted, err := s.Cipher.Encrypt(string(data)) - if err != nil { - panic(err) - } - return encrypted +func (s secureBytesExt) ConvertExt(v interface{}) interface{} { + encrypted, err := s.cipher.Encrypt(v.(api.SecureBytes)) + if err != nil { + panic(err) } - return string(data) -} -func (s SecureBytesExt) UpdateExt(dest interface{}, v interface{}) { - output := dest.(*api.SecureBytes) - if s.Cipher != nil { - decrypted, err := s.Cipher.Decrypt(v.(string)) - if err != nil { - panic(err) - } - *output = api.SecureBytes(decrypted) - return - } - *output = api.SecureBytes(v.(string)) + + return base64.StdEncoding.EncodeToString([]byte(encrypted)) } -var _ codec.InterfaceExt = (*SecureStringExt)(nil) +func (s secureBytesExt) UpdateExt(dest interface{}, v interface{}) { + b, err := base64.StdEncoding.DecodeString(v.(string)) + if err != nil { + panic(err) + } -type SecureStringExt struct { - Cipher encrypt.Cipher + b, err = s.cipher.Decrypt(b) + if err != nil { + panic(err) + } + + *dest.(*api.SecureBytes) = b } -func (s SecureStringExt) ConvertExt(v interface{}) interface{} { - data := v.(api.SecureString) - if s.Cipher != nil { - encrypted, err := s.Cipher.Encrypt(string(data)) - if err != nil { - panic(err) - } - return encrypted - } - return string(data) +var _ codec.InterfaceExt = (*secureStringExt)(nil) + +type secureStringExt struct { + cipher encrypt.Cipher } -func (s SecureStringExt) UpdateExt(dest interface{}, v interface{}) { - output := dest.(*api.SecureString) - if s.Cipher != nil { - decrypted, err := s.Cipher.Decrypt(v.(string)) - if err != nil { - panic(err) - } - *output = api.SecureString(decrypted) - return + +func (s secureStringExt) ConvertExt(v interface{}) interface{} { + encrypted, err := s.cipher.Encrypt([]byte(v.(api.SecureString))) + if err != nil { + panic(err) } - *output = api.SecureString(v.(string)) + + return base64.StdEncoding.EncodeToString([]byte(encrypted)) +} + +func (s secureStringExt) UpdateExt(dest interface{}, v interface{}) { + b, err := base64.StdEncoding.DecodeString(v.(string)) + if err != nil { + panic(err) + } + + b, err = s.cipher.Decrypt(b) + if err != nil { + panic(err) + } + + *dest.(*api.SecureString) = api.SecureString(b) } diff --git a/pkg/database/extensions_test.go b/pkg/database/extensions_test.go index 702d2e50a..3bcae610a 100644 --- a/pkg/database/extensions_test.go +++ b/pkg/database/extensions_test.go @@ -4,11 +4,7 @@ package database // Licensed under the Apache License 2.0. import ( - "bytes" "context" - "encoding/base64" - "encoding/json" - "io/ioutil" "reflect" "testing" @@ -20,164 +16,63 @@ import ( ) type testStruct struct { - SecureBytes api.SecureBytes - SecureString api.SecureString - Bytes []byte - Str string + SecureBytes api.SecureBytes `json:"secureBytes,omitempty"` + SecureString api.SecureString `json:"secureString,omitempty"` + Bytes []byte `json:"bytes,omitempty"` + String string `json:"string,omitempty"` } func TestExtensions(t *testing.T) { - encryption.RandRead = func(b []byte) (n int, err error) { - b = make([]byte, len(b)) - return len(b), nil + ctx := context.Background() + + env := &env.Test{ + TestSecret: []byte("\x63\xb5\x59\xf0\x43\x34\x79\x49\x68\x46\xab\x8b\xce\xdb\xc1\x2d\x7a\x0b\x14\x86\x7e\x1a\xb2\xd7\x3a\x92\x4e\x98\x6c\x5e\xcb\xe1"), } - key := make([]byte, 32) - keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key))) - base64.StdEncoding.Encode(keybase64, key) - env := &env.Test{TestSecret: keybase64} - - cipher, err := encryption.NewCipher(context.Background(), env) + cipher, err := encryption.NewXChaCha20Poly1305(ctx, env) if err != nil { - t.Error(err) + t.Fatal(err) } - for _, tt := range []struct { - name string - input func(input *testStruct) - inputCodec *codec.JsonHandle - output func(input *testStruct) - outputCodec *codec.JsonHandle - }{ - { - name: "noop", - }, - { - name: "SecureByte - encrypt - decrypt", - input: func(input *testStruct) { - input.SecureBytes = []byte("test") - }, - output: func(output *testStruct) { - output.SecureBytes = []byte("test") - }, - }, - { - name: "SecureByte - encrypt - raw", - input: func(input *testStruct) { - input.SecureBytes = []byte("test") - }, - inputCodec: newJSONHandle(cipher), - output: func(output *testStruct) { - output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=") - // empty string encoded - output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAjzuUWlGQbchgDen4li0A5g==" - }, - outputCodec: newJSONHandle(nil), - }, - { - name: "SecureByte - raw - decrypt", - input: func(input *testStruct) { - input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=") - }, - inputCodec: newJSONHandle(nil), - output: func(output *testStruct) { - output.SecureBytes = []byte("test") - }, - outputCodec: newJSONHandle(cipher), - }, - { - name: "SecureByte - raw - raw", - input: func(input *testStruct) { - input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=") - }, - inputCodec: newJSONHandle(nil), - output: func(output *testStruct) { - output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=") - }, - outputCodec: newJSONHandle(nil), - }, - { - name: "SecureString - encrypt - decrypt", - input: func(input *testStruct) { - input.SecureString = "test" - }, - output: func(output *testStruct) { - output.SecureString = "test" - }, - }, - { - name: "SecureString - encrypt - raw", - input: func(input *testStruct) { - input.SecureString = "test" - }, - inputCodec: newJSONHandle(cipher), - output: func(output *testStruct) { - output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=" - }, - outputCodec: newJSONHandle(nil), - }, - { - name: "SecureString - raw - decrypt", - input: func(input *testStruct) { - input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=" - }, - inputCodec: newJSONHandle(nil), - output: func(output *testStruct) { - output.SecureString = "test" - }, - outputCodec: newJSONHandle(cipher), - }, - { - name: "SecureString - raw - raw", - input: func(input *testStruct) { - input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=" - }, - inputCodec: newJSONHandle(nil), - output: func(output *testStruct) { - output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=" - }, - outputCodec: newJSONHandle(nil), - }, - } { - t.Run(tt.name, func(t *testing.T) { - if tt.inputCodec == nil { - tt.inputCodec = newJSONHandle(cipher) - } - if tt.outputCodec == nil { - tt.outputCodec = newJSONHandle(cipher) - } + h := newJSONHandle(cipher) - input := &testStruct{} - if tt.input != nil { - tt.input(input) - } + encrypted := []byte(`{ + "bytes": "Ynl0ZXM=", + "secureBytes": "6w8Uah0zX40LRfkYHuU9UvLuGrBcHb7l8I2M6qTcmtclOGJNONfHqAuaJWifZj7dd8fI", + "secureString": "YT+ZNR23JBvILGw1WBn6/NhtCj9LM14EXp5VR6XloD7CN1MfmvW5FEn9duRSPYbdr98tLQ==", + "string": "string" + }`) - output := &testStruct{} - if tt.output != nil { - tt.output(output) - } + decrypted := &testStruct{ + SecureBytes: []byte("securebytes"), + SecureString: "securestring", + Bytes: []byte("bytes"), + String: "string", + } - buf := &bytes.Buffer{} - err = codec.NewEncoder(buf, tt.inputCodec).Encode(input) - if err != nil { - t.Error(err) - } - data, err := ioutil.ReadAll(buf) - if err != nil { - t.Error(err) - } + var ts testStruct + err = codec.NewDecoderBytes(encrypted, h).Decode(&ts) + if err != nil { + t.Fatal(err) + } - result := &testStruct{} - err = codec.NewDecoder(bytes.NewReader(data), tt.outputCodec).Decode(result) - if err != nil { - t.Error(err) - } + if !reflect.DeepEqual(&ts, decrypted) { + t.Errorf("%#v", &ts) + } - if !reflect.DeepEqual(output, result) { - output, _ := json.Marshal(output) - result, _ := json.Marshal(result) - t.Errorf("\n wants: %s,'\ngot: %s", string(output), string(result)) - } - }) + var enc []byte + err = codec.NewEncoderBytes(&enc, h).Encode(ts) + if err != nil { + t.Fatal(err) + } + + ts = testStruct{} + err = codec.NewDecoderBytes(encrypted, h).Decode(&ts) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(&ts, decrypted) { + t.Errorf("%#v", &ts) } } diff --git a/pkg/env/env.go b/pkg/env/env.go index 8966bc494..4a23262ae 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -28,10 +28,10 @@ type Interface interface { DialContext(context.Context, string, string) (net.Conn, error) Domain() string FPAuthorizer(string, string) (autorest.Authorizer, error) - ManagedDomain(string) (string, error) - GetSecret(context.Context, string) ([]byte, error) GetCertificateSecret(context.Context, string) (*rsa.PrivateKey, []*x509.Certificate, error) + GetSecret(context.Context, string) ([]byte, error) Listen() (net.Listener, error) + ManagedDomain(string) (string, error) VnetName() string Zones(vmSize string) ([]string, error) } diff --git a/pkg/env/prod.go b/pkg/env/prod.go index f6826b936..53fb0dfbb 100644 --- a/pkg/env/prod.go +++ b/pkg/env/prod.go @@ -7,6 +7,7 @@ import ( "context" "crypto/rsa" "crypto/x509" + "encoding/base64" "encoding/pem" "fmt" "net" @@ -265,16 +266,13 @@ func (p *prod) FPAuthorizer(tenantID, resource string) (autorest.Authorizer, err return autorest.NewBearerAuthorizer(sp), nil } -func (p *prod) GetSecret(ctx context.Context, secretName string) ([]byte, error) { - bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "") - return []byte(*bundle.Value), err -} - func (p *prod) GetCertificateSecret(ctx context.Context, secretName string) (key *rsa.PrivateKey, certs []*x509.Certificate, err error) { - b, err := p.GetSecret(ctx, secretName) + bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "") if err != nil { return nil, nil, err } + + b := []byte(*bundle.Value) for { var block *pem.Block block, b = pem.Decode(b) @@ -314,6 +312,15 @@ func (p *prod) GetCertificateSecret(ctx context.Context, secretName string) (key return key, certs, nil } +func (p *prod) GetSecret(ctx context.Context, secretName string) ([]byte, error) { + bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "") + if err != nil { + return nil, err + } + + return base64.StdEncoding.DecodeString(*bundle.Value) +} + func (p *prod) Listen() (net.Listener, error) { return net.Listen("tcp", ":8443") } diff --git a/pkg/install/install.go b/pkg/install/install.go index a96824c9a..ac905a093 100644 --- a/pkg/install/install.go +++ b/pkg/install/install.go @@ -81,7 +81,7 @@ func NewInstaller(ctx context.Context, log *logrus.Entry, env env.Interface, db return nil, err } - cipher, err := encryption.NewCipher(ctx, env) + cipher, err := encryption.NewXChaCha20Poly1305(ctx, env) if err != nil { return nil, err } @@ -211,13 +211,13 @@ func (i *Installer) loadGraph(ctx context.Context) (graph, error) { return nil, err } - output, err := i.cipher.Decrypt(string(encrypted)) + output, err := i.cipher.Decrypt(encrypted) if err != nil { return nil, err } var g graph - err = json.Unmarshal([]byte(output), &g) + err = json.Unmarshal(output, &g) if err != nil { return nil, err } @@ -246,7 +246,7 @@ func (i *Installer) saveGraph(ctx context.Context, g graph) error { return err } - output, err := i.cipher.Encrypt(string(b)) + output, err := i.cipher.Encrypt(b) if err != nil { return err } diff --git a/pkg/util/encryption/encrypt.go b/pkg/util/encryption/encrypt.go deleted file mode 100644 index 877033458..000000000 --- a/pkg/util/encryption/encrypt.go +++ /dev/null @@ -1,112 +0,0 @@ -package encryption - -// Copyright (c) Microsoft Corporation. -// Licensed under the Apache License 2.0. - -import ( - "context" - "crypto/cipher" - "crypto/rand" - "encoding/base64" - "fmt" - "strings" - - "golang.org/x/crypto/chacha20poly1305" - - "github.com/Azure/ARO-RP/pkg/env" -) - -// encryptionSecretName must match key name in the service keyvault -const ( - encryptionSecretName = "encryption-key" - Prefix = "ENC*" -) - -var ( - _ Cipher = (*aeadCipher)(nil) - RandRead = rand.Read -) - -type Cipher interface { - Decrypt(string) (string, error) - Encrypt(string) (string, error) -} - -type aeadCipher struct { - aead cipher.AEAD -} - -func NewCipher(ctx context.Context, env env.Interface) (Cipher, error) { - keybase64, err := env.GetSecret(ctx, encryptionSecretName) - if err != nil { - return nil, err - } - - key := make([]byte, base64.StdEncoding.DecodedLen(len(keybase64))) - n, err := base64.StdEncoding.Decode(key, keybase64) - if err != nil { - return nil, err - } - - if n < 32 { - return nil, fmt.Errorf("chacha20poly1305: bad key length") - } - key = key[:32] - - aead, err := chacha20poly1305.NewX(key) - if err != nil { - return nil, err - } - - return &aeadCipher{ - aead: aead, - }, nil -} - -// Decrypt decrypts input -func (c *aeadCipher) Decrypt(input string) (string, error) { - if !strings.HasPrefix(input, Prefix) { - return input, nil - } - input = input[len(Prefix):] - - r := make([]byte, base64.StdEncoding.DecodedLen(len(input))) - r, err := base64.StdEncoding.DecodeString(input) - if err != nil { - return "", err - } - - if len(r) >= 24 { - nonce := r[0:24] - data := r[24:] - output, err := c.aead.Open(nil, nonce, data, nil) - if err != nil { - return "", err - } - return string(output), nil - } - return "", fmt.Errorf("error while decrypting message") -} - -// Encrypt encrypts input using 24 byte nonce -func (c *aeadCipher) Encrypt(input string) (string, error) { - nonce := make([]byte, chacha20poly1305.NonceSizeX) - _, err := RandRead(nonce) - if err != nil { - return "", err - } - encrypted := c.aead.Seal(nil, nonce, []byte(input), nil) - - var encryptedFinal []byte - encryptedFinal = append(encryptedFinal, nonce...) - encryptedFinal = append(encryptedFinal, encrypted...) - - encryptedBase64 := make([]byte, base64.StdEncoding.EncodedLen(len(encryptedFinal))) - base64.StdEncoding.Encode(encryptedBase64, encryptedFinal) - - // return prefix+base64(nonce+encryptedFinal) - var result []byte - result = append(result, Prefix...) - result = append(result, encryptedBase64...) - return string(result), nil -} diff --git a/pkg/util/encryption/encrypt_test.go b/pkg/util/encryption/encrypt_test.go deleted file mode 100644 index 611aa97bf..000000000 --- a/pkg/util/encryption/encrypt_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package encryption - -// Copyright (c) Microsoft Corporation. -// Licensed under the Apache License 2.0. - -import ( - "context" - "encoding/base64" - "strings" - "testing" - - "github.com/Azure/ARO-RP/pkg/env" -) - -func TestEncryptRoundTrip(t *testing.T) { - key := make([]byte, 32) - keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key))) - base64.StdEncoding.Encode(keybase64, key) - env := &env.Test{TestSecret: keybase64} - - cipher, err := NewCipher(context.Background(), env) - if err != nil { - t.Error(err) - } - - test := "secert" - encrypted, err := cipher.Encrypt(test) - if err != nil { - t.Error(err) - } - - decrypted, err := cipher.Decrypt(encrypted) - if err != nil { - t.Error(err) - } - - if r := strings.Compare(test, decrypted); r != 0 { - t.Error("encryption roundTrip failed") - } -} - -func TestEncrypt(t *testing.T) { - RandRead = func(b []byte) (n int, err error) { - b = make([]byte, len(b)) - return len(b), nil - } - - for _, tt := range []struct { - name string - input string - expected string - wantErr string - env func(e *env.Test) - }{ - { - name: "ok encrypt", - input: "test", - expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=", - wantErr: "", - env: func(input *env.Test) { - key := make([]byte, 32) - keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key))) - base64.StdEncoding.Encode(keybase64, key) - input.TestSecret = keybase64 - }, - }, - { - name: "base64 key error", - wantErr: "illegal base64 data at input byte 8", - env: func(input *env.Test) { - input.TestSecret = []byte("badsecret") - }, - }, - { - name: "key too short", - input: "test", - expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=", - wantErr: "chacha20poly1305: bad key length", - env: func(input *env.Test) { - keybase64 := base64.StdEncoding.EncodeToString(make([]byte, 15)) - input.TestSecret = []byte(keybase64) - }, - }, - { - name: "key too long", // due to base64 approximations library truncates the secret to right lenhgt - input: "test", - expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=", - env: func(input *env.Test) { - keybase64 := base64.StdEncoding.EncodeToString((make([]byte, 40))) - input.TestSecret = []byte(keybase64) - }, - }, - } { - t.Run(tt.name, func(t *testing.T) { - e := &env.Test{} - if tt.env != nil { - tt.env(e) - } - - cipher, err := NewCipher(context.Background(), e) - if err != nil { - if err.Error() != tt.wantErr { - t.Errorf("\n wants: %s,'\ngot: %s", tt.wantErr, err.Error()) - t.FailNow() - } - t.SkipNow() - } - - result, err := cipher.Encrypt(tt.input) - if err != nil { - t.Error(err) - } - if tt.expected != result { - t.Errorf("\n wants: %s,'\ngot: %s", tt.expected, result) - } - }) - } -} diff --git a/pkg/util/encryption/xchacha20poly1305.go b/pkg/util/encryption/xchacha20poly1305.go new file mode 100644 index 000000000..9bd723c4b --- /dev/null +++ b/pkg/util/encryption/xchacha20poly1305.go @@ -0,0 +1,74 @@ +package encryption + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +import ( + "context" + "crypto/cipher" + "crypto/rand" + "fmt" + + "golang.org/x/crypto/chacha20poly1305" + + "github.com/Azure/ARO-RP/pkg/env" +) + +const ( + encryptionSecretName = "encryption-key" // must match key name in the service keyvault +) + +var _ Cipher = (*aeadCipher)(nil) + +type Cipher interface { + Decrypt([]byte) ([]byte, error) + Encrypt([]byte) ([]byte, error) +} + +type aeadCipher struct { + aead cipher.AEAD + randRead func([]byte) (int, error) +} + +func NewXChaCha20Poly1305(ctx context.Context, env env.Interface) (Cipher, error) { + key, err := env.GetSecret(ctx, encryptionSecretName) + if err != nil { + return nil, err + } + + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return nil, err + } + + return &aeadCipher{ + aead: aead, + randRead: rand.Read, + }, nil +} + +func (c *aeadCipher) Decrypt(input []byte) ([]byte, error) { + if len(input) < chacha20poly1305.NonceSizeX { + return nil, fmt.Errorf("encrypted value too short") + } + + nonce := input[:chacha20poly1305.NonceSizeX] + data := input[chacha20poly1305.NonceSizeX:] + + return c.aead.Open(nil, nonce, data, nil) +} + +func (c *aeadCipher) Encrypt(input []byte) ([]byte, error) { + nonce := make([]byte, chacha20poly1305.NonceSizeX) + + n, err := c.randRead(nonce) + if err != nil { + return nil, err + } + + if n != chacha20poly1305.NonceSizeX { + return nil, fmt.Errorf("rand.Read returned %d bytes, expected %d", n, chacha20poly1305.NonceSizeX) + } + + return append(nonce, c.aead.Seal(nil, nonce, input, nil)...), nil +} diff --git a/pkg/util/encryption/xchacha20poly1305_test.go b/pkg/util/encryption/xchacha20poly1305_test.go new file mode 100644 index 000000000..eac91d0cc --- /dev/null +++ b/pkg/util/encryption/xchacha20poly1305_test.go @@ -0,0 +1,153 @@ +package encryption + +// Copyright (c) Microsoft Corporation. +// Licensed under the Apache License 2.0. + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "testing" + + "github.com/Azure/ARO-RP/pkg/env" +) + +func TestNewXChaCha20Poly1305(t *testing.T) { + for _, tt := range []struct { + name string + key []byte + wantErr string + }{ + { + name: "valid", + key: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + }, + { + name: "key too short", + key: []byte("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"), + wantErr: "chacha20poly1305: bad key length", + }, + { + name: "key too long", + key: []byte("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"), + wantErr: "chacha20poly1305: bad key length", + }, + } { + t.Run(tt.name, func(t *testing.T) { + env := &env.Test{ + TestSecret: tt.key, + } + + _, err := NewXChaCha20Poly1305(context.Background(), env) + if err != nil && err.Error() != tt.wantErr || + err == nil && tt.wantErr != "" { + t.Fatal(err) + } + }) + } +} + +func TestXChaCha20Poly1305Decrypt(t *testing.T) { + for _, tt := range []struct { + name string + key []byte + input []byte + wantDecrypted []byte + wantErr string + }{ + { + name: "valid", + key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"), + input: []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"), + wantDecrypted: []byte("test"), + }, + { + name: "invalid - encrypted value tampered with", + key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"), + input: []byte("\xda\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"), + wantErr: "chacha20poly1305: message authentication failed", + }, + { + name: "invalid - too short", + key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"), + input: []byte("XXXXXXXXXXXXXXXXXXXXXXX"), + wantErr: "encrypted value too short", + }, + } { + t.Run(tt.name, func(t *testing.T) { + env := &env.Test{ + TestSecret: tt.key, + } + + cipher, err := NewXChaCha20Poly1305(context.Background(), env) + if err != nil { + t.Fatal(err) + } + + decrypted, err := cipher.Decrypt(tt.input) + if err != nil && err.Error() != tt.wantErr || + err == nil && tt.wantErr != "" { + t.Fatal(err) + } + + if !bytes.Equal(tt.wantDecrypted, decrypted) { + t.Error(string(decrypted)) + } + }) + } +} + +func TestXChaCha20Poly1305Encrypt(t *testing.T) { + for _, tt := range []struct { + name string + key []byte + randRead func(b []byte) (int, error) + input []byte + wantEncrypted []byte + wantErr string + }{ + { + name: "valid", + key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"), + randRead: func(b []byte) (int, error) { + nonce := []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2") + copy(b, nonce) + return len(nonce), nil + }, + input: []byte("test"), + wantEncrypted: []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"), + }, + { + name: "rand.Read error", + key: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), + randRead: func(b []byte) (int, error) { + return 0, fmt.Errorf("random error") + }, + wantErr: "random error", + }, + } { + t.Run(tt.name, func(t *testing.T) { + env := &env.Test{ + TestSecret: tt.key, + } + + cipher, err := NewXChaCha20Poly1305(context.Background(), env) + if err != nil { + t.Fatal(err) + } + + cipher.(*aeadCipher).randRead = tt.randRead + + encrypted, err := cipher.Encrypt(tt.input) + if err != nil && err.Error() != tt.wantErr || + err == nil && tt.wantErr != "" { + t.Fatal(err) + } + + if !bytes.Equal(tt.wantEncrypted, encrypted) { + t.Error(hex.EncodeToString(encrypted)) + } + }) + } +}