feat: Add CNS API to retrieve VMUniqueID from IMDS (#2842)

* Add CNS API to retrieve VMUniqueID from IMDS

* Address the PR review comments

* Address the security comment from Evans to expose this API wherever needed

* fixed the linter error

* address the PR comments from Matt

* lowercase the struct json fields
This commit is contained in:
msvik 2024-07-18 14:32:02 -07:00 коммит произвёл GitHub
Родитель 0d294720c7
Коммит 6c50d0dcdd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
13 изменённых файлов: 173 добавлений и 13 удалений

Просмотреть файл

@ -18,6 +18,7 @@ import (
const ( const (
SetOrchestratorType = "/network/setorchestratortype" SetOrchestratorType = "/network/setorchestratortype"
GetHomeAz = "/homeaz" GetHomeAz = "/homeaz"
GetVMUniqueID = "/metadata/vmuniqueid"
CreateOrUpdateNetworkContainer = "/network/createorupdatenetworkcontainer" CreateOrUpdateNetworkContainer = "/network/createorupdatenetworkcontainer"
DeleteNetworkContainer = "/network/deletenetworkcontainer" DeleteNetworkContainer = "/network/deletenetworkcontainer"
PublishNetworkContainer = "/network/publishnetworkcontainer" PublishNetworkContainer = "/network/publishnetworkcontainer"

Просмотреть файл

@ -306,8 +306,8 @@ type IpamPoolMonitorStateSnapshot struct {
// Response describes generic response from CNS. // Response describes generic response from CNS.
type Response struct { type Response struct {
ReturnCode types.ResponseCode ReturnCode types.ResponseCode `json:"ReturnCode"`
Message string Message string `json:"Message"`
} }
// NumOfCPUCoresResponse describes num of cpu cores present on host. // NumOfCPUCoresResponse describes num of cpu cores present on host.
@ -373,3 +373,8 @@ type EndpointRequest struct {
HostVethName string `json:"hostVethName"` HostVethName string `json:"hostVethName"`
InterfaceName string `json:"InterfaceName"` InterfaceName string `json:"InterfaceName"`
} }
type GetVMUniqueIDResponse struct {
Response Response `json:"response"`
VMUniqueID string `json:"vmuniqueid"`
}

Просмотреть файл

@ -151,7 +151,9 @@ func TestMain(m *testing.M) {
logger.InitLogger(logName, 0, 0, tmpLogDir+"/") logger.InitLogger(logName, 0, 0, tmpLogDir+"/")
config := common.ServiceConfig{} config := common.ServiceConfig{}
httpRestService, err := restserver.NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.WireserverProxyFake{}, &fakes.NMAgentClientFake{}, nil, nil, nil) httpRestService, err := restserver.NewHTTPRestService(&config, &fakes.WireserverClientFake{},
&fakes.WireserverProxyFake{}, &fakes.NMAgentClientFake{}, nil, nil, nil,
fakes.NewMockIMDSClient())
svc = httpRestService svc = httpRestService
httpRestService.Name = "cns-test-server" httpRestService.Name = "cns-test-server"

Просмотреть файл

@ -9,6 +9,7 @@ package fakes
import ( import (
"context" "context"
"github.com/Azure/azure-container-networking/cns/imds"
"github.com/Azure/azure-container-networking/cns/wireserver" "github.com/Azure/azure-container-networking/cns/wireserver"
) )
@ -16,10 +17,13 @@ const (
// HostPrimaryIP 10.0.0.4 // HostPrimaryIP 10.0.0.4
HostPrimaryIP = "10.0.0.4" HostPrimaryIP = "10.0.0.4"
// HostSubnet 10.0.0.0/24 // HostSubnet 10.0.0.0/24
HostSubnet = "10.0.0.0/24" HostSubnet = "10.0.0.0/24"
SimulateError MockIMDSCtxKey = "simulate-error"
) )
type WireserverClientFake struct{} type WireserverClientFake struct{}
type MockIMDSCtxKey string
type MockIMDSClient struct{}
func (c *WireserverClientFake) GetInterfaces(ctx context.Context) (*wireserver.GetInterfacesResult, error) { func (c *WireserverClientFake) GetInterfaces(ctx context.Context) (*wireserver.GetInterfacesResult, error) {
return &wireserver.GetInterfacesResult{ return &wireserver.GetInterfacesResult{
@ -41,3 +45,15 @@ func (c *WireserverClientFake) GetInterfaces(ctx context.Context) (*wireserver.G
}, },
}, nil }, nil
} }
func NewMockIMDSClient() *MockIMDSClient {
return &MockIMDSClient{}
}
func (m *MockIMDSClient) GetVMUniqueID(ctx context.Context) (string, error) {
if ctx.Value(SimulateError) != nil {
return "", imds.ErrUnexpectedStatusCode
}
return "55b8499d-9b42-4f85-843f-24ff69f4a643", nil
}

Просмотреть файл

@ -74,3 +74,29 @@ func TestIMDSInvalidJSON(t *testing.T) {
_, err := imdsClient.GetVMUniqueID(context.Background()) _, err := imdsClient.GetVMUniqueID(context.Background())
require.Error(t, err, "expected json decoding error") require.Error(t, err, "expected json decoding error")
} }
func TestInvalidVMUniqueID(t *testing.T) {
computeMetadata, err := os.ReadFile("testdata/invalidComputeMetadata.json")
require.NoError(t, err, "error reading testdata compute metadata file")
mockIMDSServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// request header "Metadata: true" must be present
metadataHeader := r.Header.Get("Metadata")
assert.Equal(t, "true", metadataHeader)
// query params should include apiversion and json format
apiVersion := r.URL.Query().Get("api-version")
assert.Equal(t, "2021-01-01", apiVersion)
format := r.URL.Query().Get("format")
assert.Equal(t, "json", format)
w.WriteHeader(http.StatusOK)
_, writeErr := w.Write(computeMetadata)
require.NoError(t, writeErr, "error writing response")
}))
defer mockIMDSServer.Close()
imdsClient := imds.NewClient(imds.Endpoint(mockIMDSServer.URL))
vmUniqueID, err := imdsClient.GetVMUniqueID(context.Background())
require.Error(t, err, "error querying testserver")
require.Equal(t, "", vmUniqueID)
}

