feat: add mTLS to CNS (#2751)
* feat: add UseMTLS config * feat: add mTLS auth for CNS * test: add testdata for mTLS tests * chore: add logs on TLS config retrieval * lint: in tests * refactor: use CNS logger, not ACN logger * refactor: add guards to mtlsRootCAsFromCertificate and unit tests * lint: fix lint errors * test: include HTTP listener tests for when TLS/mTLS is enabled * chore: add log for stopping the TLS listener * test: add test helper to create certificates for testing instead of using hardcoded pem file * test: assert non-TLS service has no TLSSettings * test: refactor TestMtlsRootCAsFromCertificate to table-based tests * refactor: pull listener addresses from listener and remove redundant struct field for tls address
This commit is contained in:
Родитель
13f7037fd9
Коммит
de225e4d34
|
@ -20,6 +20,7 @@
|
|||
"TLSPort": "10091",
|
||||
"TLSSubjectName": "",
|
||||
"UseHTTPS": false,
|
||||
"UseMTLS": false,
|
||||
"WireserverIP": "168.63.129.16",
|
||||
"KeyVaultSettings": {
|
||||
"URL": "",
|
||||
|
|
|
@ -55,6 +55,7 @@ type CNSConfig struct {
|
|||
TLSSubjectName string
|
||||
TelemetrySettings TelemetrySettings
|
||||
UseHTTPS bool
|
||||
UseMTLS bool
|
||||
WatchPods bool `json:"-"`
|
||||
WireserverIP string
|
||||
}
|
||||
|
|
|
@ -87,6 +87,7 @@ func TestReadConfigFromFile(t *testing.T) {
|
|||
PopulateHomeAzCacheRetryIntervalSecs: 60,
|
||||
},
|
||||
UseHTTPS: true,
|
||||
UseMTLS: true,
|
||||
WireserverIP: "168.63.129.16",
|
||||
},
|
||||
wantErr: false,
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
"TelemetryBatchSizeBytes": 16384
|
||||
},
|
||||
"UseHTTPS": true,
|
||||
"UseMTLS": true,
|
||||
"WireserverIP": "168.63.129.16",
|
||||
"AZRSettings": {
|
||||
"PopulateHomeAzCacheRetryIntervalSecs": 60
|
||||
|
|
|
@ -6,6 +6,7 @@ package cns
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -190,6 +191,18 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
|
|||
},
|
||||
}
|
||||
|
||||
if tlsSettings.UseMTLS {
|
||||
rootCAs, err := mtlsRootCAsFromCertificate(&tlsCert)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
|
||||
}
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsConfig.ClientCAs = rootCAs
|
||||
tlsConfig.RootCAs = rootCAs
|
||||
}
|
||||
|
||||
logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
|
@ -224,9 +237,51 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
|
|||
},
|
||||
}
|
||||
|
||||
if tlsSettings.UseMTLS {
|
||||
tlsCert := cr.GetCertificate()
|
||||
rootCAs, err := mtlsRootCAsFromCertificate(tlsCert)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
|
||||
}
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsConfig.ClientCAs = rootCAs
|
||||
tlsConfig.RootCAs = rootCAs
|
||||
}
|
||||
|
||||
logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
|
||||
|
||||
return &tlsConfig, nil
|
||||
}
|
||||
|
||||
// Given a TLS cert, return the root CAs
|
||||
func mtlsRootCAsFromCertificate(tlsCert *tls.Certificate) (*x509.CertPool, error) {
|
||||
switch {
|
||||
case tlsCert == nil || len(tlsCert.Certificate) == 0:
|
||||
return nil, errors.New("no certificate provided")
|
||||
case len(tlsCert.Certificate) == 1:
|
||||
certs := x509.NewCertPool()
|
||||
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parsing self signed cert")
|
||||
}
|
||||
certs.AddCert(cert)
|
||||
|
||||
return certs, nil
|
||||
default:
|
||||
certs := x509.NewCertPool()
|
||||
// given a fullchain cert, we skip leaf cert at index 0 because
|
||||
// we only want intermediate and root certs in the cert pool for mTLS
|
||||
for _, certBytes := range tlsCert.Certificate[1:] {
|
||||
cert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parsing root certs")
|
||||
}
|
||||
certs.AddCert(cert)
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (service *Service) StartListener(config *common.ServiceConfig) error {
|
||||
log.Debugf("[Azure CNS] Going to start listener: %+v", config)
|
||||
|
||||
|
|
|
@ -786,6 +786,7 @@ func main() {
|
|||
KeyVaultCertificateName: cnsconfig.KeyVaultSettings.CertificateName,
|
||||
MSIResourceID: cnsconfig.MSISettings.ResourceID,
|
||||
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
|
||||
UseMTLS: cnsconfig.UseMTLS,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
// Copyright 2017 Microsoft. All rights reserved.
|
||||
// MIT License
|
||||
|
||||
package cns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-container-networking/cns/common"
|
||||
"github.com/Azure/azure-container-networking/cns/logger"
|
||||
acn "github.com/Azure/azure-container-networking/common"
|
||||
serverTLS "github.com/Azure/azure-container-networking/server/tls"
|
||||
"github.com/Azure/azure-container-networking/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewService(t *testing.T) {
|
||||
logger.InitLogger("azure-cns.log", 0, 0, "/")
|
||||
mockStore := store.NewMockStore("test")
|
||||
|
||||
config := &common.ServiceConfig{
|
||||
Name: "test",
|
||||
Version: "1.0",
|
||||
ChannelMode: "Direct",
|
||||
Store: mockStore,
|
||||
}
|
||||
|
||||
t.Run("NewService", func(t *testing.T) {
|
||||
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &Service{}, svc)
|
||||
|
||||
svc.SetOption(acn.OptCnsURL, "")
|
||||
svc.SetOption(acn.OptCnsPort, "")
|
||||
|
||||
require.Empty(t, config.TLSSettings)
|
||||
|
||||
err = svc.Initialize(config)
|
||||
t.Cleanup(func() {
|
||||
svc.Uninitialize()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.StartListener(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
t.Cleanup(func() {
|
||||
resp.Body.Close()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("NewServiceWithTLS", func(t *testing.T) {
|
||||
testCertFilePath := createTestCertificate(t)
|
||||
|
||||
config.TLSSettings = serverTLS.TlsSettings{
|
||||
TLSPort: "10091",
|
||||
TLSSubjectName: "localhost",
|
||||
TLSCertificatePath: testCertFilePath,
|
||||
}
|
||||
|
||||
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &Service{}, svc)
|
||||
|
||||
svc.SetOption(acn.OptCnsURL, "")
|
||||
svc.SetOption(acn.OptCnsPort, "")
|
||||
|
||||
err = svc.Initialize(config)
|
||||
t.Cleanup(func() {
|
||||
svc.Uninitialize()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.StartListener(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
ServerName: config.TLSSettings.TLSSubjectName,
|
||||
// #nosec G402 for test purposes only
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TLS listener
|
||||
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
resp, err := tlsClient.Do(req)
|
||||
t.Cleanup(func() {
|
||||
resp.Body.Close()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// HTTP listener
|
||||
httpClient := &http.Client{}
|
||||
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
resp, err = httpClient.Do(req)
|
||||
t.Cleanup(func() {
|
||||
resp.Body.Close()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
|
||||
testCertFilePath := createTestCertificate(t)
|
||||
|
||||
config.TLSSettings = serverTLS.TlsSettings{
|
||||
TLSPort: "10091",
|
||||
TLSSubjectName: "localhost",
|
||||
TLSCertificatePath: testCertFilePath,
|
||||
UseMTLS: true,
|
||||
}
|
||||
|
||||
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &Service{}, svc)
|
||||
|
||||
svc.SetOption(acn.OptCnsURL, "")
|
||||
svc.SetOption(acn.OptCnsPort, "")
|
||||
|
||||
err = svc.Initialize(config)
|
||||
t.Cleanup(func() {
|
||||
svc.Uninitialize()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.StartListener(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: mTLSConfig,
|
||||
},
|
||||
}
|
||||
|
||||
// TLS listener
|
||||
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
t.Cleanup(func() {
|
||||
resp.Body.Close()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// HTTP listener
|
||||
httpClient := &http.Client{}
|
||||
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
resp, err = httpClient.Do(req)
|
||||
t.Cleanup(func() {
|
||||
resp.Body.Close()
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMtlsRootCAsFromCertificate(t *testing.T) {
|
||||
testCertFilePath := createTestCertificate(t)
|
||||
|
||||
tlsSettings := serverTLS.TlsSettings{
|
||||
TLSCertificatePath: testCertFilePath,
|
||||
}
|
||||
tlsCertRetriever, err := serverTLS.GetTlsCertificateRetriever(tlsSettings)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := tlsCertRetriever.GetCertificate()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := tlsCertRetriever.GetPrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *tls.Certificate
|
||||
wantErr bool
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "returns root CA pool when provided a single self-signed CA cert",
|
||||
cert: &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw},
|
||||
PrivateKey: key,
|
||||
Leaf: cert,
|
||||
},
|
||||
wantErr: false,
|
||||
wantErrMsg: "",
|
||||
},
|
||||
{
|
||||
name: "returns root CA pool when provided with a full cert chain",
|
||||
cert: &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw, cert.Raw},
|
||||
PrivateKey: key,
|
||||
Leaf: cert,
|
||||
},
|
||||
wantErr: false,
|
||||
wantErrMsg: "",
|
||||
},
|
||||
{
|
||||
name: "does not return root CA pool when provided with nil",
|
||||
cert: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "no certificate provided",
|
||||
},
|
||||
{
|
||||
name: "does not return root CA pool when provided with empty cert",
|
||||
cert: &tls.Certificate{},
|
||||
wantErr: true,
|
||||
wantErrMsg: "no certificate provided",
|
||||
},
|
||||
{
|
||||
name: "does not return root CA pool when provided with single invalid cert",
|
||||
cert: &tls.Certificate{
|
||||
Certificate: [][]byte{[]byte("invalid leaf cert")},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrMsg: "parsing self signed cert",
|
||||
},
|
||||
{
|
||||
name: "does not return root CA pool when provided with invalid full chain cert",
|
||||
cert: &tls.Certificate{
|
||||
Certificate: [][]byte{[]byte("invalid leaf cert"), []byte("invalid root CA cert")},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrMsg: "parsing root certs",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r, err := mtlsRootCAsFromCertificate(tt.cert)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.wantErrMsg)
|
||||
assert.Nil(t, r)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// createTestCertificate is a test helper that creates a test certificate
|
||||
// and writes it to a temporary file that is cleaned up after the test.
|
||||
// Returns the path to the test certificate file
|
||||
func createTestCertificate(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
t.Log("Creating test certificate...")
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "foo.com",
|
||||
},
|
||||
DNSNames: []string{"localhost", "127.0.0.1", "example.com"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(3 * time.Hour),
|
||||
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
// Create certificate with the template and keys
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cert PEM
|
||||
pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
require.NotNil(t, pemCert)
|
||||
|
||||
// Private Key PEM
|
||||
privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
|
||||
require.NotNil(t, pemKey)
|
||||
|
||||
pemCert = append(pemCert, pemKey...)
|
||||
|
||||
// Write PEM cert and key to a file in a temp dir
|
||||
testCertFilePath := filepath.Join(t.TempDir(), "dummy.pem")
|
||||
err = os.WriteFile(testCertFilePath, pemCert, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Log("Created test certificate file at: ", testCertFilePath)
|
||||
|
||||
return testCertFilePath
|
||||
}
|
|
@ -100,6 +100,7 @@ func (l *Listener) Stop() {
|
|||
if l.tlsListener != nil {
|
||||
// Stop servicing requests on secure listener
|
||||
_ = l.tlsListener.Close()
|
||||
log.Printf("[Listener] Stopped listening on tls endpoint %s", l.tlsListener.Addr())
|
||||
}
|
||||
|
||||
// Delete the unix socket.
|
||||
|
@ -107,7 +108,7 @@ func (l *Listener) Stop() {
|
|||
_ = os.Remove(l.localAddress)
|
||||
}
|
||||
|
||||
log.Printf("[Listener] Stopped listening on %s", l.localAddress)
|
||||
log.Printf("[Listener] Stopped listening on %s", l.listener.Addr())
|
||||
}
|
||||
|
||||
// GetMux returns the HTTP mux for the listener.
|
||||
|
|
|
@ -13,6 +13,7 @@ type TlsSettings struct {
|
|||
KeyVaultCertificateName string
|
||||
MSIResourceID string
|
||||
KeyVaultCertificateRefreshInterval time.Duration
|
||||
UseMTLS bool
|
||||
}
|
||||
|
||||
func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче