* feat: add UseMTLS config

* feat: add mTLS auth for CNS

* test: add testdata for mTLS tests

* chore: add logs on TLS config retrieval

* lint: in tests

* refactor: use CNS logger, not ACN logger

* refactor: add guards to mtlsRootCAsFromCertificate and unit tests

* lint: fix lint errors

* test: include HTTP listener tests for when TLS/mTLS is enabled

* chore: add log for stopping the TLS listener

* test: add test helper to create certificates for testing instead of using hardcoded pem file

* test: assert non-TLS service has no TLSSettings

* test: refactor TestMtlsRootCAsFromCertificate to table-based tests

* refactor: pull listener addresses from listener and remove redundant struct field for tls address
This commit is contained in:
Jackie Luc 2024-06-07 15:26:49 -07:00 коммит произвёл GitHub
Родитель 13f7037fd9
Коммит de225e4d34
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 387 добавлений и 1 удалений

Просмотреть файл

@ -20,6 +20,7 @@
"TLSPort": "10091",
"TLSSubjectName": "",
"UseHTTPS": false,
"UseMTLS": false,
"WireserverIP": "168.63.129.16",
"KeyVaultSettings": {
"URL": "",

Просмотреть файл

@ -55,6 +55,7 @@ type CNSConfig struct {
TLSSubjectName string
TelemetrySettings TelemetrySettings
UseHTTPS bool
UseMTLS bool
WatchPods bool `json:"-"`
WireserverIP string
}

Просмотреть файл

@ -87,6 +87,7 @@ func TestReadConfigFromFile(t *testing.T) {
PopulateHomeAzCacheRetryIntervalSecs: 60,
},
UseHTTPS: true,
UseMTLS: true,
WireserverIP: "168.63.129.16",
},
wantErr: false,

1
cns/configuration/testdata/good.json поставляемый
Просмотреть файл

@ -30,6 +30,7 @@
"TelemetryBatchSizeBytes": 16384
},
"UseHTTPS": true,
"UseMTLS": true,
"WireserverIP": "168.63.129.16",
"AZRSettings": {
"PopulateHomeAzCacheRetryIntervalSecs": 60

Просмотреть файл

@ -6,6 +6,7 @@ package cns
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
@ -190,6 +191,18 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error)
},
}
if tlsSettings.UseMTLS {
rootCAs, err := mtlsRootCAsFromCertificate(&tlsCert)
if err != nil {
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
}
logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)
return tlsConfig, nil
}
@ -224,9 +237,51 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e
},
}
if tlsSettings.UseMTLS {
tlsCert := cr.GetCertificate()
rootCAs, err := mtlsRootCAsFromCertificate(tlsCert)
if err != nil {
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
}
logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
return &tlsConfig, nil
}
// Given a TLS cert, return the root CAs
func mtlsRootCAsFromCertificate(tlsCert *tls.Certificate) (*x509.CertPool, error) {
switch {
case tlsCert == nil || len(tlsCert.Certificate) == 0:
return nil, errors.New("no certificate provided")
case len(tlsCert.Certificate) == 1:
certs := x509.NewCertPool()
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "parsing self signed cert")
}
certs.AddCert(cert)
return certs, nil
default:
certs := x509.NewCertPool()
// given a fullchain cert, we skip leaf cert at index 0 because
// we only want intermediate and root certs in the cert pool for mTLS
for _, certBytes := range tlsCert.Certificate[1:] {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, errors.Wrap(err, "parsing root certs")
}
certs.AddCert(cert)
}
return certs, nil
}
}
func (service *Service) StartListener(config *common.ServiceConfig) error {
log.Debugf("[Azure CNS] Going to start listener: %+v", config)

Просмотреть файл

@ -786,6 +786,7 @@ func main() {
KeyVaultCertificateName: cnsconfig.KeyVaultSettings.CertificateName,
MSIResourceID: cnsconfig.MSISettings.ResourceID,
KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour,
UseMTLS: cnsconfig.UseMTLS,
}
}

324
cns/service_test.go Normal file
Просмотреть файл