5
cns/imds/testdata/invalidComputeMetadata.json поставляемый Normal file
Просмотреть файл

@ -0,0 +1,5 @@
{
"azEnvironment": "AzurePublicCloud",
"location": "westus2",
"vmId": ""
}

Просмотреть файл

@ -1522,3 +1522,42 @@ func (service *HTTPRestService) nmAgentSupportedApisHandler(w http.ResponseWrite
logger.Response(service.Name, nmAgentSupportedApisResponse, resp.ReturnCode, serviceErr) logger.Response(service.Name, nmAgentSupportedApisResponse, resp.ReturnCode, serviceErr)
} }
// getVMUniqueID retrieves VMUniqueID from the IMDS
func (service *HTTPRestService) getVMUniqueID(w http.ResponseWriter, r *http.Request) {
logger.Request(service.Name, "getVMUniqueID", nil)
ctx := r.Context()
switch r.Method {
case http.MethodGet:
vmUniqueID, err := service.imdsClient.GetVMUniqueID(ctx)
if err != nil {
resp := cns.GetVMUniqueIDResponse{
Response: cns.Response{
ReturnCode: types.UnexpectedError,
Message: errors.Wrap(err, "failed to get vmuniqueid").Error(),
},
}
respondJSON(w, http.StatusInternalServerError, resp)
logger.Response(service.Name, resp, resp.Response.ReturnCode, err)
return
}
resp := cns.GetVMUniqueIDResponse{
Response: cns.Response{
ReturnCode: types.Success,
},
VMUniqueID: vmUniqueID,
}
respondJSON(w, http.StatusOK, resp)
logger.Response(service.Name, resp, resp.Response.ReturnCode, err)
default:
returnMessage := fmt.Sprintf("[Azure CNS] Error. getVMUniqueID did not receive a GET."+
" Received: %s", r.Method)
returnCode := types.UnsupportedVerb
service.setResponse(w, returnCode, cns.GetHomeAzResponse{
Response: cns.Response{ReturnCode: returnCode, Message: returnMessage},
})
}
}

Просмотреть файл

