* pass cipher into database.NewDatabase, rather than bool
* unexport as much as possible
* remove backwards-compatibility and "read without key" options for now, adds too much complexity
This commit is contained in:
Jim Minter 2020-02-09 16:39:16 -06:00 коммит произвёл Mangirdas Judeikis
Родитель 468621f73c
Коммит edd02eacbe
17 изменённых файлов: 369 добавлений и 463 удалений

Просмотреть файл

@ -62,8 +62,8 @@ required = [
version = "1.1.7"
[[constraint]]
branch = "master"
name = "golang.org/x/crypto"
branch = "master"
[[override]]
name = "k8s.io/api"

Просмотреть файл

@ -13,6 +13,7 @@ import (
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/metrics/statsd"
pkgmonitor "github.com/Azure/ARO-RP/pkg/monitor"
"github.com/Azure/ARO-RP/pkg/util/encryption"
)
func monitor(ctx context.Context, log *logrus.Entry) error {
@ -30,7 +31,12 @@ func monitor(ctx context.Context, log *logrus.Entry) error {
}
defer m.Close()
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true)
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
if err != nil {
return err
}
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, cipher, uuid)
if err != nil {
return err
}

Просмотреть файл

@ -19,6 +19,7 @@ import (
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/frontend"
"github.com/Azure/ARO-RP/pkg/metrics/statsd"
"github.com/Azure/ARO-RP/pkg/util/encryption"
)
func rp(ctx context.Context, log *logrus.Entry) error {
@ -36,7 +37,12 @@ func rp(ctx context.Context, log *logrus.Entry) error {
}
defer m.Close()
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true)
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
if err != nil {
return err
}
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, cipher, uuid)
if err != nil {
return err
}

Просмотреть файл

@ -310,9 +310,9 @@ locations.
>/dev/null
az keyvault secret set \
--vault-name "$KEYVAULT_PREFIX-service" \
--name "encryption-key" \
--value $(openssl rand -base64 32) \
>/dev/null
--name encryption-key \
--value "$(openssl rand -base64 32)" \
>/dev/null
```
1. Create nameserver records in the parent DNS zone:

Просмотреть файл

@ -15,6 +15,7 @@ import (
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/metrics/noop"
"github.com/Azure/ARO-RP/pkg/util/encryption"
utillog "github.com/Azure/ARO-RP/pkg/util/log"
)
@ -28,7 +29,12 @@ func run(ctx context.Context, log *logrus.Entry) error {
return err
}
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, "", true)
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
if err != nil {
return err
}
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, cipher, "")
if err != nil {
return err
}

Просмотреть файл

@ -21,10 +21,10 @@ type OpenShiftCluster struct {
Properties Properties `json:"properties,omitempty"`
}
// SecureBytes represents encrypted []byte
// SecureBytes represents an encrypted []byte
type SecureBytes []byte
// SecureString represents encrypted string
// SecureString represents an encrypted string
type SecureString string
// Properties represents an OpenShift cluster's properties

Просмотреть файл

@ -40,16 +40,13 @@ func (m *Manager) Create(ctx context.Context) error {
resourceGroup := m.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID[strings.LastIndexByte(m.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID, '/')+1:]
m.doc, err = m.db.PatchWithLease(ctx, m.doc.Key, func(doc *api.OpenShiftClusterDocument) error {
var err error
if doc.OpenShiftCluster.Properties.SSHKey == nil {
sshKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
doc.OpenShiftCluster.Properties.SSHKey = x509.MarshalPKCS1PrivateKey(sshKey)
if err != nil {
return err
}
}
if doc.OpenShiftCluster.Properties.StorageSuffix == "" {

Просмотреть файл

@ -33,15 +33,7 @@ type Database struct {
}
// NewDatabase returns a new Database
func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, uuid string, decryptDatabase bool) (db *Database, err error) {
var cipher encryption.Cipher
if decryptDatabase {
cipher, err = encryption.NewCipher(ctx, env)
if err != nil {
return nil, err
}
}
func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, cipher encryption.Cipher, uuid string) (db *Database, err error) {
databaseAccount, masterKey := env.CosmosDB()
h := newJSONHandle(cipher)
@ -99,7 +91,7 @@ func newJSONHandle(cipher encryption.Cipher) *codec.JsonHandle {
},
}
h.SetInterfaceExt(reflect.TypeOf(api.SecureBytes{}), 1, SecureBytesExt{Cipher: cipher})
h.SetInterfaceExt(reflect.TypeOf((*api.SecureString)(nil)), 1, SecureStringExt{Cipher: cipher})
h.SetInterfaceExt(reflect.TypeOf(api.SecureBytes{}), 1, secureBytesExt{cipher: cipher})
h.SetInterfaceExt(reflect.TypeOf((*api.SecureString)(nil)), 1, secureStringExt{cipher: cipher})
return h
}

Просмотреть файл

@ -4,6 +4,8 @@ package database
// Licensed under the Apache License 2.0.
import (
"encoding/base64"
"github.com/ugorji/go/codec"
"github.com/Azure/ARO-RP/pkg/api"
@ -11,62 +13,60 @@ import (
encrypt "github.com/Azure/ARO-RP/pkg/util/encryption"
)
var _ codec.InterfaceExt = (*SecureBytesExt)(nil)
var _ codec.InterfaceExt = (*secureBytesExt)(nil)
type SecureBytesExt struct {
Cipher encryption.Cipher
type secureBytesExt struct {
cipher encryption.Cipher
}
func (s SecureBytesExt) ConvertExt(v interface{}) interface{} {
data := v.(api.SecureBytes)
if s.Cipher != nil {
encrypted, err := s.Cipher.Encrypt(string(data))
if err != nil {
panic(err)
}
return encrypted
func (s secureBytesExt) ConvertExt(v interface{}) interface{} {
encrypted, err := s.cipher.Encrypt(v.(api.SecureBytes))
if err != nil {
panic(err)
}
return string(data)
}
func (s SecureBytesExt) UpdateExt(dest interface{}, v interface{}) {
output := dest.(*api.SecureBytes)
if s.Cipher != nil {
decrypted, err := s.Cipher.Decrypt(v.(string))
if err != nil {
panic(err)
}
*output = api.SecureBytes(decrypted)
return
}
*output = api.SecureBytes(v.(string))
return base64.StdEncoding.EncodeToString([]byte(encrypted))
}
var _ codec.InterfaceExt = (*SecureStringExt)(nil)
func (s secureBytesExt) UpdateExt(dest interface{}, v interface{}) {
b, err := base64.StdEncoding.DecodeString(v.(string))
if err != nil {
panic(err)
}
type SecureStringExt struct {
Cipher encrypt.Cipher
b, err = s.cipher.Decrypt(b)
if err != nil {
panic(err)
}
*dest.(*api.SecureBytes) = b
}
func (s SecureStringExt) ConvertExt(v interface{}) interface{} {
data := v.(api.SecureString)
if s.Cipher != nil {
encrypted, err := s.Cipher.Encrypt(string(data))
if err != nil {
panic(err)
}
return encrypted
}
return string(data)
var _ codec.InterfaceExt = (*secureStringExt)(nil)
type secureStringExt struct {
cipher encrypt.Cipher
}
func (s SecureStringExt) UpdateExt(dest interface{}, v interface{}) {
output := dest.(*api.SecureString)
if s.Cipher != nil {
decrypted, err := s.Cipher.Decrypt(v.(string))
if err != nil {
panic(err)
}
*output = api.SecureString(decrypted)
return
func (s secureStringExt) ConvertExt(v interface{}) interface{} {
encrypted, err := s.cipher.Encrypt([]byte(v.(api.SecureString)))
if err != nil {
panic(err)
}
*output = api.SecureString(v.(string))
return base64.StdEncoding.EncodeToString([]byte(encrypted))
}
func (s secureStringExt) UpdateExt(dest interface{}, v interface{}) {
b, err := base64.StdEncoding.DecodeString(v.(string))
if err != nil {
panic(err)
}
b, err = s.cipher.Decrypt(b)
if err != nil {
panic(err)
}
*dest.(*api.SecureString) = api.SecureString(b)
}

Просмотреть файл

@ -4,11 +4,7 @@ package database
// Licensed under the Apache License 2.0.
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io/ioutil"
"reflect"
"testing"
@ -20,164 +16,63 @@ import (
)
type testStruct struct {
SecureBytes api.SecureBytes
SecureString api.SecureString
Bytes []byte
Str string
SecureBytes api.SecureBytes `json:"secureBytes,omitempty"`
SecureString api.SecureString `json:"secureString,omitempty"`
Bytes []byte `json:"bytes,omitempty"`
String string `json:"string,omitempty"`
}
func TestExtensions(t *testing.T) {
encryption.RandRead = func(b []byte) (n int, err error) {
b = make([]byte, len(b))
return len(b), nil
ctx := context.Background()
env := &env.Test{
TestSecret: []byte("\x63\xb5\x59\xf0\x43\x34\x79\x49\x68\x46\xab\x8b\xce\xdb\xc1\x2d\x7a\x0b\x14\x86\x7e\x1a\xb2\xd7\x3a\x92\x4e\x98\x6c\x5e\xcb\xe1"),
}
key := make([]byte, 32)
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
base64.StdEncoding.Encode(keybase64, key)
env := &env.Test{TestSecret: keybase64}
cipher, err := encryption.NewCipher(context.Background(), env)
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
if err != nil {
t.Error(err)
t.Fatal(err)
}
for _, tt := range []struct {
name string
input func(input *testStruct)
inputCodec *codec.JsonHandle
output func(input *testStruct)
outputCodec *codec.JsonHandle
}{
{
name: "noop",
},
{
name: "SecureByte - encrypt - decrypt",
input: func(input *testStruct) {
input.SecureBytes = []byte("test")
},
output: func(output *testStruct) {
output.SecureBytes = []byte("test")
},
},
{
name: "SecureByte - encrypt - raw",
input: func(input *testStruct) {
input.SecureBytes = []byte("test")
},
inputCodec: newJSONHandle(cipher),
output: func(output *testStruct) {
output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
// empty string encoded
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAjzuUWlGQbchgDen4li0A5g=="
},
outputCodec: newJSONHandle(nil),
},
{
name: "SecureByte - raw - decrypt",
input: func(input *testStruct) {
input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
},
inputCodec: newJSONHandle(nil),
output: func(output *testStruct) {
output.SecureBytes = []byte("test")
},
outputCodec: newJSONHandle(cipher),
},
{
name: "SecureByte - raw - raw",
input: func(input *testStruct) {
input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
},
inputCodec: newJSONHandle(nil),
output: func(output *testStruct) {
output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
},
outputCodec: newJSONHandle(nil),
},
{
name: "SecureString - encrypt - decrypt",
input: func(input *testStruct) {
input.SecureString = "test"
},
output: func(output *testStruct) {
output.SecureString = "test"
},
},
{
name: "SecureString - encrypt - raw",
input: func(input *testStruct) {
input.SecureString = "test"
},
inputCodec: newJSONHandle(cipher),
output: func(output *testStruct) {
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
},
outputCodec: newJSONHandle(nil),
},
{
name: "SecureString - raw - decrypt",
input: func(input *testStruct) {
input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
},
inputCodec: newJSONHandle(nil),
output: func(output *testStruct) {
output.SecureString = "test"
},
outputCodec: newJSONHandle(cipher),
},
{
name: "SecureString - raw - raw",
input: func(input *testStruct) {
input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
},
inputCodec: newJSONHandle(nil),
output: func(output *testStruct) {
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
},
outputCodec: newJSONHandle(nil),
},
} {
t.Run(tt.name, func(t *testing.T) {
if tt.inputCodec == nil {
tt.inputCodec = newJSONHandle(cipher)
}
if tt.outputCodec == nil {
tt.outputCodec = newJSONHandle(cipher)
}
h := newJSONHandle(cipher)
input := &testStruct{}
if tt.input != nil {
tt.input(input)
}
encrypted := []byte(`{
"bytes": "Ynl0ZXM=",
"secureBytes": "6w8Uah0zX40LRfkYHuU9UvLuGrBcHb7l8I2M6qTcmtclOGJNONfHqAuaJWifZj7dd8fI",
"secureString": "YT+ZNR23JBvILGw1WBn6/NhtCj9LM14EXp5VR6XloD7CN1MfmvW5FEn9duRSPYbdr98tLQ==",
"string": "string"
}`)
output := &testStruct{}
if tt.output != nil {
tt.output(output)
}
decrypted := &testStruct{
SecureBytes: []byte("securebytes"),
SecureString: "securestring",
Bytes: []byte("bytes"),
String: "string",
}
buf := &bytes.Buffer{}
err = codec.NewEncoder(buf, tt.inputCodec).Encode(input)
if err != nil {
t.Error(err)
}
data, err := ioutil.ReadAll(buf)
if err != nil {
t.Error(err)
}
var ts testStruct
err = codec.NewDecoderBytes(encrypted, h).Decode(&ts)
if err != nil {
t.Fatal(err)
}
result := &testStruct{}
err = codec.NewDecoder(bytes.NewReader(data), tt.outputCodec).Decode(result)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(&ts, decrypted) {
t.Errorf("%#v", &ts)
}
if !reflect.DeepEqual(output, result) {
output, _ := json.Marshal(output)
result, _ := json.Marshal(result)
t.Errorf("\n wants: %s,'\ngot: %s", string(output), string(result))
}
})
var enc []byte
err = codec.NewEncoderBytes(&enc, h).Encode(ts)
if err != nil {
t.Fatal(err)
}
ts = testStruct{}
err = codec.NewDecoderBytes(encrypted, h).Decode(&ts)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(&ts, decrypted) {
t.Errorf("%#v", &ts)
}
}

4
pkg/env/env.go поставляемый
Просмотреть файл

@ -28,10 +28,10 @@ type Interface interface {
DialContext(context.Context, string, string) (net.Conn, error)
Domain() string
FPAuthorizer(string, string) (autorest.Authorizer, error)
ManagedDomain(string) (string, error)
GetSecret(context.Context, string) ([]byte, error)
GetCertificateSecret(context.Context, string) (*rsa.PrivateKey, []*x509.Certificate, error)
GetSecret(context.Context, string) ([]byte, error)
Listen() (net.Listener, error)
ManagedDomain(string) (string, error)
VnetName() string
Zones(vmSize string) ([]string, error)
}

19
pkg/env/prod.go поставляемый
Просмотреть файл

@ -7,6 +7,7 @@ import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"net"
@ -265,16 +266,13 @@ func (p *prod) FPAuthorizer(tenantID, resource string) (autorest.Authorizer, err
return autorest.NewBearerAuthorizer(sp), nil
}
func (p *prod) GetSecret(ctx context.Context, secretName string) ([]byte, error) {
bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "")
return []byte(*bundle.Value), err
}
func (p *prod) GetCertificateSecret(ctx context.Context, secretName string) (key *rsa.PrivateKey, certs []*x509.Certificate, err error) {
b, err := p.GetSecret(ctx, secretName)
bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "")
if err != nil {
return nil, nil, err
}
b := []byte(*bundle.Value)
for {
var block *pem.Block
block, b = pem.Decode(b)
@ -314,6 +312,15 @@ func (p *prod) GetCertificateSecret(ctx context.Context, secretName string) (key
return key, certs, nil
}
func (p *prod) GetSecret(ctx context.Context, secretName string) ([]byte, error) {
bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "")
if err != nil {
return nil, err
}
return base64.StdEncoding.DecodeString(*bundle.Value)
}
func (p *prod) Listen() (net.Listener, error) {
return net.Listen("tcp", ":8443")
}

Просмотреть файл

@ -81,7 +81,7 @@ func NewInstaller(ctx context.Context, log *logrus.Entry, env env.Interface, db
return nil, err
}
cipher, err := encryption.NewCipher(ctx, env)
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
if err != nil {
return nil, err
}
@ -211,13 +211,13 @@ func (i *Installer) loadGraph(ctx context.Context) (graph, error) {
return nil, err
}
output, err := i.cipher.Decrypt(string(encrypted))
output, err := i.cipher.Decrypt(encrypted)
if err != nil {
return nil, err
}
var g graph
err = json.Unmarshal([]byte(output), &g)
err = json.Unmarshal(output, &g)
if err != nil {
return nil, err
}
@ -246,7 +246,7 @@ func (i *Installer) saveGraph(ctx context.Context, g graph) error {
return err
}
output, err := i.cipher.Encrypt(string(b))
output, err := i.cipher.Encrypt(b)
if err != nil {
return err
}

Просмотреть файл

@ -1,112 +0,0 @@
package encryption
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"strings"
"golang.org/x/crypto/chacha20poly1305"
"github.com/Azure/ARO-RP/pkg/env"
)
// encryptionSecretName must match key name in the service keyvault
const (
encryptionSecretName = "encryption-key"
Prefix = "ENC*"
)
var (
_ Cipher = (*aeadCipher)(nil)
RandRead = rand.Read
)
type Cipher interface {
Decrypt(string) (string, error)
Encrypt(string) (string, error)
}
type aeadCipher struct {
aead cipher.AEAD
}
func NewCipher(ctx context.Context, env env.Interface) (Cipher, error) {
keybase64, err := env.GetSecret(ctx, encryptionSecretName)
if err != nil {
return nil, err
}
key := make([]byte, base64.StdEncoding.DecodedLen(len(keybase64)))
n, err := base64.StdEncoding.Decode(key, keybase64)
if err != nil {
return nil, err
}
if n < 32 {
return nil, fmt.Errorf("chacha20poly1305: bad key length")
}
key = key[:32]
aead, err := chacha20poly1305.NewX(key)
if err != nil {
return nil, err
}
return &aeadCipher{
aead: aead,
}, nil
}
// Decrypt decrypts input
func (c *aeadCipher) Decrypt(input string) (string, error) {
if !strings.HasPrefix(input, Prefix) {
return input, nil
}
input = input[len(Prefix):]
r := make([]byte, base64.StdEncoding.DecodedLen(len(input)))
r, err := base64.StdEncoding.DecodeString(input)
if err != nil {
return "", err
}
if len(r) >= 24 {
nonce := r[0:24]
data := r[24:]
output, err := c.aead.Open(nil, nonce, data, nil)
if err != nil {
return "", err
}
return string(output), nil
}
return "", fmt.Errorf("error while decrypting message")
}
// Encrypt encrypts input using 24 byte nonce
func (c *aeadCipher) Encrypt(input string) (string, error) {
nonce := make([]byte, chacha20poly1305.NonceSizeX)
_, err := RandRead(nonce)
if err != nil {
return "", err
}
encrypted := c.aead.Seal(nil, nonce, []byte(input), nil)
var encryptedFinal []byte
encryptedFinal = append(encryptedFinal, nonce...)
encryptedFinal = append(encryptedFinal, encrypted...)
encryptedBase64 := make([]byte, base64.StdEncoding.EncodedLen(len(encryptedFinal)))
base64.StdEncoding.Encode(encryptedBase64, encryptedFinal)
// return prefix+base64(nonce+encryptedFinal)
var result []byte
result = append(result, Prefix...)
result = append(result, encryptedBase64...)
return string(result), nil
}

Просмотреть файл

@ -1,118 +0,0 @@
package encryption
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"encoding/base64"
"strings"
"testing"
"github.com/Azure/ARO-RP/pkg/env"
)
func TestEncryptRoundTrip(t *testing.T) {
key := make([]byte, 32)
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
base64.StdEncoding.Encode(keybase64, key)
env := &env.Test{TestSecret: keybase64}
cipher, err := NewCipher(context.Background(), env)
if err != nil {
t.Error(err)
}
test := "secert"
encrypted, err := cipher.Encrypt(test)
if err != nil {
t.Error(err)
}
decrypted, err := cipher.Decrypt(encrypted)
if err != nil {
t.Error(err)
}
if r := strings.Compare(test, decrypted); r != 0 {
t.Error("encryption roundTrip failed")
}
}
func TestEncrypt(t *testing.T) {
RandRead = func(b []byte) (n int, err error) {
b = make([]byte, len(b))
return len(b), nil
}
for _, tt := range []struct {
name string
input string
expected string
wantErr string
env func(e *env.Test)
}{
{
name: "ok encrypt",
input: "test",
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
wantErr: "",
env: func(input *env.Test) {
key := make([]byte, 32)
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
base64.StdEncoding.Encode(keybase64, key)
input.TestSecret = keybase64
},
},
{
name: "base64 key error",
wantErr: "illegal base64 data at input byte 8",
env: func(input *env.Test) {
input.TestSecret = []byte("badsecret")
},
},
{
name: "key too short",
input: "test",
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
wantErr: "chacha20poly1305: bad key length",
env: func(input *env.Test) {
keybase64 := base64.StdEncoding.EncodeToString(make([]byte, 15))
input.TestSecret = []byte(keybase64)
},
},
{
name: "key too long", // due to base64 approximations library truncates the secret to right lenhgt
input: "test",
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
env: func(input *env.Test) {
keybase64 := base64.StdEncoding.EncodeToString((make([]byte, 40)))
input.TestSecret = []byte(keybase64)
},
},
} {
t.Run(tt.name, func(t *testing.T) {
e := &env.Test{}
if tt.env != nil {
tt.env(e)
}
cipher, err := NewCipher(context.Background(), e)
if err != nil {
if err.Error() != tt.wantErr {
t.Errorf("\n wants: %s,'\ngot: %s", tt.wantErr, err.Error())
t.FailNow()
}
t.SkipNow()
}
result, err := cipher.Encrypt(tt.input)
if err != nil {
t.Error(err)
}
if tt.expected != result {
t.Errorf("\n wants: %s,'\ngot: %s", tt.expected, result)
}
})
}
}

Просмотреть файл

@ -0,0 +1,74 @@
package encryption
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"context"
"crypto/cipher"
"crypto/rand"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
"github.com/Azure/ARO-RP/pkg/env"
)
const (
encryptionSecretName = "encryption-key" // must match key name in the service keyvault
)
var _ Cipher = (*aeadCipher)(nil)
type Cipher interface {
Decrypt([]byte) ([]byte, error)
Encrypt([]byte) ([]byte, error)
}
type aeadCipher struct {
aead cipher.AEAD
randRead func([]byte) (int, error)
}
func NewXChaCha20Poly1305(ctx context.Context, env env.Interface) (Cipher, error) {
key, err := env.GetSecret(ctx, encryptionSecretName)
if err != nil {
return nil, err
}
aead, err := chacha20poly1305.NewX(key)
if err != nil {
return nil, err
}
return &aeadCipher{
aead: aead,
randRead: rand.Read,
}, nil
}
func (c *aeadCipher) Decrypt(input []byte) ([]byte, error) {
if len(input) < chacha20poly1305.NonceSizeX {
return nil, fmt.Errorf("encrypted value too short")
}
nonce := input[:chacha20poly1305.NonceSizeX]
data := input[chacha20poly1305.NonceSizeX:]
return c.aead.Open(nil, nonce, data, nil)
}
func (c *aeadCipher) Encrypt(input []byte) ([]byte, error) {
nonce := make([]byte, chacha20poly1305.NonceSizeX)
n, err := c.randRead(nonce)
if err != nil {
return nil, err
}
if n != chacha20poly1305.NonceSizeX {
return nil, fmt.Errorf("rand.Read returned %d bytes, expected %d", n, chacha20poly1305.NonceSizeX)
}
return append(nonce, c.aead.Seal(nil, nonce, input, nil)...), nil
}

Просмотреть файл

@ -0,0 +1,153 @@
package encryption
// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"testing"
"github.com/Azure/ARO-RP/pkg/env"
)
func TestNewXChaCha20Poly1305(t *testing.T) {
for _, tt := range []struct {
name string
key []byte
wantErr string
}{
{
name: "valid",
key: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
},
{
name: "key too short",
key: []byte("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"),
wantErr: "chacha20poly1305: bad key length",
},
{
name: "key too long",
key: []byte("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"),
wantErr: "chacha20poly1305: bad key length",
},
} {
t.Run(tt.name, func(t *testing.T) {
env := &env.Test{
TestSecret: tt.key,
}
_, err := NewXChaCha20Poly1305(context.Background(), env)
if err != nil && err.Error() != tt.wantErr ||
err == nil && tt.wantErr != "" {
t.Fatal(err)
}
})
}
}
func TestXChaCha20Poly1305Decrypt(t *testing.T) {
for _, tt := range []struct {
name string
key []byte
input []byte
wantDecrypted []byte
wantErr string
}{
{
name: "valid",
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
input: []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"),
wantDecrypted: []byte("test"),
},
{
name: "invalid - encrypted value tampered with",
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
input: []byte("\xda\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"),
wantErr: "chacha20poly1305: message authentication failed",
},
{
name: "invalid - too short",
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
input: []byte("XXXXXXXXXXXXXXXXXXXXXXX"),
wantErr: "encrypted value too short",
},
} {
t.Run(tt.name, func(t *testing.T) {
env := &env.Test{
TestSecret: tt.key,
}
cipher, err := NewXChaCha20Poly1305(context.Background(), env)
if err != nil {
t.Fatal(err)
}
decrypted, err := cipher.Decrypt(tt.input)
if err != nil && err.Error() != tt.wantErr ||
err == nil && tt.wantErr != "" {
t.Fatal(err)
}
if !bytes.Equal(tt.wantDecrypted, decrypted) {
t.Error(string(decrypted))
}
})
}
}
func TestXChaCha20Poly1305Encrypt(t *testing.T) {
for _, tt := range []struct {
name string
key []byte
randRead func(b []byte) (int, error)
input []byte
wantEncrypted []byte
wantErr string
}{
{
name: "valid",
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
randRead: func(b []byte) (int, error) {
nonce := []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2")
copy(b, nonce)
return len(nonce), nil
},
input: []byte("test"),
wantEncrypted: []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"),
},
{
name: "rand.Read error",
key: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
randRead: func(b []byte) (int, error) {
return 0, fmt.Errorf("random error")
},
wantErr: "random error",
},
} {
t.Run(tt.name, func(t *testing.T) {
env := &env.Test{
TestSecret: tt.key,
}
cipher, err := NewXChaCha20Poly1305(context.Background(), env)
if err != nil {
t.Fatal(err)
}
cipher.(*aeadCipher).randRead = tt.randRead
encrypted, err := cipher.Encrypt(tt.input)
if err != nil && err.Error() != tt.wantErr ||
err == nil && tt.wantErr != "" {
t.Fatal(err)
}
if !bytes.Equal(tt.wantEncrypted, encrypted) {
t.Error(hex.EncodeToString(encrypted))
}
})
}
}