@ -0,0 +1,324 @@
// Copyright 2017 Microsoft. All rights reserved.
// MIT License
package cns
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/Azure/azure-container-networking/cns/common"
"github.com/Azure/azure-container-networking/cns/logger"
acn "github.com/Azure/azure-container-networking/common"
serverTLS "github.com/Azure/azure-container-networking/server/tls"
"github.com/Azure/azure-container-networking/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewService(t *testing.T) {
logger.InitLogger("azure-cns.log", 0, 0, "/")
mockStore := store.NewMockStore("test")
config := &common.ServiceConfig{
Name: "test",
Version: "1.0",
ChannelMode: "Direct",
Store: mockStore,
}
t.Run("NewService", func(t *testing.T) {
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)
svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")
require.Empty(t, config.TLSSettings)
err = svc.Initialize(config)
t.Cleanup(func() {
svc.Uninitialize()
})
require.NoError(t, err)
err = svc.StartListener(config)
require.NoError(t, err)
client := &http.Client{}
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
})
t.Run("NewServiceWithTLS", func(t *testing.T) {
testCertFilePath := createTestCertificate(t)
config.TLSSettings = serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
}
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)
svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")
err = svc.Initialize(config)
t.Cleanup(func() {
svc.Uninitialize()
})
require.NoError(t, err)
err = svc.StartListener(config)
require.NoError(t, err)
tlsClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
ServerName: config.TLSSettings.TLSSubjectName,
// #nosec G402 for test purposes only
InsecureSkipVerify: true,
},
},
}
// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
require.NoError(t, err)
resp, err := tlsClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
})
t.Run("NewServiceWithMutualTLS", func(t *testing.T) {
testCertFilePath := createTestCertificate(t)
config.TLSSettings = serverTLS.TlsSettings{
TLSPort: "10091",
TLSSubjectName: "localhost",
TLSCertificatePath: testCertFilePath,
UseMTLS: true,
}
svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store)
require.NoError(t, err)
require.IsType(t, &Service{}, svc)
svc.SetOption(acn.OptCnsURL, "")
svc.SetOption(acn.OptCnsPort, "")
err = svc.Initialize(config)
t.Cleanup(func() {
svc.Uninitialize()
})
require.NoError(t, err)
err = svc.StartListener(config)
require.NoError(t, err)
mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings)
require.NoError(t, err)
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: mTLSConfig,
},
}
// TLS listener
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
// HTTP listener
httpClient := &http.Client{}
req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody)
require.NoError(t, err)
resp, err = httpClient.Do(req)
t.Cleanup(func() {
resp.Body.Close()
})
require.NoError(t, err)
})
}
func TestMtlsRootCAsFromCertificate(t *testing.T) {
testCertFilePath := createTestCertificate(t)
tlsSettings := serverTLS.TlsSettings{
TLSCertificatePath: testCertFilePath,
}
tlsCertRetriever, err := serverTLS.GetTlsCertificateRetriever(tlsSettings)
require.NoError(t, err)
cert, err := tlsCertRetriever.GetCertificate()
require.NoError(t, err)
key, err := tlsCertRetriever.GetPrivateKey()
require.NoError(t, err)
tests := []struct {
name string
cert *tls.Certificate
wantErr bool
wantErrMsg string
}{
{
name: "returns root CA pool when provided a single self-signed CA cert",
cert: &tls.Certificate{
Certificate: [][]byte{cert.Raw},
PrivateKey: key,
Leaf: cert,
},
wantErr: false,
wantErrMsg: "",
},
{
name: "returns root CA pool when provided with a full cert chain",
cert: &tls.Certificate{
Certificate: [][]byte{cert.Raw, cert.Raw},
PrivateKey: key,
Leaf: cert,
},
wantErr: false,
wantErrMsg: "",
},
{
name: "does not return root CA pool when provided with nil",
cert: nil,
wantErr: true,
wantErrMsg: "no certificate provided",
},
{
name: "does not return root CA pool when provided with empty cert",
cert: &tls.Certificate{},
wantErr: true,
wantErrMsg: "no certificate provided",
},
{
name: "does not return root CA pool when provided with single invalid cert",
cert: &tls.Certificate{
Certificate: [][]byte{[]byte("invalid leaf cert")},
},
wantErr: true,
wantErrMsg: "parsing self signed cert",
},
{
name: "does not return root CA pool when provided with invalid full chain cert",
cert: &tls.Certificate{
Certificate: [][]byte{[]byte("invalid leaf cert"), []byte("invalid root CA cert")},
},
wantErr: true,
wantErrMsg: "parsing root certs",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, err := mtlsRootCAsFromCertificate(tt.cert)
if tt.wantErr {
require.Error(t, err)
require.ErrorContains(t, err, tt.wantErrMsg)
assert.Nil(t, r)
} else {
require.NoError(t, err)
assert.NotNil(t, r)
}
})
}
}
// createTestCertificate is a test helper that creates a test certificate
// and writes it to a temporary file that is cleaned up after the test.
// Returns the path to the test certificate file
func createTestCertificate(t *testing.T) string {
t.Helper()
t.Log("Creating test certificate...")
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "foo.com",
},
DNSNames: []string{"localhost", "127.0.0.1", "example.com"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(3 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
// Create certificate with the template and keys
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
// Cert PEM
pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
require.NotNil(t, pemCert)
// Private Key PEM
privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
require.NotNil(t, pemKey)
pemCert = append(pemCert, pemKey...)
// Write PEM cert and key to a file in a temp dir
testCertFilePath := filepath.Join(t.TempDir(), "dummy.pem")
err = os.WriteFile(testCertFilePath, pemCert, 0o600)
require.NoError(t, err)
t.Log("Created test certificate file at: ", testCertFilePath)
return testCertFilePath
}

Просмотреть файл

@ -100,6 +100,7 @@ func (l *Listener) Stop() {
if l.tlsListener != nil {
// Stop servicing requests on secure listener
_ = l.tlsListener.Close()
log.Printf("[Listener] Stopped listening on tls endpoint %s", l.tlsListener.Addr())
}
// Delete the unix socket.
@ -107,7 +108,7 @@ func (l *Listener) Stop() {
_ = os.Remove(l.localAddress)
}
log.Printf("[Listener] Stopped listening on %s", l.localAddress)
log.Printf("[Listener] Stopped listening on %s", l.listener.Addr())
}
// GetMux returns the HTTP mux for the listener.

Просмотреть файл

@ -13,6 +13,7 @@ type TlsSettings struct {
KeyVaultCertificateName string
MSIResourceID string
KeyVaultCertificateRefreshInterval time.Duration
UseMTLS bool
}
func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) {