@ -20,6 +20,7 @@ import (
"github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns"
"github.com/Azure/azure-container-networking/cns/common" "github.com/Azure/azure-container-networking/cns/common"
"github.com/Azure/azure-container-networking/cns/configuration"
"github.com/Azure/azure-container-networking/cns/fakes" "github.com/Azure/azure-container-networking/cns/fakes"
"github.com/Azure/azure-container-networking/cns/logger" "github.com/Azure/azure-container-networking/cns/logger"
"github.com/Azure/azure-container-networking/cns/types" "github.com/Azure/azure-container-networking/cns/types"
@ -29,6 +30,7 @@ import (
"github.com/Azure/azure-container-networking/store" "github.com/Azure/azure-container-networking/store"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -172,8 +174,10 @@ func TestMain(m *testing.M) {
var err error var err error
logger.InitLogger("testlogs", 0, 0, "./") logger.InitLogger("testlogs", 0, 0, "./")
// Create the service. // Create the service. If CRD channel mode is needed, then at the start of the test,
if err = startService(); err != nil { // it can stop the service (service.Stop), invoke startService again with new ServiceConfig (with CRD mode)
// perform the test and then restore the service again.
if err = startService(common.ServiceConfig{ChannelMode: cns.Direct}, configuration.CNSConfig{}); err != nil {
fmt.Printf("Failed to start CNS Service. Error: %v", err) fmt.Printf("Failed to start CNS Service. Error: %v", err)
os.Exit(1) os.Exit(1)
} }
@ -1666,9 +1670,9 @@ func setEnv(t *testing.T) *httptest.ResponseRecorder {
return w return w
} }
func startService() error { func startService(serviceConfig common.ServiceConfig, _ configuration.CNSConfig) error {
// Create the service. // Create the service.
config := common.ServiceConfig{} config := serviceConfig
// Create the key value fileStore. // Create the key value fileStore.
fileStore, err := store.NewJsonFileStore(cnsJsonFileName, processlock.NewMockFileLock(false), nil) fileStore, err := store.NewJsonFileStore(cnsJsonFileName, processlock.NewMockFileLock(false), nil)
@ -1679,7 +1683,8 @@ func startService() error {
config.Store = fileStore config.Store = fileStore
nmagentClient := &fakes.NMAgentClientFake{} nmagentClient := &fakes.NMAgentClientFake{}
service, err = NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.WireserverProxyFake{}, nmagentClient, nil, nil, nil) service, err = NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.WireserverProxyFake{},
nmagentClient, nil, nil, nil, fakes.NewMockIMDSClient())
if err != nil { if err != nil {
return err return err
} }
@ -1758,6 +1763,43 @@ func contains(networkContainers []cns.GetNetworkContainerResponse, str string) b
return false return false
} }
// Testing GetVMUniqueID API handler with success
func TestGetVMUniqueIDSuccess(t *testing.T) {
req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, cns.GetVMUniqueID, http.NoBody)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
var vmIDResp cns.GetVMUniqueIDResponse
err = decodeResponse(w, &vmIDResp)
require.NoError(t, err)
assert.Equal(t, types.Success, vmIDResp.Response.ReturnCode)
assert.Equal(t, "55b8499d-9b42-4f85-843f-24ff69f4a643", vmIDResp.VMUniqueID)
}
// Testing GetVMUniqueID API handler with failure
func TestGetVMUniqueIDFailed(t *testing.T) {
ctx := context.TODO()
ctx = context.WithValue(ctx, fakes.SimulateError, Interface{})
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cns.GetVMUniqueID, http.NoBody)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
var vmIDResp cns.GetVMUniqueIDResponse
err = json.NewDecoder(w.Body).Decode(&vmIDResp)
require.NoError(t, err)
assert.Equal(t, types.UnexpectedError, vmIDResp.Response.ReturnCode)
}
// IGNORE TEST AS IT IS FAILING. TODO:- Fix it https://msazure.visualstudio.com/One/_workitems/edit/7720083 // IGNORE TEST AS IT IS FAILING. TODO:- Fix it https://msazure.visualstudio.com/One/_workitems/edit/7720083
// // Tests CreateNetwork functionality. // // Tests CreateNetwork functionality.

Просмотреть файл

