зеркало из https://github.com/Azure/ARO-RP.git
add encrypt pkg
This commit is contained in:
Родитель
e86f3b7a0d
Коммит
468621f73c
|
@ -61,6 +61,10 @@ required = [
|
|||
name = "github.com/ugorji/go"
|
||||
version = "1.1.7"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/crypto"
|
||||
|
||||
[[override]]
|
||||
name = "k8s.io/api"
|
||||
branch = "origin-4.3-kubernetes-1.16.2"
|
||||
|
@ -92,7 +96,3 @@ required = [
|
|||
[[prune.project]]
|
||||
name = "github.com/openshift/installer"
|
||||
unused-packages = false
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/crypto"
|
||||
|
|
|
@ -30,7 +30,7 @@ func monitor(ctx context.Context, log *logrus.Entry) error {
|
|||
}
|
||||
defer m.Close()
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid)
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func rp(ctx context.Context, log *logrus.Entry) error {
|
|||
}
|
||||
defer m.Close()
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid)
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, m, uuid, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -67,6 +67,10 @@
|
|||
"tenantId": "[subscription().tenantId]",
|
||||
"objectId": "[parameters('adminObjectId')]",
|
||||
"permissions": {
|
||||
"secrets": [
|
||||
"set",
|
||||
"list"
|
||||
],
|
||||
"certificates": [
|
||||
"delete",
|
||||
"get",
|
||||
|
|
|
@ -308,6 +308,11 @@ locations.
|
|||
--name rp-server \
|
||||
--file secrets/localhost.pem \
|
||||
>/dev/null
|
||||
az keyvault secret set \
|
||||
--vault-name "$KEYVAULT_PREFIX-service" \
|
||||
--name "encryption-key" \
|
||||
--value $(openssl rand -base64 32) \
|
||||
>/dev/null
|
||||
```
|
||||
|
||||
1. Create nameserver records in the parent DNS zone:
|
||||
|
|
|
@ -5,14 +5,13 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/ugorji/go/codec"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/database"
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/metrics/noop"
|
||||
|
@ -29,7 +28,7 @@ func run(ctx context.Context, log *logrus.Entry) error {
|
|||
return err
|
||||
}
|
||||
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, "")
|
||||
db, err := database.NewDatabase(ctx, log.WithField("component", "database"), env, &noop.Noop{}, "", true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -39,16 +38,7 @@ func run(ctx context.Context, log *logrus.Entry) error {
|
|||
return err
|
||||
}
|
||||
|
||||
h := &codec.JsonHandle{
|
||||
Indent: 4,
|
||||
}
|
||||
|
||||
err = api.AddExtensions(&h.BasicHandle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return codec.NewEncoder(os.Stdout, h).Encode(doc)
|
||||
return json.NewEncoder(os.Stdout).Encode(doc)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -4,37 +4,9 @@ package api
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/ugorji/go/codec"
|
||||
)
|
||||
|
||||
// AddExtensions adds extensions to a ugorji/go/codec to enable it to serialise
|
||||
// our types properly
|
||||
func AddExtensions(h *codec.BasicHandle) error {
|
||||
err := h.AddExt(reflect.TypeOf(&rsa.PrivateKey{}), 0, func(v reflect.Value) ([]byte, error) {
|
||||
if reflect.DeepEqual(v.Elem().Interface(), rsa.PrivateKey{}) {
|
||||
return nil, nil
|
||||
}
|
||||
return x509.MarshalPKCS1PrivateKey(v.Interface().(*rsa.PrivateKey)), nil
|
||||
}, func(v reflect.Value, b []byte) error {
|
||||
key, err := x509.ParsePKCS1PrivateKey(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Elem().Set(reflect.ValueOf(key).Elem())
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON marshals an InstallPhase
|
||||
func (p InstallPhase) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(p.String())
|
||||
|
|
|
@ -4,7 +4,6 @@ package api
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -22,6 +21,12 @@ type OpenShiftCluster struct {
|
|||
Properties Properties `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// SecureBytes represents encrypted []byte
|
||||
type SecureBytes []byte
|
||||
|
||||
// SecureString represents encrypted string
|
||||
type SecureString string
|
||||
|
||||
// Properties represents an OpenShift cluster's properties
|
||||
type Properties struct {
|
||||
MissingFields
|
||||
|
@ -80,9 +85,9 @@ type Properties struct {
|
|||
|
||||
StorageSuffix string `json:"storageSuffix,omitempty"`
|
||||
|
||||
SSHKey *rsa.PrivateKey `json:"sshKey,omitempty"`
|
||||
AdminKubeconfig []byte `json:"adminKubeconfig,omitempty"`
|
||||
KubeadminPassword string `json:"kubeadminPassword,omitempty"`
|
||||
SSHKey SecureBytes `json:"sshKey,omitempty"`
|
||||
AdminKubeconfig SecureBytes `json:"adminKubeconfig,omitempty"`
|
||||
KubeadminPassword SecureString `json:"kubeadminPassword,omitempty"`
|
||||
}
|
||||
|
||||
// ProvisioningState represents a provisioning state
|
||||
|
@ -119,7 +124,7 @@ type ServicePrincipalProfile struct {
|
|||
|
||||
TenantID string `json:"tenantId,omitempty"`
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
ClientSecret SecureString `json:"clientSecret,omitempty"`
|
||||
}
|
||||
|
||||
// NetworkProfile represents a network profile
|
||||
|
|
|
@ -31,7 +31,7 @@ func (c *openShiftClusterConverter) ToExternal(oc *api.OpenShiftCluster) interfa
|
|||
},
|
||||
ServicePrincipalProfile: ServicePrincipalProfile{
|
||||
ClientID: oc.Properties.ServicePrincipalProfile.ClientID,
|
||||
ClientSecret: oc.Properties.ServicePrincipalProfile.ClientSecret,
|
||||
ClientSecret: string(oc.Properties.ServicePrincipalProfile.ClientSecret),
|
||||
},
|
||||
NetworkProfile: NetworkProfile{
|
||||
PodCIDR: oc.Properties.NetworkProfile.PodCIDR,
|
||||
|
@ -121,7 +121,7 @@ func (c *openShiftClusterConverter) ToInternal(_oc interface{}, out *api.OpenShi
|
|||
out.Properties.ClusterProfile.ResourceGroupID = oc.Properties.ClusterProfile.ResourceGroupID
|
||||
out.Properties.ConsoleProfile.URL = oc.Properties.ConsoleProfile.URL
|
||||
out.Properties.ServicePrincipalProfile.ClientID = oc.Properties.ServicePrincipalProfile.ClientID
|
||||
out.Properties.ServicePrincipalProfile.ClientSecret = oc.Properties.ServicePrincipalProfile.ClientSecret
|
||||
out.Properties.ServicePrincipalProfile.ClientSecret = api.SecureString(oc.Properties.ServicePrincipalProfile.ClientSecret)
|
||||
out.Properties.NetworkProfile.PodCIDR = oc.Properties.NetworkProfile.PodCIDR
|
||||
out.Properties.NetworkProfile.ServiceCIDR = oc.Properties.NetworkProfile.ServiceCIDR
|
||||
out.Properties.MasterProfile.VMSize = api.VMSize(oc.Properties.MasterProfile.VMSize)
|
||||
|
|
|
@ -87,7 +87,7 @@ func (v *openShiftClusterValidator) Dynamic(ctx context.Context, oc *api.OpenShi
|
|||
|
||||
func (dv *openShiftClusterDynamicValidator) validateServicePrincipalProfile() (autorest.Authorizer, error) {
|
||||
spp := &dv.oc.Properties.ServicePrincipalProfile
|
||||
conf := auth.NewClientCredentialsConfig(spp.ClientID, spp.ClientSecret, spp.TenantID)
|
||||
conf := auth.NewClientCredentialsConfig(spp.ClientID, string(spp.ClientSecret), spp.TenantID)
|
||||
|
||||
token, err := conf.ServicePrincipalToken()
|
||||
if err != nil {
|
||||
|
@ -107,7 +107,7 @@ func (dv *openShiftClusterDynamicValidator) validateServicePrincipalProfile() (a
|
|||
|
||||
func (dv *openShiftClusterDynamicValidator) validateServicePrincipalRole() error {
|
||||
spp := &dv.oc.Properties.ServicePrincipalProfile
|
||||
conf := auth.NewClientCredentialsConfig(spp.ClientID, spp.ClientSecret, spp.TenantID)
|
||||
conf := auth.NewClientCredentialsConfig(spp.ClientID, string(spp.ClientSecret), spp.TenantID)
|
||||
conf.Resource = azure.PublicCloud.GraphEndpoint
|
||||
|
||||
token, err := conf.ServicePrincipalToken()
|
||||
|
|
|
@ -17,7 +17,7 @@ type openShiftClusterCredentialsConverter struct{}
|
|||
func (*openShiftClusterCredentialsConverter) ToExternal(oc *api.OpenShiftCluster) interface{} {
|
||||
out := &OpenShiftClusterCredentials{
|
||||
KubeadminUsername: "kubeadmin",
|
||||
KubeadminPassword: oc.Properties.KubeadminPassword,
|
||||
KubeadminPassword: string(oc.Properties.KubeadminPassword),
|
||||
}
|
||||
|
||||
return out
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
@ -40,9 +41,12 @@ func (m *Manager) Create(ctx context.Context) error {
|
|||
|
||||
m.doc, err = m.db.PatchWithLease(ctx, m.doc.Key, func(doc *api.OpenShiftClusterDocument) error {
|
||||
var err error
|
||||
|
||||
if doc.OpenShiftCluster.Properties.SSHKey == nil {
|
||||
doc.OpenShiftCluster.Properties.SSHKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
sshKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
doc.OpenShiftCluster.Properties.SSHKey = x509.MarshalPKCS1PrivateKey(sshKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -81,7 +85,12 @@ func (m *Manager) Create(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
sshkey, err := ssh.NewPublicKey(&m.doc.OpenShiftCluster.Properties.SSHKey.PublicKey)
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(m.doc.OpenShiftCluster.Properties.SSHKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sshkey, err := ssh.NewPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -111,7 +120,7 @@ func (m *Manager) Create(ctx context.Context) error {
|
|||
Azure: &icazure.Credentials{
|
||||
TenantID: m.doc.OpenShiftCluster.Properties.ServicePrincipalProfile.TenantID,
|
||||
ClientID: m.doc.OpenShiftCluster.Properties.ServicePrincipalProfile.ClientID,
|
||||
ClientSecret: m.doc.OpenShiftCluster.Properties.ServicePrincipalProfile.ClientSecret,
|
||||
ClientSecret: string(m.doc.OpenShiftCluster.Properties.ServicePrincipalProfile.ClientSecret),
|
||||
SubscriptionID: r.SubscriptionID,
|
||||
},
|
||||
}
|
||||
|
@ -224,7 +233,7 @@ func (m *Manager) Create(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
i, err := install.NewInstaller(m.log, m.env, m.db, m.doc)
|
||||
i, err := install.NewInstaller(ctx, m.log, m.env, m.db, m.doc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -17,6 +18,7 @@ import (
|
|||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/metrics"
|
||||
dbmetrics "github.com/Azure/ARO-RP/pkg/metrics/statsd/cosmosdb"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
)
|
||||
|
||||
// Database represents a database
|
||||
|
@ -31,21 +33,18 @@ type Database struct {
|
|||
}
|
||||
|
||||
// NewDatabase returns a new Database
|
||||
func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, uuid string) (db *Database, err error) {
|
||||
databaseAccount, masterKey := env.CosmosDB()
|
||||
|
||||
h := &codec.JsonHandle{
|
||||
BasicHandle: codec.BasicHandle{
|
||||
DecodeOptions: codec.DecodeOptions{
|
||||
ErrorIfNoField: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = api.AddExtensions(&h.BasicHandle)
|
||||
func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m metrics.Interface, uuid string, decryptDatabase bool) (db *Database, err error) {
|
||||
var cipher encryption.Cipher
|
||||
if decryptDatabase {
|
||||
cipher, err = encryption.NewCipher(ctx, env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
databaseAccount, masterKey := env.CosmosDB()
|
||||
|
||||
h := newJSONHandle(cipher)
|
||||
|
||||
c := &http.Client{
|
||||
Transport: dbmetrics.New(log, &http.Transport{
|
||||
|
@ -90,3 +89,17 @@ func NewDatabase(ctx context.Context, log *logrus.Entry, env env.Interface, m me
|
|||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func newJSONHandle(cipher encryption.Cipher) *codec.JsonHandle {
|
||||
h := &codec.JsonHandle{
|
||||
BasicHandle: codec.BasicHandle{
|
||||
DecodeOptions: codec.DecodeOptions{
|
||||
ErrorIfNoField: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h.SetInterfaceExt(reflect.TypeOf(api.SecureBytes{}), 1, SecureBytesExt{Cipher: cipher})
|
||||
h.SetInterfaceExt(reflect.TypeOf((*api.SecureString)(nil)), 1, SecureStringExt{Cipher: cipher})
|
||||
return h
|
||||
}
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
package database
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"github.com/ugorji/go/codec"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
encrypt "github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
)
|
||||
|
||||
var _ codec.InterfaceExt = (*SecureBytesExt)(nil)
|
||||
|
||||
type SecureBytesExt struct {
|
||||
Cipher encryption.Cipher
|
||||
}
|
||||
|
||||
func (s SecureBytesExt) ConvertExt(v interface{}) interface{} {
|
||||
data := v.(api.SecureBytes)
|
||||
if s.Cipher != nil {
|
||||
encrypted, err := s.Cipher.Encrypt(string(data))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
func (s SecureBytesExt) UpdateExt(dest interface{}, v interface{}) {
|
||||
output := dest.(*api.SecureBytes)
|
||||
if s.Cipher != nil {
|
||||
decrypted, err := s.Cipher.Decrypt(v.(string))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
*output = api.SecureBytes(decrypted)
|
||||
return
|
||||
}
|
||||
*output = api.SecureBytes(v.(string))
|
||||
}
|
||||
|
||||
var _ codec.InterfaceExt = (*SecureStringExt)(nil)
|
||||
|
||||
type SecureStringExt struct {
|
||||
Cipher encrypt.Cipher
|
||||
}
|
||||
|
||||
func (s SecureStringExt) ConvertExt(v interface{}) interface{} {
|
||||
data := v.(api.SecureString)
|
||||
if s.Cipher != nil {
|
||||
encrypted, err := s.Cipher.Encrypt(string(data))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
func (s SecureStringExt) UpdateExt(dest interface{}, v interface{}) {
|
||||
output := dest.(*api.SecureString)
|
||||
if s.Cipher != nil {
|
||||
decrypted, err := s.Cipher.Decrypt(v.(string))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
*output = api.SecureString(decrypted)
|
||||
return
|
||||
}
|
||||
*output = api.SecureString(v.(string))
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
package database
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ugorji/go/codec"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/api"
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
)
|
||||
|
||||
type testStruct struct {
|
||||
SecureBytes api.SecureBytes
|
||||
SecureString api.SecureString
|
||||
Bytes []byte
|
||||
Str string
|
||||
}
|
||||
|
||||
func TestExtensions(t *testing.T) {
|
||||
encryption.RandRead = func(b []byte) (n int, err error) {
|
||||
b = make([]byte, len(b))
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
key := make([]byte, 32)
|
||||
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
|
||||
base64.StdEncoding.Encode(keybase64, key)
|
||||
env := &env.Test{TestSecret: keybase64}
|
||||
|
||||
cipher, err := encryption.NewCipher(context.Background(), env)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
input func(input *testStruct)
|
||||
inputCodec *codec.JsonHandle
|
||||
output func(input *testStruct)
|
||||
outputCodec *codec.JsonHandle
|
||||
}{
|
||||
{
|
||||
name: "noop",
|
||||
},
|
||||
{
|
||||
name: "SecureByte - encrypt - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("test")
|
||||
},
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("test")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecureByte - encrypt - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("test")
|
||||
},
|
||||
inputCodec: newJSONHandle(cipher),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
// empty string encoded
|
||||
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAjzuUWlGQbchgDen4li0A5g=="
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
{
|
||||
name: "SecureByte - raw - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("test")
|
||||
},
|
||||
outputCodec: newJSONHandle(cipher),
|
||||
},
|
||||
{
|
||||
name: "SecureByte - raw - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureBytes = []byte("ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=")
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
{
|
||||
name: "SecureString - encrypt - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "test"
|
||||
},
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "test"
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecureString - encrypt - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "test"
|
||||
},
|
||||
inputCodec: newJSONHandle(cipher),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
{
|
||||
name: "SecureString - raw - decrypt",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "test"
|
||||
},
|
||||
outputCodec: newJSONHandle(cipher),
|
||||
},
|
||||
{
|
||||
name: "SecureString - raw - raw",
|
||||
input: func(input *testStruct) {
|
||||
input.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
inputCodec: newJSONHandle(nil),
|
||||
output: func(output *testStruct) {
|
||||
output.SecureString = "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50="
|
||||
},
|
||||
outputCodec: newJSONHandle(nil),
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.inputCodec == nil {
|
||||
tt.inputCodec = newJSONHandle(cipher)
|
||||
}
|
||||
if tt.outputCodec == nil {
|
||||
tt.outputCodec = newJSONHandle(cipher)
|
||||
}
|
||||
|
||||
input := &testStruct{}
|
||||
if tt.input != nil {
|
||||
tt.input(input)
|
||||
}
|
||||
|
||||
output := &testStruct{}
|
||||
if tt.output != nil {
|
||||
tt.output(output)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
err = codec.NewEncoder(buf, tt.inputCodec).Encode(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
data, err := ioutil.ReadAll(buf)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
result := &testStruct{}
|
||||
err = codec.NewDecoder(bytes.NewReader(data), tt.outputCodec).Decode(result)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(output, result) {
|
||||
output, _ := json.Marshal(output)
|
||||
result, _ := json.Marshal(result)
|
||||
t.Errorf("\n wants: %s,'\ngot: %s", string(output), string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -554,6 +554,10 @@ func (g *generator) serviceKeyvault() *arm.Resource {
|
|||
mgmtkeyvault.Import,
|
||||
mgmtkeyvault.List,
|
||||
},
|
||||
Secrets: &[]mgmtkeyvault.SecretPermissions{
|
||||
mgmtkeyvault.SecretPermissionsSet,
|
||||
mgmtkeyvault.SecretPermissionsList,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -28,8 +28,9 @@ type Interface interface {
|
|||
DialContext(context.Context, string, string) (net.Conn, error)
|
||||
Domain() string
|
||||
FPAuthorizer(string, string) (autorest.Authorizer, error)
|
||||
GetSecret(context.Context, string) (*rsa.PrivateKey, []*x509.Certificate, error)
|
||||
ManagedDomain(string) (string, error)
|
||||
GetSecret(context.Context, string) ([]byte, error)
|
||||
GetCertificateSecret(context.Context, string) (*rsa.PrivateKey, []*x509.Certificate, error)
|
||||
Listen() (net.Listener, error)
|
||||
VnetName() string
|
||||
Zones(vmSize string) ([]string, error)
|
||||
|
|
|
@ -100,7 +100,7 @@ func newProd(ctx context.Context, log *logrus.Entry, instancemetadata instanceme
|
|||
return nil, err
|
||||
}
|
||||
|
||||
fpPrivateKey, fpCertificates, err := p.GetSecret(ctx, "rp-firstparty")
|
||||
fpPrivateKey, fpCertificates, err := p.GetCertificateSecret(ctx, "rp-firstparty")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -265,13 +265,16 @@ func (p *prod) FPAuthorizer(tenantID, resource string) (autorest.Authorizer, err
|
|||
return autorest.NewBearerAuthorizer(sp), nil
|
||||
}
|
||||
|
||||
func (p *prod) GetSecret(ctx context.Context, secretName string) (key *rsa.PrivateKey, certs []*x509.Certificate, err error) {
|
||||
func (p *prod) GetSecret(ctx context.Context, secretName string) ([]byte, error) {
|
||||
bundle, err := p.keyvault.GetSecret(ctx, p.serviceKeyvaultURI, secretName, "")
|
||||
return []byte(*bundle.Value), err
|
||||
}
|
||||
|
||||
func (p *prod) GetCertificateSecret(ctx context.Context, secretName string) (key *rsa.PrivateKey, certs []*x509.Certificate, err error) {
|
||||
b, err := p.GetSecret(ctx, secretName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
b := []byte(*bundle.Value)
|
||||
for {
|
||||
var block *pem.Block
|
||||
block, b = pem.Decode(b)
|
||||
|
|
|
@ -25,6 +25,7 @@ type Test struct {
|
|||
TestResourceGroup string
|
||||
TestDomain string
|
||||
TestVNetName string
|
||||
TestSecret []byte
|
||||
|
||||
TLSKey *rsa.PrivateKey
|
||||
TLSCerts []*x509.Certificate
|
||||
|
@ -45,7 +46,7 @@ func (t *Test) FPAuthorizer(tenantID, resource string) (autorest.Authorizer, err
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (t *Test) GetSecret(ctx context.Context, secretName string) (key *rsa.PrivateKey, certs []*x509.Certificate, err error) {
|
||||
func (t *Test) GetCertificateSecret(ctx context.Context, secretName string) (key *rsa.PrivateKey, certs []*x509.Certificate, err error) {
|
||||
switch secretName {
|
||||
case "rp-server":
|
||||
return t.TLSKey, t.TLSCerts, nil
|
||||
|
@ -54,6 +55,10 @@ func (t *Test) GetSecret(ctx context.Context, secretName string) (key *rsa.Priva
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Test) GetSecret(ctx context.Context, secretName string) ([]byte, error) {
|
||||
return t.TestSecret, nil
|
||||
}
|
||||
|
||||
func (t *Test) Listen() (net.Listener, error) {
|
||||
return t.L, nil
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ func NewFrontend(ctx context.Context, baseLog *logrus.Entry, env env.Interface,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
key, certs, err := f.env.GetSecret(ctx, "rp-server")
|
||||
key, certs, err := f.env.GetCertificateSecret(ctx, "rp-server")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -4,9 +4,7 @@ package install
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -18,7 +16,6 @@ import (
|
|||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/Azure/go-autorest/autorest/to"
|
||||
"github.com/openshift/installer/pkg/asset/ignition/bootstrap"
|
||||
"github.com/openshift/installer/pkg/asset/installconfig"
|
||||
"github.com/openshift/installer/pkg/asset/kubeconfig"
|
||||
"github.com/openshift/installer/pkg/asset/releaseimage"
|
||||
|
@ -66,7 +63,6 @@ func (i *Installer) installStorage(ctx context.Context, installConfig *installco
|
|||
}
|
||||
|
||||
adminClient := g[reflect.TypeOf(&kubeconfig.AdminClient{})].(*kubeconfig.AdminClient)
|
||||
bootstrap := g[reflect.TypeOf(&bootstrap.Bootstrap{})].(*bootstrap.Bootstrap)
|
||||
|
||||
resourceGroup := i.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID[strings.LastIndexByte(i.doc.OpenShiftCluster.Properties.ClusterProfile.ResourceGroupID, '/')+1:]
|
||||
|
||||
|
@ -198,26 +194,10 @@ func (i *Installer) installStorage(ctx context.Context, installConfig *installco
|
|||
}
|
||||
|
||||
{
|
||||
blobService, err := i.getBlobService(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bootstrapIgn := blobService.GetContainerReference("ignition").GetBlobReference("bootstrap.ign")
|
||||
err = bootstrapIgn.CreateBlockBlobFromReader(bytes.NewReader(bootstrap.File.Data), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// the graph is quite big so we store it in a storage account instead of
|
||||
// in cosmosdb
|
||||
graph := blobService.GetContainerReference("aro").GetBlobReference("graph")
|
||||
b, err := json.MarshalIndent(g, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = graph.CreateBlockBlobFromReader(bytes.NewReader(b), nil)
|
||||
err := i.saveGraph(ctx, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ import (
|
|||
)
|
||||
|
||||
func (i *Installer) installResources(ctx context.Context) error {
|
||||
g, err := i.getGraph(ctx)
|
||||
g, err := i.loadGraph(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ func (i *Installer) installResources(ctx context.Context) error {
|
|||
{
|
||||
spp := &i.doc.OpenShiftCluster.Properties.ServicePrincipalProfile
|
||||
|
||||
conf := auth.NewClientCredentialsConfig(spp.ClientID, spp.ClientSecret, spp.TenantID)
|
||||
conf := auth.NewClientCredentialsConfig(spp.ClientID, string(spp.ClientSecret), spp.TenantID)
|
||||
conf.Resource = azure.PublicCloud.GraphEndpoint
|
||||
|
||||
spGraphAuthorizer, err := conf.Authorizer()
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
)
|
||||
|
||||
func (i *Installer) removeBootstrap(ctx context.Context) error {
|
||||
g, err := i.getGraph(ctx)
|
||||
g, err := i.loadGraph(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ func (i *Installer) removeBootstrap(ctx context.Context) error {
|
|||
doc.OpenShiftCluster.Properties.APIServerProfile.URL = "https://api." + installConfig.Config.ObjectMeta.Name + "." + installConfig.Config.BaseDomain + ":6443/"
|
||||
doc.OpenShiftCluster.Properties.IngressProfiles[0].IP = routerIP
|
||||
doc.OpenShiftCluster.Properties.ConsoleProfile.URL = "https://console-openshift-console.apps." + installConfig.Config.ObjectMeta.Name + "." + installConfig.Config.BaseDomain + "/"
|
||||
doc.OpenShiftCluster.Properties.KubeadminPassword = kubeadminPassword.Password
|
||||
doc.OpenShiftCluster.Properties.KubeadminPassword = api.SecureString(kubeadminPassword.Password)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
|
|
|
@ -4,8 +4,6 @@ package install
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/openshift/installer/pkg/asset"
|
||||
|
@ -143,40 +141,3 @@ func (g graph) resolve(a asset.Asset) (asset.Asset, error) {
|
|||
|
||||
return g[reflect.TypeOf(a)], nil
|
||||
}
|
||||
|
||||
func (g graph) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]asset.Asset{}
|
||||
for t, a := range g {
|
||||
m[t.String()] = a
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func (g *graph) UnmarshalJSON(b []byte) error {
|
||||
if *g == nil {
|
||||
*g = graph{}
|
||||
}
|
||||
|
||||
var m map[string]json.RawMessage
|
||||
err := json.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for n, b := range m {
|
||||
t, found := registeredTypes[n]
|
||||
if !found {
|
||||
return fmt.Errorf("unregistered type %q", n)
|
||||
}
|
||||
|
||||
a := reflect.New(reflect.TypeOf(t).Elem()).Interface().(asset.Asset)
|
||||
err = json.Unmarshal(b, a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*g)[reflect.TypeOf(a)] = a
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,10 +4,13 @@ package install
|
|||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -16,6 +19,7 @@ import (
|
|||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/openshift/installer/pkg/asset/ignition/bootstrap"
|
||||
"github.com/openshift/installer/pkg/asset/installconfig"
|
||||
"github.com/openshift/installer/pkg/asset/releaseimage"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -28,6 +32,7 @@ import (
|
|||
"github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/resources"
|
||||
"github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/storage"
|
||||
"github.com/Azure/ARO-RP/pkg/util/dns"
|
||||
"github.com/Azure/ARO-RP/pkg/util/encryption"
|
||||
"github.com/Azure/ARO-RP/pkg/util/keyvault"
|
||||
"github.com/Azure/ARO-RP/pkg/util/privateendpoint"
|
||||
"github.com/Azure/ARO-RP/pkg/util/subnet"
|
||||
|
@ -38,6 +43,7 @@ type Installer struct {
|
|||
env env.Interface
|
||||
db database.OpenShiftClusters
|
||||
doc *api.OpenShiftClusterDocument
|
||||
cipher encryption.Cipher
|
||||
fpAuthorizer autorest.Authorizer
|
||||
|
||||
disks compute.DisksClient
|
||||
|
@ -54,7 +60,7 @@ type Installer struct {
|
|||
subnet subnet.Manager
|
||||
}
|
||||
|
||||
func NewInstaller(log *logrus.Entry, env env.Interface, db database.OpenShiftClusters, doc *api.OpenShiftClusterDocument) (*Installer, error) {
|
||||
func NewInstaller(ctx context.Context, log *logrus.Entry, env env.Interface, db database.OpenShiftClusters, doc *api.OpenShiftClusterDocument) (*Installer, error) {
|
||||
r, err := azure.ParseResourceID(doc.OpenShiftCluster.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -75,10 +81,16 @@ func NewInstaller(log *logrus.Entry, env env.Interface, db database.OpenShiftClu
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cipher, err := encryption.NewCipher(ctx, env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Installer{
|
||||
log: log,
|
||||
env: env,
|
||||
db: db,
|
||||
cipher: cipher,
|
||||
doc: doc,
|
||||
fpAuthorizer: fpAuthorizer,
|
||||
|
||||
|
@ -178,8 +190,8 @@ func (i *Installer) getBlobService(ctx context.Context) (*azstorage.BlobStorageC
|
|||
return &c, nil
|
||||
}
|
||||
|
||||
func (i *Installer) getGraph(ctx context.Context) (graph, error) {
|
||||
i.log.Print("retrieving graph")
|
||||
func (i *Installer) loadGraph(ctx context.Context) (graph, error) {
|
||||
i.log.Print("load graph")
|
||||
|
||||
blobService, err := i.getBlobService(ctx)
|
||||
if err != nil {
|
||||
|
@ -194,11 +206,50 @@ func (i *Installer) getGraph(ctx context.Context) (graph, error) {
|
|||
}
|
||||
defer rc.Close()
|
||||
|
||||
encrypted, err := ioutil.ReadAll(rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output, err := i.cipher.Decrypt(string(encrypted))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var g graph
|
||||
err = json.NewDecoder(rc).Decode(&g)
|
||||
err = json.Unmarshal([]byte(output), &g)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (i *Installer) saveGraph(ctx context.Context, g graph) error {
|
||||
i.log.Print("save graph")
|
||||
|
||||
blobService, err := i.getBlobService(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bootstrap := g[reflect.TypeOf(&bootstrap.Bootstrap{})].(*bootstrap.Bootstrap)
|
||||
bootstrapIgn := blobService.GetContainerReference("ignition").GetBlobReference("bootstrap.ign")
|
||||
err = bootstrapIgn.CreateBlockBlobFromReader(bytes.NewReader(bootstrap.File.Data), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
graph := blobService.GetContainerReference("aro").GetBlobReference("graph")
|
||||
b, err := json.MarshalIndent(g, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output, err := i.cipher.Encrypt(string(b))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return graph.CreateBlockBlobFromReader(bytes.NewReader([]byte(output)), nil)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package install
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/openshift/installer/pkg/asset"
|
||||
)
|
||||
|
||||
func (g graph) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]asset.Asset{}
|
||||
for t, a := range g {
|
||||
m[t.String()] = a
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func (g *graph) UnmarshalJSON(b []byte) error {
|
||||
if *g == nil {
|
||||
*g = graph{}
|
||||
}
|
||||
|
||||
var m map[string]json.RawMessage
|
||||
err := json.Unmarshal(b, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for n, b := range m {
|
||||
t, found := registeredTypes[n]
|
||||
if !found {
|
||||
return fmt.Errorf("unregistered type %q", n)
|
||||
}
|
||||
|
||||
a := reflect.New(reflect.TypeOf(t).Elem()).Interface().(asset.Asset)
|
||||
err = json.Unmarshal(b, a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*g)[reflect.TypeOf(a)] = a
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package encryption
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
)
|
||||
|
||||
// encryptionSecretName must match key name in the service keyvault
|
||||
const (
|
||||
encryptionSecretName = "encryption-key"
|
||||
Prefix = "ENC*"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Cipher = (*aeadCipher)(nil)
|
||||
RandRead = rand.Read
|
||||
)
|
||||
|
||||
type Cipher interface {
|
||||
Decrypt(string) (string, error)
|
||||
Encrypt(string) (string, error)
|
||||
}
|
||||
|
||||
type aeadCipher struct {
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func NewCipher(ctx context.Context, env env.Interface) (Cipher, error) {
|
||||
keybase64, err := env.GetSecret(ctx, encryptionSecretName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := make([]byte, base64.StdEncoding.DecodedLen(len(keybase64)))
|
||||
n, err := base64.StdEncoding.Decode(key, keybase64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n < 32 {
|
||||
return nil, fmt.Errorf("chacha20poly1305: bad key length")
|
||||
}
|
||||
key = key[:32]
|
||||
|
||||
aead, err := chacha20poly1305.NewX(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &aeadCipher{
|
||||
aead: aead,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts input
|
||||
func (c *aeadCipher) Decrypt(input string) (string, error) {
|
||||
if !strings.HasPrefix(input, Prefix) {
|
||||
return input, nil
|
||||
}
|
||||
input = input[len(Prefix):]
|
||||
|
||||
r := make([]byte, base64.StdEncoding.DecodedLen(len(input)))
|
||||
r, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(r) >= 24 {
|
||||
nonce := r[0:24]
|
||||
data := r[24:]
|
||||
output, err := c.aead.Open(nil, nonce, data, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
return "", fmt.Errorf("error while decrypting message")
|
||||
}
|
||||
|
||||
// Encrypt encrypts input using 24 byte nonce
|
||||
func (c *aeadCipher) Encrypt(input string) (string, error) {
|
||||
nonce := make([]byte, chacha20poly1305.NonceSizeX)
|
||||
_, err := RandRead(nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
encrypted := c.aead.Seal(nil, nonce, []byte(input), nil)
|
||||
|
||||
var encryptedFinal []byte
|
||||
encryptedFinal = append(encryptedFinal, nonce...)
|
||||
encryptedFinal = append(encryptedFinal, encrypted...)
|
||||
|
||||
encryptedBase64 := make([]byte, base64.StdEncoding.EncodedLen(len(encryptedFinal)))
|
||||
base64.StdEncoding.Encode(encryptedBase64, encryptedFinal)
|
||||
|
||||
// return prefix+base64(nonce+encryptedFinal)
|
||||
var result []byte
|
||||
result = append(result, Prefix...)
|
||||
result = append(result, encryptedBase64...)
|
||||
return string(result), nil
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package encryption
|
||||
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the Apache License 2.0.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/ARO-RP/pkg/env"
|
||||
)
|
||||
|
||||
func TestEncryptRoundTrip(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
|
||||
base64.StdEncoding.Encode(keybase64, key)
|
||||
env := &env.Test{TestSecret: keybase64}
|
||||
|
||||
cipher, err := NewCipher(context.Background(), env)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
test := "secert"
|
||||
encrypted, err := cipher.Encrypt(test)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
decrypted, err := cipher.Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if r := strings.Compare(test, decrypted); r != 0 {
|
||||
t.Error("encryption roundTrip failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
RandRead = func(b []byte) (n int, err error) {
|
||||
b = make([]byte, len(b))
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr string
|
||||
env func(e *env.Test)
|
||||
}{
|
||||
{
|
||||
name: "ok encrypt",
|
||||
input: "test",
|
||||
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
|
||||
wantErr: "",
|
||||
env: func(input *env.Test) {
|
||||
key := make([]byte, 32)
|
||||
keybase64 := make([]byte, base64.StdEncoding.EncodedLen(len(key)))
|
||||
base64.StdEncoding.Encode(keybase64, key)
|
||||
input.TestSecret = keybase64
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base64 key error",
|
||||
wantErr: "illegal base64 data at input byte 8",
|
||||
env: func(input *env.Test) {
|
||||
input.TestSecret = []byte("badsecret")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "key too short",
|
||||
input: "test",
|
||||
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
|
||||
wantErr: "chacha20poly1305: bad key length",
|
||||
env: func(input *env.Test) {
|
||||
keybase64 := base64.StdEncoding.EncodeToString(make([]byte, 15))
|
||||
input.TestSecret = []byte(keybase64)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "key too long", // due to base64 approximations library truncates the secret to right lenhgt
|
||||
input: "test",
|
||||
expected: "ENC*AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADPvl/edTVlZfXuNqdeWf2B1jR50=",
|
||||
env: func(input *env.Test) {
|
||||
keybase64 := base64.StdEncoding.EncodeToString((make([]byte, 40)))
|
||||
input.TestSecret = []byte(keybase64)
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &env.Test{}
|
||||
if tt.env != nil {
|
||||
tt.env(e)
|
||||
}
|
||||
|
||||
cipher, err := NewCipher(context.Background(), e)
|
||||
if err != nil {
|
||||
if err.Error() != tt.wantErr {
|
||||
t.Errorf("\n wants: %s,'\ngot: %s", tt.wantErr, err.Error())
|
||||
t.FailNow()
|
||||
}
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
result, err := cipher.Encrypt(tt.input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if tt.expected != result {
|
||||
t.Errorf("\n wants: %s,'\ngot: %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -8,3 +8,5 @@ require (
|
|||
github.com/stretchr/testify v1.2.2
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894
|
||||
)
|
||||
|
||||
go 1.13
|
||||
|
|
Загрузка…
Ссылка в новой задаче