зеркало из https://github.com/Azure/ARO-RP.git
encryption fixups:
* pass cipher into database.NewDatabase, rather than bool * unexport as much as possible * remove backwards-compatibility and "read without key" options for now, adds too much complexity
This commit is contained in:
Родитель
468621f73c
Коммит
edd02eacbe
|
@ -62,8 +62,8 @@ required = [
|
|||
version = "1.1.7"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/crypto"
|
||||
branch = "master"
|
||||
|
||||
[[override]]
|
||||
name = "k8s.io/api"
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/metrics/statsd"
|
||||
pkgmonitor "github.com/Azure/ARO-RP/pkg/monitor"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
)
|
||||
|
||||
func monitor(ctx context.Context, log *logrus.Entry) error {
|
||||
|
@ -30,7 +31,12 @@ func monitor(ctx context.Context, log *logrus.Entry) error {
|
|||
}
|
||||
defer m.Close()
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true)
|
||||
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, cipher, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/frontend"
|
||||
"github.com/Azure/ARO-RP/pkg/metrics/statsd"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
)
|
||||
|
||||
func rp(ctx context.Context, log *logrus.Entry) error {
|
||||
|
@ -36,7 +37,12 @@ func rp(ctx context.Context, log *logrus.Entry) error {
|
|||
}
|
||||
defer m.Close()
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true)
|
||||
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, cipher, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -310,9 +310,9 @@ locations.
|
|||
>/dev/null
|
||||
az keyvault secret set \
|
||||
--vault-name "$KEYVAULT_PREFIX-service" \
|
||||
--name "encryption-key" \
|
||||
--value $(openssl rand -base64 32) \
|
||||
>/dev/null
|
||||
--name encryption-key \
|
||||
--value "$(openssl rand -base64 32)" \
|
||||
>/dev/null
|
||||
```
|
||||
|
||||
1. Create nameserver records in the parent DNS zone:
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/metrics/noop"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
utillog "github.com/Azure/ARO-RP/pkg/util/log"
|
||||
)
|
||||
|
||||
|
@ -28,7 +29,12 @@ func run(ctx context.Context, log *logrus.Entry) error {
|
|||
return err
|
||||
}
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, "", true)
|
||||
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, cipher, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -21,10 +21,10 @@ type OpenShiftCluster struct {
|
|||
Properties Properties `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// SecureBytes represents encrypted []byte
|
||||
// SecureBytes represents an encrypted []byte
|
||||
type SecureBytes []byte
|
||||
|
||||
// SecureString represents encrypted string
|
||||
// SecureString represents an encrypted string
|
||||
type SecureString string
|
||||
|
||||
// Properties represents an OpenShift cluster's properties
|
||||
|
|
|
@ -40,16 +40,13 @@ func (m *Manager) Create(ctx context.Context) error {
|
|||
resourceGroup := m.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID[strings.LastIndexByte(m.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID, '/')+1:]
|
||||
|
||||
m.doc, err = m.db.PatchWithLease(ctx, m.doc.Key, func(doc *api.OpenShiftClusterDocument) error {
|
||||
var err error
|
||||
if doc.OpenShiftCluster.Properties.SSHKey == nil {
|
||||
sshKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
doc.OpenShiftCluster.Properties.SSHKey = x509.MarshalPKCS1PrivateKey(sshKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if doc.OpenShiftCluster.Properties.StorageSuffix == "" {
|
||||
|
|
|
@ -33,15 +33,7 @@ type Database struct {
|
|||
}
|
||||
|
||||
// NewDatabase returns a new Database
|
||||
func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, uuid string, decryptDatabase bool) (db *Database, err error) {
|
||||
var cipher encryption.Cipher
|
||||
if decryptDatabase {
|
||||
cipher, err = encryption.NewCipher(ctx, env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, cipher encryption.Cipher, uuid string) (db *Database, err error) {
|
||||
databaseAccount, masterKey := env.CosmosDB()
|
||||
|
||||
h := newJSONHandle(cipher)
|
||||
|
@ -99,7 +91,7 @@ func newJSONHandle(cipher encryption.Cipher) *codec.JsonHandle {
|
|||
},
|
||||
}
|
||||
|
||||
h.SetInterfaceExt(reflect.TypeOf(api.SecureBytes{}), 1, SecureBytesExt{Cipher: cipher})
|
||||
h.SetInterfaceExt(reflect.TypeOf((*api.SecureString)(nil)), 1, SecureStringExt{Cipher: cipher})
|
||||
h.SetInterfaceExt(reflect.TypeOf(api.SecureBytes{}), 1, secureBytesExt{cipher: cipher})
|
||||
h.SetInterfaceExt(reflect.TypeOf((*api.SecureString)(nil)), 1, secureStringExt{cipher: cipher})
|
||||
return h
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ package database
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/ugorji/go/codec"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
|
@ -11,62 +13,60 @@ import (
|
|||
encrypt "github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
)
|
||||
|
||||
var _ codec.InterfaceExt = (*SecureBytesExt)(nil)
|
||||
var _ codec.InterfaceExt = (*secureBytesExt)(nil)
|
||||
|
||||
type SecureBytesExt struct {
|
||||
Cipher encryption.Cipher
|
||||
type secureBytesExt struct {
|
||||
cipher encryption.Cipher
|
||||
}
|
||||
|
||||
func (s SecureBytesExt) ConvertExt(v interface{}) interface{} {
|
||||
data := v.(api.SecureBytes)
|
||||
if s.Cipher != nil {
|
||||
encrypted, err := s.Cipher.Encrypt(string(data))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return encrypted
|
||||
func (s secureBytesExt) ConvertExt(v interface{}) interface{} {
|
||||
encrypted, err := s.cipher.Encrypt(v.(api.SecureBytes))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
func (s SecureBytesExt) UpdateExt(dest interface{}, v interface{}) {
|
||||
output := dest.(*api.SecureBytes)
|
||||
if s.Cipher != nil {
|
||||
decrypted, err := s.Cipher.Decrypt(v.(string))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
*output = api.SecureBytes(decrypted)
|
||||
return
|
||||
}
|
||||
*output = api.SecureBytes(v.(string))
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(encrypted))
|
||||
}
|
||||
|
||||
var _ codec.InterfaceExt = (*SecureStringExt)(nil)
|
||||
func (s secureBytesExt) UpdateExt(dest interface{}, v interface{}) {
|
||||
b, err := base64.StdEncoding.DecodeString(v.(string))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
type SecureStringExt struct {
|
||||
Cipher encrypt.Cipher
|
||||
b, err = s.cipher.Decrypt(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
*dest.(*api.SecureBytes) = b
|
||||
}
|
||||
|
||||
func (s SecureStringExt) ConvertExt(v interface{}) interface{} {
|
||||
data := v.(api.SecureString)
|
||||
if s.Cipher != nil {
|
||||
encrypted, err := s.Cipher.Encrypt(string(data))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
return string(data)
|
||||
var _ codec.InterfaceExt = (*secureStringExt)(nil)
|
||||
|
||||
type secureStringExt struct {
|
||||
cipher encrypt.Cipher
|
||||
}
|
||||
func (s SecureStringExt) UpdateExt(dest interface{}, v interface{}) {
|
||||
output := dest.(*api.SecureString)
|
||||
if s.Cipher != nil {
|
||||
decrypted, err := s.Cipher.Decrypt(v.(string))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
*output = api.SecureString(decrypted)
|
||||
return
|
||||
|
||||
func (s secureStringExt) ConvertExt(v interface{}) interface{} {
|
||||
encrypted, err := s.cipher.Encrypt([]byte(v.(api.SecureString)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
*output = api.SecureString(v.(string))
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(encrypted))
|
||||
}
|
||||
|
||||
func (s secureStringExt) UpdateExt(dest interface{}, v interface{}) {
|
||||
b, err := base64.StdEncoding.DecodeString(v.(string))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b, err = s.cipher.Decrypt(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
*dest.(*api.SecureString) = api.SecureString(b)
|
||||
}
|
||||
|
|
|
@ -4,11 +4,7 @@ package database
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -20,164 +16,63 @@ import (
|
|||
)
|
||||
|
||||
type testStruct struct {
|
||||
SecureBytes api.SecureBytes
|
||||
SecureString api.SecureString
|
||||
Bytes []byte
|
||||
Str string
|
||||
SecureBytes api.SecureBytes `json:"secureBytes,omitempty"`
|
||||
SecureString api.SecureString `json:"secureString,omitempty"`
|
||||
Bytes []byte `json:"bytes,omitempty"`
|
||||
String string `json:"string,omitempty"`
|
||||
}
|
||||
|
||||
func TestExtensions(t *testing.T) {
|
||||
encryption.RandRead = func(b []byte) (n int, err error) {
|
||||
b = make([]byte, len(b))
|
||||
return len(b), nil
|
||||
ctx := context.Background()
|
||||
|
||||
env := &env.Test{
|
||||
TestSecret: []byte("\x63\xb5\x59\xf0\x43\x34\x79\x49\x68\x46\xab\x8b\xce\xdb\xc1\x2d\x7a\x0b\x14\x86\x7e\x1a\xb2\xd7\x3a\x92\x4e\x98\x6c\x5e\xcb\xe1"),
|
||||
}
|
||||
|
||||
key := make([]byte, 32)
|
||||
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
|
||||
base64.StdEncoding.Encode(keybase64, key)
|
||||
env := &env.Test{TestSecret: keybase64}
|
||||
|
||||
cipher, err := encryption.NewCipher(context.Background(), env)
|
||||
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
input func(input *testStruct)
|
||||
inputCodec *codec.JsonHandle
|
||||
output func(input *testStruct)
|
||||
outputCodec *codec.JsonHandle
|
||||
}{
|
||||
{
|
||||
name: "noop",
|
||||
},
|
||||
{
|
||||
name: "SecureByte - encrypt - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("test")
|
||||
},
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("test")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecureByte - encrypt - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("test")
|
||||
},
|
||||
inputCodec: newJSONHandle(cipher),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
// empty string encoded
|
||||
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAjzuUWlGQbchgDen4li0A5g=="
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
{
|
||||
name: "SecureByte - raw - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("test")
|
||||
},
|
||||
outputCodec: newJSONHandle(cipher),
|
||||
},
|
||||
{
|
||||
name: "SecureByte - raw - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
{
|
||||
name: "SecureString - encrypt - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "test"
|
||||
},
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "test"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecureString - encrypt - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "test"
|
||||
},
|
||||
inputCodec: newJSONHandle(cipher),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
{
|
||||
name: "SecureString - raw - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "test"
|
||||
},
|
||||
outputCodec: newJSONHandle(cipher),
|
||||
},
|
||||
{
|
||||
name: "SecureString - raw - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.inputCodec == nil {
|
||||
tt.inputCodec = newJSONHandle(cipher)
|
||||
}
|
||||
if tt.outputCodec == nil {
|
||||
tt.outputCodec = newJSONHandle(cipher)
|
||||
}
|
||||
h := newJSONHandle(cipher)
|
||||
|
||||
input := &testStruct{}
|
||||
if tt.input != nil {
|
||||
tt.input(input)
|
||||
}
|
||||
encrypted := []byte(`{
|
||||
"bytes": "Ynl0ZXM=",
|
||||
"secureBytes": "6w8Uah0zX40LRfkYHuU9UvLuGrBcHb7l8I2M6qTcmtclOGJNONfHqAuaJWifZj7dd8fI",
|
||||
"secureString": "YT+ZNR23JBvILGw1WBn6/NhtCj9LM14EXp5VR6XloD7CN1MfmvW5FEn9duRSPYbdr98tLQ==",
|
||||
"string": "string"
|
||||
}`)
|
||||
|
||||
output := &testStruct{}
|
||||
if tt.output != nil {
|
||||
tt.output(output)
|
||||
}
|
||||
decrypted := &testStruct{
|
||||
SecureBytes: []byte("securebytes"),
|
||||
SecureString: "securestring",
|
||||
Bytes: []byte("bytes"),
|
||||
String: "string",
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
err = codec.NewEncoder(buf, tt.inputCodec).Encode(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
data, err := ioutil.ReadAll(buf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
var ts testStruct
|
||||
err = codec.NewDecoderBytes(encrypted, h).Decode(&ts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result := &testStruct{}
|
||||
err = codec.NewDecoder(bytes.NewReader(data), tt.outputCodec).Decode(result)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(&ts, decrypted) {
|
||||
t.Errorf("%#v", &ts)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(output, result) {
|
||||
output, _ := json.Marshal(output)
|
||||
result, _ := json.Marshal(result)
|
||||
t.Errorf("\n wants: %s,'\ngot: %s", string(output), string(result))
|
||||
}
|
||||
})
|
||||
var enc []byte
|
||||
err = codec.NewEncoderBytes(&enc, h).Encode(ts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ts = testStruct{}
|
||||
err = codec.NewDecoderBytes(encrypted, h).Decode(&ts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(&ts, decrypted) {
|
||||
t.Errorf("%#v", &ts)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,10 +28,10 @@ type Interface interface {
|
|||
DialContext(context.Context, string, string) (net.Conn, error)
|
||||
Domain() string
|
||||
FPAuthorizer(string, string) (autorest.Authorizer, error)
|
||||
ManagedDomain(string) (string, error)
|
||||
GetSecret(context.Context, string) ([]byte, error)
|
||||
GetCertificateSecret(context.Context, string) (*rsa.PrivateKey, []*x509.Certificate, error)
|
||||
GetSecret(context.Context, string) ([]byte, error)
|
||||
Listen() (net.Listener, error)
|
||||
ManagedDomain(string) (string, error)
|
||||
VnetName() string
|
||||
Zones(vmSize string) ([]string, error)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -265,16 +266,13 @@ func (p *prod) FPAuthorizer(tenantID, resource string) (autorest.Authorizer, err
|
|||
return autorest.NewBearerAuthorizer(sp), nil
|
||||
}
|
||||
|
||||
func (p *prod) GetSecret(ctx context.Context, secretName string) ([]byte, error) {
|
||||
bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "")
|
||||
return []byte(*bundle.Value), err
|
||||
}
|
||||
|
||||
func (p *prod) GetCertificateSecret(ctx context.Context, secretName string) (key *rsa.PrivateKey, certs []*x509.Certificate, err error) {
|
||||
b, err := p.GetSecret(ctx, secretName)
|
||||
bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
b := []byte(*bundle.Value)
|
||||
for {
|
||||
var block *pem.Block
|
||||
block, b = pem.Decode(b)
|
||||
|
@ -314,6 +312,15 @@ func (p *prod) GetCertificateSecret(ctx context.Context, secretName string) (key
|
|||
return key, certs, nil
|
||||
}
|
||||
|
||||
func (p *prod) GetSecret(ctx context.Context, secretName string) ([]byte, error) {
|
||||
bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.DecodeString(*bundle.Value)
|
||||
}
|
||||
|
||||
func (p *prod) Listen() (net.Listener, error) {
|
||||
return net.Listen("tcp", ":8443")
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ func NewInstaller(ctx context.Context, log *logrus.Entry, env env.Interface, db
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cipher, err := encryption.NewCipher(ctx, env)
|
||||
cipher, err := encryption.NewXChaCha20Poly1305(ctx, env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -211,13 +211,13 @@ func (i *Installer) loadGraph(ctx context.Context) (graph, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
output, err := i.cipher.Decrypt(string(encrypted))
|
||||
output, err := i.cipher.Decrypt(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var g graph
|
||||
err = json.Unmarshal([]byte(output), &g)
|
||||
err = json.Unmarshal(output, &g)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -246,7 +246,7 @@ func (i *Installer) saveGraph(ctx context.Context, g graph) error {
|
|||
return err
|
||||
}
|
||||
|
||||
output, err := i.cipher.Encrypt(string(b))
|
||||
output, err := i.cipher.Encrypt(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,112 +0,0 @@
|
|||
package encryption
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
)
|
||||
|
||||
// encryptionSecretName must match key name in the service keyvault
|
||||
const (
|
||||
encryptionSecretName = "encryption-key"
|
||||
Prefix = "ENC*"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Cipher = (*aeadCipher)(nil)
|
||||
RandRead = rand.Read
|
||||
)
|
||||
|
||||
type Cipher interface {
|
||||
Decrypt(string) (string, error)
|
||||
Encrypt(string) (string, error)
|
||||
}
|
||||
|
||||
type aeadCipher struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func NewCipher(ctx context.Context, env env.Interface) (Cipher, error) {
|
||||
keybase64, err := env.GetSecret(ctx, encryptionSecretName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := make([]byte, base64.StdEncoding.DecodedLen(len(keybase64)))
|
||||
n, err := base64.StdEncoding.Decode(key, keybase64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n < 32 {
|
||||
return nil, fmt.Errorf("chacha20poly1305: bad key length")
|
||||
}
|
||||
key = key[:32]
|
||||
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &aeadCipher{
|
||||
aead: aead,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts input
|
||||
func (c *aeadCipher) Decrypt(input string) (string, error) {
|
||||
if !strings.HasPrefix(input, Prefix) {
|
||||
return input, nil
|
||||
}
|
||||
input = input[len(Prefix):]
|
||||
|
||||
r := make([]byte, base64.StdEncoding.DecodedLen(len(input)))
|
||||
r, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(r) >= 24 {
|
||||
nonce := r[0:24]
|
||||
data := r[24:]
|
||||
output, err := c.aead.Open(nil, nonce, data, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
return "", fmt.Errorf("error while decrypting message")
|
||||
}
|
||||
|
||||
// Encrypt encrypts input using 24 byte nonce
|
||||
func (c *aeadCipher) Encrypt(input string) (string, error) {
|
||||
nonce := make([]byte, chacha20poly1305.NonceSizeX)
|
||||
_, err := RandRead(nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
encrypted := c.aead.Seal(nil, nonce, []byte(input), nil)
|
||||
|
||||
var encryptedFinal []byte
|
||||
encryptedFinal = append(encryptedFinal, nonce...)
|
||||
encryptedFinal = append(encryptedFinal, encrypted...)
|
||||
|
||||
encryptedBase64 := make([]byte, base64.StdEncoding.EncodedLen(len(encryptedFinal)))
|
||||
base64.StdEncoding.Encode(encryptedBase64, encryptedFinal)
|
||||
|
||||
// return prefix+base64(nonce+encryptedFinal)
|
||||
var result []byte
|
||||
result = append(result, Prefix...)
|
||||
result = append(result, encryptedBase64...)
|
||||
return string(result), nil
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
package encryption
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
)
|
||||
|
||||
func TestEncryptRoundTrip(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
|
||||
base64.StdEncoding.Encode(keybase64, key)
|
||||
env := &env.Test{TestSecret: keybase64}
|
||||
|
||||
cipher, err := NewCipher(context.Background(), env)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
test := "secert"
|
||||
encrypted, err := cipher.Encrypt(test)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
decrypted, err := cipher.Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if r := strings.Compare(test, decrypted); r != 0 {
|
||||
t.Error("encryption roundTrip failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
RandRead = func(b []byte) (n int, err error) {
|
||||
b = make([]byte, len(b))
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr string
|
||||
env func(e *env.Test)
|
||||
}{
|
||||
{
|
||||
name: "ok encrypt",
|
||||
input: "test",
|
||||
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
|
||||
wantErr: "",
|
||||
env: func(input *env.Test) {
|
||||
key := make([]byte, 32)
|
||||
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
|
||||
base64.StdEncoding.Encode(keybase64, key)
|
||||
input.TestSecret = keybase64
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 key error",
|
||||
wantErr: "illegal base64 data at input byte 8",
|
||||
env: func(input *env.Test) {
|
||||
input.TestSecret = []byte("badsecret")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "key too short",
|
||||
input: "test",
|
||||
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
|
||||
wantErr: "chacha20poly1305: bad key length",
|
||||
env: func(input *env.Test) {
|
||||
keybase64 := base64.StdEncoding.EncodeToString(make([]byte, 15))
|
||||
input.TestSecret = []byte(keybase64)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "key too long", // due to base64 approximations library truncates the secret to right lenhgt
|
||||
input: "test",
|
||||
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
|
||||
env: func(input *env.Test) {
|
||||
keybase64 := base64.StdEncoding.EncodeToString((make([]byte, 40)))
|
||||
input.TestSecret = []byte(keybase64)
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &env.Test{}
|
||||
if tt.env != nil {
|
||||
tt.env(e)
|
||||
}
|
||||
|
||||
cipher, err := NewCipher(context.Background(), e)
|
||||
if err != nil {
|
||||
if err.Error() != tt.wantErr {
|
||||
t.Errorf("\n wants: %s,'\ngot: %s", tt.wantErr, err.Error())
|
||||
t.FailNow()
|
||||
}
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
result, err := cipher.Encrypt(tt.input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if tt.expected != result {
|
||||
t.Errorf("\n wants: %s,'\ngot: %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package encryption
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
)
|
||||
|
||||
const (
|
||||
encryptionSecretName = "encryption-key" // must match key name in the service keyvault
|
||||
)
|
||||
|
||||
var _ Cipher = (*aeadCipher)(nil)
|
||||
|
||||
type Cipher interface {
|
||||
Decrypt([]byte) ([]byte, error)
|
||||
Encrypt([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type aeadCipher struct {
|
||||
aead cipher.AEAD
|
||||
randRead func([]byte) (int, error)
|
||||
}
|
||||
|
||||
func NewXChaCha20Poly1305(ctx context.Context, env env.Interface) (Cipher, error) {
|
||||
key, err := env.GetSecret(ctx, encryptionSecretName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &aeadCipher{
|
||||
aead: aead,
|
||||
randRead: rand.Read,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *aeadCipher) Decrypt(input []byte) ([]byte, error) {
|
||||
if len(input) < chacha20poly1305.NonceSizeX {
|
||||
return nil, fmt.Errorf("encrypted value too short")
|
||||
}
|
||||
|
||||
nonce := input[:chacha20poly1305.NonceSizeX]
|
||||
data := input[chacha20poly1305.NonceSizeX:]
|
||||
|
||||
return c.aead.Open(nil, nonce, data, nil)
|
||||
}
|
||||
|
||||
func (c *aeadCipher) Encrypt(input []byte) ([]byte, error) {
|
||||
nonce := make([]byte, chacha20poly1305.NonceSizeX)
|
||||
|
||||
n, err := c.randRead(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n != chacha20poly1305.NonceSizeX {
|
||||
return nil, fmt.Errorf("rand.Read returned %d bytes, expected %d", n, chacha20poly1305.NonceSizeX)
|
||||
}
|
||||
|
||||
return append(nonce, c.aead.Seal(nil, nonce, input, nil)...), nil
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
package encryption
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
)
|
||||
|
||||
func TestNewXChaCha20Poly1305(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
key []byte
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
key: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
|
||||
},
|
||||
{
|
||||
name: "key too short",
|
||||
key: []byte("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"),
|
||||
wantErr: "chacha20poly1305: bad key length",
|
||||
},
|
||||
{
|
||||
name: "key too long",
|
||||
key: []byte("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"),
|
||||
wantErr: "chacha20poly1305: bad key length",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
env := &env.Test{
|
||||
TestSecret: tt.key,
|
||||
}
|
||||
|
||||
_, err := NewXChaCha20Poly1305(context.Background(), env)
|
||||
if err != nil && err.Error() != tt.wantErr ||
|
||||
err == nil && tt.wantErr != "" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestXChaCha20Poly1305Decrypt(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
key []byte
|
||||
input []byte
|
||||
wantDecrypted []byte
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
|
||||
input: []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"),
|
||||
wantDecrypted: []byte("test"),
|
||||
},
|
||||
{
|
||||
name: "invalid - encrypted value tampered with",
|
||||
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
|
||||
input: []byte("\xda\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"),
|
||||
wantErr: "chacha20poly1305: message authentication failed",
|
||||
},
|
||||
{
|
||||
name: "invalid - too short",
|
||||
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
|
||||
input: []byte("XXXXXXXXXXXXXXXXXXXXXXX"),
|
||||
wantErr: "encrypted value too short",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
env := &env.Test{
|
||||
TestSecret: tt.key,
|
||||
}
|
||||
|
||||
cipher, err := NewXChaCha20Poly1305(context.Background(), env)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
decrypted, err := cipher.Decrypt(tt.input)
|
||||
if err != nil && err.Error() != tt.wantErr ||
|
||||
err == nil && tt.wantErr != "" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(tt.wantDecrypted, decrypted) {
|
||||
t.Error(string(decrypted))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestXChaCha20Poly1305Encrypt(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
key []byte
|
||||
randRead func(b []byte) (int, error)
|
||||
input []byte
|
||||
wantEncrypted []byte
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
key: []byte("\x6a\x98\x95\x6b\x2b\xb2\x7e\xfd\x1b\x68\xdf\x5c\x40\xc3\x4f\x8b\xcf\xff\xe8\x17\xc2\x2d\xf6\x40\x2e\x5a\xb0\x15\x63\x4a\x2d\x2e"),
|
||||
randRead: func(b []byte) (int, error) {
|
||||
nonce := []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2")
|
||||
copy(b, nonce)
|
||||
return len(nonce), nil
|
||||
},
|
||||
input: []byte("test"),
|
||||
wantEncrypted: []byte("\xd9\x1c\x3c\x05\xb2\xf3\xc5\x93\x20\x9f\x9b\x67\x43\x8c\x0c\x3d\x9c\x33\x5b\x16\xd6\x9a\x9c\xf2\x9c\xf6\xe9\xbd\xdd\xe3\x1d\x54\xde\x41\xa2\x99\x56\x6a\xfc\x9a\xf3\x58\x73\x03"),
|
||||
},
|
||||
{
|
||||
name: "rand.Read error",
|
||||
key: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
|
||||
randRead: func(b []byte) (int, error) {
|
||||
return 0, fmt.Errorf("random error")
|
||||
},
|
||||
wantErr: "random error",
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
env := &env.Test{
|
||||
TestSecret: tt.key,
|
||||
}
|
||||
|
||||
cipher, err := NewXChaCha20Poly1305(context.Background(), env)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cipher.(*aeadCipher).randRead = tt.randRead
|
||||
|
||||
encrypted, err := cipher.Encrypt(tt.input)
|
||||
if err != nil && err.Error() != tt.wantErr ||
|
||||
err == nil && tt.wantErr != "" {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(tt.wantEncrypted, encrypted) {
|
||||
t.Error(hex.EncodeToString(encrypted))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче