azure-container-networking/cns/service.go

343 строки
10 KiB
Go

// Copyright 2017 Microsoft. All rights reserved.
// MIT License
package cns
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/Azure/azure-container-networking/cns/common"
"github.com/Azure/azure-container-networking/cns/logger"
acn "github.com/Azure/azure-container-networking/common"
"github.com/Azure/azure-container-networking/keyvault"
localtls "github.com/Azure/azure-container-networking/server/tls"
"github.com/Azure/azure-container-networking/store"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/pkg/errors"
)
const (
defaultAPIServerPort = "10090"
genericData = "com.microsoft.azure.network.generic"
)
var errTLSConfig = errors.New("unsupported TLS version name from config")
// Service defines Container Networking Service.
type Service struct {
*common.Service
EndpointType string
Listener *acn.Listener
}
// NewService creates a new Service object.
func NewService(name, version, channelMode string, store store.KeyValueStore) (*Service, error) {
service, err := common.NewService(name, version, channelMode, store)
if err != nil {
return nil, err
}
return &Service{
Service: service,
}, nil
}
func (service *Service) AddListener(config *common.ServiceConfig) error {
var (
err error
nodeURL *url.URL
)
// if cnsURL is empty the VM primary interface IP will be used
// if customer specifies -c option, then use this URL with warning message and it will be deprecated soon
cnsURL, ok := service.GetOption(acn.OptCnsURL).(string)
if !ok {
return errors.New("cnsURL type is wrong")
}
// if customer provides port number by -p option, then use VM IP with this port and localhost server also uses this port
// otherwise it will use defaultAPIServerPort 10090
cnsPort, ok := service.GetOption(acn.OptCnsPort).(string)
if !ok {
return errors.New("cnsPort type is wrong")
}
if cnsURL == "" {
config.Server.EnableLocalServer = true
// get VM primary interface's private IP
// if customer does use -p option, then use port number customers provide
if cnsPort == "" {
nodeURL, err = url.Parse(fmt.Sprintf("tcp://%s:%s", config.Server.PrimaryInterfaceIP, defaultAPIServerPort))
} else {
config.Server.Port = cnsPort
nodeURL, err = url.Parse(fmt.Sprintf("tcp://%s:%s", config.Server.PrimaryInterfaceIP, cnsPort))
}
if err != nil {
return errors.Wrap(err, "Failed to parse URL for legacy server")
}
} else {
// use the URL that customer provides by -c
logger.Printf("user specifies -c option")
// do not enable local server if customer uses -c option
config.Server.EnableLocalServer = false
nodeURL, err = url.Parse(cnsURL)
if err != nil {
return errors.Wrap(err, "Failed to parse URL that customer provides")
}
}
logger.Debugf("CNS remote server url: %+v", nodeURL)
nodeListener, err := acn.NewListener(nodeURL)
if err != nil {
return errors.Wrap(err, "Failed to construct url for node listener")
}
// only use TLS connection for DNC/CNS listener:
if config.TLSSettings.TLSPort != "" {
// listener.URL.Host will always be hostname:port, passed in to CNS via CNS command
// else it will default to localhost
// extract hostname and override tls port.
hostParts := strings.Split(nodeListener.URL.Host, ":")
tlsAddress := net.JoinHostPort(hostParts[0], config.TLSSettings.TLSPort)
// Start the listener and HTTP and HTTPS server.
tlsConfig, err := getTLSConfig(config.TLSSettings, config.ErrChan) //nolint
if err != nil {
logger.Printf("Failed to compose Tls Configuration with error: %+v", err)
return errors.Wrap(err, "could not get tls config")
}
if err := nodeListener.StartTLS(config.ErrChan, tlsConfig, tlsAddress); err != nil {
return errors.Wrap(err, "could not start tls")
}
}
service.Listener = nodeListener
logger.Debugf("[Azure CNS] Successfully initialized a service with config: %+v", config)
return nil
}
// Initialize initializes the service and starts the listener.
func (service *Service) Initialize(config *common.ServiceConfig) error {
logger.Debugf("[Azure CNS] Going to initialize a service with config: %+v", config)
// Initialize the base service.
if err := service.Service.Initialize(config); err != nil {
return errors.Wrap(err, "failed to initialize")
}
if err := service.AddListener(config); err != nil {
return errors.Wrap(err, "failed to initialize listener")
}
return nil
}
func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) {
if tlsSettings.TLSCertificatePath != "" {
return getTLSConfigFromFile(tlsSettings)
}
if tlsSettings.KeyVaultURL != "" {
return getTLSConfigFromKeyVault(tlsSettings, errChan)
}
return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings)
}
func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) {
tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings)
if err != nil {
return nil, errors.Wrap(err, "failed to get certificate retriever")
}
leafCertificate, err := tlsCertRetriever.GetCertificate()
if err != nil {
return nil, errors.Wrap(err, "failed to get certificate")
}
if leafCertificate == nil {
return nil, errors.New("certificate retrieval returned empty")
}
privateKey, err := tlsCertRetriever.GetPrivateKey()
if err != nil {
return nil, errors.Wrap(err, "failed to get certificate private key")
}
tlsCert := tls.Certificate{
Certificate: [][]byte{leafCertificate.Raw},
PrivateKey: privateKey,
Leaf: leafCertificate,
}
minTLSVersionNumber, err := parseTLSVersionName(tlsSettings.MinTLSVersion)
if err != nil {
return nil, errors.Wrap(err, "parsing MinTLSVersion from config")
}
tlsConfig := &tls.Config{
MaxVersion: tls.VersionTLS13,
MinVersion: minTLSVersionNumber,
Certificates: []tls.Certificate{
tlsCert,
},
}
if tlsSettings.UseMTLS {
rootCAs, err := mtlsRootCAsFromCertificate(&tlsCert)
if err != nil {
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
}
logger.Debugf("TLS configured successfully from file: %+v", tlsSettings)
return tlsConfig, nil
}
func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls.Config, error) {
credOpts := azidentity.ManagedIdentityCredentialOptions{ID: azidentity.ResourceID(tlsSettings.MSIResourceID)}
cred, err := azidentity.NewManagedIdentityCredential(&credOpts)
if err != nil {
return nil, errors.Wrap(err, "could not create managed identity credential")
}
kvs, err := keyvault.NewShim(tlsSettings.KeyVaultURL, cred)
if err != nil {
return nil, errors.Wrap(err, "could not create new keyvault shim")
}
ctx := context.TODO()
cr, err := keyvault.NewCertRefresher(ctx, kvs, logger.Log, tlsSettings.KeyVaultCertificateName)
if err != nil {
return nil, errors.Wrap(err, "could not create new cert refresher")
}
go func() {
errChan <- cr.Refresh(ctx, tlsSettings.KeyVaultCertificateRefreshInterval)
}()
minTLSVersionNumber, err := parseTLSVersionName(tlsSettings.MinTLSVersion)
if err != nil {
return nil, errors.Wrap(err, "parsing MinTLSVersion from config")
}
tlsConfig := tls.Config{
MinVersion: minTLSVersionNumber,
MaxVersion: tls.VersionTLS13,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return cr.GetCertificate(), nil
},
}
if tlsSettings.UseMTLS {
tlsCert := cr.GetCertificate()
rootCAs, err := mtlsRootCAsFromCertificate(tlsCert)
if err != nil {
return nil, errors.Wrap(err, "failed to get root CAs for configuring mTLS")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = rootCAs
tlsConfig.RootCAs = rootCAs
}
logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings)
return &tlsConfig, nil
}
// Given a TLS cert, return the root CAs
func mtlsRootCAsFromCertificate(tlsCert *tls.Certificate) (*x509.CertPool, error) {
switch {
case tlsCert == nil || len(tlsCert.Certificate) == 0:
return nil, errors.New("no certificate provided")
case len(tlsCert.Certificate) == 1:
certs := x509.NewCertPool()
cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "parsing self signed cert")
}
certs.AddCert(cert)
return certs, nil
default:
certs := x509.NewCertPool()
// given a fullchain cert, we skip leaf cert at index 0 because
// we only want intermediate and root certs in the cert pool for mTLS
for _, certBytes := range tlsCert.Certificate[1:] {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, errors.Wrap(err, "parsing root certs")
}
certs.AddCert(cert)
}
return certs, nil
}
}
func (service *Service) StartListener(config *common.ServiceConfig) error {
logger.Debugf("[Azure CNS] Going to start listener: %+v", config)
// Initialize the listener.
if service.Listener != nil {
logger.Debugf("[Azure CNS] Starting listener: %+v", config)
// Start the listener.
// continue to listen on the normal endpoint for http traffic, this will be supported
// for sometime until partners migrate fully to https
if err := service.Listener.Start(config.ErrChan); err != nil {
return err
}
} else {
return fmt.Errorf("Failed to start a listener, it is not initialized, config %+v", config)
}
return nil
}
// Uninitialize cleans up the plugin.
func (service *Service) Uninitialize() {
service.Listener.Stop()
service.Service.Uninitialize()
}
// ParseOptions returns generic options from a libnetwork request.
func (service *Service) ParseOptions(options OptionMap) OptionMap {
opt, _ := options[genericData].(OptionMap)
return opt
}
// SendErrorResponse sends and logs an error response.
func (service *Service) SendErrorResponse(w http.ResponseWriter, errMsg error) {
resp := errorResponse{errMsg.Error()}
err := acn.Encode(w, &resp)
logger.Errorf("[%s] %+v %s.", service.Name, &resp, err.Error())
}
// parseTLSVersionName returns the version number for the provided TLS version name
// (e.g. 0x0301)
func parseTLSVersionName(versionName string) (uint16, error) {
switch versionName {
case "TLS 1.2":
return tls.VersionTLS12, nil
case "TLS 1.3":
return tls.VersionTLS13, nil
default:
return 0, errors.Wrapf(errTLSConfig, "version name %s", versionName)
}
}