@ -16,6 +16,8 @@ import (
"time" "time"
"github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns"
"github.com/Azure/azure-container-networking/cns/common"
"github.com/Azure/azure-container-networking/cns/configuration"
"github.com/Azure/azure-container-networking/cns/fakes" "github.com/Azure/azure-container-networking/cns/fakes"
"github.com/Azure/azure-container-networking/cns/types" "github.com/Azure/azure-container-networking/cns/types"
"github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha" "github.com/Azure/azure-container-networking/crd/nodenetworkconfig/api/v1alpha"
@ -1056,7 +1058,7 @@ func restartService() {
fmt.Println("Restart Service") fmt.Println("Restart Service")
service.Stop() service.Stop()
if err := startService(); err != nil { if err := startService(common.ServiceConfig{}, configuration.CNSConfig{}); err != nil {
fmt.Printf("Failed to restart CNS Service. Error: %v", err) fmt.Printf("Failed to restart CNS Service. Error: %v", err)
os.Exit(1) os.Exit(1)
} }

Просмотреть файл

@ -72,7 +72,9 @@ type ncState struct {
func getTestService() *HTTPRestService { func getTestService() *HTTPRestService {
var config common.ServiceConfig var config common.ServiceConfig
httpsvc, _ := NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.WireserverProxyFake{}, &fakes.NMAgentClientFake{}, store.NewMockStore(""), nil, nil) httpsvc, _ := NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.WireserverProxyFake{},
&fakes.NMAgentClientFake{}, store.NewMockStore(""), nil, nil,
fakes.NewMockIMDSClient())
svc = httpsvc svc = httpsvc
setOrchestratorTypeInternal(cns.KubernetesCRD) setOrchestratorTypeInternal(cns.KubernetesCRD)

Просмотреть файл

@ -49,6 +49,10 @@ type wireserverProxy interface {
UnpublishNC(ctx context.Context, ncParams cns.NetworkContainerParameters, payload []byte) (*http.Response, error) UnpublishNC(ctx context.Context, ncParams cns.NetworkContainerParameters, payload []byte) (*http.Response, error)
} }
type imdsClient interface {
GetVMUniqueID(ctx context.Context) (string, error)
}
// HTTPRestService represents http listener for CNS - Container Networking Service. // HTTPRestService represents http listener for CNS - Container Networking Service.
type HTTPRestService struct { type HTTPRestService struct {
*cns.Service *cns.Service
@ -73,6 +77,7 @@ type HTTPRestService struct {
generateCNIConflistOnce sync.Once generateCNIConflistOnce sync.Once
IPConfigsHandlerMiddleware cns.IPConfigsHandlerMiddleware IPConfigsHandlerMiddleware cns.IPConfigsHandlerMiddleware
PnpIDByMacAddress map[string]string PnpIDByMacAddress map[string]string
imdsClient imdsClient
} }
type CNIConflistGenerator interface { type CNIConflistGenerator interface {
@ -163,6 +168,7 @@ type networkInfo struct {
// NewHTTPRestService creates a new HTTP Service object. // NewHTTPRestService creates a new HTTP Service object.
func NewHTTPRestService(config *common.ServiceConfig, wscli interfaceGetter, wsproxy wireserverProxy, nmagentClient nmagentClient, func NewHTTPRestService(config *common.ServiceConfig, wscli interfaceGetter, wsproxy wireserverProxy, nmagentClient nmagentClient,
endpointStateStore store.KeyValueStore, gen CNIConflistGenerator, homeAzMonitor *HomeAzMonitor, endpointStateStore store.KeyValueStore, gen CNIConflistGenerator, homeAzMonitor *HomeAzMonitor,
imdsClient imdsClient,
) (*HTTPRestService, error) { ) (*HTTPRestService, error) {
service, err := cns.NewService(config.Name, config.Version, config.ChannelMode, config.Store) service, err := cns.NewService(config.Name, config.Version, config.ChannelMode, config.Store)
if err != nil { if err != nil {
@ -225,6 +231,7 @@ func NewHTTPRestService(config *common.ServiceConfig, wscli interfaceGetter, wsp
EndpointState: make(map[string]*EndpointInfo), EndpointState: make(map[string]*EndpointInfo),
homeAzMonitor: homeAzMonitor, homeAzMonitor: homeAzMonitor,
cniConflistGenerator: gen, cniConflistGenerator: gen,
imdsClient: imdsClient,
}, nil }, nil
} }
@ -280,6 +287,11 @@ func (service *HTTPRestService) Init(config *common.ServiceConfig) error {
listener.AddHandler(cns.NetworkContainersURLPath, service.getOrRefreshNetworkContainers) listener.AddHandler(cns.NetworkContainersURLPath, service.getOrRefreshNetworkContainers)
listener.AddHandler(cns.GetHomeAz, service.getHomeAz) listener.AddHandler(cns.GetHomeAz, service.getHomeAz)
listener.AddHandler(cns.EndpointPath, service.EndpointHandlerAPI) listener.AddHandler(cns.EndpointPath, service.EndpointHandlerAPI)
// This API is only needed for Direct channel mode with Swift v2.
if config.ChannelMode == cns.Direct {
listener.AddHandler(cns.GetVMUniqueID, service.getVMUniqueID)
}
// handlers for v0.2 // handlers for v0.2
listener.AddHandler(cns.V2Prefix+cns.SetEnvironmentPath, service.setEnvironment) listener.AddHandler(cns.V2Prefix+cns.SetEnvironmentPath, service.setEnvironment)
listener.AddHandler(cns.V2Prefix+cns.CreateNetworkPath, service.createNetwork) listener.AddHandler(cns.V2Prefix+cns.CreateNetworkPath, service.createNetwork)
@ -305,6 +317,10 @@ func (service *HTTPRestService) Init(config *common.ServiceConfig) error {
listener.AddHandler(cns.V2Prefix+cns.NmAgentSupportedApisPath, service.nmAgentSupportedApisHandler) listener.AddHandler(cns.V2Prefix+cns.NmAgentSupportedApisPath, service.nmAgentSupportedApisHandler)
listener.AddHandler(cns.V2Prefix+cns.GetHomeAz, service.getHomeAz) listener.AddHandler(cns.V2Prefix+cns.GetHomeAz, service.getHomeAz)
listener.AddHandler(cns.V2Prefix+cns.EndpointPath, service.EndpointHandlerAPI) listener.AddHandler(cns.V2Prefix+cns.EndpointPath, service.EndpointHandlerAPI)
// This API is only needed for Direct channel mode with Swift v2.
if config.ChannelMode == cns.Direct {
listener.AddHandler(cns.V2Prefix+cns.GetVMUniqueID, service.getVMUniqueID)
}
// Initialize HTTP client to be reused in CNS // Initialize HTTP client to be reused in CNS
connectionTimeout, _ := service.GetOption(acn.OptHttpConnectionTimeout).(int) connectionTimeout, _ := service.GetOption(acn.OptHttpConnectionTimeout).(int)

Просмотреть файл

@ -49,7 +49,9 @@ func startService(cnsPort, cnsURL string) error {
config := common.ServiceConfig{} config := common.ServiceConfig{}
nmagentClient := &fakes.NMAgentClientFake{} nmagentClient := &fakes.NMAgentClientFake{}
service, err := restserver.NewHTTPRestService(&config, &fakes.WireserverClientFake{}, &fakes.WireserverProxyFake{}, nmagentClient, nil, nil, nil) service, err := restserver.NewHTTPRestService(&config, &fakes.WireserverClientFake{},
&fakes.WireserverProxyFake{}, nmagentClient, nil, nil, nil,
fakes.NewMockIMDSClient())
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to initialize service") return errors.Wrap(err, "Failed to initialize service")
} }

Просмотреть файл

@ -748,8 +748,10 @@ func main() {
Logger: logger.Log, Logger: logger.Log,
} }
imdsClient := imds.NewClient()
httpRemoteRestService, err := restserver.NewHTTPRestService(&config, wsclient, &wsProxy, nmaClient, httpRemoteRestService, err := restserver.NewHTTPRestService(&config, wsclient, &wsProxy, nmaClient,
endpointStateStore, conflistGenerator, homeAzMonitor) endpointStateStore, conflistGenerator, homeAzMonitor, imdsClient)
if err != nil { if err != nil {
logger.Errorf("Failed to create CNS object, err:%v.\n", err) logger.Errorf("Failed to create CNS object, err:%v.\n", err)
return return