343 строки
10 KiB
Go
343 строки
10 KiB
Go
// Copyright 2017 Microsoft. All rights reserved.
|
|
// MIT License
|
|
|
|
package cns
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/Azure/azure-container-networking/cns/common"
|
|
"github.com/Azure/azure-container-networking/cns/logger"
|
|
acn "github.com/Azure/azure-container-networking/common"
|
|
"github.com/Azure/azure-container-networking/keyvault"
|
|
localtls "github.com/Azure/azure-container-networking/server/tls"
|
|
"github.com/Azure/azure-container-networking/store"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const (
|
|
defaultAPIServerPort = "10090"
|
|
genericData = "com.microsoft.azure.network.generic"
|
|
)
|
|
|
|
var errTLSConfig = errors.New("unsupported TLS version name from config")
|
|
|
|
// Service defines Container Networking Service.
|
|
type Service struct {
|
|
*common.Service
|
|
EndpointType string
|
|
Listener *acn.Listener
|
|
}
|
|
|
|
// NewService creates a new Service object.
|
|
func NewService(name, version, channelMode string, store store.KeyValueStore) (*Service, error) {
|
|
service, err := common.NewService(name, version, channelMode, store)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Service{
|
|
Service: service,
|
|
}, nil
|
|
}
|
|
|
|
func (service *Service) AddListener(config *common.ServiceConfig) error {
|
|
var (
|
|
err error
|
|
nodeURL *url.URL
|
|
)
|
|
|
|
// if cnsURL is empty the VM primary interface IP will be used
|
|
// if customer specifies -c option, then use this URL with warning message and it will be deprecated soon
|
|
cnsURL, ok := service.GetOption(acn.OptCnsURL).(string)
|
|
if !ok {
|
|
return errors.New("cnsURL type is wrong")
|
|
}
|
|
|
|
// if customer provides port number by -p option, then use VM IP with this port and localhost server also uses this port
|
|
// otherwise it will use defaultAPIServerPort 10090
|
|
cnsPort, ok := service.GetOption(acn.OptCnsPort).(string)
|
|
if !ok {
|
|
return errors.New("cnsPort type is wrong")
|
|
}
|
|
|
|
if cnsURL == "" {
|
|
config.Server.EnableLocalServer = true
|
|
// get VM primary interface's private IP
|
|
// if customer does use -p option, then use port number customers provide
|
|
if cnsPort == "" {
|
|
nodeURL, err = url.Parse(fmt.Sprintf("tcp://%s:%s", config.Server.PrimaryInterfaceIP, defaultAPIServerPort))
|
|
} else {
|
|
config.Server.Port = cnsPort
|
|
nodeURL, err = url.Parse(fmt.Sprintf("tcp://%s:%s", config.Server.PrimaryInterfaceIP, cnsPort))
|
|
}
|
|
|
|
if err != nil {
|
|
return errors.Wrap(err, "Failed to parse URL for legacy server")
|
|
}
|
|
} else {
|
|
// use the URL that customer provides by -c
|
|
logger.Printf("user specifies -c option")
|
|
|
|
// do not enable local server if customer uses -c option
|
|
config.Server.EnableLocalServer = false
|
|
nodeURL, err = url.Parse(cnsURL)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Failed to parse URL that customer provides")
|
|
}
|
|
}
|
|
|
|
logger.Debugf("CNS remote server url: %+v", nodeURL)
|
|
|
|
nodeListener, err := acn.NewListener(nodeURL)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Failed to construct url for node listener")
|
|
}
|
|
|
|
// only use TLS connection for DNC/CNS listener:
|
|
if config.TLSSettings.TLSPort != "" {
|
|
// listener.URL.Host will always be hostname:port, passed in to CNS via CNS command
|
|
// else it will default to localhost
|
|
// extract hostname and override tls port.
|
|
hostParts := strings.Split(nodeListener.URL.Host, ":")
|
|
tlsAddress := net.JoinHostPort(hostParts[0], config.TLSSettings.TLSPort)
|
|
|
|
// Start the listener and HTTP and HTTPS server.
|
|
tlsConfig, err := getTLSConfig(config.TLSSettings, config.ErrChan) //nolint
|
|
if err != nil {
|
|
logger.Printf("Failed to compose Tls Configuration with error: %+v", err)
|
|
return errors.Wrap(err, "could not get tls config")
|
|
}
|
|
|
|
if err := nodeListener.StartTLS(config.ErrChan, tlsConfig, tlsAddress); err != nil {
|
|
return errors.Wrap(err, "could not start tls")
|
|
}
|
|
}
|
|
|
|
service.Listener = nodeListener
|
|
logger.Debugf("[Azure CNS] Successfully initialized a service with config: %+v", config)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Initialize initializes the service and starts the listener.
|
|
func (service *Service) Initialize(config *common.ServiceConfig) error {
|
|
logger.Debugf("[Azure CNS] Going to initialize a service with config: %+v", config)
|
|
|
|
// Initialize the base service.
|
|
if err := service.Service.Initialize(config); err != nil {
|
|
return errors.Wrap(err, "failed to initialize")
|
|
}
|
|
|
|
if err := service.AddListener(config); err != nil {
|
|
return errors.Wrap(err, "failed to initialize listener")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) {
|
|
if tlsSettings.TLSCertificatePath != "" {
|
|
return getTLSConfigFromFile(tlsSettings)
|
|
}
|
|
|
|
if tlsSettings.KeyVaultURL != "" {
|
|
return getTLSConfigFromKeyVault(tlsSettings, errChan)
|
|
}
|
|
|
|
return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
|
|
}
|
|
|
|
func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
|
|
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get certificate retriever")
|
|
}
|
|
|
|
leafCertificate, err := tlsCertRetriever.GetCertificate()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get certificate")
|
|
}
|
|
|
|
if leafCertificate == nil {
|
|
return nil, errors.New("certificate retrieval returned empty")
|
|
}
|
|
|
|
privateKey, err := tlsCertRetriever.GetPrivateKey()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get certificate private key")
|
|
}
|
|
|
|
tlsCert := tls.Certificate{
|
|
Certificate: [][]byte{leafCertificate.Raw},
|
|
PrivateKey: privateKey,
|
|
Leaf: leafCertificate,
|
|
}
|
|
minTLSVersionNumber, err := parseTLSVersionName(tlsSettings.MinTLSVersion)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "parsing MinTLSVersion from config")
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
MaxVersion: tls.VersionTLS13,
|
|
MinVersion: minTLSVersionNumber,
|
|
Certificates: []tls.Certificate{
|
|
tlsCert,
|
|
},
|
|
}
|
|
|
|
if tlsSettings.UseMTLS {
|
|
rootCAs, err := mtlsRootCAsFromCertificate(&tlsCert)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
|
|
}
|
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
tlsConfig.ClientCAs = rootCAs
|
|
tlsConfig.RootCAs = rootCAs
|
|
}
|
|
|
|
logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)
|
|
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) {
|
|
credOpts := azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(tlsSettings.MSIResourceID)}
|
|
cred, err := azidentity.NewManagedIdentityCredential(&credOpts)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "could not create managed identity credential")
|
|
}
|
|
|
|
kvs, err := keyvault.NewShim(tlsSettings.KeyVaultURL, cred)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "could not create new keyvault shim")
|
|
}
|
|
|
|
ctx := context.TODO()
|
|
|
|
cr, err := keyvault.NewCertRefresher(ctx, kvs, logger.Log, tlsSettings.KeyVaultCertificateName)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "could not create new cert refresher")
|
|
}
|
|
|
|
go func() {
|
|
errChan <- cr.Refresh(ctx, tlsSettings.KeyVaultCertificateRefreshInterval)
|
|
}()
|
|
|
|
minTLSVersionNumber, err := parseTLSVersionName(tlsSettings.MinTLSVersion)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "parsing MinTLSVersion from config")
|
|
}
|
|
|
|
tlsConfig := tls.Config{
|
|
MinVersion: minTLSVersionNumber,
|
|
MaxVersion: tls.VersionTLS13,
|
|
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return cr.GetCertificate(), nil
|
|
},
|
|
}
|
|
|
|
if tlsSettings.UseMTLS {
|
|
tlsCert := cr.GetCertificate()
|
|
rootCAs, err := mtlsRootCAsFromCertificate(tlsCert)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
|
|
}
|
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
tlsConfig.ClientCAs = rootCAs
|
|
tlsConfig.RootCAs = rootCAs
|
|
}
|
|
|
|
logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
|
|
|
|
return &tlsConfig, nil
|
|
}
|
|
|
|
// Given a TLS cert, return the root CAs
|
|
func mtlsRootCAsFromCertificate(tlsCert *tls.Certificate) (*x509.CertPool, error) {
|
|
switch {
|
|
case tlsCert == nil || len(tlsCert.Certificate) == 0:
|
|
return nil, errors.New("no certificate provided")
|
|
case len(tlsCert.Certificate) == 1:
|
|
certs := x509.NewCertPool()
|
|
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "parsing self signed cert")
|
|
}
|
|
certs.AddCert(cert)
|
|
|
|
return certs, nil
|
|
default:
|
|
certs := x509.NewCertPool()
|
|
// given a fullchain cert, we skip leaf cert at index 0 because
|
|
// we only want intermediate and root certs in the cert pool for mTLS
|
|
for _, certBytes := range tlsCert.Certificate[1:] {
|
|
cert, err := x509.ParseCertificate(certBytes)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "parsing root certs")
|
|
}
|
|
certs.AddCert(cert)
|
|
}
|
|
return certs, nil
|
|
}
|
|
}
|
|
|
|
func (service *Service) StartListener(config *common.ServiceConfig) error {
|
|
logger.Debugf("[Azure CNS] Going to start listener: %+v", config)
|
|
|
|
// Initialize the listener.
|
|
if service.Listener != nil {
|
|
logger.Debugf("[Azure CNS] Starting listener: %+v", config)
|
|
// Start the listener.
|
|
// continue to listen on the normal endpoint for http traffic, this will be supported
|
|
// for sometime until partners migrate fully to https
|
|
if err := service.Listener.Start(config.ErrChan); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return fmt.Errorf("Failed to start a listener, it is not initialized, config %+v", config)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Uninitialize cleans up the plugin.
|
|
func (service *Service) Uninitialize() {
|
|
service.Listener.Stop()
|
|
service.Service.Uninitialize()
|
|
}
|
|
|
|
// ParseOptions returns generic options from a libnetwork request.
|
|
func (service *Service) ParseOptions(options OptionMap) OptionMap {
|
|
opt, _ := options[genericData].(OptionMap)
|
|
return opt
|
|
}
|
|
|
|
// SendErrorResponse sends and logs an error response.
|
|
func (service *Service) SendErrorResponse(w http.ResponseWriter, errMsg error) {
|
|
resp := errorResponse{errMsg.Error()}
|
|
err := acn.Encode(w, &resp)
|
|
logger.Errorf("[%s] %+v %s.", service.Name, &resp, err.Error())
|
|
}
|
|
|
|
// parseTLSVersionName returns the version number for the provided TLS version name
|
|
// (e.g. 0x0301)
|
|
func parseTLSVersionName(versionName string) (uint16, error) {
|
|
switch versionName {
|
|
case "TLS 1.2":
|
|
return tls.VersionTLS12, nil
|
|
case "TLS 1.3":
|
|
return tls.VersionTLS13, nil
|
|
default:
|
|
return 0, errors.Wrapf(errTLSConfig, "version name %s", versionName)
|
|
}
|
|
}
|